diff --git a/book/src/how/architecture/opening-proof.md b/book/src/how/architecture/opening-proof.md index 59147c8a9..3e54dd3bf 100644 --- a/book/src/how/architecture/opening-proof.md +++ b/book/src/how/architecture/opening-proof.md @@ -23,6 +23,8 @@ The claim reduction sumchecks can be found in `jolt-core/src/zkvm/claim_reductio - **Increments** (`increments.rs`): Reduces claims related to increment checks. - **Hamming weight** (`hamming_weight.rs`): Reduces hamming weight-related claims. - **Advice** (`advice.rs`): Reduces claims from advice polynomials. +- **Bytecode** (`bytecode.rs`): Reduces committed bytecode openings into the shared Stage 8 Dory geometry. +- **Program image** (`program_image.rs`): Reduces the committed initial-memory image into the same final opening geometry. ### How claim reduction sumchecks work @@ -43,6 +45,354 @@ We apply the [Multiple polynomials, same point](../optimizations/batched-opening On the verifier side, this entails taking a linear combination of commitments. Since Dory is an additively homomorphic commitment scheme, the verifier is able to do so. +### Precommitted geometry and Dory embedding + +Some committed polynomials in Stage 8 do not naturally live in the "main" Dory geometry induced by the trace-domain witness polynomials. Examples include the bytecode chunks, the program image, and trusted or untrusted advice. In the implementation these are called **precommitted** polynomials. + +The goal of Stage 8 is still the same: every committed polynomial must be opened at one common Dory point so that a single random linear combination can be opened. The subtlety is that these precommitted polynomials may have a different number of variables from the main trace-domain polynomials. + +In this section we write: + +- $T$ for the **log** trace length +- $K$ for the **log** main address space size +- $B$ for the number of **extra** variables contributed by the largest precommitted polynomial beyond the main geometry + +With that notation, the final Dory opening point has length + +$$ +D = T + K + B. +$$ + +Equivalently, Stage 8 works in a joint Dory matrix of size $2^{\nu_D} \times 2^{\sigma_D}$ where + +$$ +\sigma_D = \left\lceil \frac{D}{2} \right\rceil, \qquad \nu_D = D - \sigma_D. +$$ + +Here $\nu_D$ is the number of **row variables** and $\sigma_D$ is the number of **column variables**. This matches the implementation in `DoryGlobals::balanced_sigma_nu()` and the split used by `PrecommittedClaimReduction::project_dory_round_permutation_for_poly()`. + + +Write: + +- the main geometry size as $T + K$ +- the joint geometry size as $D = T + K + B$ +- the joint Dory matrix as $2^{\nu_D} \times 2^{\sigma_D}$ with $\nu_D + \sigma_D = D$ + +The main design constraint is that we do not want to complicate the existing main sumchecks round scheduling. So Jolt does the following: + +- precommitted reductions are forward-loaded +- main reductions are backward-loaded +- Stage 6b always has exactly $T + B$ rounds +- Stage 7 always has exactly $K$ rounds + +This way: + +- the precommitted reductions see the full challenge set needed for the joint geometry +- the main sumchecks keep their old round scheduling +- Stage 8 only has to normalize already-produced opening points into the final Dory point + +If some precommitted polynomial already has $D$ variables, we call it a **dominant precommitted polynomial**. Otherwise there is **no dominant precommitted polynomial**, and the joint point is anchored by the ordinary main openings. + +#### How Main Polynomials Sit In The Joint Matrix + +The main polynomials are embedded depending on the dory layout. As a concrete example, take $D = 5$. Since Dory uses a balanced split, this means: + +$$ +\sigma_D = 3, \qquad \nu_D = 2, +$$ + +so the joint matrix has $2^2 = 4$ rows and $2^3 = 8$ columns, for a total of $2^5 = 32$ slots. + +##### `CycleMajor` dense placement + +Take a dense polynomial with $T = 3$ variables and coefficients + +$$ +a_{000}, a_{001}, a_{010}, a_{011}, a_{100}, a_{101}, a_{110}, a_{111}. +$$ + +In `CycleMajor`, the dense polynomial is written across the top of the matrix, so only the lowest $T$ index bits vary: + +```text +Joint 4 x 8 matrix + + col000 col001 col010 col011 col100 col101 col110 col111 +row00 | a_000 | a_001 | a_010 | a_011 | a_100 | a_101 | a_110 | a_111 | +row01 | . | . | . | . | . | . | . | . | +row10 | . | . | . | . | . | . | . | . | +row11 | . | . | . | . | . | . | . | . | +``` + +so the first $2$ bits are fixed and only the last $3$ bits vary. + +##### `AddressMajor` dense placement + +Now take the same joint geometry $D = 5$, but the dense polynomial should now use the highest $T = 3$ bits. Then its coefficients are written into slots whose last $K+B = 2$ bits are zero: + +In the same $4 \times 8$ matrix this looks like: + +```text +Joint 4 x 8 matrix + + col000 col001 col010 col011 col100 col101 col110 col111 +row00 | a_000 | . | . | . | a_001 | . | . | . | +row01 | a_010 | . | . | . | a_011 | . | . | . | +row10 | a_100 | . | . | . | a_101 | . | . | . | +row11 | a_110 | . | . | . | a_111 | . | . | . | +``` + +The same idea applies to one-hot polynomials: + +- in `CycleMajor`, they use the lowest $T+K$ bits +- in `AddressMajor`, they use the highest $T+K$ bits, so the trailing $B$ bits are zero + +Therefore, the extra $B$ variables must end up on opposite sides of the final Dory opening point in the two layouts. + +#### When Address-Major Dense Stride Exceeds The Row Width + +In `AddressMajor`, dense polynomials are embedded with stride $2^{K+B}$. Sometimes that stride is larger than the number of columns of the joint matrix. This is the special branch handled in `dory/wrappers.rs`. + +Take a real example: + +- joint geometry $D = 7$, so the balanced Dory matrix is $2^3 \times 2^4 = 8 \times 16$ +- dense polynomial has $T = 2$ variables, so it has 4 coefficients +- therefore $K+B = 5$, so the stride is $2^5 = 32$ + +Since the row width is only 16, consecutive coefficients jump by two whole rows: + +```text +coeff a_00 -> slot 0 -> row 0, col 0 +coeff a_01 -> slot 32 -> row 2, col 0 +coeff a_10 -> slot 64 -> row 4, col 0 +coeff a_11 -> slot 96 -> row 6, col 0 +``` + +and the matrix picture is: + +```text +8 x 16 joint matrix + +row0 | a_00 . . . . . . . . . . . . . . . | +row1 | . . . . . . . . . . . . . . . . | +row2 | a_01 . . . . . . . . . . . . . . . | +row3 | . . . . . . . . . . . . . . . . | +row4 | a_10 . . . . . . . . . . . . . . . | +row5 | . . . . . . . . . . . . . . . . | +row6 | a_11 . . . . . . . . . . . . . . . | +row7 | . . . . . . . . . . . . . . . . | +``` + +So the logical embedding is unchanged, but it is no longer a convenient row-local chunking. That is why the implementation switches to explicit sparse row/column placement in this case. Because polynomial lengths are powers of two, the placement still stays aligned: either the stride is a multiple of the row width, so the polynomial occupies the same column range in every row it touches, or the stride divides the row width, so it stays in a fixed column but appears only in every few rows, as in the example above. + +#### Final Dory Opening Point + +In summary +- in `CycleMajor`, the main dense / one-hot geometry consumes the low bits of the final Dory point, so any extra precommitted variables must sit on the high side +- in `AddressMajor`, the main geometry consumes the high bits, so any extra precommitted variables must sit on the low side +- each block appears in reverse because we always bind polynomials during claim reduction sumchecks from low to high bits + +Now we study two cases: +If there **is** a dominant precommitted polynomial, let the raw Stage 6b challenges be + +$$ +[x_1, x_2, \dots, x_B, x_{B+1}, \dots, x_{B+T}] +$$ + +and the raw Stage 7 challenges be + +$$ +[y_1, y_2, \dots, y_K]. +$$ + +The final big-endian Dory opening point is obtained by normalizing these challenges into Dory order. + +For **AddressMajor**: + +$$ +[x_{B+T}, x_{B+T-1}, \dots, x_{B+1} \;\Vert\; y_K, y_{K-1}, \dots, y_1 \;\Vert\; x_B, x_{B-1}, \dots, x_1] +$$ + +For **CycleMajor**: + +$$ +[x_B, x_{B-1}, \dots, x_1 \;\Vert\; y_K, y_{K-1}, \dots, y_1 \;\Vert\; x_{B+T}, x_{B+T-1}, \dots, x_{B+1}] +$$ + +Each block is reversed, and the extra $B$ variables move to different sides depending on the layout. + +If there is **no dominant precommitted polynomial**, then the final point is anchored by the ordinary main openings: + +- in this case the joint geometry is just the main geometry, so $B = 0$ +- let $r_{\mathrm{inc}}$ be the Stage 6b opening point from `IncClaimReduction` +- let $r_{\mathrm{ham}}$ be the Stage 7 opening point from `HammingWeightClaimReduction` + +These are already normalized opening points. + +Then: + +For **AddressMajor**: + +$$ +r_{\mathrm{final}} = +\big[ +r_{\mathrm{inc}} +\;\Vert\; +r_{\mathrm{ham}} +\big] +$$ + +For **CycleMajor**: + +$$ +r_{\mathrm{final}} = +\big[ +r_{\mathrm{ham}} +\big]. +$$ + +This is exactly the logic implemented in `stage8_opening_point()` in `prover.rs`. + +#### Embedding Precommitted Polynomials + +The verifier already has the commitment to the precommitted polynomial. That commitment is computed under the convention that the polynomial occupies the top-left block of its balanced Dory matrix, meaning the earliest rows and earliest columns. So when we embed that polynomial into the larger joint matrix, we must preserve that same top-left placement; otherwise the verifier would be checking the Dory proof against a different geometry from the one encoded in the commitment. + +```text +Joint Dory matrix: 2^nu_D rows x 2^sigma_D columns +Smaller precommitted matrix: 2^nu_C rows x 2^sigma_C columns + + left 2^sigma_C cols remaining cols + +---------------------------+------------------+ +top 2^nu_C rows | smaller precommitted poly | not used by this | + | lives here | poly | + +---------------------------+------------------+ +remaining rows | not used by this poly | not used by this | + | | poly | + +---------------------------+------------------+ +``` + +Suppose the smaller precommitted polynomial has + +$$ +C = \nu_C + \sigma_C +$$ + +variables, while the joint point has + +$$ +D = \nu_D + \sigma_D. +$$ + +Split the joint point as + +$$ +r_{\mathrm{joint}} = +\big[ +r_{\mathrm{row}}^{\mathrm{hi}} +\;\Vert\; +r_{\mathrm{row}}^{\mathrm{lo}} +\;\Vert\; +r_{\mathrm{col}}^{\mathrm{hi}} +\;\Vert\; +r_{\mathrm{col}}^{\mathrm{lo}} +\big] +$$ + +where: + +- $r_{\mathrm{row}}^{\mathrm{hi}}$ has length $\nu_D - \nu_C$ +- $r_{\mathrm{row}}^{\mathrm{lo}}$ has length $\nu_C$ +- $r_{\mathrm{col}}^{\mathrm{hi}}$ has length $\sigma_D - \sigma_C$ +- $r_{\mathrm{col}}^{\mathrm{lo}}$ has length $\sigma_C$ + +Then the smaller polynomial is evaluated on + +$$ +r_{\mathrm{small}} = +\big[ +r_{\mathrm{row}}^{\mathrm{lo}} +\;\Vert\; +r_{\mathrm{col}}^{\mathrm{lo}} +\big]. +$$ + +The reason is that top-left embedding forces the missing high row bits and high column bits to be zero: + +```text +joint row variables : [row_hi | row_lo] +joint col variables : [col_hi | col_lo] + +top-left embedding forces: + row_hi = 0 + col_hi = 0 +``` + +So if $P$ is the smaller polynomial and $P_{\mathrm{emb}}$ is its embedding into the joint matrix, then + +$$ +P_{\mathrm{emb}}(r_{\mathrm{joint}}) += +\operatorname{eq}\!\left(r_{\mathrm{row}}^{\mathrm{hi}}, 0^{\nu_D - \nu_C}\right) +\cdot +\operatorname{eq}\!\left(r_{\mathrm{col}}^{\mathrm{hi}}, 0^{\sigma_D - \sigma_C}\right) +\cdot +P(r_{\mathrm{small}}). +$$ + +This selector is exactly why top-left embedding works inside one shared Dory proof. + +The same selector appears when a joint `RLCPolynomial` mixes a main polynomial with a smaller precommitted polynomial: + +$$ +\text{RLC coefficient} +\cdot +P(r_{\mathrm{small}}) +\cdot +\operatorname{eq}\!\left(r_{\mathrm{row}}^{\mathrm{hi}}, 0^{\nu_D - \nu_C}\right) +\cdot +\operatorname{eq}\!\left(r_{\mathrm{col}}^{\mathrm{hi}}, 0^{\sigma_D - \sigma_C}\right). +$$ + +#### Permuting Precommitted Polynomial Variables + +The precommitted sumchecks still bind variables low-to-high. But the final Dory point order is determined by the joint geometry, not by the order in which those rounds happen. + +So Jolt permutes the variables of each precommitted polynomial before running the sumcheck. This keeps the sumcheck code simple while ensuring the final claim corresponds to the original polynomial at the correct Stage 8 point. This permutation is cheap because it is only a variable-position movement, so on the coefficient table it is just a bit permutation of the $2^n$ Boolean-hypercube evaluations. + +Here is a concrete 3-variable example. Suppose the original polynomial is encoded by + +```text +point 000 001 010 011 100 101 110 111 +P(point) v0 v1 v2 v3 v4 v5 v6 v7 +``` + +Now suppose the Stage 8 geometry wants the variables in the order $(c,b,a)$ rather than $(a,b,c)$. Define + +$$ +P'(u,v,w) = P(w,v,u). +$$ + +Then the new coefficient table becomes + +```text +point 000 001 010 011 100 101 110 111 +P'(point) v0 v4 v2 v6 v1 v5 v3 v7 +``` + +because + +```text +P'(000) = P(000) +P'(001) = P(100) +P'(010) = P(010) +P'(011) = P(110) +P'(100) = P(001) +P'(101) = P(101) +P'(110) = P(011) +P'(111) = P(111) +``` + +After the sumcheck finishes, `normalize_opening_point()` converts the collected challenges back into the true opening point of the original, non-permuted polynomial. + ### `RLCPolynomial` Recall that all of the polynomials in Jolt fall into one of two categories: **one-hot** polynomials (the $\widetilde{\textsf{ra}}$ and $\widetilde{\textsf{wa}}$ arising in [Twist/Shout](../twist-shout.md)), and **dense** polynomials (we use this to mean anything that's not one-hot). diff --git a/examples/advice-demo/src/main.rs b/examples/advice-demo/src/main.rs index 58c081c00..0c9af5a35 100644 --- a/examples/advice-demo/src/main.rs +++ b/examples/advice-demo/src/main.rs @@ -4,6 +4,10 @@ use tracing::info; // Demonstration of advice tape usage in a provable computation pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; @@ -13,11 +17,21 @@ pub fn main() { let b = vec![0usize, 1, 2, 3, 4, 5, 6, 7, 8, 9]; let mut program = guest::compile_advice_demo(target_dir); - let shared_preprocessing = guest::preprocess_shared_advice_demo(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_advice_demo(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); - let verifier_preprocessing = - guest::preprocess_verifier_advice_demo(shared_preprocessing, verifier_setup, None); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_advice_demo(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_advice_demo(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_advice_demo(&mut program).unwrap(); + let prover_preprocessing = + guest::preprocess_prover_advice_demo(shared_preprocessing.clone()); + let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let verifier_preprocessing = + guest::preprocess_verifier_advice_demo(shared_preprocessing, verifier_setup, None); + (prover_preprocessing, verifier_preprocessing) + }; let prove_advice_demo = guest::build_prover_advice_demo(program, prover_preprocessing); let verify_advice_demo = guest::build_verifier_advice_demo(verifier_preprocessing); diff --git a/examples/alloc/src/main.rs b/examples/alloc/src/main.rs index 71a712b4e..b7497b33b 100644 --- a/examples/alloc/src/main.rs +++ b/examples/alloc/src/main.rs @@ -3,17 +3,30 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_alloc(target_dir); - let shared_preprocessing = guest::preprocess_shared_alloc(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_alloc(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_alloc( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_alloc(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_alloc(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_alloc(&mut program).unwrap(); + let prover_preprocessing = guest::preprocess_prover_alloc(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_alloc( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove_alloc = guest::build_prover_alloc(program, prover_preprocessing); let verify_alloc = guest::build_verifier_alloc(verifier_preprocessing); diff --git a/examples/backtrace/src/main.rs b/examples/backtrace/src/main.rs index 247f860b5..321e05101 100644 --- a/examples/backtrace/src/main.rs +++ b/examples/backtrace/src/main.rs @@ -15,13 +15,18 @@ fn main() { #[cfg(any(feature = "nostd", feature = "std"))] let should_panic = env_flag("JOLT_BT_TRIGGER").unwrap_or(true); #[cfg(any(feature = "nostd", feature = "std"))] + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); + #[cfg(any(feature = "nostd", feature = "std"))] let target_dir = "/tmp/jolt-guest-targets"; #[cfg(feature = "nostd")] - run_nostd(target_dir, should_panic); + run_nostd(target_dir, should_panic, bytecode_chunk); #[cfg(feature = "std")] - run_std(target_dir, should_panic); + run_std(target_dir, should_panic, bytecode_chunk); #[cfg(not(any(feature = "nostd", feature = "std")))] { @@ -39,7 +44,7 @@ fn env_flag(key: &str) -> Option { } #[cfg(feature = "nostd")] -fn run_nostd(target_dir: &str, should_panic: bool) { +fn run_nostd(target_dir: &str, should_panic: bool, bytecode_chunk: Option) { info!("mode=nostd should_panic={}", should_panic); let trace_enabled = env_flag("JOLT_BACKTRACE").unwrap_or(false); @@ -47,15 +52,27 @@ fn run_nostd(target_dir: &str, should_panic: bool) { let mut program = guest_nostd::compile_panic_backtrace_nostd(target_dir); - let shared_preprocessing = - guest_nostd::preprocess_shared_panic_backtrace_nostd(&mut program).unwrap(); - let prover_preprocessing = - guest_nostd::preprocess_prover_panic_backtrace_nostd(shared_preprocessing.clone()); - let verifier_preprocessing = guest_nostd::preprocess_verifier_panic_backtrace_nostd( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest_nostd::preprocess_committed_panic_backtrace_nostd(&mut program, chunk_count) + .unwrap(); + let verifier_preprocessing = + guest_nostd::verifier_preprocessing_from_prover_panic_backtrace_nostd( + &prover_preprocessing, + ); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = + guest_nostd::preprocess_shared_panic_backtrace_nostd(&mut program).unwrap(); + let prover_preprocessing = + guest_nostd::preprocess_prover_panic_backtrace_nostd(shared_preprocessing.clone()); + let verifier_preprocessing = guest_nostd::preprocess_verifier_panic_backtrace_nostd( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove = guest_nostd::build_prover_panic_backtrace_nostd(program, prover_preprocessing); let verify = guest_nostd::build_verifier_panic_backtrace_nostd(verifier_preprocessing); @@ -79,7 +96,7 @@ fn run_nostd(target_dir: &str, should_panic: bool) { } #[cfg(feature = "std")] -fn run_std(target_dir: &str, should_panic: bool) { +fn run_std(target_dir: &str, should_panic: bool, bytecode_chunk: Option) { info!("mode=std should_panic={}", should_panic); let trace_enabled = env_flag("JOLT_BACKTRACE").unwrap_or(false); @@ -87,15 +104,26 @@ fn run_std(target_dir: &str, should_panic: bool) { let mut program = guest_std::compile_panic_backtrace_std(target_dir); - let shared_preprocessing = - guest_std::preprocess_shared_panic_backtrace_std(&mut program).unwrap(); - let prover_preprocessing = - guest_std::preprocess_prover_panic_backtrace_std(shared_preprocessing.clone()); - let verifier_preprocessing = guest_std::preprocess_verifier_panic_backtrace_std( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest_std::preprocess_committed_panic_backtrace_std(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest_std::verifier_preprocessing_from_prover_panic_backtrace_std( + &prover_preprocessing, + ); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = + guest_std::preprocess_shared_panic_backtrace_std(&mut program).unwrap(); + let prover_preprocessing = + guest_std::preprocess_prover_panic_backtrace_std(shared_preprocessing.clone()); + let verifier_preprocessing = guest_std::preprocess_verifier_panic_backtrace_std( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove = guest_std::build_prover_panic_backtrace_std(program, prover_preprocessing); let verify = guest_std::build_verifier_panic_backtrace_std(verifier_preprocessing); diff --git a/examples/btreemap/host/src/main.rs b/examples/btreemap/host/src/main.rs index b698f8998..37cb6a154 100644 --- a/examples/btreemap/host/src/main.rs +++ b/examples/btreemap/host/src/main.rs @@ -11,27 +11,40 @@ macro_rules! step { } pub fn btreemap() { + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; let mut program = step!("Compiling guest code", { guest::compile_btreemap(target_dir) }); - let shared_preprocessing = step!("Preprocessing shared", { - guest::preprocess_shared_btreemap(&mut program).unwrap() - }); - - let prover_preprocessing = step!("Preprocessing prover", { - guest::preprocess_prover_btreemap(shared_preprocessing.clone()) - }); - - let verifier_preprocessing = step!("Preprocessing verifier", { - guest::preprocess_verifier_btreemap( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ) - }); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = step!("Preprocessing prover", { + guest::preprocess_committed_btreemap(&mut program, chunk_count).unwrap() + }); + let verifier_preprocessing = step!("Preprocessing verifier", { + guest::verifier_preprocessing_from_prover_btreemap(&prover_preprocessing) + }); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = step!("Preprocessing shared", { + guest::preprocess_shared_btreemap(&mut program).unwrap() + }); + let prover_preprocessing = step!("Preprocessing prover", { + guest::preprocess_prover_btreemap(shared_preprocessing.clone()) + }); + let verifier_preprocessing = step!("Preprocessing verifier", { + guest::preprocess_verifier_btreemap( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ) + }); + (prover_preprocessing, verifier_preprocessing) + }; let prove = step!("Building prover", { guest::build_prover_btreemap(program, prover_preprocessing) diff --git a/examples/collatz/src/main.rs b/examples/collatz/src/main.rs index 15fe35288..a61ff4244 100644 --- a/examples/collatz/src/main.rs +++ b/examples/collatz/src/main.rs @@ -3,17 +3,34 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); // Prove/verify convergence for a single number: let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_collatz_convergence(target_dir); - let shared_preprocessing = guest::preprocess_shared_collatz_convergence(&mut program).unwrap(); - let prover_preprocessing = - guest::preprocess_prover_collatz_convergence(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); - let verifier_preprocessing = - guest::preprocess_verifier_collatz_convergence(shared_preprocessing, verifier_setup, None); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_collatz_convergence(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_collatz_convergence(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = + guest::preprocess_shared_collatz_convergence(&mut program).unwrap(); + let prover_preprocessing = + guest::preprocess_prover_collatz_convergence(shared_preprocessing.clone()); + let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let verifier_preprocessing = guest::preprocess_verifier_collatz_convergence( + shared_preprocessing, + verifier_setup, + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove_collatz_single = guest::build_prover_collatz_convergence(program, prover_preprocessing); @@ -31,16 +48,28 @@ pub fn main() { // Prove/verify convergence for a range of numbers: let mut program = guest::compile_collatz_convergence_range(target_dir); - let shared_preprocessing = - guest::preprocess_shared_collatz_convergence_range(&mut program).unwrap(); - let prover_preprocessing = - guest::preprocess_prover_collatz_convergence_range(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); - let verifier_preprocessing = guest::preprocess_verifier_collatz_convergence_range( - shared_preprocessing, - verifier_setup, - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_collatz_convergence_range(&mut program, chunk_count) + .unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_collatz_convergence_range( + &prover_preprocessing, + ); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = + guest::preprocess_shared_collatz_convergence_range(&mut program).unwrap(); + let prover_preprocessing = + guest::preprocess_prover_collatz_convergence_range(shared_preprocessing.clone()); + let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let verifier_preprocessing = guest::preprocess_verifier_collatz_convergence_range( + shared_preprocessing, + verifier_setup, + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove_collatz_convergence = guest::build_prover_collatz_convergence_range(program, prover_preprocessing); diff --git a/examples/fibonacci/src/main.rs b/examples/fibonacci/src/main.rs index 78c0b967d..72f59fc7b 100644 --- a/examples/fibonacci/src/main.rs +++ b/examples/fibonacci/src/main.rs @@ -4,18 +4,30 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let save_to_disk = std::env::args().any(|arg| arg == "--save"); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_fib(target_dir); - let shared_preprocessing = guest::preprocess_shared_fib(&mut program).unwrap(); - - let prover_preprocessing = guest::preprocess_prover_fib(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); - let verifier_preprocessing = - guest::preprocess_verifier_fib(shared_preprocessing, verifier_setup, None); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_fib(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_fib(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_fib(&mut program).unwrap(); + let prover_preprocessing = guest::preprocess_prover_fib(shared_preprocessing.clone()); + let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let verifier_preprocessing = + guest::preprocess_verifier_fib(shared_preprocessing, verifier_setup, None); + (prover_preprocessing, verifier_preprocessing) + }; if save_to_disk { serialize_and_print_size( diff --git a/examples/hash-bench/src/main.rs b/examples/hash-bench/src/main.rs index e1f953855..3ad3e4700 100644 --- a/examples/hash-bench/src/main.rs +++ b/examples/hash-bench/src/main.rs @@ -2,15 +2,28 @@ use std::time::Instant; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_hashbench(target_dir); - let shared_preprocessing = guest::preprocess_shared_hashbench(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_hashbench(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); - let verifier_preprocessing = - guest::preprocess_verifier_hashbench(shared_preprocessing, verifier_setup, None); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_hashbench(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_hashbench(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_hashbench(&mut program).unwrap(); + let prover_preprocessing = guest::preprocess_prover_hashbench(shared_preprocessing.clone()); + let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let verifier_preprocessing = + guest::preprocess_verifier_hashbench(shared_preprocessing, verifier_setup, None); + (prover_preprocessing, verifier_preprocessing) + }; let prove_hashbench = guest::build_prover_hashbench(program, prover_preprocessing); let verify_hashbench = guest::build_verifier_hashbench(verifier_preprocessing); diff --git a/examples/malloc/src/main.rs b/examples/malloc/src/main.rs index 792ce081f..e8810470e 100644 --- a/examples/malloc/src/main.rs +++ b/examples/malloc/src/main.rs @@ -1,16 +1,29 @@ use std::time::Instant; pub fn main() { + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_alloc(target_dir); - let shared_preprocessing = guest::preprocess_shared_alloc(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_alloc(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_alloc( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_alloc(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_alloc(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_alloc(&mut program).unwrap(); + let prover_preprocessing = guest::preprocess_prover_alloc(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_alloc( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove = guest::build_prover_alloc(program, prover_preprocessing); let verify = guest::build_verifier_alloc(verifier_preprocessing); diff --git a/examples/memory-ops/src/main.rs b/examples/memory-ops/src/main.rs index 3fc4b039c..2787b12a2 100644 --- a/examples/memory-ops/src/main.rs +++ b/examples/memory-ops/src/main.rs @@ -3,17 +3,31 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_memory_ops(target_dir); - let shared_preprocessing = guest::preprocess_shared_memory_ops(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_memory_ops(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_memory_ops( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_memory_ops(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_memory_ops(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_memory_ops(&mut program).unwrap(); + let prover_preprocessing = + guest::preprocess_prover_memory_ops(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_memory_ops( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove = guest::build_prover_memory_ops(program, prover_preprocessing); let verify = guest::build_verifier_memory_ops(verifier_preprocessing); diff --git a/examples/merkle-tree/src/main.rs b/examples/merkle-tree/src/main.rs index c1c9875ae..380292d29 100644 --- a/examples/merkle-tree/src/main.rs +++ b/examples/merkle-tree/src/main.rs @@ -1,20 +1,36 @@ -use jolt_sdk::{TrustedAdvice, UntrustedAdvice}; +use jolt_sdk::{DoryContext, DoryGlobals, DoryLayout, TrustedAdvice, UntrustedAdvice}; use std::time::Instant; use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); + DoryGlobals::initialize_context(1, 1, DoryContext::Main, Some(DoryLayout::CycleMajor)) + .expect("failed to set Dory layout"); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_merkle_tree(target_dir); - let shared_preprocessing = guest::preprocess_shared_merkle_tree(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_merkle_tree(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_merkle_tree( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_merkle_tree(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_merkle_tree(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_merkle_tree(&mut program).unwrap(); + let prover_preprocessing = + guest::preprocess_prover_merkle_tree(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_merkle_tree( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let leaf1: &[u8] = &[5u8; 32]; let leaf2 = [6u8; 32]; diff --git a/examples/modinv/src/main.rs b/examples/modinv/src/main.rs index 4c72040ec..be700382c 100644 --- a/examples/modinv/src/main.rs +++ b/examples/modinv/src/main.rs @@ -3,6 +3,10 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; @@ -15,11 +19,20 @@ pub fn main() { // Compile and preprocess the advice-based version let mut program = guest::compile_modinv(target_dir); - let shared_preprocessing = guest::preprocess_shared_modinv(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_modinv(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); - let verifier_preprocessing = - guest::preprocess_verifier_modinv(shared_preprocessing, verifier_setup, None); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_modinv(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_modinv(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_modinv(&mut program).unwrap(); + let prover_preprocessing = guest::preprocess_prover_modinv(shared_preprocessing.clone()); + let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let verifier_preprocessing = + guest::preprocess_verifier_modinv(shared_preprocessing, verifier_setup, None); + (prover_preprocessing, verifier_preprocessing) + }; let prove_modinv = guest::build_prover_modinv(program, prover_preprocessing); let verify_modinv = guest::build_verifier_modinv(verifier_preprocessing); diff --git a/examples/muldiv/src/main.rs b/examples/muldiv/src/main.rs index 4dc1c6f72..674d8df66 100644 --- a/examples/muldiv/src/main.rs +++ b/examples/muldiv/src/main.rs @@ -3,17 +3,30 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_muldiv(target_dir); - let shared_preprocessing = guest::preprocess_shared_muldiv(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_muldiv(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_muldiv( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_muldiv(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_muldiv(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_muldiv(&mut program).unwrap(); + let prover_preprocessing = guest::preprocess_prover_muldiv(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_muldiv( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove = guest::build_prover_muldiv(program, prover_preprocessing); let verify = guest::build_verifier_muldiv(verifier_preprocessing); diff --git a/examples/multi-function/src/main.rs b/examples/multi-function/src/main.rs index ce209ee21..6979ade1e 100644 --- a/examples/multi-function/src/main.rs +++ b/examples/multi-function/src/main.rs @@ -3,16 +3,29 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); // Prove addition. let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_add(target_dir); - let shared_preprocessing = guest::preprocess_shared_add(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_add(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); - let verifier_preprocessing = - guest::preprocess_verifier_add(shared_preprocessing, verifier_setup, None); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_add(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_add(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_add(&mut program).unwrap(); + let prover_preprocessing = guest::preprocess_prover_add(shared_preprocessing.clone()); + let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let verifier_preprocessing = + guest::preprocess_verifier_add(shared_preprocessing, verifier_setup, None); + (prover_preprocessing, verifier_preprocessing) + }; let prove_add = guest::build_prover_add(program, prover_preprocessing); let verify_add = guest::build_verifier_add(verifier_preprocessing); @@ -21,13 +34,22 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_mul(target_dir); - let shared_preprocessing = guest::preprocess_shared_mul(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_mul(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_mul( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_mul(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_mul(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_mul(&mut program).unwrap(); + let prover_preprocessing = guest::preprocess_prover_mul(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_mul( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove_mul = guest::build_prover_mul(program, prover_preprocessing); let verify_mul = guest::build_verifier_mul(verifier_preprocessing); diff --git a/examples/overflow/src/main.rs b/examples/overflow/src/main.rs index a2f59deb0..0a1da20c0 100644 --- a/examples/overflow/src/main.rs +++ b/examples/overflow/src/main.rs @@ -5,13 +5,20 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); // An overflowing stack should fail to prove. let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_overflow_stack(target_dir); - let shared_preprocessing = guest::preprocess_shared_overflow_stack(&mut program).unwrap(); - let prover_preprocessing = - guest::preprocess_prover_overflow_stack(shared_preprocessing.clone()); + let prover_preprocessing = if let Some(chunk_count) = bytecode_chunk { + guest::preprocess_committed_overflow_stack(&mut program, chunk_count).unwrap() + } else { + let shared_preprocessing = guest::preprocess_shared_overflow_stack(&mut program).unwrap(); + guest::preprocess_prover_overflow_stack(shared_preprocessing.clone()) + }; let prove_overflow_stack = guest::build_prover_overflow_stack(program, prover_preprocessing); let res = panic::catch_unwind(|| { @@ -23,8 +30,12 @@ pub fn main() { // now lets try to overflow the heap, should also panic let mut program = guest::compile_overflow_heap(target_dir); - let shared_preprocessing = guest::preprocess_shared_overflow_heap(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_overflow_heap(shared_preprocessing.clone()); + let prover_preprocessing = if let Some(chunk_count) = bytecode_chunk { + guest::preprocess_committed_overflow_heap(&mut program, chunk_count).unwrap() + } else { + let shared_preprocessing = guest::preprocess_shared_overflow_heap(&mut program).unwrap(); + guest::preprocess_prover_overflow_heap(shared_preprocessing.clone()) + }; let prove_overflow_heap = guest::build_prover_overflow_heap(program, prover_preprocessing); let res = panic::catch_unwind(|| { @@ -36,15 +47,30 @@ pub fn main() { // but with stack_size=8192 let mut program = guest::compile_allocate_stack_with_increased_size(target_dir); - let shared_preprocessing = - guest::preprocess_shared_allocate_stack_with_increased_size(&mut program).unwrap(); - let prover_preprocessing = - guest::preprocess_prover_allocate_stack_with_increased_size(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_allocate_stack_with_increased_size( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = guest::preprocess_committed_allocate_stack_with_increased_size( + &mut program, + chunk_count, + ) + .unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_allocate_stack_with_increased_size( + &prover_preprocessing, + ); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = + guest::preprocess_shared_allocate_stack_with_increased_size(&mut program).unwrap(); + let prover_preprocessing = guest::preprocess_prover_allocate_stack_with_increased_size( + shared_preprocessing.clone(), + ); + let verifier_preprocessing = guest::preprocess_verifier_allocate_stack_with_increased_size( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove_allocate_stack_with_increased_size = guest::build_prover_allocate_stack_with_increased_size(program, prover_preprocessing); diff --git a/examples/random/src/main.rs b/examples/random/src/main.rs index 5180976e0..88a6828e0 100644 --- a/examples/random/src/main.rs +++ b/examples/random/src/main.rs @@ -3,17 +3,30 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_rand(target_dir); - let shared_preprocessing = guest::preprocess_shared_rand(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_rand(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_rand( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_rand(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_rand(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_rand(&mut program).unwrap(); + let prover_preprocessing = guest::preprocess_prover_rand(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_rand( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove = guest::build_prover_rand(program, prover_preprocessing); let verify = guest::build_verifier_rand(verifier_preprocessing); diff --git a/examples/recover-ecdsa/src/main.rs b/examples/recover-ecdsa/src/main.rs index 4357232e3..3b031081c 100644 --- a/examples/recover-ecdsa/src/main.rs +++ b/examples/recover-ecdsa/src/main.rs @@ -11,6 +11,10 @@ const SECRET_KEY: [u8; 32] = [ pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let secp = Secp256k1::new(); @@ -31,13 +35,22 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_recover(target_dir); - let shared_preprocessing = guest::preprocess_shared_recover(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_recover(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_recover( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_recover(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_recover(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_recover(&mut program).unwrap(); + let prover_preprocessing = guest::preprocess_prover_recover(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_recover( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; if save_to_disk { serialize_and_print_size( diff --git a/examples/recursion/src/main.rs b/examples/recursion/src/main.rs index f0419da4f..4bf1d1df4 100644 --- a/examples/recursion/src/main.rs +++ b/examples/recursion/src/main.rs @@ -1,7 +1,10 @@ use ark_serialize::CanonicalDeserialize; use ark_serialize::CanonicalSerialize; use clap::{Parser, Subcommand}; -use jolt_sdk::{JoltDevice, MemoryConfig, RV64IMACProof, Serializable}; +use jolt_sdk::{ + JoltDevice, JoltProverPreprocessing, JoltSharedPreprocessing, MemoryConfig, MemoryLayout, + ProgramPreprocessing, RV64IMACProof, Serializable, +}; use std::cmp::PartialEq; use std::path::{Path, PathBuf}; use std::time::Instant; @@ -17,10 +20,63 @@ fn get_guest_src_dir() -> PathBuf { #[derive(Parser)] #[command(author, version, about, long_about = None)] struct Cli { + #[arg(long, global = true, default_value_t = false)] + committed_bytecode: bool, + #[arg( + long, + global = true, + value_name = "COUNT", + requires = "committed_bytecode", + value_parser = parse_bytecode_chunk + )] + bytecode_chunk: Option, #[command(subcommand)] command: Option, } +#[derive(Clone, Copy)] +struct BytecodeConfig { + committed_bytecode: bool, + bytecode_chunk: Option, +} + +impl BytecodeConfig { + fn chunk_count(self) -> usize { + self.bytecode_chunk.unwrap_or(1) + } +} + +fn parse_bytecode_chunk(value: &str) -> Result { + value + .parse::() + .map_err(|_| format!("invalid bytecode chunk count `{value}`")) +} + +fn preprocess_guest_program( + guest: &jolt_sdk::guest::program::Program, + max_trace_length: usize, + bytecode_config: BytecodeConfig, +) -> JoltProverPreprocessing { + let (bytecode, memory_init, program_size, _e_entry) = guest.decode(); + + let mut memory_config = guest.memory_config; + memory_config.program_size = Some(program_size); + let memory_layout = MemoryLayout::new(&memory_config); + let program = ProgramPreprocessing::preprocess(bytecode, memory_init).unwrap(); + let shared_preprocessing = if bytecode_config.committed_bytecode { + JoltSharedPreprocessing::new_committed( + program, + memory_layout, + max_trace_length, + bytecode_config.chunk_count(), + ) + } else { + JoltSharedPreprocessing::new(program, memory_layout, max_trace_length) + }; + + JoltProverPreprocessing::new(shared_preprocessing) +} + #[derive(Subcommand)] enum Commands { /// Generate proofs for guest programs @@ -268,7 +324,12 @@ fn check_data_integrity(all_groups_data: &[u8]) -> (u32, u32) { (n, remaining_data.len() as u32) } -fn collect_guest_proofs(guest: GuestProgram, target_dir: &str, use_embed: bool) -> Vec { +fn collect_guest_proofs( + guest: GuestProgram, + target_dir: &str, + use_embed: bool, + bytecode_config: BytecodeConfig, +) -> Vec { info!("Starting collect_guest_proofs for {}", guest.name()); let max_trace_length = guest.get_max_trace_length(use_embed); @@ -293,7 +354,7 @@ fn collect_guest_proofs(guest: GuestProgram, target_dir: &str, use_embed: bool) info!("Preprocessing guest prover..."); let guest_prover_preprocessing = - jolt_sdk::guest::prover::preprocess(&guest_prog, max_trace_length).unwrap(); + preprocess_guest_program(&guest_prog, max_trace_length, bytecode_config); info!("Preprocessing guest verifier..."); let guest_verifier_preprocessing = jolt_sdk::JoltVerifierPreprocessing::from(&guest_prover_preprocessing); @@ -449,13 +510,13 @@ fn load_proof_data(guest: GuestProgram, workdir: &Path) -> Vec { proof_data } -fn generate_proofs(guest: GuestProgram, workdir: &Path) { +fn generate_proofs(guest: GuestProgram, workdir: &Path, bytecode_config: BytecodeConfig) { info!("Generating proofs for {} guest program...", guest.name()); let target_dir = "/tmp/jolt-guest-targets"; // Collect guest proofs - let all_groups_data = collect_guest_proofs(guest, target_dir, false); + let all_groups_data = collect_guest_proofs(guest, target_dir, false, bytecode_config); // Save proof data save_proof_data(guest, &all_groups_data, workdir); @@ -469,6 +530,7 @@ fn run_recursion_proof( input_bytes: Vec, memory_config: MemoryConfig, mut max_trace_length: usize, + bytecode_config: BytecodeConfig, ) { let target_dir = "/tmp/jolt-guest-targets"; @@ -486,7 +548,7 @@ fn run_recursion_proof( max_trace_length = 0; } let recursion_prover_preprocessing = - jolt_sdk::guest::prover::preprocess(&recursion, max_trace_length).unwrap(); + preprocess_guest_program(&recursion, max_trace_length, bytecode_config); let recursion_verifier_preprocessing = jolt_sdk::JoltVerifierPreprocessing::from(&recursion_prover_preprocessing); @@ -555,6 +617,7 @@ fn verify_proofs( workdir: &Path, output_dir: &Path, run_config: RunConfig, + bytecode_config: BytecodeConfig, ) { info!("Verifying proofs for {} guest program...", guest.name()); info!("Using embed mode: {use_embed}"); @@ -581,6 +644,7 @@ fn verify_proofs( input_bytes, memory_config, guest.get_max_trace_length(use_embed), + bytecode_config, ); } else { info!("Running {} recursion with input data...", guest.name()); @@ -610,6 +674,7 @@ fn verify_proofs( input_bytes, memory_config, guest.get_max_trace_length(use_embed), + bytecode_config, ); } } @@ -618,6 +683,10 @@ fn main() { tracing_subscriber::fmt::init(); let cli = Cli::parse(); + let bytecode_config = BytecodeConfig { + committed_bytecode: cli.committed_bytecode, + bytecode_chunk: cli.bytecode_chunk, + }; match &cli.command { Some(Commands::Generate { example, workdir }) => { @@ -628,7 +697,7 @@ fn main() { return; } }; - generate_proofs(guest, workdir); + generate_proofs(guest, workdir, bytecode_config); } Some(Commands::Verify { example, @@ -653,6 +722,7 @@ fn main() { workdir, &output_dir, RunConfig::Prove, + bytecode_config, ); } Some(Commands::Trace { @@ -678,7 +748,14 @@ fn main() { } else { RunConfig::Trace }; - verify_proofs(guest, embed.is_some(), workdir, &output_dir, run_config); + verify_proofs( + guest, + embed.is_some(), + workdir, + &output_dir, + run_config, + bytecode_config, + ); } None => { info!("No subcommand specified. Available commands:"); diff --git a/examples/secp256k1-ecdsa-verify/src/main.rs b/examples/secp256k1-ecdsa-verify/src/main.rs index e47e89448..9de836e14 100644 --- a/examples/secp256k1-ecdsa-verify/src/main.rs +++ b/examples/secp256k1-ecdsa-verify/src/main.rs @@ -3,20 +3,33 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_secp256k1_ecdsa_verify(target_dir); - let shared_preprocessing = - guest::preprocess_shared_secp256k1_ecdsa_verify(&mut program).unwrap(); - let prover_preprocessing = - guest::preprocess_prover_secp256k1_ecdsa_verify(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); - let verifier_preprocessing = guest::preprocess_verifier_secp256k1_ecdsa_verify( - shared_preprocessing, - verifier_setup, - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_secp256k1_ecdsa_verify(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_secp256k1_ecdsa_verify(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = + guest::preprocess_shared_secp256k1_ecdsa_verify(&mut program).unwrap(); + let prover_preprocessing = + guest::preprocess_prover_secp256k1_ecdsa_verify(shared_preprocessing.clone()); + let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let verifier_preprocessing = guest::preprocess_verifier_secp256k1_ecdsa_verify( + shared_preprocessing, + verifier_setup, + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove_secp256k1_ecdsa_verify = guest::build_prover_secp256k1_ecdsa_verify(program, prover_preprocessing); diff --git a/examples/sha2-chain/src/main.rs b/examples/sha2-chain/src/main.rs index 347e07ecd..35789f7c1 100644 --- a/examples/sha2-chain/src/main.rs +++ b/examples/sha2-chain/src/main.rs @@ -3,17 +3,31 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_sha2_chain(target_dir); - let shared_preprocessing = guest::preprocess_shared_sha2_chain(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_sha2_chain(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_sha2_chain( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_sha2_chain(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_sha2_chain(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_sha2_chain(&mut program).unwrap(); + let prover_preprocessing = + guest::preprocess_prover_sha2_chain(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_sha2_chain( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove_sha2_chain = guest::build_prover_sha2_chain(program, prover_preprocessing); let verify_sha2_chain = guest::build_verifier_sha2_chain(verifier_preprocessing); diff --git a/examples/sha2-ex/src/main.rs b/examples/sha2-ex/src/main.rs index d13d1f4ae..cb9116ae0 100644 --- a/examples/sha2-ex/src/main.rs +++ b/examples/sha2-ex/src/main.rs @@ -3,17 +3,30 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_sha2(target_dir); - let shared_preprocessing = guest::preprocess_shared_sha2(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_sha2(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_sha2( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_sha2(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_sha2(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_sha2(&mut program).unwrap(); + let prover_preprocessing = guest::preprocess_prover_sha2(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_sha2( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove_sha2 = guest::build_prover_sha2(program, prover_preprocessing); let verify_sha2 = guest::build_verifier_sha2(verifier_preprocessing); diff --git a/examples/sha3-chain/src/main.rs b/examples/sha3-chain/src/main.rs index 73e04f967..8f3fff939 100644 --- a/examples/sha3-chain/src/main.rs +++ b/examples/sha3-chain/src/main.rs @@ -3,16 +3,30 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_sha3_chain(target_dir); - let shared_preprocessing = guest::preprocess_shared_sha3_chain(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_sha3_chain(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_sha3_chain( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_sha3_chain(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_sha3_chain(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_sha3_chain(&mut program).unwrap(); + let prover_preprocessing = + guest::preprocess_prover_sha3_chain(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_sha3_chain( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove_sha3_chain = guest::build_prover_sha3_chain(program, prover_preprocessing); let verify_sha3_chain = guest::build_verifier_sha3_chain(verifier_preprocessing); diff --git a/examples/sha3-ex/src/main.rs b/examples/sha3-ex/src/main.rs index ecbc9ad2b..c211d49a4 100644 --- a/examples/sha3-ex/src/main.rs +++ b/examples/sha3-ex/src/main.rs @@ -3,16 +3,29 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_sha3(target_dir); - let shared_preprocessing = guest::preprocess_shared_sha3(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_sha3(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_sha3( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_sha3(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_sha3(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_sha3(&mut program).unwrap(); + let prover_preprocessing = guest::preprocess_prover_sha3(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_sha3( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove_sha3 = guest::build_prover_sha3(program, prover_preprocessing); let verify_sha3 = guest::build_verifier_sha3(verifier_preprocessing); diff --git a/examples/sig-recovery/host/src/main.rs b/examples/sig-recovery/host/src/main.rs index 6a7624eb9..7aeed25f2 100644 --- a/examples/sig-recovery/host/src/main.rs +++ b/examples/sig-recovery/host/src/main.rs @@ -17,6 +17,10 @@ fn main() { .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")), ) .init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); info!("sig-recovery: zkVM ECDSA Signature Recovery"); info!("=============================================\n"); @@ -42,11 +46,21 @@ fn main() { info!("\nPreprocessing..."); let start = Instant::now(); - let shared_preprocessing = guest::preprocess_shared_verify_txs(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_verify_txs(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); - let verifier_preprocessing = - guest::preprocess_verifier_verify_txs(shared_preprocessing, verifier_setup, None); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_verify_txs(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_verify_txs(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_verify_txs(&mut program).unwrap(); + let prover_preprocessing = + guest::preprocess_prover_verify_txs(shared_preprocessing.clone()); + let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let verifier_preprocessing = + guest::preprocess_verifier_verify_txs(shared_preprocessing, verifier_setup, None); + (prover_preprocessing, verifier_preprocessing) + }; info!("Preprocessing time: {:?}", start.elapsed()); let prove_verify_txs = guest::build_prover_verify_txs(program, prover_preprocessing); diff --git a/examples/stdlib/src/main.rs b/examples/stdlib/src/main.rs index a69bbc896..4d147f856 100644 --- a/examples/stdlib/src/main.rs +++ b/examples/stdlib/src/main.rs @@ -3,6 +3,10 @@ use tracing::info; pub fn main() { tracing_subscriber::fmt::init(); + let bytecode_chunk = std::env::args() + .skip_while(|arg| arg != "--committed-bytecode") + .nth(1) + .map(|arg| arg.parse().unwrap()); let target_dir = "/tmp/jolt-guest-targets"; @@ -10,13 +14,23 @@ pub fn main() { info!("=== Int to String ==="); let mut program = guest::compile_int_to_string(target_dir); - let shared_preprocessing = guest::preprocess_shared_int_to_string(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_int_to_string(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_int_to_string( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_int_to_string(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_int_to_string(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_int_to_string(&mut program).unwrap(); + let prover_preprocessing = + guest::preprocess_prover_int_to_string(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_int_to_string( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove = guest::build_prover_int_to_string(program, prover_preprocessing); let verify = guest::build_verifier_int_to_string(verifier_preprocessing); @@ -31,13 +45,23 @@ pub fn main() { info!("=== String Concat ==="); let mut program = guest::compile_string_concat(target_dir); - let shared_preprocessing = guest::preprocess_shared_string_concat(&mut program).unwrap(); - let prover_preprocessing = guest::preprocess_prover_string_concat(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_string_concat( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_string_concat(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_string_concat(&prover_preprocessing); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = guest::preprocess_shared_string_concat(&mut program).unwrap(); + let prover_preprocessing = + guest::preprocess_prover_string_concat(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_string_concat( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove = guest::build_prover_string_concat(program, prover_preprocessing); let verify = guest::build_verifier_string_concat(verifier_preprocessing); @@ -56,15 +80,26 @@ pub fn main() { info!("=== Parallel Sum of Squares (rayon) ==="); let mut program = guest::compile_parallel_sum_of_squares(target_dir); - let shared_preprocessing = - guest::preprocess_shared_parallel_sum_of_squares(&mut program).unwrap(); - let prover_preprocessing = - guest::preprocess_prover_parallel_sum_of_squares(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_parallel_sum_of_squares( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - None, - ); + let (prover_preprocessing, verifier_preprocessing) = if let Some(chunk_count) = bytecode_chunk { + let prover_preprocessing = + guest::preprocess_committed_parallel_sum_of_squares(&mut program, chunk_count).unwrap(); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_parallel_sum_of_squares( + &prover_preprocessing, + ); + (prover_preprocessing, verifier_preprocessing) + } else { + let shared_preprocessing = + guest::preprocess_shared_parallel_sum_of_squares(&mut program).unwrap(); + let prover_preprocessing = + guest::preprocess_prover_parallel_sum_of_squares(shared_preprocessing.clone()); + let verifier_preprocessing = guest::preprocess_verifier_parallel_sum_of_squares( + shared_preprocessing, + prover_preprocessing.generators.to_verifier_setup(), + None, + ); + (prover_preprocessing, verifier_preprocessing) + }; let prove = guest::build_prover_parallel_sum_of_squares(program, prover_preprocessing); let verify = guest::build_verifier_parallel_sum_of_squares(verifier_preprocessing); diff --git a/jolt-core/benches/e2e_profiling.rs b/jolt-core/benches/e2e_profiling.rs index 05df3ddd5..af2e8dfc2 100644 --- a/jolt-core/benches/e2e_profiling.rs +++ b/jolt-core/benches/e2e_profiling.rs @@ -1,5 +1,6 @@ use ark_serialize::CanonicalSerialize; use jolt_core::host; +use jolt_core::zkvm::program::ProgramPreprocessing; use jolt_core::zkvm::prover::JoltProverPreprocessing; use jolt_core::zkvm::verifier::{JoltSharedPreprocessing, JoltVerifierPreprocessing}; use jolt_core::zkvm::{RV64IMACProver, RV64IMACVerifier}; @@ -201,20 +202,18 @@ fn prove_example( ) -> Vec<(tracing::Span, Box)> { let mut tasks = Vec::new(); let mut program = host::Program::new(example_name); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let (_lazy_trace, trace, _, program_io) = program.trace(&serialized_input, &[], &[]); let padded_trace_len = (trace.len() + 1).next_power_of_two(); drop(trace); let task = move || { + let program_data = ProgramPreprocessing::preprocess(bytecode, init_memory_state).unwrap(); let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode, + program_data, program_io.memory_layout.clone(), - init_memory_state, padded_trace_len, - e_entry, - ) - .unwrap(); + ); let preprocessing = JoltProverPreprocessing::new(shared_preprocessing); let elf_contents_opt = program.get_elf_contents(); @@ -255,7 +254,7 @@ fn prove_example_with_trace( _scale: usize, ) -> (std::time::Duration, usize, usize, usize) { let mut program = host::Program::new(example_name); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let (_, trace, _, program_io) = program.trace(&serialized_input, &[], &[]); assert!( @@ -263,14 +262,12 @@ fn prove_example_with_trace( "Trace is longer than expected" ); + let program_data = ProgramPreprocessing::preprocess(bytecode, init_memory_state).unwrap(); let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), + program_data, program_io.memory_layout.clone(), - init_memory_state, trace.len().next_power_of_two(), - e_entry, - ) - .unwrap(); + ); let preprocessing = JoltProverPreprocessing::new(shared_preprocessing); let elf_contents_opt = program.get_elf_contents(); diff --git a/jolt-core/src/guest/prover.rs b/jolt-core/src/guest/prover.rs index db86b9576..b6486981b 100644 --- a/jolt-core/src/guest/prover.rs +++ b/jolt-core/src/guest/prover.rs @@ -6,6 +6,7 @@ use crate::poly::commitment::commitment_scheme::{StreamingCommitmentScheme, ZkEv use crate::poly::commitment::dory::DoryCommitmentScheme; use crate::transcripts::Transcript; use crate::zkvm::bytecode::PreprocessingError; +use crate::zkvm::program::ProgramPreprocessing; use crate::zkvm::proof_serialization::JoltProof; use crate::zkvm::prover::JoltProverPreprocessing; use crate::zkvm::ProverDebugInfo; @@ -23,18 +24,14 @@ pub fn preprocess( > { use crate::zkvm::verifier::JoltSharedPreprocessing; - let (bytecode, memory_init, program_size, e_entry) = guest.decode(); + let (bytecode, memory_init, program_size, _e_entry) = guest.decode(); let mut memory_config = guest.memory_config; memory_config.program_size = Some(program_size); let memory_layout = MemoryLayout::new(&memory_config); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode, - memory_layout, - memory_init, - max_trace_length, - e_entry, - )?; + let program = ProgramPreprocessing::preprocess(bytecode, memory_init)?; + let shared_preprocessing = + JoltSharedPreprocessing::new(program, memory_layout, max_trace_length); Ok(JoltProverPreprocessing::new(shared_preprocessing)) } diff --git a/jolt-core/src/guest/verifier.rs b/jolt-core/src/guest/verifier.rs index aa75a8f15..e2a2e07f3 100644 --- a/jolt-core/src/guest/verifier.rs +++ b/jolt-core/src/guest/verifier.rs @@ -9,6 +9,7 @@ use crate::zkvm::verifier::BlindfoldSetup; use crate::guest::program::Program; use crate::poly::commitment::dory::DoryCommitmentScheme; use crate::transcripts::Transcript; +use crate::zkvm::program::ProgramPreprocessing; use crate::zkvm::proof_serialization::JoltProof; use crate::zkvm::verifier::JoltSharedPreprocessing; use crate::zkvm::verifier::JoltVerifier; @@ -37,18 +38,17 @@ fn preprocess_shared( guest: &Program, max_trace_length: usize, ) -> Result { - let (bytecode, memory_init, program_size, e_entry) = guest.decode(); + let (bytecode, memory_init, program_size, _e_entry) = guest.decode(); let mut memory_config = guest.memory_config; memory_config.program_size = Some(program_size); let memory_layout = MemoryLayout::new(&memory_config); - JoltSharedPreprocessing::new( - bytecode, + let program = ProgramPreprocessing::preprocess(bytecode, memory_init)?; + Ok(JoltSharedPreprocessing::new( + program, memory_layout, - memory_init, max_trace_length, - e_entry, - ) + )) } pub fn verify< diff --git a/jolt-core/src/poly/commitment/dory/commitment_scheme.rs b/jolt-core/src/poly/commitment/dory/commitment_scheme.rs index e4f85acd8..eb7a27f6e 100644 --- a/jolt-core/src/poly/commitment/dory/commitment_scheme.rs +++ b/jolt-core/src/poly/commitment/dory/commitment_scheme.rs @@ -1,6 +1,6 @@ //! Dory polynomial commitment scheme implementation -use super::dory_globals::{DoryGlobals, DoryLayout}; +use super::dory_globals::DoryGlobals; use super::jolt_dory_routines::{JoltG1Routines, JoltG2Routines}; use super::wrappers::{ ark_to_jolt, jolt_to_ark, ArkDoryProof, ArkFr, ArkG1, ArkGT, ArkworksProverSetup, @@ -27,6 +27,15 @@ use rayon::prelude::*; use std::borrow::Borrow; use tracing::trace_span; +fn debug_disable_dory_setup_cache() -> bool { + std::env::var("JOLT_DEBUG_DISABLE_DORY_SETUP_CACHE") + .map(|v| { + let value = v.trim().to_ascii_lowercase(); + !matches!(value.as_str(), "" | "0" | "false" | "off") + }) + .unwrap_or(false) +} + #[derive(Clone)] pub struct DoryCommitmentScheme; @@ -38,11 +47,31 @@ impl DoryOpeningProofHint { Self(row_commitments) } + pub fn empty() -> Self { + Self(Vec::new()) + } + + pub fn rows(&self) -> &[ArkG1] { + &self.0 + } + fn into_rows(self) -> Vec { self.0 } } +#[inline] +fn canonical_setup_log_n(max_num_vars: usize) -> usize { + // Dory's generator count depends on ceil(max_log_n / 2), so odd/even pairs like + // 23 and 24 share the same generator bucket. Canonicalizing to the even bucket + // representative keeps those runs on a single URS file. + if max_num_vars.is_multiple_of(2) { + max_num_vars + } else { + max_num_vars + 1 + } +} + pub fn bind_opening_inputs( transcript: &mut ProofTranscript, opening_point: &[F::Challenge], @@ -85,13 +114,13 @@ impl CommitmentScheme for DoryCommitmentScheme { fn setup_prover(max_num_vars: usize) -> Self::ProverSetup { let _span = trace_span!("DoryCommitmentScheme::setup_prover").entered(); + let canonical_max_num_vars = canonical_setup_log_n(max_num_vars); #[cfg(test)] DoryGlobals::configure_test_cache_root(); - #[cfg(not(target_arch = "wasm32"))] - let setup = ArkworksProverSetup::new_from_urs(max_num_vars); + let setup = ArkworksProverSetup::new_from_urs(canonical_max_num_vars); #[cfg(target_arch = "wasm32")] - let setup = ArkworksProverSetup::new(max_num_vars); + let setup = ArkworksProverSetup::new(canonical_max_num_vars); // The prepared-point cache in dory-pcs is global and can only be initialized once. // In unit tests, multiple setups with different sizes are created, so initializing the @@ -166,8 +195,7 @@ impl CommitmentScheme for DoryCommitmentScheme { let sigma = num_cols.log_2(); let nu = num_rows.log_2(); - let reordered_point = reorder_opening_point_for_layout::(opening_point); - let ark_point: Vec = reordered_point + let ark_point: Vec = opening_point .iter() .rev() .map(|p| { @@ -209,10 +237,8 @@ impl CommitmentScheme for DoryCommitmentScheme { ) -> Result<(), ProofVerifyError> { let _span = trace_span!("DoryCommitmentScheme::verify").entered(); - let reordered_point = reorder_opening_point_for_layout::(opening_point); - // Dory uses the opposite endian-ness as Jolt - let ark_point: Vec = reordered_point + let ark_point: Vec = opening_point .iter() .rev() .map(|p| { @@ -389,7 +415,11 @@ impl StreamingCommitmentScheme for DoryCommitmentScheme { } let g2_bases = &setup.g2_vec[..num_rows]; - let tier_2 = ::multi_pair_g2_setup(&row_commitments, g2_bases); + let tier_2 = if debug_disable_dory_setup_cache() { + ::multi_pair(&row_commitments, g2_bases) + } else { + ::multi_pair_g2_setup(&row_commitments, g2_bases) + }; (tier_2, DoryOpeningProofHint::new(row_commitments)) } else { @@ -397,7 +427,11 @@ impl StreamingCommitmentScheme for DoryCommitmentScheme { chunks.iter().flat_map(|chunk| chunk.clone()).collect(); let g2_bases = &setup.g2_vec[..row_commitments.len()]; - let tier_2 = ::multi_pair_g2_setup(&row_commitments, g2_bases); + let tier_2 = if debug_disable_dory_setup_cache() { + ::multi_pair(&row_commitments, g2_bases) + } else { + ::multi_pair_g2_setup(&row_commitments, g2_bases) + }; (tier_2, DoryOpeningProofHint::new(row_commitments)) } @@ -443,24 +477,3 @@ where Some((g1s, h1)) } } - -/// Reorders opening_point for AddressMajor layout. -/// -/// For AddressMajor layout, reorders opening_point from [r_address, r_cycle] to [r_cycle, r_address]. -/// This ensures that after Dory's reversal and splitting: -/// - Column (right) vector gets address variables (matching AddressMajor column indexing) -/// - Row (left) vector gets cycle variables (matching AddressMajor row indexing) -/// -/// For CycleMajor layout, returns the point unchanged. -fn reorder_opening_point_for_layout( - opening_point: &[F::Challenge], -) -> Vec { - if DoryGlobals::get_layout() == DoryLayout::AddressMajor { - let log_T = DoryGlobals::get_T().log_2(); - let log_K = opening_point.len().saturating_sub(log_T); - let (r_address, r_cycle) = opening_point.split_at(log_K); - [r_cycle, r_address].concat() - } else { - opening_point.to_vec() - } -} diff --git a/jolt-core/src/poly/commitment/dory/dory_globals.rs b/jolt-core/src/poly/commitment/dory/dory_globals.rs index 8772532d7..fdf430fc6 100644 --- a/jolt-core/src/poly/commitment/dory/dory_globals.rs +++ b/jolt-core/src/poly/commitment/dory/dory_globals.rs @@ -4,7 +4,7 @@ use crate::utils::math::Math; use allocative::Allocative; use dory::backends::arkworks::{init_cache, ArkG1, ArkG2}; use std::sync::{ - atomic::{AtomicU8, Ordering}, + atomic::{AtomicU8, AtomicUsize, Ordering}, RwLock, }; #[cfg(test)] @@ -143,6 +143,7 @@ impl From for u8 { // Main polynomial globals static GLOBAL_T: RwLock> = RwLock::new(None); +static MAIN_K_CHUNK: RwLock> = RwLock::new(None); static MAX_NUM_ROWS: RwLock> = RwLock::new(None); static NUM_COLUMNS: RwLock> = RwLock::new(None); @@ -161,6 +162,8 @@ static CURRENT_CONTEXT: AtomicU8 = AtomicU8::new(0); // Layout tracking: 0=CycleMajor, 1=AddressMajor static CURRENT_LAYOUT: AtomicU8 = AtomicU8::new(0); +// Largest Main log-embedding needed for precommitted/embed calculations. +static MAIN_LOG_EMBEDDING: AtomicUsize = AtomicUsize::new(0); /// Dory commitment context - determines which set of global parameters to use #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -256,6 +259,22 @@ impl DoryGlobals { log_t.saturating_sub(sigma_main) } + #[inline] + pub fn get_main_log_embedding() -> usize { + let stored = MAIN_LOG_EMBEDDING.load(Ordering::SeqCst); + if stored > 0 { + stored + } else { + let main_cols = Self::configured_main_num_columns(); + let main_rows = *MAX_NUM_ROWS + .read() + .unwrap() + .as_ref() + .expect("main max_num_rows not initialized"); + main_cols.log_2() + main_rows.log_2() + } + } + /// Get the current Dory context pub fn current_context() -> DoryContext { CURRENT_CONTEXT.load(Ordering::SeqCst).into() @@ -288,11 +307,84 @@ impl DoryGlobals { (Self::get_max_num_rows(), Self::get_num_columns()) } + #[inline] + pub(crate) fn main_k() -> usize { + *MAIN_K_CHUNK + .read() + .unwrap() + .as_ref() + .expect("main k not initialized") + } + + #[inline] + pub(crate) fn main_t() -> usize { + *GLOBAL_T + .read() + .unwrap() + .as_ref() + .expect("main t not initialized") + } + + #[inline] + pub(crate) fn configured_main_num_columns() -> usize { + *NUM_COLUMNS + .read() + .unwrap() + .as_ref() + .expect("main num_columns not initialized") + } + + #[inline] + fn main_embedding_extra_vars() -> usize { + let main_total_vars = Self::main_k().log_2() + Self::get_T().log_2(); + Self::get_main_log_embedding().saturating_sub(main_total_vars) + } + + /// Column stride for one-hot embeddings in the current layout/context. + pub fn one_hot_stride() -> usize { + if Self::current_context() != DoryContext::Main + || Self::get_layout() != DoryLayout::AddressMajor + { + return 1; + } + 1usize << Self::main_embedding_extra_vars() + } + + /// Column stride for dense trace-domain embeddings in the current layout/context. + pub fn dense_stride() -> usize { + if Self::current_context() != DoryContext::Main + || Self::get_layout() != DoryLayout::AddressMajor + { + return 1; + } + let dense_stride_log = Self::main_embedding_extra_vars() + Self::main_k().log_2(); + 1usize << dense_stride_log + } + + /// Returns the embedded cycle-domain size for the current Dory matrix. + pub fn get_embedded_t() -> usize { + let context = Self::current_context(); + if context != DoryContext::Main { + return Self::get_T(); + } + + let k = Self::main_k(); + let num_rows = Self::get_max_num_rows(); + let num_cols = Self::get_num_columns(); + let total = num_rows * num_cols; + debug_assert_eq!( + total % k, + 0, + "Invalid Main DoryGlobals: num_rows*num_cols must be divisible by K" + ); + total / k + } + /// Returns the "K" used to initialize the *main* Dory matrix for OneHot polynomials. - /// - /// This is derived from the identity: - /// `K * T == num_rows * num_cols` (all values are powers of two in our usage). pub fn k_from_matrix_shape() -> usize { + if Self::current_context() == DoryContext::Main { + return Self::main_k(); + } let (num_rows, num_cols) = Self::matrix_shape(); let t = Self::get_T(); debug_assert_eq!( @@ -303,22 +395,6 @@ impl DoryGlobals { (num_rows * num_cols) / t } - /// For `AddressMajor`, each Dory matrix row corresponds to this many cycles. - /// - /// Equivalent to `T / num_rows` and to `num_cols / K`. - pub fn address_major_cycles_per_row() -> usize { - let (num_rows, num_cols) = Self::matrix_shape(); - let k = Self::k_from_matrix_shape(); - debug_assert!(k > 0); - debug_assert_eq!(num_cols % k, 0, "Expected num_cols to be divisible by K"); - debug_assert_eq!( - Self::get_T() % num_rows, - 0, - "Expected T to be divisible by num_rows" - ); - num_cols / k - } - fn set_max_num_rows_for_context(max_num_rows: usize, context: DoryContext) { match context { DoryContext::Main => { @@ -365,6 +441,10 @@ impl DoryGlobals { } } + fn set_main_k(k: usize) { + *MAIN_K_CHUNK.write().unwrap() = Some(k); + } + pub fn get_num_columns() -> usize { let context = Self::current_context(); match context { @@ -430,6 +510,20 @@ impl DoryGlobals { (num_columns, num_rows, T) } + fn initialize_context_common( + K: usize, + embedded_t: usize, + stored_t: usize, + context: DoryContext, + ) -> Option<()> { + let (num_columns, num_rows, _) = Self::calculate_dimensions(K, embedded_t); + Self::set_num_columns_for_context(num_columns, context); + Self::set_T_for_context(stored_t, context); + Self::set_max_num_rows_for_context(num_rows, context); + + Some(()) + } + /// Initialize the globals for a specific Dory context /// /// # Arguments @@ -451,20 +545,30 @@ impl DoryGlobals { ) -> Option<()> { #[cfg(test)] Self::configure_test_cache_root(); - - let (num_columns, num_rows, t) = Self::calculate_dimensions(K, T); - Self::set_num_columns_for_context(num_columns, context); - Self::set_T_for_context(t, context); - Self::set_max_num_rows_for_context(num_rows, context); - - // For Main context, set layout (if provided) and ensure subsequent uses of `get_*` read from it if context == DoryContext::Main { - if let Some(l) = layout { - CURRENT_LAYOUT.store(l as u8, Ordering::SeqCst); - } - CURRENT_CONTEXT.store(DoryContext::Main as u8, Ordering::SeqCst); + return Self::initialize_main_with_log_embedding(K, T, K.log_2() + T.log_2(), layout); } + Self::initialize_context_common(K, T, T, context)?; + Some(()) + } + /// Initialize Main context with execution `T` and explicit `main_log_embedding` for + /// global precommitted geometry. + pub fn initialize_main_with_log_embedding( + K: usize, + T: usize, + matrix_total_vars: usize, + layout: Option, + ) -> Option<()> { + let log_k = K.log_2(); + let embedded_t = 1usize << matrix_total_vars.saturating_sub(log_k); + Self::initialize_context_common(K, embedded_t, T, DoryContext::Main)?; + Self::set_main_k(K); + if let Some(l) = layout { + CURRENT_LAYOUT.store(l as u8, Ordering::SeqCst); + } + CURRENT_CONTEXT.store(DoryContext::Main as u8, Ordering::SeqCst); + MAIN_LOG_EMBEDDING.store(matrix_total_vars, Ordering::SeqCst); Some(()) } @@ -475,6 +579,7 @@ impl DoryGlobals { // Reset main globals *GLOBAL_T.write().unwrap() = None; + *MAIN_K_CHUNK.write().unwrap() = None; *MAX_NUM_ROWS.write().unwrap() = None; *NUM_COLUMNS.write().unwrap() = None; @@ -492,6 +597,7 @@ impl DoryGlobals { *UNTRUSTED_ADVICE_NUM_COLUMNS.write().unwrap() = None; CURRENT_CONTEXT.store(0, Ordering::SeqCst); + MAIN_LOG_EMBEDDING.store(0, Ordering::SeqCst); } /// Initialize the prepared point cache for faster pairing operations diff --git a/jolt-core/src/poly/commitment/dory/mod.rs b/jolt-core/src/poly/commitment/dory/mod.rs index 4204949e8..11c2444f1 100644 --- a/jolt-core/src/poly/commitment/dory/mod.rs +++ b/jolt-core/src/poly/commitment/dory/mod.rs @@ -13,7 +13,7 @@ mod tests; #[cfg(feature = "zk")] pub use commitment_scheme::bind_opening_inputs_zk; -pub use commitment_scheme::{bind_opening_inputs, DoryCommitmentScheme}; +pub use commitment_scheme::{bind_opening_inputs, DoryCommitmentScheme, DoryOpeningProofHint}; pub use dory_globals::{DoryContext, DoryGlobals, DoryLayout}; pub use jolt_dory_routines::{JoltG1Routines, JoltG2Routines}; pub use wrappers::{ diff --git a/jolt-core/src/poly/commitment/dory/tests.rs b/jolt-core/src/poly/commitment/dory/tests.rs index 0237a5549..69052a653 100644 --- a/jolt-core/src/poly/commitment/dory/tests.rs +++ b/jolt-core/src/poly/commitment/dory/tests.rs @@ -8,6 +8,7 @@ mod tests { use crate::poly::dense_mlpoly::DensePolynomial; use crate::poly::multilinear_polynomial::{MultilinearPolynomial, PolynomialEvaluation}; use crate::transcripts::{Blake2bTranscript, Transcript}; + use crate::utils::math::Math; use ark_ff::biginteger::S128; use ark_std::rand::{thread_rng, Rng}; use ark_std::{UniformRand, Zero}; @@ -879,19 +880,26 @@ mod tests { let num_vars = one_hot_poly.get_num_vars(); let poly = MultilinearPolynomial::OneHot(one_hot_poly); - let opening_point: Vec<::Challenge> = (0..num_vars) + // AddressMajor Dory opening points are consumed as [cycle vars || address vars], + // while OneHotPolynomial::evaluate expects [address vars || cycle vars]. + let log_t = T.log_2(); + let log_k = num_vars - log_t; + let r_cycle: Vec<::Challenge> = (0..log_t) + .map(|_| ::Challenge::random(&mut rng)) + .collect(); + let r_address: Vec<::Challenge> = (0..log_k) .map(|_| ::Challenge::random(&mut rng)) .collect(); + let opening_point = [r_cycle.clone(), r_address.clone()].concat(); + let eval_point = [r_address, r_cycle].concat(); let prover_setup = DoryCommitmentScheme::setup_prover(num_vars); let verifier_setup = DoryCommitmentScheme::setup_verifier(&prover_setup); let (commitment, row_commitments) = DoryCommitmentScheme::commit(&poly, &prover_setup); - let evaluation = as PolynomialEvaluation>::evaluate( - &poly, - &opening_point, - ); + let evaluation = + as PolynomialEvaluation>::evaluate(&poly, &eval_point); let mut prove_transcript = Blake2bTranscript::new(b"dory_test"); bind_opening_inputs::(&mut prove_transcript, &opening_point, &evaluation); @@ -975,16 +983,26 @@ mod tests { let vmp_result = rlc_poly.vector_matrix_product(&left_vec); let mut expected = vec![Fr::zero(); num_columns]; - let cycles_per_row = DoryGlobals::address_major_cycles_per_row(); + let dense_stride = DoryGlobals::dense_stride(); + let cycles_per_row = num_columns / dense_stride; // Dense contribution for AddressMajor layout: // Dense coefficients occupy evenly-spaced columns (every K-th column). // Coefficient i maps to: row = i / cycles_per_row, col = (i % cycles_per_row) * K for (i, &coeff) in rlc_dense.iter().enumerate() { - let row = i / cycles_per_row; - let col = (i % cycles_per_row) * K; - if row < num_rows && col < num_columns { - expected[col] += left_vec[row] * coeff; + if cycles_per_row == 0 { + let scaled_index = i * dense_stride; + let row = scaled_index / num_columns; + let col = scaled_index % num_columns; + if row < num_rows && col < num_columns { + expected[col] += left_vec[row] * coeff; + } + } else { + let row = i / cycles_per_row; + let col = (i % cycles_per_row) * K; + if row < num_rows && col < num_columns { + expected[col] += left_vec[row] * coeff; + } } } diff --git a/jolt-core/src/poly/commitment/dory/wrappers.rs b/jolt-core/src/poly/commitment/dory/wrappers.rs index 18fb13d21..1ce7149ec 100644 --- a/jolt-core/src/poly/commitment/dory/wrappers.rs +++ b/jolt-core/src/poly/commitment/dory/wrappers.rs @@ -202,30 +202,109 @@ where let dory_context = DoryGlobals::current_context(); let dory_layout = DoryGlobals::get_layout(); - // Dense polynomials (all scalar variants except OneHot/RLC) are committed row-wise. - // Under AddressMajor, dense coefficients occupy evenly-spaced columns, so each row - // commitment uses `cycles_per_row` bases (one per occupied column). - let (dense_affine_bases, dense_chunk_size): (Vec<_>, usize) = match (dory_context, dory_layout) - { - (DoryContext::Main, DoryLayout::AddressMajor) => { - let cycles_per_row = DoryGlobals::address_major_cycles_per_row(); - let bases: Vec<_> = g1_slice - .par_iter() - .take(row_len) - .step_by(row_len / cycles_per_row) - .map(|g| g.0.into_affine()) - .collect(); - (bases, cycles_per_row) + let is_dense_poly = !matches!( + poly, + MultilinearPolynomial::OneHot(_) | MultilinearPolynomial::RLC(_) + ); + + let is_trace_dense_addr_major = matches!(dory_context, DoryContext::Main) + && dory_layout == DoryLayout::AddressMajor + && is_dense_poly; + debug_assert!( + !is_trace_dense_addr_major || poly.original_len() <= DoryGlobals::get_T(), + "Main+AddressMajor dense polynomial length exceeds trace T" + ); + + let (dense_affine_bases, dense_chunk_size, dense_sparse_row_terms) = + if is_trace_dense_addr_major { + let stride = DoryGlobals::dense_stride(); + let cycles_per_row = row_len / stride; + // This branch is taken when the AddressMajor trace-dense embedding stride exceeds + // the post-embedded Main row width (`row_len`), i.e. `row_len < stride`. + // + // With: + // - M = DoryGlobals::get_main_log_embedding() = total embedded Main vars + // - k = log2(main K) + // - t = log2(execution T) + // - e = embedding extra vars = M - (k + t) + // + // we have: + // - row_len = 2^sigma_main, where sigma_main = ceil(M/2) + // = 2^ceil((e + k + t)/2) + // - stride = 2^(main_embedding_extra_vars + k) = 2^(M - t) = 2^(e + k) + // + // so `cycles_per_row == 0` exactly when: + // ceil(M/2) < (M - t) <=> t < floor(M/2). + if cycles_per_row == 0 { + let dense_len = poly.original_len(); + let dense_affine_bases: Vec<_> = g1_slice + .par_iter() + .take(row_len) + .map(|g| g.0.into_affine()) + .collect(); + let num_rows = DoryGlobals::get_max_num_rows(); + let sparse_terms: Vec<(usize, usize, Fr)> = (0..dense_len) + .into_par_iter() + .filter_map(|cycle| { + let coeff = poly.get_coeff(cycle); + if coeff.is_zero() { + return None; + } + let scaled_index = cycle.saturating_mul(stride); + let row_index = scaled_index / row_len; + let col_index = scaled_index % row_len; + debug_assert!(row_index < num_rows); + Some((row_index, col_index, coeff)) + }) + .collect(); + let mut row_terms: Vec> = vec![Vec::new(); num_rows]; + for (row_index, col_index, coeff) in sparse_terms { + row_terms[row_index].push((col_index, coeff)); + } + (dense_affine_bases, 1, Some(row_terms)) + } else { + let dense_affine_bases: Vec<_> = g1_slice + .par_iter() + .take(row_len) + .step_by(stride) + .map(|g| g.0.into_affine()) + .collect(); + (dense_affine_bases, cycles_per_row, None) + } + } else { + ( + g1_slice + .par_iter() + .take(row_len) + .map(|g| g.0.into_affine()) + .collect(), + row_len, + None, + ) + }; + + if let Some(row_terms) = dense_sparse_row_terms { + let result: Vec = row_terms + .into_par_iter() + .map(|terms| { + if terms.is_empty() { + return ArkG1(ark_bn254::G1Projective::zero()); + } + let mut bases = Vec::with_capacity(terms.len()); + let mut scalars = Vec::with_capacity(terms.len()); + for (col_index, scalar) in terms { + bases.push(dense_affine_bases[col_index]); + scalars.push(scalar); + } + ArkG1(VariableBaseMSM::msm_field_elements(&bases, &scalars).unwrap()) + }) + .collect(); + // SAFETY: Vec and Vec have the same memory layout when E = BN254. + #[allow(clippy::missing_transmute_annotations)] + unsafe { + return Ok(std::mem::transmute(result)); } - _ => ( - g1_slice - .par_iter() - .take(row_len) - .map(|g| g.0.into_affine()) - .collect(), - row_len, - ), - }; + } let result: Vec = match poly { MultilinearPolynomial::LargeScalars(poly) => poly diff --git a/jolt-core/src/poly/one_hot_polynomial.rs b/jolt-core/src/poly/one_hot_polynomial.rs index 5a807446f..f4d6b107c 100644 --- a/jolt-core/src/poly/one_hot_polynomial.rs +++ b/jolt-core/src/poly/one_hot_polynomial.rs @@ -56,10 +56,11 @@ impl OneHotPolynomial { /// /// Note: the Dory matrix may be square or almost-square depending on `log2(K*T)`. pub fn num_rows(&self) -> usize { - let t = self.nonzero_indices.len(); match DoryGlobals::get_layout() { - DoryLayout::AddressMajor => t.div_ceil(DoryGlobals::address_major_cycles_per_row()), - DoryLayout::CycleMajor => (t * self.K).div_ceil(DoryGlobals::get_num_columns()), + DoryLayout::AddressMajor => DoryGlobals::get_max_num_rows(), + DoryLayout::CycleMajor => { + (DoryGlobals::get_T() * self.K).div_ceil(DoryGlobals::get_num_columns()) + } } } @@ -104,7 +105,7 @@ impl OneHotPolynomial { } pub fn from_indices(nonzero_indices: Vec>, K: usize) -> Self { - debug_assert_eq!(DoryGlobals::get_T(), nonzero_indices.len()); + debug_assert!(nonzero_indices.len() <= DoryGlobals::get_T()); assert!(K <= 1usize << u8::BITS, "K must be <= 256 for indices"); Self { @@ -120,9 +121,15 @@ impl OneHotPolynomial { bases: &[G::Affine], ) -> Vec { let layout = DoryGlobals::get_layout(); + let one_hot_stride = DoryGlobals::one_hot_stride(); let num_rows = self.num_rows(); let row_len = DoryGlobals::get_num_columns(); let t = self.nonzero_indices.len(); + let effective_t = DoryGlobals::get_T(); + debug_assert_eq!( + effective_t, t, + "one-hot polynomial length must match configured Main T" + ); debug_assert!( bases.len() >= row_len, @@ -172,11 +179,16 @@ impl OneHotPolynomial { // General path: collect column indices for each row based on layout let mut row_indices: Vec> = vec![Vec::new(); num_rows]; + let dense_stride = DoryGlobals::dense_stride(); for (cycle, k) in self.nonzero_indices.iter().enumerate() { if let Some(k) = k { - let global_index = layout.address_cycle_to_index(*k as usize, cycle, self.K, t); - let row_index = global_index / row_len; - let col_index = global_index % row_len; + let scaled_index = if layout == DoryLayout::AddressMajor { + cycle * dense_stride + (*k as usize) * one_hot_stride + } else { + layout.address_cycle_to_index(*k as usize, cycle, self.K, effective_t) + }; + let row_index = scaled_index / row_len; + let col_index = scaled_index % row_len; if row_index < num_rows { row_indices[row_index].push(col_index); } @@ -211,12 +223,13 @@ impl OneHotPolynomial { pub fn vector_matrix_product(&self, left_vec: &[F], coeff: F, result: &mut [F]) { let layout = DoryGlobals::get_layout(); let t = self.nonzero_indices.len(); + let effective_t = DoryGlobals::get_T(); let num_columns = DoryGlobals::get_num_columns(); debug_assert_eq!(result.len(), num_columns); // CycleMajor optimization for T >= row_len (typical case where T >= K) if layout == DoryLayout::CycleMajor && t >= num_columns { - let rows_per_k = t / num_columns; + let rows_per_k = effective_t / num_columns; result .par_iter_mut() .enumerate() @@ -234,11 +247,17 @@ impl OneHotPolynomial { } // General path: iterate through nonzero indices and compute contributions + let dense_stride = DoryGlobals::dense_stride(); + let one_hot_stride = DoryGlobals::one_hot_stride(); for (cycle, k) in self.nonzero_indices.iter().enumerate() { if let Some(k) = k { - let global_index = layout.address_cycle_to_index(*k as usize, cycle, self.K, t); - let row_index = global_index / num_columns; - let col_index = global_index % num_columns; + let scaled_index = if layout == DoryLayout::AddressMajor { + cycle * dense_stride + (*k as usize) * one_hot_stride + } else { + layout.address_cycle_to_index(*k as usize, cycle, self.K, effective_t) + }; + let row_index = scaled_index / num_columns; + let col_index = scaled_index % num_columns; if row_index < left_vec.len() && col_index < result.len() { result[col_index] += coeff * left_vec[row_index]; } diff --git a/jolt-core/src/poly/opening_proof.rs b/jolt-core/src/poly/opening_proof.rs index caaa32e72..2baec6de6 100644 --- a/jolt-core/src/poly/opening_proof.rs +++ b/jolt-core/src/poly/opening_proof.rs @@ -7,7 +7,10 @@ use crate::{ poly::rlc_polynomial::{RLCPolynomial, RLCStreamingData, TraceSource}, - zkvm::{claim_reductions::AdviceKind, config::OneHotParams}, + zkvm::{ + claim_reductions::{AdviceKind, PrecommittedPolynomial}, + config::OneHotParams, + }, }; use allocative::Allocative; use num_derive::FromPrimitive; @@ -151,10 +154,16 @@ pub enum SumcheckId { RegistersClaimReduction, RegistersReadWriteChecking, RegistersValEvaluation, + BytecodeReadRafAddressPhase, BytecodeReadRaf, + BooleanityAddressPhase, Booleanity, AdviceClaimReductionCyclePhase, AdviceClaimReduction, + BytecodeClaimReductionCyclePhase, + BytecodeClaimReduction, + ProgramImageClaimReductionCyclePhase, + ProgramImageClaimReduction, IncClaimReduction, HammingWeightClaimReduction, } @@ -327,7 +336,7 @@ pub struct DoryOpeningState { impl DoryOpeningState { /// Build streaming RLC polynomial from this state. /// Streams directly from trace - no witness regeneration needed. - /// Advice polynomials are passed separately (not streamed from trace). + /// Precommitted polynomials are passed separately (not streamed from trace). #[tracing::instrument(skip_all)] pub fn build_streaming_rlc>( &self, @@ -335,7 +344,7 @@ impl DoryOpeningState { trace_source: TraceSource, rlc_streaming_data: Arc, mut opening_hints: HashMap, - advice_polys: HashMap>, + precommitted_polys: HashMap>, ) -> (MultilinearPolynomial, PCS::OpeningProofHint) { // Accumulate gamma coefficients per polynomial let mut rlc_map = BTreeMap::new(); @@ -352,7 +361,7 @@ impl DoryOpeningState { trace_source, poly_ids.clone(), &coeffs, - advice_polys, + precommitted_polys, )); let hints: Vec = rlc_map @@ -621,6 +630,10 @@ where pub fn take_pending_claim_ids(&mut self) -> Vec { std::mem::take(&mut self.pending_claim_ids) } + + pub fn pending_claim_ids_debug(&self) -> &[OpeningId] { + &self.pending_claim_ids + } } impl Default for VerifierOpeningAccumulator @@ -851,39 +864,43 @@ where pub fn take_pending_claim_ids(&mut self) -> Vec { std::mem::take(&mut self.pending_claim_ids) } + + pub fn pending_claim_ids_debug(&self) -> &[OpeningId] { + &self.pending_claim_ids + } } -/// Computes the Lagrange factor for embedding a smaller "advice" polynomial into the top-left -/// block of the main Dory matrix. +/// Computes the Lagrange factor for embedding a smaller polynomial into the top-left block of +/// the main Dory matrix. /// -/// Advice polynomials have fewer variables than main polynomials. To batch them together, -/// we embed advice in the top-left corner of the larger matrix and multiply by a Lagrange +/// Embedded polynomials can have fewer variables than main polynomials. To batch them together, +/// we embed them in the top-left corner of the larger matrix and multiply by a Lagrange /// selector that is 1 on that block and 0 elsewhere: /// /// ```text -/// Lagrange factor = ∏_{r ∈ opening_point, r ∉ advice_opening_point} (1 - r) +/// Lagrange factor = ∏_{r ∈ opening_point, r ∉ embedded_opening_point} (1 - r) /// ``` /// /// # Arguments /// - `opening_point`: The unified opening point for the Dory opening proof -/// - `advice_opening_point`: The opening point for the advice polynomial +/// - `embedded_opening_point`: The opening point for the embedded polynomial /// /// # Returns /// The Lagrange factor as a field element -pub fn compute_advice_lagrange_factor( +pub fn compute_lagrange_factor( opening_point: &[F::Challenge], - advice_opening_point: &[F::Challenge], + embedded_opening_point: &[F::Challenge], ) -> F { #[cfg(test)] { - for r in advice_opening_point.iter() { + for r in embedded_opening_point.iter() { assert!(opening_point.contains(r)); } } opening_point .iter() .map(|r| { - if advice_opening_point.contains(r) { + if embedded_opening_point.contains(r) { F::one() } else { F::one() - r diff --git a/jolt-core/src/poly/rlc_polynomial.rs b/jolt-core/src/poly/rlc_polynomial.rs index b60aa40c3..a1fe5aaa3 100644 --- a/jolt-core/src/poly/rlc_polynomial.rs +++ b/jolt-core/src/poly/rlc_polynomial.rs @@ -4,10 +4,17 @@ use crate::poly::multilinear_polynomial::MultilinearPolynomial; use crate::utils::accumulation::MedAccumS; use crate::utils::math::{s64_from_diff_u64s, Math}; use crate::utils::thread::unsafe_allocate_zero_vec; +use crate::zkvm::claim_reductions::PrecommittedPolynomial; use crate::zkvm::config::OneHotParams; use crate::zkvm::instruction::LookupQuery; use crate::zkvm::ram::remap_address; -use crate::zkvm::{bytecode::BytecodePreprocessing, witness::CommittedPolynomial}; +use crate::zkvm::{ + bytecode::{ + chunks::{committed_lanes, for_each_active_lane_value, ActiveLaneValue}, + BytecodePreprocessing, + }, + witness::CommittedPolynomial, +}; use allocative::Allocative; use common::constants::XLEN; use common::jolt_device::MemoryLayout; @@ -56,9 +63,9 @@ impl TraceSource { pub struct StreamingRLCContext { pub dense_polys: Vec<(CommittedPolynomial, F)>, pub onehot_polys: Vec<(CommittedPolynomial, F)>, - /// Advice polynomials with their RLC coefficients. + /// Precommitted polynomials with their RLC coefficients. /// These are NOT streamed from trace - they're passed in directly. - pub advice_polys: Vec<(F, MultilinearPolynomial)>, + pub precommitted_polys: Vec<(F, PrecommittedPolynomial)>, pub trace_source: TraceSource, pub preprocessing: Arc, pub one_hot_params: OneHotParams, @@ -165,7 +172,7 @@ impl RLCPolynomial { /// * `trace_source` - Either materialized trace (default) or lazy trace (experimental) /// * `poly_ids` - List of polynomial identifiers /// * `coefficients` - RLC coefficients for each polynomial - /// * `advice_poly_map` - Map of advice polynomial IDs to their actual polynomials + /// * `precommitted_poly_map` - Map of precommitted polynomial IDs to their actual polynomials #[tracing::instrument(skip_all)] pub fn new_streaming( one_hot_params: OneHotParams, @@ -173,13 +180,13 @@ impl RLCPolynomial { trace_source: TraceSource, poly_ids: Vec, coefficients: &[F], - mut advice_poly_map: HashMap>, + mut precommitted_poly_map: HashMap>, ) -> Self { debug_assert_eq!(poly_ids.len(), coefficients.len()); let mut dense_polys = Vec::new(); let mut onehot_polys = Vec::new(); - let mut advice_polys = Vec::new(); + let mut precommitted_polys = Vec::new(); for (poly_id, coeff) in poly_ids.iter().zip(coefficients.iter()) { match poly_id { @@ -191,10 +198,14 @@ impl RLCPolynomial { | CommittedPolynomial::RamRa(_) => { onehot_polys.push((*poly_id, *coeff)); } - CommittedPolynomial::TrustedAdvice | CommittedPolynomial::UntrustedAdvice => { - // Advice polynomials are passed in directly (not streamed from trace) - if advice_poly_map.contains_key(poly_id) { - advice_polys.push((*coeff, advice_poly_map.remove(poly_id).unwrap())); + CommittedPolynomial::TrustedAdvice + | CommittedPolynomial::UntrustedAdvice + | CommittedPolynomial::BytecodeChunk(_) + | CommittedPolynomial::ProgramImageInit => { + // Precommitted polynomials are passed in directly (not streamed from trace). + if precommitted_poly_map.contains_key(poly_id) { + precommitted_polys + .push((*coeff, precommitted_poly_map.remove(poly_id).unwrap())); } } } @@ -206,7 +217,7 @@ impl RLCPolynomial { streaming_context: Some(Arc::new(StreamingRLCContext { dense_polys, onehot_polys, - advice_polys, + precommitted_polys, trace_source, preprocessing, one_hot_params, @@ -295,21 +306,31 @@ impl RLCPolynomial { }); } DoryLayout::AddressMajor => { - let cycles_per_row = DoryGlobals::address_major_cycles_per_row(); - dense_result - .par_iter_mut() - .step_by(num_columns / cycles_per_row) + let dense_stride = DoryGlobals::dense_stride(); + dense_result = self + .dense_rlc + .par_iter() .enumerate() - .for_each(|(offset, dot_product_result)| { - *dot_product_result = self - .dense_rlc - .par_iter() - .skip(offset) - .step_by(cycles_per_row) - .zip(left_vec.par_iter()) - .map(|(&a, &b)| -> F { a * b }) - .sum::(); - }); + .fold( + || unsafe_allocate_zero_vec(num_columns), + |mut acc, (cycle, coeff)| { + let scaled_index = cycle.saturating_mul(dense_stride); + let row_index = scaled_index / num_columns; + if row_index >= left_vec.len() { + return acc; + } + let col_index = scaled_index % num_columns; + acc[col_index] += *coeff * left_vec[row_index]; + acc + }, + ) + .reduce( + || unsafe_allocate_zero_vec(num_columns), + |mut a, b| { + a.iter_mut().zip(b.iter()).for_each(|(x, y)| *x += *y); + a + }, + ); } } dense_result @@ -328,74 +349,177 @@ impl RLCPolynomial { result } - /// Adds the advice polynomial contribution to the vector-matrix-vector product result. + /// Adds the precommitted polynomial contribution to the vector-matrix-vector product result. /// - /// In Dory's batch opening, advice polynomials are embedded as the top-left block of the + /// In Dory's batch opening, precommitted polynomials are embedded as the top-left block of the /// main matrix. This function computes their contribution to the VMV product: /// ```text - /// result[col] += left_vec[row] * (coeff * advice[row, col]) + /// result[col] += left_vec[row] * (coeff * precommitted[row, col]) /// ``` - /// for rows and columns within the advice block. + /// for rows and columns within the precommitted block. /// - /// The advice block occupies: - /// - `sigma_a = ceil(advice_vars/2)`, `nu_a = advice_vars - sigma_a` - /// - `advice` occupies rows `[0 .. 2^{nu_a})` and cols `[0 .. 2^{sigma_a})` + /// The precommitted block occupies: + /// - `sigma_a = ceil(poly_vars/2)`, `nu_a = poly_vars - sigma_a` + /// - each precommitted polynomial occupies rows `[0 .. 2^{nu_a})` and cols `[0 .. 2^{sigma_a})` /// /// # Complexity /// It uses O(m + a) space where m is the number of rows - /// and a is the advice size, so even though it is linear it is negl space overall. - fn vmp_advice_contribution( + /// and a is the precommitted size, so even though it is linear it is negl space overall. + fn vmp_precommitted_contribution( result: &mut [F], left_vec: &[F], num_columns: usize, ctx: &StreamingRLCContext, ) { - // For each advice polynomial, compute its contribution to the result - ctx.advice_polys + // For each precommitted polynomial, compute its contribution to the result + ctx.precommitted_polys .iter() - .filter(|(_, advice_poly)| advice_poly.original_len() > 0) - .for_each(|(coeff, advice_poly)| { - let advice_len = advice_poly.original_len(); - let advice_vars = advice_len.log_2(); - let (sigma_a, nu_a) = DoryGlobals::balanced_sigma_nu(advice_vars); - let advice_cols = 1usize << sigma_a; - let advice_rows = 1usize << nu_a; - - debug_assert!( - advice_cols <= num_columns, - "Advice columns (2^{{sigma_a}}={advice_cols}) must fit in main num_columns={num_columns}; \ + .filter(|(_, precommitted_poly)| precommitted_poly.original_len() > 0) + .for_each(|(coeff, precommitted_poly)| { + match precommitted_poly { + PrecommittedPolynomial::Dense(poly) => { + let precommitted_len = poly.original_len(); + let precommitted_vars = precommitted_len.log_2(); + let (sigma_a, nu_a) = DoryGlobals::balanced_sigma_nu(precommitted_vars); + let precommitted_cols = 1usize << sigma_a; + let precommitted_rows = 1usize << nu_a; + + debug_assert!( + precommitted_cols <= num_columns, + "Precommitted columns (2^{{sigma_a}}={precommitted_cols}) must fit in main num_columns={num_columns}; \ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." - ); - - // Only the top-left block contributes: rows [0..advice_rows), cols [0..advice_cols) - let effective_rows = advice_rows.min(left_vec.len()); - - // Compute column contributions: for each column, sum contributions from all rows - // Note: advice_len is always advice_cols * advice_rows (advice size must be power of 2) - let column_contributions: Vec = (0..advice_cols) - .into_par_iter() - .map(|col_idx| { - // For this column, sum contributions from all non-zero rows - left_vec[..effective_rows] - .iter() - .enumerate() - .filter(|(_, &left)| !left.is_zero()) - .map(|(row_idx, &left)| { - let coeff_idx = row_idx * advice_cols + col_idx; - let advice_val = advice_poly.get_coeff(coeff_idx); - left * *coeff * advice_val + ); + + let effective_rows = precommitted_rows.min(left_vec.len()); + let column_contributions: Vec = (0..precommitted_cols) + .into_par_iter() + .map(|col_idx| { + left_vec[..effective_rows] + .iter() + .enumerate() + .filter(|(_, &left)| !left.is_zero()) + .map(|(row_idx, &left)| { + let coeff_idx = row_idx * precommitted_cols + col_idx; + let precommitted_val = poly.get_coeff(coeff_idx); + left * *coeff * precommitted_val + }) + .sum() }) - .sum() - }) - .collect(); - - // Add column contributions to result in parallel - result[..advice_cols] - .par_iter_mut() - .zip(column_contributions.par_iter()) - .for_each(|(res, &contrib)| { - *res += contrib; - }); + .collect(); + + result[..precommitted_cols] + .par_iter_mut() + .zip(column_contributions.par_iter()) + .for_each(|(res, &contrib)| { + *res += contrib; + }); + } + PrecommittedPolynomial::BytecodeChunk { + chunk_index, + chunk_cycle_len, + } => { + let precommitted_len = committed_lanes() * *chunk_cycle_len; + let precommitted_vars = precommitted_len.log_2(); + let (sigma_a, nu_a) = DoryGlobals::balanced_sigma_nu(precommitted_vars); + let precommitted_cols = 1usize << sigma_a; + let effective_rows = (1usize << nu_a).min(left_vec.len()); + let chunk_start = chunk_index * chunk_cycle_len; + let chunk_end = chunk_start + chunk_cycle_len; + let layout = DoryGlobals::get_layout(); + let column_contributions = ctx.preprocessing.bytecode.bytecode + [chunk_start..chunk_end] + .par_iter() + .enumerate() + .fold( + || unsafe_allocate_zero_vec(precommitted_cols), + |mut acc, (chunk_cycle, instr)| { + for_each_active_lane_value::(instr, |global_lane, lane_val| { + let coeff_idx = layout.address_cycle_to_index( + global_lane, + chunk_cycle, + committed_lanes(), + *chunk_cycle_len, + ); + let row_idx = coeff_idx / precommitted_cols; + if row_idx >= effective_rows { + return; + } + let left = left_vec[row_idx]; + if left.is_zero() { + return; + } + let lane_value = match lane_val { + ActiveLaneValue::One => F::one(), + ActiveLaneValue::Scalar(v) => v, + }; + let col_idx = coeff_idx % precommitted_cols; + acc[col_idx] += left * *coeff * lane_value; + }); + acc + }, + ) + .reduce( + || unsafe_allocate_zero_vec(precommitted_cols), + |mut a, b| { + a.iter_mut().zip(b.iter()).for_each(|(x, y)| *x += *y); + a + }, + ); + + result[..precommitted_cols] + .par_iter_mut() + .zip(column_contributions.par_iter()) + .for_each(|(res, &contrib)| { + *res += contrib; + }); + } + PrecommittedPolynomial::ProgramImage { + words, + padded_len, + } => { + let precommitted_vars = padded_len.log_2(); + let (sigma_a, nu_a) = DoryGlobals::balanced_sigma_nu(precommitted_vars); + let precommitted_cols = 1usize << sigma_a; + let effective_rows = (1usize << nu_a).min(left_vec.len()); + let column_contributions = words + .par_iter() + .enumerate() + .fold( + || unsafe_allocate_zero_vec(precommitted_cols), + |mut acc, (offset, &word)| { + if word == 0 { + return acc; + } + let coeff_idx = offset; + let row_idx = coeff_idx / precommitted_cols; + if row_idx >= effective_rows { + return acc; + } + let left = left_vec[row_idx]; + if left.is_zero() { + return acc; + } + let col_idx = coeff_idx % precommitted_cols; + acc[col_idx] += left * coeff.mul_u64(word); + acc + }, + ) + .reduce( + || unsafe_allocate_zero_vec(precommitted_cols), + |mut a, b| { + a.iter_mut().zip(b.iter()).for_each(|(x, y)| *x += *y); + a + }, + ); + + result[..precommitted_cols] + .par_iter_mut() + .zip(column_contributions.par_iter()) + .for_each(|(res, &contrib)| { + *res += contrib; + }); + } + } }); } @@ -415,7 +539,7 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." return self.address_major_vector_matrix_product(left_vec, num_columns, &ctx); } - let T = DoryGlobals::get_T(); + let T = DoryGlobals::get_embedded_t(); match &ctx.trace_source { TraceSource::Materialized(trace) => { self.materialized_vector_matrix_product(left_vec, num_columns, trace, &ctx, T) @@ -449,7 +573,7 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." // Use the regular vector_matrix_product on the materialized polynomial let mut result = materialized.vector_matrix_product(left_vec); - Self::vmp_advice_contribution(&mut result, left_vec, num_columns, ctx); + Self::vmp_precommitted_contribution(&mut result, left_vec, num_columns, ctx); result } @@ -516,8 +640,13 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." let num_rows = T / num_columns; let trace_len = trace.len(); - // Setup: precompute coefficients, row factors, and folded one-hot tables. - let setup = VmvSetup::new(ctx, left_vec, num_rows); + let main_embedding_mode = + DoryGlobals::get_layout() == DoryLayout::CycleMajor && trace_len < T; + + // When the dominant Stage-8 matrix is larger than the trace-backed prefix, one-hot + // witnesses still live on the exact trace prefix rather than the expanded matrix T. + let onehot_rows_per_k = trace_len.div_ceil(num_columns).min(num_rows); + let setup = VmvSetup::new(ctx, left_vec, num_rows, onehot_rows_per_k); // Divide rows evenly among threads using par_chunks on left_vec // Only use first num_rows elements (left_vec may be longer due to padding) @@ -538,7 +667,6 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." let scaled_rd_inc = row_weight * setup.rd_inc_coeff; let scaled_ram_inc = row_weight * setup.ram_inc_coeff; - let row_factor = setup.row_factors[row_idx]; // Split into valid trace range vs padding range. let valid_end = std::cmp::min(chunk_start + num_columns, trace_len); @@ -550,14 +678,33 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." // Process valid trace elements. for (col_idx, cycle) in row_cycles.iter().enumerate() { - setup.process_cycle( - cycle, - scaled_rd_inc, - scaled_ram_inc, - row_factor, - &mut dense_accs[col_idx], - &mut onehot_accs[col_idx], - ); + if main_embedding_mode { + setup.process_cycle_dense( + cycle, + scaled_rd_inc, + scaled_ram_inc, + &mut dense_accs[col_idx], + ); + setup.process_cycle_onehot_prefix( + cycle, + chunk_start + col_idx, + trace_len, + num_columns, + left_vec, + &ctx.onehot_polys, + &mut onehot_accs, + ); + } else { + let row_factor = setup.row_factors[row_idx]; + setup.process_cycle( + cycle, + scaled_rd_inc, + scaled_ram_inc, + row_factor, + &mut dense_accs[col_idx], + &mut onehot_accs[col_idx], + ); + } } } @@ -570,8 +717,8 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." let mut result = VmvSetup::::finalize(dense_accs, onehot_accs, num_columns); - // Advice contribution is small and independent of the trace; add it after the streamed pass. - Self::vmp_advice_contribution(&mut result, left_vec, num_columns, ctx); + // Precommitted contribution is small and independent of the trace; add it after the streamed pass. + Self::vmp_precommitted_contribution(&mut result, left_vec, num_columns, ctx); result } @@ -586,9 +733,13 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." T: usize, ) -> Vec { let num_rows = T / num_columns; + let trace_len = DoryGlobals::main_t(); + let main_embedding_mode = + DoryGlobals::get_layout() == DoryLayout::CycleMajor && trace_len < T; // Setup: precompute coefficients, row factors, and folded one-hot tables. - let setup = VmvSetup::new(ctx, left_vec, num_rows); + let onehot_rows_per_k = trace_len.div_ceil(num_columns).min(num_rows); + let setup = VmvSetup::new(ctx, left_vec, num_rows, onehot_rows_per_k); let (dense_accs, onehot_accs) = lazy_trace .pad_using(T, |_| Cycle::NoOp) @@ -601,18 +752,37 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." let row_weight = left_vec[row_idx]; let scaled_rd_inc = row_weight * setup.rd_inc_coeff; let scaled_ram_inc = row_weight * setup.ram_inc_coeff; - let row_factor = setup.row_factors[row_idx]; // Process columns within chunk sequentially. for (col_idx, cycle) in chunk.iter().enumerate() { - setup.process_cycle( - cycle, - scaled_rd_inc, - scaled_ram_inc, - row_factor, - &mut dense_accs[col_idx], - &mut onehot_accs[col_idx], - ); + let cycle_idx = row_idx * num_columns + col_idx; + if main_embedding_mode && cycle_idx < trace_len { + setup.process_cycle_dense( + cycle, + scaled_rd_inc, + scaled_ram_inc, + &mut dense_accs[col_idx], + ); + setup.process_cycle_onehot_prefix( + cycle, + cycle_idx, + trace_len, + num_columns, + left_vec, + &ctx.onehot_polys, + &mut onehot_accs, + ); + } else { + let row_factor = setup.row_factors[row_idx]; + setup.process_cycle( + cycle, + scaled_rd_inc, + scaled_ram_inc, + row_factor, + &mut dense_accs[col_idx], + &mut onehot_accs[col_idx], + ); + } } (dense_accs, onehot_accs) @@ -624,8 +794,8 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." ); let mut result = VmvSetup::::finalize(dense_accs, onehot_accs, num_columns); - // Advice contribution is small and independent of the trace; add it after the streamed pass. - Self::vmp_advice_contribution(&mut result, left_vec, num_columns, ctx); + // Precommitted contribution is small and independent of the trace; add it after the streamed pass. + Self::vmp_precommitted_contribution(&mut result, left_vec, num_columns, ctx); result } } @@ -659,19 +829,29 @@ struct VmvSetup<'a, F: JoltField> { } impl<'a, F: JoltField> VmvSetup<'a, F> { - fn new(ctx: &'a StreamingRLCContext, left_vec: &[F], num_rows: usize) -> Self { + fn new( + ctx: &'a StreamingRLCContext, + left_vec: &[F], + matrix_rows_per_k: usize, + active_onehot_rows_per_k: usize, + ) -> Self { let one_hot_params = &ctx.one_hot_params; let k_chunk = one_hot_params.k_chunk; debug_assert!( - left_vec.len() >= k_chunk * num_rows, + left_vec.len() >= k_chunk * matrix_rows_per_k, "left_vec too short for one-hot VMV: len={} need_at_least={}", left_vec.len(), - k_chunk * num_rows + k_chunk * matrix_rows_per_k ); // Compute row_factors and eq_k from left vector - let (row_factors, eq_k) = Self::compute_row_factors_and_eq_k(left_vec, num_rows, k_chunk); + let (row_factors, eq_k) = Self::compute_row_factors_and_eq_k( + left_vec, + matrix_rows_per_k, + active_onehot_rows_per_k, + k_chunk, + ); // Extract dense coefficients let mut rd_inc_coeff = F::zero(); @@ -703,16 +883,17 @@ impl<'a, F: JoltField> VmvSetup<'a, F> { #[inline] fn compute_row_factors_and_eq_k( left_vec: &[F], - rows_per_k: usize, + matrix_rows_per_k: usize, + active_onehot_rows_per_k: usize, k_chunk: usize, ) -> (Vec, Vec) { - let mut row_factors: Vec = unsafe_allocate_zero_vec(rows_per_k); + let mut row_factors: Vec = unsafe_allocate_zero_vec(matrix_rows_per_k); let mut eq_k: Vec = unsafe_allocate_zero_vec(k_chunk); for k in 0..k_chunk { - let base = k * rows_per_k; + let base = k * matrix_rows_per_k; let mut sum_k = F::zero(); - for row in 0..rows_per_k { + for row in 0..active_onehot_rows_per_k { let v = left_vec[base + row]; sum_k += v; row_factors[row] += v; @@ -723,6 +904,71 @@ impl<'a, F: JoltField> VmvSetup<'a, F> { (row_factors, eq_k) } + #[inline(always)] + fn process_cycle_dense( + &self, + cycle: &Cycle, + scaled_rd_inc: F, + scaled_ram_inc: F, + dense_acc: &mut MedAccumS, + ) { + let (_, pre_value, post_value) = cycle.rd_write().unwrap_or_default(); + let diff = s64_from_diff_u64s(post_value, pre_value); + dense_acc.fmadd(&scaled_rd_inc, &diff); + + if let tracer::instruction::RAMAccess::Write(write) = cycle.ram_access() { + let diff = s64_from_diff_u64s(write.post_value, write.pre_value); + dense_acc.fmadd(&scaled_ram_inc, &diff); + } + } + + #[allow(clippy::too_many_arguments)] + #[inline(always)] + fn process_cycle_onehot_prefix( + &self, + cycle: &Cycle, + cycle_idx: usize, + trace_len: usize, + num_columns: usize, + left_vec: &[F], + onehot_polys: &[(CommittedPolynomial, F)], + onehot_accs: &mut [F::UnreducedProductAccum], + ) { + let lookup_index = LookupQuery::::to_lookup_index(cycle); + let pc = self.bytecode.get_pc(cycle); + let remapped_address = + remap_address(cycle.ram_access().address() as u64, self.memory_layout); + + for (poly_id, coeff) in onehot_polys.iter() { + if coeff.is_zero() { + continue; + } + + let k = match poly_id { + CommittedPolynomial::InstructionRa(idx) => { + self.one_hot_params.lookup_index_chunk(lookup_index, *idx) as usize + } + CommittedPolynomial::BytecodeRa(idx) => { + self.one_hot_params.bytecode_pc_chunk(pc, *idx) as usize + } + CommittedPolynomial::RamRa(idx) => { + let Some(addr) = remapped_address else { + continue; + }; + self.one_hot_params.ram_address_chunk(addr, *idx) as usize + } + _ => unreachable!("dense polynomial found in onehot_polys"), + }; + + let global_index = k * trace_len + cycle_idx; + let row_index = global_index / num_columns; + let col_index = global_index % num_columns; + if row_index < left_vec.len() && col_index < onehot_accs.len() { + onehot_accs[col_index] += left_vec[row_index].mul_to_product_accum(*coeff); + } + } + } + /// Build per-polynomial folded one-hot tables (non-flattened). fn build_folded_tables( onehot_polys: &[(CommittedPolynomial, F)], @@ -788,15 +1034,7 @@ impl<'a, F: JoltField> VmvSetup<'a, F> { dense_acc: &mut MedAccumS, onehot_acc: &mut F::UnreducedProductAccum, ) { - // Dense polynomials: accumulate scaled_coeff * (post - pre) - let (_, pre_value, post_value) = cycle.rd_write().unwrap_or_default(); - let diff = s64_from_diff_u64s(post_value, pre_value); - dense_acc.fmadd(&scaled_rd_inc, &diff); - - if let tracer::instruction::RAMAccess::Write(write) = cycle.ram_access() { - let diff = s64_from_diff_u64s(write.post_value, write.pre_value); - dense_acc.fmadd(&scaled_ram_inc, &diff); - } + self.process_cycle_dense(cycle, scaled_rd_inc, scaled_ram_inc, dense_acc); // One-hot polynomials: accumulate using pre-folded K tables (unreduced) let mut inner_sum = F::UnreducedMulU64::default(); diff --git a/jolt-core/src/subprotocols/booleanity.rs b/jolt-core/src/subprotocols/booleanity.rs index 192fc39e4..871d360b1 100644 --- a/jolt-core/src/subprotocols/booleanity.rs +++ b/jolt-core/src/subprotocols/booleanity.rs @@ -1,12 +1,11 @@ -//! Booleanity Sumcheck +//! Booleanity Sumcheck (split into address/cycle phases) //! -//! This module implements a single booleanity sumcheck that handles all three families: -//! - Instruction RA polynomials -//! - Bytecode RA polynomials -//! - RAM RA polynomials +//! This module implements Stage 6 booleanity as two explicit sumcheck instances: +//! - Address phase (`log_k_chunk` rounds) +//! - Cycle phase (`log_t` rounds) //! -//! By combining them into a single sumcheck, all families share the same `r_address` and `r_cycle`, -//! which is required by the HammingWeightClaimReduction sumcheck in Stage 7. +//! Both phases still batch all three families together (InstructionRA, BytecodeRA, RAMRA), +//! so they share the same `r_address` and `r_cycle`, matching what Stage 7 claim reductions expect. //! //! ## Sumcheck Relation //! @@ -42,7 +41,7 @@ use crate::{ AbstractVerifierOpeningAccumulator, OpeningAccumulator, OpeningPoint, ProverOpeningAccumulator, SumcheckId, BIG_ENDIAN, }, - shared_ra_polys::{compute_all_G_and_ra_indices, RaIndices, SharedRaPolynomials}, + shared_ra_polys::{compute_all_G, compute_ra_indices, SharedRaPolynomials}, split_eq_poly::GruenSplitEqPolynomial, unipoly::UniPoly, }, @@ -51,7 +50,7 @@ use crate::{ sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}, }, transcripts::Transcript, - utils::{expanding_table::ExpandingTable, thread::drop_in_background_thread}, + utils::expanding_table::ExpandingTable, zkvm::{ bytecode::BytecodePreprocessing, config::OneHotParams, @@ -113,18 +112,18 @@ impl SumcheckInstanceParams for BooleanitySumcheckParams { } #[cfg(feature = "zk")] - fn input_constraint_challenge_values(&self, _: &dyn OpeningAccumulator) -> Vec { + fn input_constraint_challenge_values( + &self, + _accumulator: &dyn OpeningAccumulator, + ) -> Vec { Vec::new() } #[cfg(feature = "zk")] fn output_claim_constraint(&self) -> Option { - let n = self.polynomial_types.len(); - - let mut terms = Vec::with_capacity(2 * n); + let mut terms = Vec::with_capacity(2 * self.polynomial_types.len()); for (i, poly_type) in self.polynomial_types.iter().enumerate() { let opening = OpeningId::committed(*poly_type, SumcheckId::Booleanity); - terms.push(ProductTerm::scaled( ValueSource::Challenge(2 * i), vec![ValueSource::Opening(opening), ValueSource::Opening(opening)], @@ -134,22 +133,12 @@ impl SumcheckInstanceParams for BooleanitySumcheckParams { vec![ValueSource::Opening(opening)], )); } - Some(OutputClaimConstraint::sum_of_products(terms)) } #[cfg(feature = "zk")] fn output_constraint_challenge_values(&self, sumcheck_challenges: &[F::Challenge]) -> Vec { - let combined_r: Vec = self - .r_address - .iter() - .cloned() - .rev() - .chain(self.r_cycle.iter().cloned().rev()) - .collect(); - - let eq_eval: F = EqPolynomial::::mle(sumcheck_challenges, &combined_r); - + let eq_eval: F = EqPolynomial::::mle(sumcheck_challenges, &self.combined_r_big_endian()); let mut challenges = Vec::with_capacity(2 * self.polynomial_types.len()); for gamma_2i in &self.gamma_powers_square { let coeff = eq_eval * *gamma_2i; @@ -191,19 +180,9 @@ impl BooleanitySumcheckParams { // NOTE: `stage5_point.r` is stored in BIG_ENDIAN format (each segment was reversed by // `normalize_opening_point`). For internal eq evaluations we want LowToHigh (LE) order // because `GruenSplitEqPolynomial` is instantiated with `BindingOrder::LowToHigh`. - debug_assert!( - stage5_point.r.len() == log_k_instruction + log_t, - "InstructionReadRaf opening point length mismatch: got {}, expected {} (= log_k_instruction {} + log_t {})", - stage5_point.r.len(), - log_k_instruction + log_t, - log_k_instruction, - log_t - ); - // Address segment: BE -> LE let mut stage5_addr = stage5_point.r[..log_k_instruction].to_vec(); stage5_addr.reverse(); - // Cycle segment: BE -> LE let mut r_cycle = stage5_point.r[log_k_instruction..].to_vec(); r_cycle.reverse(); @@ -261,105 +240,93 @@ impl BooleanitySumcheckParams { one_hot_params: one_hot_params.clone(), } } + + fn combined_r_big_endian(&self) -> Vec { + self.r_address + .iter() + .cloned() + .rev() + .chain(self.r_cycle.iter().cloned().rev()) + .collect() + } +} + +fn compute_gamma_powers(gamma: F::Challenge, count: usize) -> (Vec, Vec) { + let gamma_f: F = gamma.into(); + let mut powers = Vec::with_capacity(count); + let mut powers_inv = Vec::with_capacity(count); + let mut rho_i = F::one(); + for _ in 0..count { + powers.push(rho_i); + powers_inv.push(rho_i.inverse().expect("gamma powers are nonzero")); + rho_i *= gamma_f; + } + (powers, powers_inv) } -/// Booleanity Sumcheck Prover. +/// Booleanity address-phase prover. #[derive(Allocative)] -pub struct BooleanitySumcheckProver { - /// Per-polynomial powers γ^i (in the base field). - /// Used to pre-scale the address eq tables for phase 2. - gamma_powers: Vec, - /// Per-polynomial inverse powers γ^{-i} (in the base field). - /// Used to unscale cached committed-polynomial openings. - gamma_powers_inv: Vec, +pub struct BooleanityAddressSumcheckProver { /// B: split-eq over address-chunk variables (phase 1, LowToHigh). B: GruenSplitEqPolynomial, - /// D: split-eq over time/cycle variables (phase 2, LowToHigh). - D: GruenSplitEqPolynomial, - /// G[i][k] = Σ_j eq(r_cycle, j) · ra_i(k, j) for all RA polynomials + /// G[i][k] = Σ_j eq(r_cycle, j) · ra_i(k, j) for all RA polynomials. G: Vec>, - /// Shared H polynomials for phase 2 (initialized at transition) - H: Option>, - /// F: Expanding table for phase 1 + /// F: Expanding table over address bits for phase 1. F: ExpandingTable, - /// eq(r_address, r_address) at end of phase 1 - eq_r_r: F, - /// RA indices (non-transposed, one per cycle) - ra_indices: Vec, - pub params: BooleanitySumcheckParams, + /// Most recent round polynomial, used to cache the address-phase output claim. + last_round_poly: Option>, + /// Output claim after the final address round (input claim for cycle phase). + address_claim: Option, + /// Shared booleanity parameters across both phases. + params: BooleanitySumcheckParams, + /// Address-only `SumcheckInstanceParams` wrapper. + address_params: BooleanityAddressPhaseParams, } -impl BooleanitySumcheckProver { - /// Initialize a BooleanitySumcheckProver with all three families. +impl BooleanityAddressSumcheckProver { + /// Initialize the address-phase prover. /// - /// All heavy computation is done here: - /// - Compute G polynomials and RA indices in a single pass over the trace - /// - Initialize split-eq polynomials for address (B) and cycle (D) variables - /// - Initialize expanding table for phase 1 - #[tracing::instrument(skip_all, name = "BooleanitySumcheckProver::initialize")] + /// Heavy precomputation for this phase happens here: + /// - Compute all G-polynomial slices from the trace + /// - Initialize the address split-eq polynomial (`B`) + /// - Initialize the address expanding table (`F`) pub fn initialize( params: BooleanitySumcheckParams, trace: &[Cycle], bytecode: &BytecodePreprocessing, memory_layout: &MemoryLayout, ) -> Self { - // Compute G and RA indices in a single pass over the trace - let (G, ra_indices) = compute_all_G_and_ra_indices::( + let G = compute_all_G::( trace, bytecode, memory_layout, ¶ms.one_hot_params, ¶ms.r_cycle, ); - - // Initialize split-eq polynomials for address and cycle variables let B = GruenSplitEqPolynomial::new(¶ms.r_address, BindingOrder::LowToHigh); - let D = GruenSplitEqPolynomial::new(¶ms.r_cycle, BindingOrder::LowToHigh); - - // Initialize expanding table for phase 1 let k_chunk = 1 << params.log_k_chunk; let mut F_table = ExpandingTable::new(k_chunk, BindingOrder::LowToHigh); F_table.reset(F::one()); - // Compute prover-only fields: gamma_powers (γ^i) and gamma_powers_inv (γ^{-i}) - let num_polys = params.polynomial_types.len(); - let gamma_f: F = params.gamma.into(); - let mut gamma_powers = Vec::with_capacity(num_polys); - let mut gamma_powers_inv = Vec::with_capacity(num_polys); - let mut rho_i = F::one(); - for _ in 0..num_polys { - gamma_powers.push(rho_i); - gamma_powers_inv.push( - rho_i - .inverse() - .expect("gamma_powers[i] is nonzero (gamma != 0)"), - ); - rho_i *= gamma_f; - } - Self { - gamma_powers, - gamma_powers_inv, B, - D, G, - ra_indices, - H: None, F: F_table, - eq_r_r: F::zero(), + last_round_poly: None, + address_claim: None, + address_params: BooleanityAddressPhaseParams::new(params.clone()), params, } } - fn compute_phase1_message(&self, round: usize, previous_claim: F) -> UniPoly { + fn compute_message_impl(&self, round: usize, previous_claim: F) -> UniPoly { let m = round + 1; - let B = &self.B; - let N = self.params.polynomial_types.len(); - - // Compute quadratic coefficients via generic split-eq fold - let quadratic_coeffs: [F; DEGREE_BOUND - 1] = B + let n = self.params.polynomial_types.len(); + // Compute quadratic coefficients via split-eq folding over the unbound address suffix. + let quadratic_coeffs: [F; DEGREE_BOUND - 1] = self + .B .par_fold_out_in_unreduced::<{ DEGREE_BOUND - 1 }>(&|k_prime| { - let coeffs = (0..N) + (0..n) .into_par_iter() .map(|i| { let G_i = &self.G[i]; @@ -370,7 +337,6 @@ impl BooleanitySumcheckProver { let k_m = k >> (m - 1); let F_k = self.F[k & ((1 << (m - 1)) - 1)]; let G_times_F = G_k * F_k; - let eval_infty = G_times_F * F_k; let eval_0 = if k_m == 0 { eval_infty - G_times_F @@ -402,29 +368,146 @@ impl BooleanitySumcheckProver { .reduce( || [F::zero(); DEGREE_BOUND - 1], |running, new| [running[0] + new[0], running[1] + new[1]], - ); - coeffs + ) }); - B.gruen_poly_deg_3(quadratic_coeffs[0], quadratic_coeffs[1], previous_claim) + self.B + .gruen_poly_deg_3(quadratic_coeffs[0], quadratic_coeffs[1], previous_claim) + } +} + +impl SumcheckInstanceProver + for BooleanityAddressSumcheckProver +{ + fn get_params(&self) -> &dyn SumcheckInstanceParams { + &self.address_params + } + + fn compute_message(&mut self, round: usize, previous_claim: F) -> UniPoly { + let poly = self.compute_message_impl(round, previous_claim); + self.last_round_poly = Some(poly.clone()); + poly } - fn compute_phase2_message(&self, _round: usize, previous_claim: F) -> UniPoly { - let D = &self.D; - let H = self.H.as_ref().expect("H should be initialized in phase 2"); - let num_polys = H.num_polys(); + fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize) { + if let Some(poly) = self.last_round_poly.take() { + let claim = poly.evaluate(&r_j); + if round == self.params.log_k_chunk - 1 { + self.address_claim = Some(claim); + } + } + self.B.bind(r_j); + self.F.update(r_j); + } - // Compute quadratic coefficients via generic split-eq fold (handles both E_in cases). - let quadratic_coeffs: [F; DEGREE_BOUND - 1] = D + fn cache_openings( + &self, + accumulator: &mut ProverOpeningAccumulator, + sumcheck_challenges: &[F::Challenge], + ) { + // Cache the intermediate address-phase claim used as input to cycle phase. + let mut r_address = sumcheck_challenges.to_vec(); + r_address.reverse(); + accumulator.append_virtual( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + OpeningPoint::::new(r_address), + self.address_claim + .expect("Booleanity address-phase claim missing"), + ); + } + + #[cfg(feature = "allocative")] + fn update_flamegraph(&self, flamegraph: &mut FlameGraphBuilder) { + flamegraph.visit_root(self); + } +} + +/// Booleanity cycle-phase prover. +#[derive(Allocative)] +pub struct BooleanityCycleSumcheckProver { + /// D: split-eq over cycle variables (phase 2, LowToHigh). + D: GruenSplitEqPolynomial, + /// Shared RA polynomials, pre-scaled for batched cycle-phase accumulation. + H: SharedRaPolynomials, + /// eq(r_address, r_address), carried from address-phase binding. + eq_r_r: F, + /// Per-polynomial powers γ^i used for pre-scaling. + gamma_powers: Vec, + /// Per-polynomial inverse powers γ^{-i} used to unscale cached openings. + gamma_powers_inv: Vec, + /// Shared booleanity parameters across both phases. + params: BooleanitySumcheckParams, + /// Cycle-only `SumcheckInstanceParams` wrapper. + cycle_params: BooleanityCyclePhaseParams, +} + +impl BooleanityCycleSumcheckProver { + /// Initialize cycle-phase state from the cached address-phase opening. + pub fn initialize( + params: BooleanitySumcheckParams, + trace: &[Cycle], + bytecode: &BytecodePreprocessing, + memory_layout: &MemoryLayout, + accumulator: &ProverOpeningAccumulator, + ) -> Self { + let (r_address_point, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + ); + let mut r_address_low_to_high = r_address_point.r; + r_address_low_to_high.reverse(); + let cycle_params = + BooleanityCyclePhaseParams::new(params.clone(), r_address_low_to_high.clone()); + + let mut B = GruenSplitEqPolynomial::new(¶ms.r_address, BindingOrder::LowToHigh); + for r_j in r_address_low_to_high.iter().copied() { + B.bind(r_j); + } + let eq_r_r = B.get_current_scalar(); + + let k_chunk = 1 << params.log_k_chunk; + let mut F_table = ExpandingTable::new(k_chunk, BindingOrder::LowToHigh); + F_table.reset(F::one()); + for r_j in r_address_low_to_high.iter().copied() { + F_table.update(r_j); + } + let base_eq = F_table.clone_values(); + + let ra_indices = compute_ra_indices(trace, bytecode, memory_layout, ¶ms.one_hot_params); + let num_polys = params.polynomial_types.len(); + let (gamma_powers, gamma_powers_inv) = compute_gamma_powers(params.gamma, num_polys); + let tables: Vec> = (0..num_polys) + .into_par_iter() + .map(|i| { + let rho = gamma_powers[i]; + base_eq.iter().map(|v| rho * *v).collect() + }) + .collect(); + + Self { + D: GruenSplitEqPolynomial::new(¶ms.r_cycle, BindingOrder::LowToHigh), + H: SharedRaPolynomials::new(tables, ra_indices, params.one_hot_params.clone()), + eq_r_r, + gamma_powers, + gamma_powers_inv, + cycle_params, + params, + } + } + + fn compute_message_impl(&self, previous_claim: F) -> UniPoly { + let num_polys = self.H.num_polys(); + let quadratic_coeffs: [F; DEGREE_BOUND - 1] = self + .D .par_fold_out_in_unreduced::<{ DEGREE_BOUND - 1 }>(&|j_prime| { - // Accumulate in unreduced form to minimize per-term reductions + // Accumulate in unreduced form to minimize per-term reductions. let mut acc_c = F::UnreducedProductAccum::zero(); let mut acc_e = F::UnreducedProductAccum::zero(); for i in 0..num_polys { - let h_0 = H.get_bound_coeff(i, 2 * j_prime); - let h_1 = H.get_bound_coeff(i, 2 * j_prime + 1); + let h_0 = self.H.get_bound_coeff(i, 2 * j_prime); + let h_1 = self.H.get_bound_coeff(i, 2 * j_prime + 1); let b = h_1 - h_0; - // Phase-2 optimization: H is pre-scaled by rho_i = gamma^i, so gamma^{2i} // factors are already accounted for: // gamma^{2i}*h0*(h0-1) = (rho*h0) * (rho*h0 - rho) @@ -438,76 +521,29 @@ impl BooleanitySumcheckProver { F::reduce_product_accum(acc_e), ] }); - // previous_claim is s(0)+s(1) of the scaled polynomial; divide out eq_r_r to get inner claim let adjusted_claim = previous_claim * self.eq_r_r.inverse().unwrap(); let gruen_poly = - D.gruen_poly_deg_3(quadratic_coeffs[0], quadratic_coeffs[1], adjusted_claim); - + self.D + .gruen_poly_deg_3(quadratic_coeffs[0], quadratic_coeffs[1], adjusted_claim); gruen_poly * self.eq_r_r } } -impl SumcheckInstanceProver for BooleanitySumcheckProver { +impl SumcheckInstanceProver + for BooleanityCycleSumcheckProver +{ fn get_params(&self) -> &dyn SumcheckInstanceParams { - &self.params + &self.cycle_params } - #[tracing::instrument(skip_all, name = "BooleanitySumcheckProver::compute_message")] - fn compute_message(&mut self, round: usize, previous_claim: F) -> UniPoly { - if round < self.params.log_k_chunk { - self.compute_phase1_message(round, previous_claim) - } else { - self.compute_phase2_message(round, previous_claim) - } + fn compute_message(&mut self, _round: usize, previous_claim: F) -> UniPoly { + self.compute_message_impl(previous_claim) } - #[tracing::instrument(skip_all, name = "BooleanitySumcheckProver::ingest_challenge")] - fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize) { - if round < self.params.log_k_chunk { - // Phase 1: Bind B and update F - self.B.bind(r_j); - self.F.update(r_j); - - // Transition to phase 2 - if round == self.params.log_k_chunk - 1 { - self.eq_r_r = self.B.get_current_scalar(); - - // Initialize SharedRaPolynomials with per-poly pre-scaled eq tables (by rho_i) - let F_table = std::mem::take(&mut self.F); - let ra_indices = std::mem::take(&mut self.ra_indices); - let base_eq = F_table.clone_values(); - let num_polys = self.params.polynomial_types.len(); - debug_assert!( - num_polys == self.gamma_powers.len(), - "gamma_powers length mismatch: got {}, expected {}", - self.gamma_powers.len(), - num_polys - ); - let tables: Vec> = (0..num_polys) - .into_par_iter() - .map(|i| { - let rho = self.gamma_powers[i]; - base_eq.iter().map(|v| rho * *v).collect() - }) - .collect(); - self.H = Some(SharedRaPolynomials::new( - tables, - ra_indices, - self.params.one_hot_params.clone(), - )); - - // Drop G arrays - let g = std::mem::take(&mut self.G); - drop_in_background_thread(g); - } - } else { - // Phase 2: Bind D and H - self.D.bind(r_j); - if let Some(ref mut h) = self.H { - h.bind_in_place(r_j, BindingOrder::LowToHigh); - } - } + fn ingest_challenge(&mut self, r_j: F::Challenge, _round: usize) { + self.D.bind(r_j); + self.H.bind_in_place(r_j, BindingOrder::LowToHigh); } fn cache_openings( @@ -515,15 +551,12 @@ impl SumcheckInstanceProver for BooleanitySum accumulator: &mut ProverOpeningAccumulator, sumcheck_challenges: &[F::Challenge], ) { - let opening_point = self.params.normalize_opening_point(sumcheck_challenges); - let H = self.H.as_ref().expect("H should be initialized"); + let full_challenges = self.cycle_params.full_challenges(sumcheck_challenges); + let opening_point = self.params.normalize_opening_point(&full_challenges); // H is scaled by rho_i; unscale so cached openings match the committed polynomials. - let claims: Vec = (0..H.num_polys()) - .map(|i| H.final_sumcheck_claim(i) * self.gamma_powers_inv[i]) + let claims: Vec = (0..self.H.num_polys()) + .map(|i| self.H.final_sumcheck_claim(i) * self.gamma_powers_inv[i]) .collect(); - - // All polynomials share the same opening point (r_address, r_cycle) - // Use a single SumcheckId for all accumulator.append_sparse( self.params.polynomial_types.clone(), SumcheckId::Booleanity, @@ -539,25 +572,85 @@ impl SumcheckInstanceProver for BooleanitySum } } -/// Booleanity Sumcheck Verifier. -pub struct BooleanitySumcheckVerifier { +/// Booleanity address-phase verifier. +pub struct BooleanityAddressSumcheckVerifier { params: BooleanitySumcheckParams, + address_params: BooleanityAddressPhaseParams, } -impl BooleanitySumcheckVerifier { +impl BooleanityAddressSumcheckVerifier { pub fn new(params: BooleanitySumcheckParams) -> Self { - Self { params } + Self { + address_params: BooleanityAddressPhaseParams::new(params.clone()), + params, + } + } + + pub fn into_params(self) -> BooleanitySumcheckParams { + self.params } } impl> - SumcheckInstanceVerifier for BooleanitySumcheckVerifier + SumcheckInstanceVerifier for BooleanityAddressSumcheckVerifier { fn get_params(&self) -> &dyn SumcheckInstanceParams { - &self.params + &self.address_params + } + + fn expected_output_claim(&self, accumulator: &A, _sumcheck_challenges: &[F::Challenge]) -> F { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + ) + .1 + } + + fn cache_openings(&self, accumulator: &mut A, sumcheck_challenges: &[F::Challenge]) { + let mut r_address = sumcheck_challenges.to_vec(); + r_address.reverse(); + accumulator.append_virtual( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + OpeningPoint::::new(r_address), + ); + } +} + +/// Booleanity cycle-phase verifier. +pub struct BooleanityCycleSumcheckVerifier { + params: BooleanitySumcheckParams, + cycle_params: BooleanityCyclePhaseParams, +} + +impl BooleanityCycleSumcheckVerifier { + pub fn new( + params: BooleanitySumcheckParams, + opening_accumulator: &dyn OpeningAccumulator, + ) -> Self { + let (r_address_point, _) = opening_accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + ); + let mut r_address_low_to_high = r_address_point.r; + r_address_low_to_high.reverse(); + Self { + cycle_params: BooleanityCyclePhaseParams::new(params.clone(), r_address_low_to_high), + params, + } + } +} + +impl> + SumcheckInstanceVerifier for BooleanityCycleSumcheckVerifier +{ + fn get_params(&self) -> &dyn SumcheckInstanceParams { + &self.cycle_params } fn expected_output_claim(&self, accumulator: &A, sumcheck_challenges: &[F::Challenge]) -> F { + let full_challenges = self.cycle_params.full_challenges(sumcheck_challenges); let ra_claims: Vec = self .params .polynomial_types @@ -568,24 +661,15 @@ impl> .1 }) .collect(); - - let combined_r: Vec = self - .params - .r_address - .iter() - .cloned() - .rev() - .chain(self.params.r_cycle.iter().cloned().rev()) - .collect(); - - EqPolynomial::::mle(sumcheck_challenges, &combined_r) + EqPolynomial::::mle(&full_challenges, &self.params.combined_r_big_endian()) * zip(&self.params.gamma_powers_square, ra_claims) .map(|(gamma_2i, ra)| (ra.square() - ra) * gamma_2i) .sum::() } fn cache_openings(&self, accumulator: &mut A, sumcheck_challenges: &[F::Challenge]) { - let opening_point = self.params.normalize_opening_point(sumcheck_challenges); + let full_challenges = self.cycle_params.full_challenges(sumcheck_challenges); + let opening_point = self.params.normalize_opening_point(&full_challenges); accumulator.append_sparse( self.params.polynomial_types.clone(), SumcheckId::Booleanity, @@ -593,3 +677,144 @@ impl> ); } } + +#[derive(Allocative, Clone)] +struct BooleanityAddressPhaseParams { + inner: BooleanitySumcheckParams, +} + +impl BooleanityAddressPhaseParams { + fn new(inner: BooleanitySumcheckParams) -> Self { + Self { inner } + } +} + +impl SumcheckInstanceParams for BooleanityAddressPhaseParams { + fn degree(&self) -> usize { + as SumcheckInstanceParams>::degree(&self.inner) + } + + fn num_rounds(&self) -> usize { + self.inner.log_k_chunk + } + + fn input_claim(&self, accumulator: &dyn OpeningAccumulator) -> F { + as SumcheckInstanceParams>::input_claim( + &self.inner, + accumulator, + ) + } + + fn normalize_opening_point(&self, challenges: &[F::Challenge]) -> OpeningPoint { + let mut r = challenges.to_vec(); + r.reverse(); + OpeningPoint::new(r) + } + + #[cfg(feature = "zk")] + fn input_claim_constraint(&self) -> InputClaimConstraint { + as SumcheckInstanceParams>::input_claim_constraint( + &self.inner, + ) + } + + #[cfg(feature = "zk")] + fn input_constraint_challenge_values(&self, accumulator: &dyn OpeningAccumulator) -> Vec { + as SumcheckInstanceParams>::input_constraint_challenge_values( + &self.inner, + accumulator, + ) + } + + #[cfg(feature = "zk")] + fn output_claim_constraint(&self) -> Option { + Some(OutputClaimConstraint::direct(OpeningId::virt( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + ))) + } + + #[cfg(feature = "zk")] + fn output_constraint_challenge_values(&self, _sumcheck_challenges: &[F::Challenge]) -> Vec { + Vec::new() + } +} +#[derive(Allocative, Clone)] +struct BooleanityCyclePhaseParams { + inner: BooleanitySumcheckParams, + r_address_low_to_high: Vec, +} + +impl BooleanityCyclePhaseParams { + fn new(inner: BooleanitySumcheckParams, r_address_low_to_high: Vec) -> Self { + Self { + inner, + r_address_low_to_high, + } + } + + fn full_challenges(&self, cycle_challenges: &[F::Challenge]) -> Vec { + let mut full = self.r_address_low_to_high.clone(); + full.extend_from_slice(cycle_challenges); + full + } +} + +impl SumcheckInstanceParams for BooleanityCyclePhaseParams { + fn degree(&self) -> usize { + as SumcheckInstanceParams>::degree(&self.inner) + } + + fn num_rounds(&self) -> usize { + self.inner.log_t + } + + fn input_claim(&self, accumulator: &dyn OpeningAccumulator) -> F { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + ) + .1 + } + + fn normalize_opening_point(&self, challenges: &[F::Challenge]) -> OpeningPoint { + let full = self.full_challenges(challenges); + as SumcheckInstanceParams>::normalize_opening_point( + &self.inner, + &full, + ) + } + + #[cfg(feature = "zk")] + fn input_claim_constraint(&self) -> InputClaimConstraint { + InputClaimConstraint::direct(OpeningId::virt( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + )) + } + + #[cfg(feature = "zk")] + fn input_constraint_challenge_values( + &self, + _accumulator: &dyn OpeningAccumulator, + ) -> Vec { + Vec::new() + } + + #[cfg(feature = "zk")] + fn output_claim_constraint(&self) -> Option { + as SumcheckInstanceParams>::output_claim_constraint( + &self.inner, + ) + } + + #[cfg(feature = "zk")] + fn output_constraint_challenge_values(&self, sumcheck_challenges: &[F::Challenge]) -> Vec { + let full = self.full_challenges(sumcheck_challenges); + as SumcheckInstanceParams>::output_constraint_challenge_values( + &self.inner, + &full, + ) + } +} diff --git a/jolt-core/src/subprotocols/mod.rs b/jolt-core/src/subprotocols/mod.rs index b0c476e4d..ecd470854 100644 --- a/jolt-core/src/subprotocols/mod.rs +++ b/jolt-core/src/subprotocols/mod.rs @@ -10,7 +10,3 @@ pub mod sumcheck_claim; pub mod sumcheck_prover; pub mod sumcheck_verifier; pub mod univariate_skip; - -pub use booleanity::{ - BooleanitySumcheckParams, BooleanitySumcheckProver, BooleanitySumcheckVerifier, -}; diff --git a/jolt-core/src/utils/errors.rs b/jolt-core/src/utils/errors.rs index 2d9d6640d..c6f2d8b01 100644 --- a/jolt-core/src/utils/errors.rs +++ b/jolt-core/src/utils/errors.rs @@ -28,6 +28,8 @@ pub enum ProofVerifyError { InvalidReadWriteConfig(String), #[error("Invalid one-hot configuration: {0}")] InvalidOneHotConfig(String), + #[error("Invalid bytecode commitment configuration: {0}")] + InvalidBytecodeConfig(String), #[error("Invalid ram_K: got {0}, minimum required {1}")] InvalidRamK(usize, usize), #[error("Invalid trace_length: got {0}, max allowed {1}")] @@ -44,4 +46,6 @@ pub enum ProofVerifyError { ZkFeatureRequired, #[error("BlindFold verification failed: {0}")] BlindFoldError(String), + #[error("Bytecode type mismatch: {0}")] + BytecodeTypeMismatch(String), } diff --git a/jolt-core/src/zkvm/bytecode/chunks.rs b/jolt-core/src/zkvm/bytecode/chunks.rs new file mode 100644 index 000000000..47f820bd5 --- /dev/null +++ b/jolt-core/src/zkvm/bytecode/chunks.rs @@ -0,0 +1,203 @@ +use crate::field::JoltField; +use crate::poly::commitment::dory::DoryGlobals; +use crate::poly::multilinear_polynomial::MultilinearPolynomial; +use crate::utils::thread::unsafe_allocate_zero_vec; +use crate::zkvm::instruction::{ + Flags, InstructionLookup, InterleavedBitsMarker, NUM_CIRCUIT_FLAGS, NUM_INSTRUCTION_FLAGS, +}; +use crate::zkvm::lookup_table::LookupTables; +use common::constants::{REGISTER_COUNT, XLEN}; +use tracer::instruction::Instruction; + +/// Total number of lanes encoded by committed-bytecode rows. +pub const fn total_lanes() -> usize { + 3 * (REGISTER_COUNT as usize) + + 2 + + NUM_CIRCUIT_FLAGS + + NUM_INSTRUCTION_FLAGS + + as strum::EnumCount>::COUNT + + 1 +} + +/// Fixed lane capacity for committed bytecode rows. +pub const COMMITTED_BYTECODE_LANE_CAPACITY: usize = total_lanes().next_power_of_two(); + +#[inline(always)] +pub const fn committed_lanes() -> usize { + COMMITTED_BYTECODE_LANE_CAPACITY +} + +pub const DEFAULT_COMMITTED_BYTECODE_CHUNK_COUNT: usize = 1; + +#[inline] +pub fn validate_committed_bytecode_chunk_count(chunk_count: usize) { + assert!(chunk_count > 0, "bytecode chunk count must be non-zero"); + assert!( + chunk_count.is_power_of_two(), + "bytecode chunk count must be a power of two" + ); +} + +#[inline(always)] +pub fn validate_committed_bytecode_chunking_for_len(bytecode_len: usize, chunk_count: usize) { + validate_committed_bytecode_chunk_count(chunk_count); + assert!( + bytecode_len.is_multiple_of(chunk_count), + "bytecode length ({bytecode_len}) must be divisible by chunk count ({chunk_count})" + ); +} + +#[inline(always)] +pub fn committed_bytecode_chunk_cycle_len(bytecode_len: usize, chunk_count: usize) -> usize { + validate_committed_bytecode_chunking_for_len(bytecode_len, chunk_count); + bytecode_len / chunk_count +} + +#[derive(Clone, Copy, Debug)] +pub struct BytecodeLaneLayout { + pub rs1_start: usize, + pub rs2_start: usize, + pub rd_start: usize, + pub unexp_pc_idx: usize, + pub imm_idx: usize, + pub circuit_start: usize, + pub instr_start: usize, + pub lookup_start: usize, + pub raf_flag_idx: usize, +} + +impl BytecodeLaneLayout { + pub const fn new() -> Self { + let reg_count = REGISTER_COUNT as usize; + let rs1_start = 0usize; + let rs2_start = rs1_start + reg_count; + let rd_start = rs2_start + reg_count; + let unexp_pc_idx = rd_start + reg_count; + let imm_idx = unexp_pc_idx + 1; + let circuit_start = imm_idx + 1; + let instr_start = circuit_start + NUM_CIRCUIT_FLAGS; + let lookup_start = instr_start + NUM_INSTRUCTION_FLAGS; + let raf_flag_idx = lookup_start + as strum::EnumCount>::COUNT; + Self { + rs1_start, + rs2_start, + rd_start, + unexp_pc_idx, + imm_idx, + circuit_start, + instr_start, + lookup_start, + raf_flag_idx, + } + } +} + +pub const BYTECODE_LANE_LAYOUT: BytecodeLaneLayout = BytecodeLaneLayout::new(); + +#[derive(Clone, Copy, Debug)] +pub enum ActiveLaneValue { + One, + Scalar(F), +} + +#[inline(always)] +pub fn for_each_active_lane_value( + instr: &Instruction, + mut visit: impl FnMut(usize, ActiveLaneValue), +) { + let l = BYTECODE_LANE_LAYOUT; + + let normalized = instr.normalize(); + let circuit_flags = ::circuit_flags(instr); + let instr_flags = ::instruction_flags(instr); + let lookup_idx = >::lookup_table(instr) + .map(|t| LookupTables::::enum_index(&t)); + let raf_flag = !InterleavedBitsMarker::is_interleaved_operands(&circuit_flags); + + if let Some(r) = normalized.operands.rs1 { + visit(l.rs1_start + (r as usize), ActiveLaneValue::One); + } + if let Some(r) = normalized.operands.rs2 { + visit(l.rs2_start + (r as usize), ActiveLaneValue::One); + } + if let Some(r) = normalized.operands.rd { + visit(l.rd_start + (r as usize), ActiveLaneValue::One); + } + + let unexpanded_pc = F::from_u64(normalized.address as u64); + if !unexpanded_pc.is_zero() { + visit(l.unexp_pc_idx, ActiveLaneValue::Scalar(unexpanded_pc)); + } + let imm = F::from_i128(normalized.operands.imm); + if !imm.is_zero() { + visit(l.imm_idx, ActiveLaneValue::Scalar(imm)); + } + + for i in 0..NUM_CIRCUIT_FLAGS { + if circuit_flags[i] { + visit(l.circuit_start + i, ActiveLaneValue::One); + } + } + for i in 0..NUM_INSTRUCTION_FLAGS { + if instr_flags[i] { + visit(l.instr_start + i, ActiveLaneValue::One); + } + } + if let Some(t) = lookup_idx { + visit(l.lookup_start + t, ActiveLaneValue::One); + } + if raf_flag { + visit(l.raf_flag_idx, ActiveLaneValue::One); + } +} + +#[tracing::instrument(skip_all, name = "bytecode::build_committed_bytecode_chunk_coeffs")] +pub fn build_committed_bytecode_chunk_coeffs( + instructions: &[Instruction], + chunk_count: usize, +) -> Vec> { + let bytecode_len = instructions.len(); + validate_committed_bytecode_chunking_for_len(bytecode_len, chunk_count); + + let chunk_cycle_len = committed_bytecode_chunk_cycle_len(bytecode_len, chunk_count); + let lane_capacity = committed_lanes(); + let mut chunk_coeffs: Vec> = (0..chunk_count) + .map(|_| unsafe_allocate_zero_vec(lane_capacity * chunk_cycle_len)) + .collect(); + + for (cycle, instr) in instructions.iter().enumerate() { + let cycle_chunk_idx = cycle / chunk_cycle_len; + let chunk_cycle = cycle % chunk_cycle_len; + let coeffs = &mut chunk_coeffs[cycle_chunk_idx]; + + for_each_active_lane_value::(instr, |global_lane, lane_val| { + let idx = DoryGlobals::get_layout().address_cycle_to_index( + global_lane, + chunk_cycle, + lane_capacity, + chunk_cycle_len, + ); + let lane_value = match lane_val { + ActiveLaneValue::One => F::one(), + ActiveLaneValue::Scalar(v) => v, + }; + coeffs[idx] += lane_value; + }); + } + + chunk_coeffs +} + +#[tracing::instrument( + skip_all, + name = "bytecode::build_committed_bytecode_chunk_polynomials" +)] +pub fn build_committed_bytecode_chunk_polynomials( + instructions: &[Instruction], + chunk_count: usize, +) -> Vec> { + build_committed_bytecode_chunk_coeffs::(instructions, chunk_count) + .into_iter() + .map(MultilinearPolynomial::from) + .collect() +} diff --git a/jolt-core/src/zkvm/bytecode/mod.rs b/jolt-core/src/zkvm/bytecode/mod.rs index 31eeb4078..7c302bb1e 100644 --- a/jolt-core/src/zkvm/bytecode/mod.rs +++ b/jolt-core/src/zkvm/bytecode/mod.rs @@ -1,10 +1,83 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use common::constants::{ALIGNMENT_FACTOR_BYTECODE, RAM_START_ADDRESS}; +use rayon::prelude::*; use thiserror::Error; use tracer::instruction::{Cycle, Instruction}; +use crate::poly::commitment::commitment_scheme::CommitmentScheme; +use crate::poly::commitment::dory::{DoryContext, DoryGlobals}; +use crate::utils::math::Math; +use crate::zkvm::bytecode::chunks::{ + build_committed_bytecode_chunk_polynomials, committed_bytecode_chunk_cycle_len, + committed_lanes, validate_committed_bytecode_chunking_for_len, +}; + +pub mod chunks; pub mod read_raf_checking; +#[derive(Clone, Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)] +pub struct TrustedBytecodeCommitments { + pub commitments: Vec, + pub num_columns: usize, + pub log_k_chunk: u8, + pub bytecode_chunk_count: usize, + pub bytecode_len: usize, + pub bytecode_T: usize, +} + +#[derive(Clone, Debug)] +pub struct TrustedBytecodeHints { + pub hints: Vec, +} + +impl TrustedBytecodeCommitments { + #[tracing::instrument(skip_all, name = "TrustedBytecodeCommitments::derive")] + pub fn derive( + bytecode: &BytecodePreprocessing, + generators: &PCS::ProverSetup, + log_k_chunk: usize, + bytecode_chunk_count: usize, + ) -> (Self, TrustedBytecodeHints) { + let bytecode_len = bytecode.code_size; + validate_committed_bytecode_chunking_for_len(bytecode_len, bytecode_chunk_count); + let bytecode_T = committed_bytecode_chunk_cycle_len(bytecode_len, bytecode_chunk_count); + + let total_vars = bytecode_T.log_2() + committed_lanes().log_2(); + let (bytecode_sigma, _) = DoryGlobals::balanced_sigma_nu(total_vars); + let num_columns = 1usize << bytecode_sigma; + + let bytecode_chunk_polys = build_committed_bytecode_chunk_polynomials::( + &bytecode.bytecode, + bytecode_chunk_count, + ); + let _bytecode_guard = DoryGlobals::initialize_context( + committed_lanes(), + bytecode_T, + DoryContext::UntrustedAdvice, + None, + ); + let (commitments, hints): (Vec<_>, Vec<_>) = bytecode_chunk_polys + .par_iter() + .map(|poly| { + let _ctx = DoryGlobals::with_context(DoryContext::UntrustedAdvice); + PCS::commit(poly, generators) + }) + .unzip(); + + ( + Self { + commitments, + num_columns, + log_k_chunk: log_k_chunk as u8, + bytecode_chunk_count, + bytecode_len, + bytecode_T, + }, + TrustedBytecodeHints { hints }, + ) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Error)] pub enum PreprocessingError { #[error( @@ -83,13 +156,17 @@ impl BytecodePCMapper { let mut indices: Vec> = { // For read-raf tests we simulate bytecode being empty #[cfg(test)] - if bytecode.len() == 1 { - vec![None; 1] - } else { - vec![None; Self::get_index(bytecode.last().unwrap().normalize().address) + 1] + { + if bytecode.len() == 1 { + vec![None; 1] + } else { + vec![None; Self::get_index(bytecode.last().unwrap().normalize().address) + 1] + } } #[cfg(not(test))] - vec![None; Self::get_index(bytecode.last().unwrap().normalize().address) + 1] + { + vec![None; Self::get_index(bytecode.last().unwrap().normalize().address) + 1] + } }; let mut last_pc = 0; // Push the initial noop instruction diff --git a/jolt-core/src/zkvm/bytecode/read_raf_checking.rs b/jolt-core/src/zkvm/bytecode/read_raf_checking.rs index 6422a7b55..2ecbe4937 100644 --- a/jolt-core/src/zkvm/bytecode/read_raf_checking.rs +++ b/jolt-core/src/zkvm/bytecode/read_raf_checking.rs @@ -10,6 +10,7 @@ use crate::subprotocols::blindfold::{ }; use crate::{ field::JoltField, + poly::commitment::commitment_scheme::CommitmentScheme, poly::{ eq_poly::EqPolynomial, identity_poly::IdentityPolynomial, @@ -29,16 +30,20 @@ use crate::{ sumcheck_prover::SumcheckInstanceProver, sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}, }, - transcripts::Transcript, - utils::{math::Math, small_scalar::SmallScalar, thread::unsafe_allocate_zero_vec}, + transcripts::{KeccakTranscript, Transcript}, + utils::{ + errors::ProofVerifyError, math::Math, small_scalar::SmallScalar, + thread::unsafe_allocate_zero_vec, + }, zkvm::{ bytecode::BytecodePreprocessing, - config::OneHotParams, + config::{OneHotParams, ProgramMode}, instruction::{ CircuitFlags, Flags, InstructionFlags, InstructionLookup, InterleavedBitsMarker, NUM_CIRCUIT_FLAGS, }, lookup_table::{LookupTables, NUM_LOOKUP_TABLES}, + program::ProgramPreprocessing, witness::{CommittedPolynomial, VirtualPolynomial}, }, }; @@ -357,6 +362,14 @@ impl BytecodeReadRafSumcheckProver { fn init_log_t_rounds(&mut self) { let int_poly = self.params.int_poly.final_sumcheck_claim(); + let staged_val_claims: [F; N_STAGES] = self + .params + .val_polys + .iter() + .map(MultilinearPolynomial::final_sumcheck_claim) + .collect::>() + .try_into() + .unwrap(); // We have a separate Val polynomial for each stage // Additionally, for stages 1 and 3 we have an Int polynomial for RAF @@ -385,7 +398,7 @@ impl BytecodeReadRafSumcheckProver { .unwrap(); self.bound_val_evals = Some(bound_val_evals); self.params.bound_val_polys = Some(bound_val_evals); - self.params.bound_int_poly = Some(int_poly); + self.params.staged_val_claims = Some(staged_val_claims); let bound_f_entry = self.f_entry_expected.final_sumcheck_claim(); self.bound_f_entry = Some(bound_f_entry); @@ -697,23 +710,233 @@ impl SumcheckInstanceProver } } +#[derive(Allocative)] +pub struct BytecodeReadRafAddressSumcheckProver { + inner: BytecodeReadRafSumcheckProver, + address_params: BytecodeReadRafAddressPhaseParams, +} + +impl BytecodeReadRafAddressSumcheckProver { + pub fn initialize( + params: BytecodeReadRafSumcheckParams, + trace: Arc>, + bytecode_preprocessing: Arc, + ) -> Self { + let address_params = BytecodeReadRafAddressPhaseParams::new(params.clone()); + Self { + inner: BytecodeReadRafSumcheckProver::initialize(params, trace, bytecode_preprocessing), + address_params, + } + } +} + +impl SumcheckInstanceProver + for BytecodeReadRafAddressSumcheckProver +{ + fn get_params(&self) -> &dyn SumcheckInstanceParams { + &self.address_params + } + + fn degree(&self) -> usize { + self.inner.params.degree() + } + + fn num_rounds(&self) -> usize { + self.inner.params.log_K + } + + fn input_claim(&self, accumulator: &ProverOpeningAccumulator) -> F { + self.inner.params.input_claim(accumulator) + } + + fn compute_message(&mut self, round: usize, previous_claim: F) -> UniPoly { + as SumcheckInstanceProver>::compute_message( + &mut self.inner, + round, + previous_claim, + ) + } + + fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize) { + as SumcheckInstanceProver>::ingest_challenge( + &mut self.inner, + r_j, + round, + ); + } + + fn cache_openings( + &self, + accumulator: &mut ProverOpeningAccumulator, + sumcheck_challenges: &[F::Challenge], + ) { + let mut r_address = sumcheck_challenges.to_vec(); + r_address.reverse(); + let opening_point = OpeningPoint::::new(r_address); + let address_claim: F = self + .inner + .prev_round_claims + .iter() + .zip(self.inner.params.gamma_powers.iter()) + .take(N_STAGES) + .map(|(claim, gamma)| *claim * *gamma) + .sum::() + + self.inner.params.entry_gamma * self.inner.prev_entry_claim; + accumulator.append_virtual( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + opening_point.clone(), + address_claim, + ); + if self.inner.params.use_staged_val_claims { + let staged_val_claims = self + .inner + .params + .staged_val_claims + .as_ref() + .expect("staged val claims must be present in committed mode"); + for stage in 0..N_STAGES { + accumulator.append_virtual( + VirtualPolynomial::BytecodeValStage(stage), + SumcheckId::BytecodeReadRafAddressPhase, + opening_point.clone(), + staged_val_claims[stage], + ); + } + } + } + + #[cfg(feature = "allocative")] + fn update_flamegraph(&self, flamegraph: &mut FlameGraphBuilder) { + flamegraph.visit_root(self); + } +} + +#[derive(Allocative)] +pub struct BytecodeReadRafCycleSumcheckProver { + inner: BytecodeReadRafSumcheckProver, + cycle_params: BytecodeReadRafCyclePhaseParams, +} + +impl BytecodeReadRafCycleSumcheckProver { + #[tracing::instrument(skip_all, name = "BytecodeReadRafCycleSumcheckProver::initialize")] + pub fn initialize( + params: BytecodeReadRafSumcheckParams, + trace: Arc>, + bytecode_preprocessing: Arc, + accumulator: &ProverOpeningAccumulator, + ) -> Self { + let (r_address_point, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ); + let mut r_address_low_to_high = r_address_point.r; + r_address_low_to_high.reverse(); + let cycle_params = + BytecodeReadRafCyclePhaseParams::new(params.clone(), r_address_low_to_high.clone()); + + let mut inner = + BytecodeReadRafSumcheckProver::initialize(params, trace, bytecode_preprocessing); + for (round, r_j) in r_address_low_to_high.iter().cloned().enumerate() { + let _ = as SumcheckInstanceProver< + F, + KeccakTranscript, + >>::compute_message(&mut inner, round, F::zero()); + as SumcheckInstanceProver< + F, + KeccakTranscript, + >>::ingest_challenge(&mut inner, r_j, round); + } + + Self { + inner, + cycle_params, + } + } +} + +impl SumcheckInstanceProver + for BytecodeReadRafCycleSumcheckProver +{ + fn get_params(&self) -> &dyn SumcheckInstanceParams { + &self.cycle_params + } + + fn degree(&self) -> usize { + self.inner.params.degree() + } + + fn num_rounds(&self) -> usize { + self.inner.params.log_T + } + + fn input_claim(&self, accumulator: &ProverOpeningAccumulator) -> F { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ) + .1 + } + + fn compute_message(&mut self, round: usize, previous_claim: F) -> UniPoly { + let log_k = self.inner.params.log_K; + as SumcheckInstanceProver>::compute_message( + &mut self.inner, + round + log_k, + previous_claim, + ) + } + + fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize) { + let log_k = self.inner.params.log_K; + as SumcheckInstanceProver>::ingest_challenge( + &mut self.inner, + r_j, + round + log_k, + ); + } + + fn cache_openings( + &self, + accumulator: &mut ProverOpeningAccumulator, + sumcheck_challenges: &[F::Challenge], + ) { + let mut full_challenges = self.cycle_params.r_address_low_to_high.clone(); + full_challenges.extend_from_slice(sumcheck_challenges); + as SumcheckInstanceProver>::cache_openings( + &self.inner, + accumulator, + &full_challenges, + ); + } + + #[cfg(feature = "allocative")] + fn update_flamegraph(&self, flamegraph: &mut FlameGraphBuilder) { + flamegraph.visit_root(self); + } +} + pub struct BytecodeReadRafSumcheckVerifier { params: BytecodeReadRafSumcheckParams, } impl BytecodeReadRafSumcheckVerifier { - pub fn gen>( - bytecode_preprocessing: &BytecodePreprocessing, + pub fn gen( + program: &ProgramPreprocessing, n_cycle_vars: usize, one_hot_params: &OneHotParams, - opening_accumulator: &A, + use_staged_val_claims: bool, + opening_accumulator: &dyn OpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { Self { params: BytecodeReadRafSumcheckParams::gen( - bytecode_preprocessing, + Some(program), n_cycle_vars, one_hot_params, + use_staged_val_claims, + None, opening_accumulator, transcript, ), @@ -754,29 +977,43 @@ impl> // Stage 5: gamma^4 * (Val_5) // Which matches with the input claim: // rv_1 + gamma * rv_2 + gamma^2 * rv_3 + gamma^3 * rv_4 + gamma^4 * rv_5 + gamma^5 * raf_1 + gamma^6 * raf_3 + let stage_val_claim = |stage: usize| { + if self.params.use_staged_val_claims { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeValStage(stage), + SumcheckId::BytecodeReadRafAddressPhase, + ) + .1 + } else { + self.params.val_polys[stage].evaluate(&r_address_prime.r) + } + }; + // Always add RAF-int terms here. In committed mode, staged Stage 6a BytecodeValStage + // openings carry val-only claims and the RAF contribution is reconstructed at verification. + let int_poly_contrib_by_stage = [ + int_poly * self.params.gamma_powers[5], // RAF for Stage1 + F::zero(), // There's no raf for Stage2 + int_poly * self.params.gamma_powers[4], // RAF for Stage3 + F::zero(), // There's no raf for Stage4 + F::zero(), // There's no raf for Stage5 + ]; + let val = self .params - .val_polys + .r_cycles .iter() - .zip(&self.params.r_cycles) .zip(&self.params.gamma_powers) - .zip([ - int_poly * self.params.gamma_powers[5], // RAF for Stage1 - F::zero(), // There's no raf for Stage2 - int_poly * self.params.gamma_powers[4], // RAF for Stage3 - F::zero(), // There's no raf for Stage4 - F::zero(), // There's no raf for Stage5 - ]) - .map(|(((val, r_cycle), gamma), int_poly)| { - (val.evaluate(&r_address_prime.r) + int_poly) + .zip(int_poly_contrib_by_stage) + .enumerate() + .map(|(stage, ((r_cycle, gamma), int_poly))| { + (stage_val_claim(stage) + int_poly) * EqPolynomial::::mle(r_cycle, &r_cycle_prime.r) * gamma }) .sum::(); // Entry constraint: γ_entry · eq(r_addr, entry_bits) · eq_zero(r_cycle). - // r_address_prime.r is MSB-first (after normalize_opening_point reversal), - // so entry_bits must also be MSB-first: entry_bits[j] = (e >> (log_K-1-j)) & 1. let entry_f_at_r_addr = { let log_k = self.params.log_K; let e = self.params.entry_bytecode_index; @@ -818,8 +1055,332 @@ impl> } } +pub struct BytecodeReadRafAddressSumcheckVerifier { + params: BytecodeReadRafSumcheckParams, + address_params: BytecodeReadRafAddressPhaseParams, +} + +impl BytecodeReadRafAddressSumcheckVerifier { + pub fn new( + program: Option<&ProgramPreprocessing>, + n_cycle_vars: usize, + one_hot_params: &OneHotParams, + opening_accumulator: &dyn OpeningAccumulator, + transcript: &mut impl Transcript, + program_mode: ProgramMode, + entry_bytecode_index: usize, + ) -> Result { + let params = match program_mode { + ProgramMode::Committed => BytecodeReadRafSumcheckParams::gen( + None::<&ProgramPreprocessing>, + n_cycle_vars, + one_hot_params, + true, + Some(entry_bytecode_index), + opening_accumulator, + transcript, + ), + ProgramMode::Full => BytecodeReadRafSumcheckParams::gen( + Some(program.ok_or_else(|| { + ProofVerifyError::BytecodeTypeMismatch( + "expected Full program preprocessing, got Committed".to_string(), + ) + })?), + n_cycle_vars, + one_hot_params, + false, + None, + opening_accumulator, + transcript, + ), + }; + let address_params = BytecodeReadRafAddressPhaseParams::new(params.clone()); + Ok(Self { + params, + address_params, + }) + } + + pub fn into_params(self) -> BytecodeReadRafSumcheckParams { + self.params + } +} + +impl> + SumcheckInstanceVerifier for BytecodeReadRafAddressSumcheckVerifier +{ + fn get_params(&self) -> &dyn SumcheckInstanceParams { + &self.address_params + } + + fn degree(&self) -> usize { + self.params.degree() + } + + fn num_rounds(&self) -> usize { + self.params.log_K + } + + fn input_claim(&self, accumulator: &A) -> F { + self.params.input_claim(accumulator) + } + + fn expected_output_claim(&self, accumulator: &A, _sumcheck_challenges: &[F::Challenge]) -> F { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ) + .1 + } + + fn cache_openings(&self, accumulator: &mut A, sumcheck_challenges: &[F::Challenge]) { + let mut r_address = sumcheck_challenges.to_vec(); + r_address.reverse(); + accumulator.append_virtual( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + OpeningPoint::::new(r_address.clone()), + ); + if self.params.use_staged_val_claims { + for stage in 0..N_STAGES { + accumulator.append_virtual( + VirtualPolynomial::BytecodeValStage(stage), + SumcheckId::BytecodeReadRafAddressPhase, + OpeningPoint::::new(r_address.clone()), + ); + } + } + } +} + +pub struct BytecodeReadRafCycleSumcheckVerifier { + params: BytecodeReadRafSumcheckParams, + cycle_params: BytecodeReadRafCyclePhaseParams, +} + +impl BytecodeReadRafCycleSumcheckVerifier { + pub fn new( + params: BytecodeReadRafSumcheckParams, + opening_accumulator: &dyn OpeningAccumulator, + ) -> Self { + let (r_address_point, _) = opening_accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ); + let mut r_address_low_to_high = r_address_point.r; + r_address_low_to_high.reverse(); + let cycle_params = + BytecodeReadRafCyclePhaseParams::new(params.clone(), r_address_low_to_high); + Self { + params, + cycle_params, + } + } +} + +impl> + SumcheckInstanceVerifier for BytecodeReadRafCycleSumcheckVerifier +{ + fn get_params(&self) -> &dyn SumcheckInstanceParams { + &self.cycle_params + } + + fn degree(&self) -> usize { + self.params.degree() + } + + fn num_rounds(&self) -> usize { + self.params.log_T + } + + fn input_claim(&self, accumulator: &A) -> F { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ) + .1 + } + + fn expected_output_claim(&self, accumulator: &A, sumcheck_challenges: &[F::Challenge]) -> F { + let mut full_challenges = self.cycle_params.r_address_low_to_high.clone(); + full_challenges.extend_from_slice(sumcheck_challenges); + + let inner = BytecodeReadRafSumcheckVerifier { + params: self.params.clone(), + }; + as SumcheckInstanceVerifier>::expected_output_claim( + &inner, + accumulator, + &full_challenges, + ) + } + + fn cache_openings(&self, accumulator: &mut A, sumcheck_challenges: &[F::Challenge]) { + let mut full_challenges = self.cycle_params.r_address_low_to_high.clone(); + full_challenges.extend_from_slice(sumcheck_challenges); + + let inner = BytecodeReadRafSumcheckVerifier { + params: self.params.clone(), + }; + as SumcheckInstanceVerifier>::cache_openings( + &inner, + accumulator, + &full_challenges, + ); + } +} + +#[derive(Allocative, Clone)] +struct BytecodeReadRafCyclePhaseParams { + inner: BytecodeReadRafSumcheckParams, + r_address_low_to_high: Vec, +} + +impl BytecodeReadRafCyclePhaseParams { + fn new( + inner: BytecodeReadRafSumcheckParams, + r_address_low_to_high: Vec, + ) -> Self { + Self { + inner, + r_address_low_to_high, + } + } + + fn full_challenges(&self, cycle_challenges: &[F::Challenge]) -> Vec { + let mut full = self.r_address_low_to_high.clone(); + full.extend_from_slice(cycle_challenges); + full + } +} + +impl SumcheckInstanceParams for BytecodeReadRafCyclePhaseParams { + fn degree(&self) -> usize { + as SumcheckInstanceParams>::degree(&self.inner) + } + + fn num_rounds(&self) -> usize { + self.inner.log_T + } + + fn input_claim(&self, accumulator: &dyn OpeningAccumulator) -> F { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ) + .1 + } + + fn normalize_opening_point(&self, challenges: &[F::Challenge]) -> OpeningPoint { + let full = self.full_challenges(challenges); + as SumcheckInstanceParams>::normalize_opening_point( + &self.inner, + &full, + ) + } + + #[cfg(feature = "zk")] + fn input_claim_constraint(&self) -> InputClaimConstraint { + InputClaimConstraint::direct(OpeningId::virt( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + )) + } + + #[cfg(feature = "zk")] + fn input_constraint_challenge_values( + &self, + _accumulator: &dyn OpeningAccumulator, + ) -> Vec { + Vec::new() + } + + #[cfg(feature = "zk")] + fn output_claim_constraint(&self) -> Option { + as SumcheckInstanceParams>::output_claim_constraint( + &self.inner, + ) + } + + #[cfg(feature = "zk")] + fn output_constraint_challenge_values(&self, sumcheck_challenges: &[F::Challenge]) -> Vec { + let full = self.full_challenges(sumcheck_challenges); + as SumcheckInstanceParams>::output_constraint_challenge_values( + &self.inner, + &full, + ) + } +} + +#[derive(Allocative, Clone)] +struct BytecodeReadRafAddressPhaseParams { + inner: BytecodeReadRafSumcheckParams, +} + +impl BytecodeReadRafAddressPhaseParams { + fn new(inner: BytecodeReadRafSumcheckParams) -> Self { + Self { inner } + } +} + +impl SumcheckInstanceParams for BytecodeReadRafAddressPhaseParams { + fn degree(&self) -> usize { + as SumcheckInstanceParams>::degree(&self.inner) + } + + fn num_rounds(&self) -> usize { + self.inner.log_K + } + + fn input_claim(&self, accumulator: &dyn OpeningAccumulator) -> F { + as SumcheckInstanceParams>::input_claim( + &self.inner, + accumulator, + ) + } + + fn normalize_opening_point(&self, challenges: &[F::Challenge]) -> OpeningPoint { + let mut r = challenges.to_vec(); + r.reverse(); + OpeningPoint::new(r) + } + + #[cfg(feature = "zk")] + fn input_claim_constraint(&self) -> InputClaimConstraint { + as SumcheckInstanceParams>::input_claim_constraint( + &self.inner, + ) + } + + #[cfg(feature = "zk")] + fn input_constraint_challenge_values(&self, accumulator: &dyn OpeningAccumulator) -> Vec { + as SumcheckInstanceParams>::input_constraint_challenge_values( + &self.inner, + accumulator, + ) + } + + #[cfg(feature = "zk")] + fn output_claim_constraint(&self) -> Option { + Some(OutputClaimConstraint::direct(OpeningId::virt( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ))) + } + + #[cfg(feature = "zk")] + fn output_constraint_challenge_values(&self, _sumcheck_challenges: &[F::Challenge]) -> Vec { + Vec::new() + } +} + #[derive(Allocative, Clone)] pub struct BytecodeReadRafSumcheckParams { + /// Whether Stage 6a should stage per-stage Val claims for BytecodeClaimReduction. + pub use_staged_val_claims: bool, /// Index `i` stores `gamma^i`. pub gamma_powers: Vec, /// Stage-specific gamma powers for input_claim_constraint @@ -849,8 +1410,9 @@ pub struct BytecodeReadRafSumcheckParams { pub int_poly: IdentityPolynomial, pub r_cycles: [Vec; N_STAGES], /// Bound values after log_K rounds (set by prover for output_constraint_challenge_values) - pub bound_int_poly: Option, pub bound_val_polys: Option<[F; N_STAGES]>, + /// Val-only claims cached after Stage 6a address binding in committed mode. + pub staged_val_claims: Option<[F; N_STAGES]>, /// γ_entry = gamma_powers[7]. Weights the entry-point constraint term. pub entry_gamma: F, /// Bytecode table index of the ELF entry point. @@ -861,17 +1423,17 @@ pub struct BytecodeReadRafSumcheckParams { impl BytecodeReadRafSumcheckParams { #[tracing::instrument(skip_all, name = "BytecodeReadRafSumcheckParams::gen")] - pub fn gen( - bytecode_preprocessing: &BytecodePreprocessing, + pub fn gen( + program: Option<&ProgramPreprocessing>, n_cycle_vars: usize, one_hot_params: &OneHotParams, + use_staged_val_claims: bool, + entry_bytecode_index: Option, opening_accumulator: &dyn OpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { let gamma_powers = transcript.challenge_scalar_powers(8); - let bytecode = &bytecode_preprocessing.bytecode; - // Generate all stage-specific gamma powers upfront (order must match verifier) let stage1_gammas: Vec = transcript.challenge_scalar_powers(2 + NUM_CIRCUIT_FLAGS); let stage2_gammas: Vec = transcript.challenge_scalar_powers(4); @@ -887,38 +1449,43 @@ impl BytecodeReadRafSumcheckParams { let rv_claim_5 = Self::compute_rv_claim_5(opening_accumulator, &stage5_gammas); let rv_claims = [rv_claim_1, rv_claim_2, rv_claim_3, rv_claim_4, rv_claim_5]; - // Pre-compute eq_r_register for stages 4 and 5 (they use different r_register points) - let r_register_4 = opening_accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::RdWa, - SumcheckId::RegistersReadWriteChecking, - ) - .0 - .r; - let eq_r_register_4 = - EqPolynomial::::evals(&r_register_4[..(REGISTER_COUNT as usize).log_2()]); - - let r_register_5 = opening_accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::RdWa, - SumcheckId::RegistersValEvaluation, + // Fused pass: compute all val polynomials in a single parallel iteration in Full mode. + let val_polys = if let Some(program) = program.and_then(|program| program.as_full().ok()) { + let r_register_4 = opening_accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::RdWa, + SumcheckId::RegistersReadWriteChecking, + ) + .0 + .r; + let eq_r_register_4 = + EqPolynomial::::evals(&r_register_4[..(REGISTER_COUNT as usize).log_2()]); + + let r_register_5 = opening_accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::RdWa, + SumcheckId::RegistersValEvaluation, + ) + .0 + .r; + let eq_r_register_5 = + EqPolynomial::::evals(&r_register_5[..(REGISTER_COUNT as usize).log_2()]); + + Self::compute_val_polys( + &program.bytecode.bytecode, + &eq_r_register_4, + &eq_r_register_5, + &stage1_gammas, + &stage2_gammas, + &stage3_gammas, + &stage4_gammas, + &stage5_gammas, ) - .0 - .r; - let eq_r_register_5 = - EqPolynomial::::evals(&r_register_5[..(REGISTER_COUNT as usize).log_2()]); - - // Fused pass: compute all val polynomials in a single parallel iteration - let val_polys = Self::compute_val_polys( - bytecode, - &eq_r_register_4, - &eq_r_register_5, - &stage1_gammas, - &stage2_gammas, - &stage3_gammas, - &stage4_gammas, - &stage5_gammas, - ); + } else { + array::from_fn(|_| { + MultilinearPolynomial::from(vec![F::zero(); one_hot_params.bytecode_k]) + }) + }; let int_poly = IdentityPolynomial::new(one_hot_params.bytecode_k.log_2()); @@ -927,11 +1494,10 @@ impl BytecodeReadRafSumcheckParams { let (_, raf_shift_claim) = opening_accumulator .get_virtual_polynomial_opening(VirtualPolynomial::PC, SumcheckId::SpartanShift); let entry_gamma = gamma_powers[7]; - let entry_bytecode_index = bytecode_preprocessing.entry_bytecode_index(); - // Both prover and verifier add entry_gamma unconditionally. - // The security comes from the sumcheck: if ra(entry_index, 0) != 1, the sum - // won't match input_claim and the sumcheck fails. - let input_claim: F = [ + let entry_bytecode_index = entry_bytecode_index + .or_else(|| program.map(|program| program.entry_bytecode_index())) + .unwrap_or_default(); + let mut input_claim: F = [ rv_claim_1, rv_claim_2, rv_claim_3, @@ -943,8 +1509,8 @@ impl BytecodeReadRafSumcheckParams { .iter() .zip(&gamma_powers) .map(|(claim, g)| *claim * g) - .sum::() - + entry_gamma; + .sum::(); + input_claim += entry_gamma; let (r_cycle_1, _) = opening_accumulator .get_virtual_polynomial_opening(VirtualPolynomial::Imm, SumcheckId::SpartanOuter); @@ -975,6 +1541,7 @@ impl BytecodeReadRafSumcheckParams { ]; Self { + use_staged_val_claims, gamma_powers, entry_gamma, entry_bytecode_index, @@ -995,8 +1562,8 @@ impl BytecodeReadRafSumcheckParams { raf_shift_claim, int_poly, r_cycles, - bound_int_poly: None, bound_val_polys: None, + staged_val_claims: None, bound_f_entry: None, } } @@ -1685,19 +2252,38 @@ impl SumcheckInstanceParams for BytecodeReadRafSumcheckParams Option { - let factors: Vec = (0..self.d) + let ra_factors: Vec = (0..self.d) .map(|i| { - let opening = OpeningId::committed( + ValueSource::Opening(OpeningId::committed( CommittedPolynomial::BytecodeRa(i), SumcheckId::BytecodeReadRaf, - ); - ValueSource::Opening(opening) + )) }) .collect(); - let terms = vec![ProductTerm::scaled(ValueSource::Challenge(0), factors)]; - - Some(OutputClaimConstraint::sum_of_products(terms)) + if self.use_staged_val_claims { + // In committed mode, verifier does not materialize stage Val polynomials. + // Encode output as: + // ra_prod * (Σ_stage coeff_stage * ValStage(stage) + const_term) + // where coeff_stage / const_term are public challenge values. + let mut terms = Vec::with_capacity(N_STAGES + 1); + for stage in 0..N_STAGES { + let mut factors = ra_factors.clone(); + factors.push(ValueSource::Opening(OpeningId::virt( + VirtualPolynomial::BytecodeValStage(stage), + SumcheckId::BytecodeReadRafAddressPhase, + ))); + terms.push(ProductTerm::scaled(ValueSource::Challenge(stage), factors)); + } + terms.push(ProductTerm::scaled( + ValueSource::Challenge(N_STAGES), + ra_factors, + )); + Some(OutputClaimConstraint::sum_of_products(terms)) + } else { + let terms = vec![ProductTerm::scaled(ValueSource::Challenge(0), ra_factors)]; + Some(OutputClaimConstraint::sum_of_products(terms)) + } } #[cfg(feature = "zk")] @@ -1705,7 +2291,46 @@ impl SumcheckInstanceParams for BytecodeReadRafSumcheckParams = self + .r_cycles + .iter() + .map(|r_cycle| EqPolynomial::::mle(r_cycle, &r_cycle_prime.r)) + .collect(); + + let mut coeffs: Vec = (0..N_STAGES) + .map(|stage| self.gamma_powers[stage] * eq_cycles[stage]) + .collect(); + + let int_poly_contrib_by_stage = [ + int_poly * self.gamma_powers[5], // RAF for Stage1 + F::zero(), + int_poly * self.gamma_powers[4], // RAF for Stage3 + F::zero(), + F::zero(), + ]; + let int_contrib: F = (0..N_STAGES) + .map(|stage| { + int_poly_contrib_by_stage[stage] * eq_cycles[stage] * self.gamma_powers[stage] + }) + .sum(); + + let log_k = self.log_K; + let e = self.entry_bytecode_index; + let entry_bits: Vec = (0..log_k) + .map(|i| F::from_u64(((e >> (log_k - 1 - i)) & 1) as u64)) + .collect(); + let f_entry_at_r_addr = EqPolynomial::::mle(&entry_bits, &r_address_prime.r); + let zeros: Vec = vec![F::Challenge::default(); r_cycle_prime.r.len()]; + let eq_zero_at_r_cycle = EqPolynomial::::mle(&zeros, &r_cycle_prime.r); + let entry_contrib = self.entry_gamma * f_entry_at_r_addr * eq_zero_at_r_cycle; + + coeffs.push(int_contrib + entry_contrib); + return coeffs; + } + + // Prover stores bound values before clearing polys; verifier evaluates directly. let val: F = if let Some(bound_val_polys) = &self.bound_val_polys { bound_val_polys .iter() @@ -1736,8 +2361,6 @@ impl SumcheckInstanceParams for BytecodeReadRafSumcheckParams> (log_K-1-j)) & 1. let f_entry_at_r_addr = if let Some(v) = self.bound_f_entry { v } else { diff --git a/jolt-core/src/zkvm/claim_reductions/advice.rs b/jolt-core/src/zkvm/claim_reductions/advice.rs index 4ecd095dc..3c02bd486 100644 --- a/jolt-core/src/zkvm/claim_reductions/advice.rs +++ b/jolt-core/src/zkvm/claim_reductions/advice.rs @@ -1,42 +1,11 @@ -//! Two-phase advice claim reduction (Stage 6 cycle → Stage 7 address) -//! -//! This module generalizes the previous single-phase `AdviceClaimReduction` so that trusted and -//! untrusted advice can be committed as an arbitrary Dory matrix `2^{nu_a} x 2^{sigma_a}` (balanced -//! by default), while still keeping a **single Stage 8 Dory opening** at the unified Dory point. -//! -//! For an advice matrix embedded as the **top-left block** `2^{nu_a} x 2^{sigma_a}`, the *native* -//! advice evaluation point (in Dory order, LSB-first) is: -//! - `advice_cols = col_coords[0..sigma_a]` -//! - `advice_rows = row_coords[0..nu_a]` -//! - `advice_point = [advice_cols || advice_rows]` -//! -//! In our current pipeline, `cycle` coordinates come from Stage 6 and `addr` coordinates come from -//! Stage 7. -//! - **Phase 1 (Stage 6)**: bind the cycle-derived advice coordinates and output an intermediate -//! scalar claim `C_mid`. -//! - **Phase 2 (Stage 7)**: resume from `C_mid`, bind the address-derived advice coordinates, and -//! cache the final advice opening `AdviceMLE(advice_point)` for batching into Stage 8. -//! -//! ## Dummy-gap scaling (within Stage 6) -//! With cycle-major order, there may be a gap during the cycle phase where the cycle variables -//! being bound in the batched sumcheck do not appear in the advice polynommial. -//! -//! We handle this without modifying the generic batched sumcheck by treating those intervening -//! rounds as **dummy internal rounds** (constant univariates), and maintaining a running scaling -//! factor `2^{-dummy_done}` so the per-round univariates remain consistent. -//! -//! Trusted and untrusted advice run as **separate** sumcheck instances (each may have different -//! dimensions). -//! +//! Two-phase advice claim reduction (Stage 6 cycle -> Stage 7 address). use std::cell::RefCell; -use std::cmp::{min, Ordering}; -use std::ops::Range; use crate::field::JoltField; -use crate::poly::commitment::dory::{DoryGlobals, DoryLayout}; +use crate::poly::commitment::dory::DoryGlobals; use crate::poly::eq_poly::EqPolynomial; -use crate::poly::multilinear_polynomial::{BindingOrder, MultilinearPolynomial, PolynomialBinding}; +use crate::poly::multilinear_polynomial::MultilinearPolynomial; #[cfg(feature = "zk")] use crate::poly::opening_proof::OpeningId; use crate::poly::opening_proof::{ @@ -49,13 +18,12 @@ use crate::subprotocols::blindfold::{InputClaimConstraint, OutputClaimConstraint use crate::subprotocols::sumcheck_prover::SumcheckInstanceProver; use crate::subprotocols::sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}; use crate::transcripts::Transcript; -use crate::utils::math::Math; -use crate::zkvm::config::OneHotConfig; +use crate::zkvm::claim_reductions::{ + permute_precommitted_polys, precommitted_eq_evals_with_scaling, precommitted_skip_round_scale, + PrecomittedParams, PrecomittedProver, PrecommittedClaimReduction, PrecommittedPhase, + PrecommittedSchedulingReference, TWO_PHASE_DEGREE_BOUND, +}; use allocative::Allocative; -use common::jolt_device::MemoryLayout; -use rayon::prelude::*; - -const DEGREE_BOUND: usize = 2; #[derive(Clone, Copy, Debug, PartialEq, Eq, Allocative)] pub enum AdviceKind { @@ -63,134 +31,69 @@ pub enum AdviceKind { Untrusted, } -#[derive(Debug, Clone, Allocative, PartialEq, Eq)] -pub enum ReductionPhase { - CycleVariables, - AddressVariables, -} - #[derive(Clone, Allocative)] pub struct AdviceClaimReductionParams { pub kind: AdviceKind, - pub phase: ReductionPhase, - pub log_k_chunk: usize, - pub log_t: usize, + pub phase: PrecommittedPhase, + pub precommitted: PrecommittedClaimReduction, pub advice_col_vars: usize, pub advice_row_vars: usize, - /// Number of column variables in the main Dory matrix - pub main_col_vars: usize, - /// Number of row variables in the main Dory matrix - pub main_row_vars: usize, - #[allocative(skip)] - pub cycle_phase_row_rounds: Range, - #[allocative(skip)] - pub cycle_phase_col_rounds: Range, pub r_val: OpeningPoint, - /// (little-endian) challenges for the cycle phase variables - pub cycle_var_challenges: Vec, -} - -fn cycle_phase_round_schedule( - log_T: usize, - log_k_chunk: usize, - main_col_vars: usize, - advice_row_vars: usize, - advice_col_vars: usize, -) -> (Range, Range) { - match DoryGlobals::get_layout() { - DoryLayout::CycleMajor => { - // Low-order cycle variables correspond to the low-order bits of the - // column index - let col_binding_rounds = 0..min(log_T, advice_col_vars); - // High-order cycle variables correspond to the low-order bits of the - // rows index - let row_binding_rounds = - min(log_T, main_col_vars)..min(log_T, main_col_vars + advice_row_vars); - (col_binding_rounds, row_binding_rounds) - } - DoryLayout::AddressMajor => { - // Low-order cycle variables correspond to the high-order bits of the - // column index - let col_binding_rounds = 0..advice_col_vars.saturating_sub(log_k_chunk); - // High-order cycle variables correspond to the bits of the row index - let row_binding_rounds = main_col_vars.saturating_sub(log_k_chunk) - ..min( - log_T, - main_col_vars.saturating_sub(log_k_chunk) + advice_row_vars, - ); - (col_binding_rounds, row_binding_rounds) - } - } } impl AdviceClaimReductionParams { pub fn new( kind: AdviceKind, - memory_layout: &MemoryLayout, - trace_len: usize, + advice_size_bytes: usize, + scheduling_reference: PrecommittedSchedulingReference, accumulator: &dyn OpeningAccumulator, ) -> Self { - let max_advice_size_bytes = match kind { - AdviceKind::Trusted => memory_layout.max_trusted_advice_size as usize, - AdviceKind::Untrusted => memory_layout.max_untrusted_advice_size as usize, - }; - - let log_t = trace_len.log_2(); - let log_k_chunk = OneHotConfig::new(log_t).log_k_chunk as usize; - let (main_col_vars, main_row_vars) = DoryGlobals::main_sigma_nu(log_k_chunk, log_t); - let r_val = accumulator .get_advice_opening(kind, SumcheckId::RamValCheck) .map(|(p, _)| p) .unwrap(); let (advice_col_vars, advice_row_vars) = - DoryGlobals::advice_sigma_nu_from_max_bytes(max_advice_size_bytes); - let (col_binding_rounds, row_binding_rounds) = cycle_phase_round_schedule( - log_t, - log_k_chunk, - main_col_vars, - advice_row_vars, - advice_col_vars, - ); + DoryGlobals::advice_sigma_nu_from_max_bytes(advice_size_bytes); + let precommitted = + PrecommittedClaimReduction::new(advice_row_vars, advice_col_vars, scheduling_reference); Self { kind, - phase: ReductionPhase::CycleVariables, + phase: PrecommittedPhase::CycleVariables, + precommitted, advice_col_vars, advice_row_vars, - log_k_chunk, - log_t, - main_col_vars, - main_row_vars, - cycle_phase_row_rounds: row_binding_rounds, - cycle_phase_col_rounds: col_binding_rounds, r_val, - cycle_var_challenges: vec![], } } - /// (Total # advice variables) - (# variables bound during cycle phase) pub fn num_address_phase_rounds(&self) -> usize { - (self.advice_col_vars + self.advice_row_vars) - - (self.cycle_phase_col_rounds.len() + self.cycle_phase_row_rounds.len()) + self.precommitted.num_address_phase_rounds() + } + + pub fn transition_to_address_phase(&mut self) { + self.phase = PrecommittedPhase::AddressVariables; + } + + pub fn round_offset(&self, max_num_rounds: usize) -> usize { + self.precommitted.round_offset( + self.phase == PrecommittedPhase::CycleVariables, + max_num_rounds, + ) } } impl SumcheckInstanceParams for AdviceClaimReductionParams { fn input_claim(&self, accumulator: &dyn OpeningAccumulator) -> F { match self.phase { - ReductionPhase::CycleVariables => { - let mut claim = F::zero(); - if let Some((_, eval)) = - accumulator.get_advice_opening(self.kind, SumcheckId::RamValCheck) - { - claim += eval; - } - claim + PrecommittedPhase::CycleVariables => { + accumulator + .get_advice_opening(self.kind, SumcheckId::RamValCheck) + .expect("RamValCheck advice opening missing") + .1 } - ReductionPhase::AddressVariables => { - // Address phase starts from the cycle phase intermediate claim. + PrecommittedPhase::AddressVariables => { accumulator .get_advice_opening(self.kind, SumcheckId::AdviceClaimReductionCyclePhase) .expect("Cycle phase intermediate claim not found") @@ -200,77 +103,36 @@ impl SumcheckInstanceParams for AdviceClaimReductionParams { } fn degree(&self) -> usize { - DEGREE_BOUND + TWO_PHASE_DEGREE_BOUND } fn num_rounds(&self) -> usize { - match self.phase { - ReductionPhase::CycleVariables => { - if !self.cycle_phase_row_rounds.is_empty() { - self.cycle_phase_row_rounds.end - self.cycle_phase_col_rounds.start - } else { - self.cycle_phase_col_rounds.len() - } - } - ReductionPhase::AddressVariables => { - let first_phase_rounds = - self.cycle_phase_row_rounds.len() + self.cycle_phase_col_rounds.len(); - // Total advice variables, minus the variables bound during the cycle phase - (self.advice_col_vars + self.advice_row_vars) - first_phase_rounds - } - } + self.precommitted + .num_rounds_for_phase(self.phase == PrecommittedPhase::CycleVariables) } - /// Rearrange the opening point so that it is big-endian with respect to the original, - /// unpermuted advice/EQ polynomials. - fn normalize_opening_point( - &self, - challenges: &[::Challenge], - ) -> OpeningPoint { - if self.phase == ReductionPhase::CycleVariables { - let advice_vars = self.advice_col_vars + self.advice_row_vars; - let mut advice_var_challenges: Vec = Vec::with_capacity(advice_vars); - advice_var_challenges - .extend_from_slice(&challenges[self.cycle_phase_col_rounds.clone()]); - advice_var_challenges - .extend_from_slice(&challenges[self.cycle_phase_row_rounds.clone()]); - return OpeningPoint::::new(advice_var_challenges).match_endianness(); - } - - match DoryGlobals::get_layout() { - DoryLayout::CycleMajor => OpeningPoint::::new( - [self.cycle_var_challenges.as_slice(), challenges].concat(), - ) - .match_endianness(), - DoryLayout::AddressMajor => OpeningPoint::::new( - [challenges, self.cycle_var_challenges.as_slice()].concat(), - ) - .match_endianness(), - } + fn normalize_opening_point(&self, challenges: &[F::Challenge]) -> OpeningPoint { + self.precommitted + .normalize_opening_point(self.phase == PrecommittedPhase::CycleVariables, challenges) } #[cfg(feature = "zk")] fn input_claim_constraint(&self) -> InputClaimConstraint { - match self.phase { - ReductionPhase::CycleVariables => { - let val_opening = match self.kind { - AdviceKind::Trusted => OpeningId::TrustedAdvice(SumcheckId::RamValCheck), - AdviceKind::Untrusted => OpeningId::UntrustedAdvice(SumcheckId::RamValCheck), - }; - InputClaimConstraint::direct(val_opening) - } - ReductionPhase::AddressVariables => { - let cycle_phase_opening = match self.kind { - AdviceKind::Trusted => { - OpeningId::TrustedAdvice(SumcheckId::AdviceClaimReductionCyclePhase) - } - AdviceKind::Untrusted => { - OpeningId::UntrustedAdvice(SumcheckId::AdviceClaimReductionCyclePhase) - } - }; - InputClaimConstraint::direct(cycle_phase_opening) - } - } + let opening = match self.phase { + PrecommittedPhase::CycleVariables => match self.kind { + AdviceKind::Trusted => OpeningId::TrustedAdvice(SumcheckId::RamValCheck), + AdviceKind::Untrusted => OpeningId::UntrustedAdvice(SumcheckId::RamValCheck), + }, + PrecommittedPhase::AddressVariables => match self.kind { + AdviceKind::Trusted => { + OpeningId::TrustedAdvice(SumcheckId::AdviceClaimReductionCyclePhase) + } + AdviceKind::Untrusted => { + OpeningId::UntrustedAdvice(SumcheckId::AdviceClaimReductionCyclePhase) + } + }, + }; + InputClaimConstraint::direct(opening) } #[cfg(feature = "zk")] @@ -281,8 +143,8 @@ impl SumcheckInstanceParams for AdviceClaimReductionParams { #[cfg(feature = "zk")] fn output_claim_constraint(&self) -> Option { match self.phase { - ReductionPhase::CycleVariables => { - let advice_opening = match self.kind { + PrecommittedPhase::CycleVariables => { + let opening = match self.kind { AdviceKind::Trusted => { OpeningId::TrustedAdvice(SumcheckId::AdviceClaimReductionCyclePhase) } @@ -290,10 +152,10 @@ impl SumcheckInstanceParams for AdviceClaimReductionParams { OpeningId::UntrustedAdvice(SumcheckId::AdviceClaimReductionCyclePhase) } }; - Some(OutputClaimConstraint::direct(advice_opening)) + Some(OutputClaimConstraint::direct(opening)) } - ReductionPhase::AddressVariables => { - let advice_opening = match self.kind { + PrecommittedPhase::AddressVariables => { + let opening = match self.kind { AdviceKind::Trusted => { OpeningId::TrustedAdvice(SumcheckId::AdviceClaimReduction) } @@ -301,11 +163,9 @@ impl SumcheckInstanceParams for AdviceClaimReductionParams { OpeningId::UntrustedAdvice(SumcheckId::AdviceClaimReduction) } }; - // output = (eq_combined * scale) * advice_claim - // Challenge(0) holds eq_combined * scale (computed in output_constraint_challenge_values) Some(OutputClaimConstraint::linear(vec![( ValueSource::Challenge(0), - ValueSource::Opening(advice_opening), + ValueSource::Opening(opening), )])) } } @@ -314,193 +174,97 @@ impl SumcheckInstanceParams for AdviceClaimReductionParams { #[cfg(feature = "zk")] fn output_constraint_challenge_values(&self, sumcheck_challenges: &[F::Challenge]) -> Vec { match self.phase { - ReductionPhase::CycleVariables => vec![], - ReductionPhase::AddressVariables => { + PrecommittedPhase::CycleVariables => vec![], + PrecommittedPhase::AddressVariables => { let opening_point = self.normalize_opening_point(sumcheck_challenges); let eq_eval = EqPolynomial::mle(&opening_point.r, &self.r_val.r); - - let gap_len = if self.cycle_phase_row_rounds.is_empty() - || self.cycle_phase_col_rounds.is_empty() - { - 0 - } else { - self.cycle_phase_row_rounds.start - self.cycle_phase_col_rounds.end - }; - let two_inv = F::from_u64(2).inverse().unwrap(); - let scale = (0..gap_len).fold(F::one(), |acc, _| acc * two_inv); - + let scale: F = precommitted_skip_round_scale(&self.precommitted); vec![eq_eval * scale] } } } } +impl PrecomittedParams for AdviceClaimReductionParams { + fn is_cycle_phase(&self) -> bool { + self.phase == PrecommittedPhase::CycleVariables + } + + fn is_cycle_phase_round(&self, round: usize) -> bool { + self.precommitted.is_cycle_phase_round(round) + } + + fn is_address_phase_round(&self, round: usize) -> bool { + self.precommitted.is_address_phase_round(round) + } + + fn cycle_alignment_rounds(&self) -> usize { + self.precommitted.cycle_alignment_rounds() + } + + fn address_alignment_rounds(&self) -> usize { + self.precommitted.address_alignment_rounds() + } + + fn record_cycle_challenge(&mut self, challenge: F::Challenge) { + self.precommitted.record_cycle_challenge(challenge); + } +} + #[derive(Allocative)] pub struct AdviceClaimReductionProver { - pub params: AdviceClaimReductionParams, - advice_poly: MultilinearPolynomial, - eq_poly: MultilinearPolynomial, - /// Maintains the running internal scaling factor 2^{-dummy_done}. - scale: F, + core: PrecomittedProver>, } impl AdviceClaimReductionProver { + pub fn params(&self) -> &AdviceClaimReductionParams { + self.core.params() + } + + pub fn transition_to_address_phase(&mut self) { + self.core.params_mut().transition_to_address_phase(); + } + pub fn initialize( params: AdviceClaimReductionParams, advice_poly: MultilinearPolynomial, ) -> Self { - let eq_evals = EqPolynomial::evals(¶ms.r_val.r); - - let main_cols = 1 << params.main_col_vars; - // Maps a (row, col) position in the Dory matrix layout to its - // implied (address, cycle). - let row_col_to_address_cycle = |row: usize, col: usize| -> (usize, usize) { - match DoryGlobals::get_layout() { - DoryLayout::CycleMajor => { - let global_index = row as u128 * main_cols + col as u128; - let address = global_index / (1 << params.log_t); - let cycle = global_index % (1 << params.log_t); - (address as usize, cycle as usize) - } - DoryLayout::AddressMajor => { - let global_index = row as u128 * main_cols + col as u128; - let address = global_index % (1 << params.log_k_chunk); - let cycle = global_index / (1 << params.log_k_chunk); - (address as usize, cycle as usize) - } - } - }; - - let advice_cols = 1 << params.advice_col_vars; - // Maps an index in the advice vector to its implied (address, cycle), based - // on the position the index maps to in the Dory matrix layout. - let advice_index_to_address_cycle = |index: usize| -> (usize, usize) { - let row = index / advice_cols; - let col = index % advice_cols; - row_col_to_address_cycle(row, col) - }; - - let mut permuted_coeffs: Vec<(usize, (u64, F))> = match advice_poly { - MultilinearPolynomial::U64Scalars(poly) => poly - .coeffs - .into_par_iter() - .zip(eq_evals.into_par_iter()) - .enumerate() - .collect(), - _ => panic!("Advice should have u64 coefficients"), + let eq_evals = + precommitted_eq_evals_with_scaling(¶ms.r_val.r, None, ¶ms.precommitted); + let (advice_poly, eq_poly): (MultilinearPolynomial, MultilinearPolynomial) = { + let MultilinearPolynomial::U64Scalars(poly) = advice_poly else { + panic!("Advice should have u64 coefficients"); + }; + let mut permuted = + permute_precommitted_polys(vec![poly.coeffs], ¶ms.precommitted).into_iter(); + let advice_poly = permuted + .next() + .expect("expected one permuted advice polynomial"); + let eq_poly = eq_evals.into(); + (advice_poly, eq_poly) }; - // Sort the advice and EQ polynomial coefficients by (address, cycle). - // By sorting this way, binding the resulting polynomials in low-to-high - // order is equivalent to binding the original polynomials' "cycle" variables - // low-to-high, then their "address" variables low-to-high. - permuted_coeffs.par_sort_by(|&(index_a, _), &(index_b, _)| { - let (address_a, cycle_a) = advice_index_to_address_cycle(index_a); - let (address_b, cycle_b) = advice_index_to_address_cycle(index_b); - match address_a.cmp(&address_b) { - Ordering::Less => Ordering::Less, - Ordering::Greater => Ordering::Greater, - Ordering::Equal => cycle_a.cmp(&cycle_b), - } - }); - - let (advice_coeffs, eq_coeffs): (Vec<_>, Vec<_>) = permuted_coeffs - .into_par_iter() - .map(|(_, coeffs)| coeffs) - .unzip(); - let advice_poly = advice_coeffs.into(); - let eq_poly = eq_coeffs.into(); Self { - params, - advice_poly, - eq_poly, - scale: F::one(), + core: PrecomittedProver::new(params, advice_poly, eq_poly), } } - - fn compute_message_unscaled(&mut self, previous_claim_unscaled: F) -> UniPoly { - let half = self.advice_poly.len() / 2; - let evals: [F; DEGREE_BOUND] = (0..half) - .into_par_iter() - .map(|j| { - let a_evals = self - .advice_poly - .sumcheck_evals_array::(j, BindingOrder::LowToHigh); - let eq_evals = self - .eq_poly - .sumcheck_evals_array::(j, BindingOrder::LowToHigh); - - let mut out = [F::zero(); DEGREE_BOUND]; - for i in 0..DEGREE_BOUND { - out[i] = a_evals[i] * eq_evals[i]; - } - out - }) - .reduce( - || [F::zero(); DEGREE_BOUND], - |mut acc, arr| { - acc.par_iter_mut() - .zip(arr.par_iter()) - .for_each(|(a, b)| *a += *b); - acc - }, - ); - UniPoly::from_evals_and_hint(previous_claim_unscaled, &evals) - } } impl SumcheckInstanceProver for AdviceClaimReductionProver { fn get_params(&self) -> &dyn SumcheckInstanceParams { - &self.params + self.core.params() + } + + fn round_offset(&self, max_num_rounds: usize) -> usize { + self.core.params().round_offset(max_num_rounds) } fn compute_message(&mut self, round: usize, previous_claim: F) -> UniPoly { - if self.params.phase == ReductionPhase::CycleVariables - && !self.params.cycle_phase_col_rounds.contains(&round) - && !self.params.cycle_phase_row_rounds.contains(&round) - { - // Current sumcheck variable does not appear in advice polynomial, so we - // can simply send a constant polynomial equal to the previous claim divided by 2 - UniPoly::from_coeff(vec![previous_claim * F::from_u64(2).inverse().unwrap()]) - } else { - // Account for (1) internal dummy rounds already traversed and - // (2) trailing dummy rounds after this instance's active window in the batched sumcheck. - let num_trailing_variables = match self.params.phase { - ReductionPhase::CycleVariables => { - self.params.log_t.saturating_sub(self.params.num_rounds()) - } - ReductionPhase::AddressVariables => self - .params - .log_k_chunk - .saturating_sub(self.params.num_rounds()), - }; - let scaling_factor = self.scale * F::one().mul_pow_2(num_trailing_variables); - let prev_unscaled = previous_claim * scaling_factor.inverse().unwrap(); - let poly_unscaled = self.compute_message_unscaled(prev_unscaled); - poly_unscaled * scaling_factor - } + self.core.compute_message(round, previous_claim) } fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize) { - match self.params.phase { - ReductionPhase::CycleVariables => { - if !self.params.cycle_phase_col_rounds.contains(&round) - && !self.params.cycle_phase_row_rounds.contains(&round) - { - // Each dummy internal round halves the running claim; equivalently, we multiply the - // scaling factor by 1/2. - self.scale *= F::from_u64(2).inverse().unwrap(); - } else { - self.advice_poly.bind_parallel(r_j, BindingOrder::LowToHigh); - self.eq_poly.bind_parallel(r_j, BindingOrder::LowToHigh); - self.params.cycle_var_challenges.push(r_j); - } - } - ReductionPhase::AddressVariables => { - self.advice_poly.bind_parallel(r_j, BindingOrder::LowToHigh); - self.eq_poly.bind_parallel(r_j, BindingOrder::LowToHigh); - } - } + self.core.ingest_challenge(r_j, round); } fn cache_openings( @@ -508,43 +272,27 @@ impl SumcheckInstanceProver for AdviceClaimRe accumulator: &mut ProverOpeningAccumulator, sumcheck_challenges: &[F::Challenge], ) { - let opening_point = self.params.normalize_opening_point(sumcheck_challenges); - if self.params.phase == ReductionPhase::CycleVariables { - // Compute the intermediate claim C_mid = (2^{-gap}) * Σ_y advice(y) * eq(y), - // where y are the remaining (address-derived) advice row variables. - let len = self.advice_poly.len(); - debug_assert_eq!(len, self.eq_poly.len()); - - let mut sum = F::zero(); - for i in 0..len { - sum += self.advice_poly.get_bound_coeff(i) * self.eq_poly.get_bound_coeff(i); - } - let c_mid = sum * self.scale; + let params = self.core.params(); + let opening_point = params.normalize_opening_point(sumcheck_challenges); + if params.phase == PrecommittedPhase::CycleVariables { + let c_mid = self.core.cycle_intermediate_claim(); - match self.params.kind { + match params.kind { AdviceKind::Trusted => accumulator.append_trusted_advice( SumcheckId::AdviceClaimReductionCyclePhase, - // This is a phase-boundary intermediate reduction claim (c_mid), not an advice - // polynomial opening. Store it without an opening point so it can't be deduped - // against the final advice opening. OpeningPoint::::new(vec![]), c_mid, ), AdviceKind::Untrusted => accumulator.append_untrusted_advice( SumcheckId::AdviceClaimReductionCyclePhase, - // This is a phase-boundary intermediate reduction claim (c_mid), not an advice - // polynomial opening. Store it without an opening point so it can't be deduped - // against the final advice opening. OpeningPoint::::new(vec![]), c_mid, ), } } - // If we're done binding advice variables, cache the final advice opening - if self.advice_poly.len() == 1 { - let advice_claim = self.advice_poly.final_sumcheck_claim(); - match self.params.kind { + if let Some(advice_claim) = self.core.final_claim_if_ready() { + match params.kind { AdviceKind::Trusted => accumulator.append_trusted_advice( SumcheckId::AdviceClaimReduction, opening_point, @@ -559,19 +307,6 @@ impl SumcheckInstanceProver for AdviceClaimRe } } - fn round_offset(&self, max_num_rounds: usize) -> usize { - match self.params.phase { - ReductionPhase::CycleVariables => { - // Align to the *start* of Booleanity's cycle segment, so local rounds correspond - // to low Dory column bits in the unified point ordering. - let booleanity_rounds = self.params.log_k_chunk + self.params.log_t; - let booleanity_offset = max_num_rounds - booleanity_rounds; - booleanity_offset + self.params.log_k_chunk - } - ReductionPhase::AddressVariables => 0, - } - } - #[cfg(feature = "allocative")] fn update_flamegraph(&self, flamegraph: &mut allocative::FlameGraphBuilder) { flamegraph.visit_root(self); @@ -585,11 +320,16 @@ pub struct AdviceClaimReductionVerifier { impl AdviceClaimReductionVerifier { pub fn new( kind: AdviceKind, - memory_layout: &MemoryLayout, - trace_len: usize, + advice_size_bytes: usize, + scheduling_reference: PrecommittedSchedulingReference, accumulator: &dyn OpeningAccumulator, ) -> Self { - let params = AdviceClaimReductionParams::new(kind, memory_layout, trace_len, accumulator); + let params = AdviceClaimReductionParams::new( + kind, + advice_size_bytes, + scheduling_reference, + accumulator, + ); Self { params: RefCell::new(params), @@ -607,32 +347,20 @@ impl> fn expected_output_claim(&self, accumulator: &A, sumcheck_challenges: &[F::Challenge]) -> F { let params = self.params.borrow(); match params.phase { - ReductionPhase::CycleVariables => { + PrecommittedPhase::CycleVariables => { accumulator .get_advice_opening(params.kind, SumcheckId::AdviceClaimReductionCyclePhase) - .unwrap_or_else(|| panic!("Cycle phase intermediate claim not found",)) + .unwrap_or_else(|| panic!("Cycle phase intermediate claim not found")) .1 } - ReductionPhase::AddressVariables => { + PrecommittedPhase::AddressVariables => { let opening_point = params.normalize_opening_point(sumcheck_challenges); let advice_claim = accumulator .get_advice_opening(params.kind, SumcheckId::AdviceClaimReduction) .expect("Final advice claim not found") .1; - let eq_eval = EqPolynomial::mle(&opening_point.r, ¶ms.r_val.r); - - let gap_len = if params.cycle_phase_row_rounds.is_empty() - || params.cycle_phase_col_rounds.is_empty() - { - 0 - } else { - params.cycle_phase_row_rounds.start - params.cycle_phase_col_rounds.end - }; - let two_inv = F::from_u64(2).inverse().unwrap(); - let scale = (0..gap_len).fold(F::one(), |acc, _| acc * two_inv); - - // Account for Phase 1's internal dummy-gap traversal via constant scaling. + let scale: F = precommitted_skip_round_scale(¶ms.precommitted); advice_claim * eq_eval * scale } } @@ -640,30 +368,26 @@ impl> fn cache_openings(&self, accumulator: &mut A, sumcheck_challenges: &[F::Challenge]) { let mut params = self.params.borrow_mut(); - if params.phase == ReductionPhase::CycleVariables { + if params.phase == PrecommittedPhase::CycleVariables { let opening_point = params.normalize_opening_point(sumcheck_challenges); match params.kind { AdviceKind::Trusted => accumulator.append_trusted_advice( SumcheckId::AdviceClaimReductionCyclePhase, - // This is a phase-boundary intermediate reduction claim (c_mid), not an advice - // polynomial opening. Store it without an opening point so it can't be deduped - // against the final advice opening. OpeningPoint::::new(vec![]), ), AdviceKind::Untrusted => accumulator.append_untrusted_advice( SumcheckId::AdviceClaimReductionCyclePhase, - // This is a phase-boundary intermediate reduction claim (c_mid), not an advice - // polynomial opening. Store it without an opening point so it can't be deduped - // against the final advice opening. OpeningPoint::::new(vec![]), ), } let opening_point_le: OpeningPoint = opening_point.match_endianness(); - params.cycle_var_challenges = opening_point_le.r; + params + .precommitted + .set_cycle_var_challenges(opening_point_le.r); } if params.num_address_phase_rounds() == 0 - || params.phase == ReductionPhase::AddressVariables + || params.phase == PrecommittedPhase::AddressVariables { let opening_point = params.normalize_opening_point(sumcheck_challenges); match params.kind { @@ -677,13 +401,6 @@ impl> fn round_offset(&self, max_num_rounds: usize) -> usize { let params = self.params.borrow(); - match params.phase { - ReductionPhase::CycleVariables => { - let booleanity_rounds = params.log_k_chunk + params.log_t; - let booleanity_offset = max_num_rounds - booleanity_rounds; - booleanity_offset + params.log_k_chunk - } - ReductionPhase::AddressVariables => 0, - } + params.round_offset(max_num_rounds) } } diff --git a/jolt-core/src/zkvm/claim_reductions/bytecode.rs b/jolt-core/src/zkvm/claim_reductions/bytecode.rs new file mode 100644 index 000000000..4b46c2d01 --- /dev/null +++ b/jolt-core/src/zkvm/claim_reductions/bytecode.rs @@ -0,0 +1,705 @@ +//! Two-phase bytecode claim reduction (Stage 6b cycle -> Stage 7 address). + +use std::cell::RefCell; + +use allocative::Allocative; +use rayon::prelude::*; + +use crate::field::JoltField; +use crate::poly::commitment::dory::{DoryGlobals, DoryLayout}; +use crate::poly::eq_poly::EqPolynomial; +use crate::poly::multilinear_polynomial::{BindingOrder, MultilinearPolynomial, PolynomialBinding}; +#[cfg(feature = "zk")] +use crate::poly::opening_proof::OpeningId; +use crate::poly::opening_proof::{ + AbstractVerifierOpeningAccumulator, OpeningAccumulator, OpeningPoint, ProverOpeningAccumulator, + SumcheckId, BIG_ENDIAN, LITTLE_ENDIAN, +}; +use crate::poly::unipoly::UniPoly; +#[cfg(feature = "zk")] +use crate::subprotocols::blindfold::{InputClaimConstraint, OutputClaimConstraint, ValueSource}; +use crate::subprotocols::sumcheck_prover::SumcheckInstanceProver; +use crate::subprotocols::sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}; +use crate::transcripts::Transcript; +use crate::utils::math::Math; +use crate::zkvm::bytecode::chunks::committed_lanes; +use crate::zkvm::bytecode::read_raf_checking::BytecodeReadRafSumcheckParams; +use crate::zkvm::claim_reductions::{ + permute_precommitted_polys, precommitted_skip_round_scale, PrecommittedClaimReduction, + PrecommittedPhase, PrecommittedSchedulingReference, TWO_PHASE_DEGREE_BOUND, +}; +use crate::zkvm::instruction::{ + CircuitFlags, InstructionFlags, NUM_CIRCUIT_FLAGS, NUM_INSTRUCTION_FLAGS, +}; +use crate::zkvm::lookup_table::LookupTables; +use crate::zkvm::witness::{CommittedPolynomial, VirtualPolynomial}; +use common::constants::{REGISTER_COUNT, XLEN}; +use strum::EnumCount; + +const NUM_VAL_STAGES: usize = 5; + +#[derive(Clone, Allocative)] +pub struct BytecodeClaimReductionParams { + pub phase: PrecommittedPhase, + pub precommitted: PrecommittedClaimReduction, + pub eta: F, + pub eta_powers: [F; NUM_VAL_STAGES], + /// Eq weights over high bytecode address bits (one per committed chunk). + pub chunk_rbc_weights: Vec, + pub bytecode_T: usize, + pub bytecode_chunk_count: usize, + pub bytecode_col_vars: usize, + pub bytecode_row_vars: usize, + pub r_bc: OpeningPoint, + pub lane_weights: Vec, +} + +impl BytecodeClaimReductionParams { + pub fn new( + bytecode_read_raf_params: &BytecodeReadRafSumcheckParams, + full_bytecode_len: usize, + bytecode_chunk_count: usize, + scheduling_reference: PrecommittedSchedulingReference, + accumulator: &dyn OpeningAccumulator, + transcript: &mut impl Transcript, + ) -> Self { + assert!( + full_bytecode_len.is_multiple_of(bytecode_chunk_count), + "bytecode chunk count ({bytecode_chunk_count}) must divide bytecode_len ({full_bytecode_len})" + ); + let bytecode_t = (full_bytecode_len / bytecode_chunk_count).log_2(); + let bytecode_t_full = full_bytecode_len.log_2(); + + let eta: F = transcript.challenge_scalar(); + let mut eta_powers = [F::one(); NUM_VAL_STAGES]; + for i in 1..NUM_VAL_STAGES { + eta_powers[i] = eta_powers[i - 1] * eta; + } + + let (r_bc_full, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ); + debug_assert_eq!(r_bc_full.r.len(), bytecode_t_full); + let dropped_bits = bytecode_t_full - bytecode_t; + let chunk_rbc_weights = if dropped_bits == 0 { + vec![F::one()] + } else { + EqPolynomial::::evals(&r_bc_full.r[..dropped_bits]) + }; + debug_assert_eq!(chunk_rbc_weights.len(), bytecode_chunk_count); + let r_bc = OpeningPoint::new(r_bc_full.r[dropped_bits..].to_vec()); + + let lane_weights = compute_lane_weights(bytecode_read_raf_params, accumulator, &eta_powers); + + // bytecode_K is the committed lane capacity (already next-power-of-two padded). + let bytecode_k = committed_lanes(); + let total_vars = bytecode_k.log_2() + bytecode_t; + // Bytecode uses its own balanced dimensions (independent from Main). + // In Stage 8 it is embedded as a top-left block in Joint. + let (bytecode_col_vars, bytecode_row_vars) = DoryGlobals::balanced_sigma_nu(total_vars); + let precommitted = PrecommittedClaimReduction::new( + bytecode_row_vars, + bytecode_col_vars, + scheduling_reference, + ); + // Align all precommitted scheduling/permutation to the shared reference domain. + + Self { + phase: PrecommittedPhase::CycleVariables, + precommitted, + eta, + eta_powers, + chunk_rbc_weights, + bytecode_T: bytecode_t, + bytecode_chunk_count, + bytecode_col_vars, + bytecode_row_vars, + r_bc, + lane_weights, + } + } + + pub fn num_address_phase_rounds(&self) -> usize { + self.precommitted.num_address_phase_rounds() + } +} + +impl BytecodeClaimReductionParams { + fn is_cycle_phase(&self) -> bool { + self.phase == PrecommittedPhase::CycleVariables + } + + fn is_cycle_phase_round(&self, round: usize) -> bool { + self.precommitted.is_cycle_phase_round(round) + } + + fn is_address_phase_round(&self, round: usize) -> bool { + self.precommitted.is_address_phase_round(round) + } + + fn cycle_alignment_rounds(&self) -> usize { + self.precommitted.cycle_alignment_rounds() + } + + fn address_alignment_rounds(&self) -> usize { + self.precommitted.address_alignment_rounds() + } + + pub fn transition_to_address_phase(&mut self) { + self.phase = PrecommittedPhase::AddressVariables; + } + + fn num_rounds_for_current_phase(&self) -> usize { + self.precommitted + .num_rounds_for_phase(self.is_cycle_phase()) + } + + pub fn round_offset(&self, max_num_rounds: usize) -> usize { + self.precommitted + .round_offset(self.is_cycle_phase(), max_num_rounds) + } +} + +impl SumcheckInstanceParams for BytecodeClaimReductionParams { + fn input_claim(&self, accumulator: &dyn OpeningAccumulator) -> F { + match self.phase { + PrecommittedPhase::CycleVariables => (0..NUM_VAL_STAGES) + .map(|stage| { + let (_, val_claim) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeValStage(stage), + SumcheckId::BytecodeReadRafAddressPhase, + ); + self.eta_powers[stage] * val_claim + }) + .sum(), + PrecommittedPhase::AddressVariables => { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeClaimReductionIntermediate, + SumcheckId::BytecodeClaimReductionCyclePhase, + ) + .1 + } + } + } + + fn degree(&self) -> usize { + TWO_PHASE_DEGREE_BOUND + } + + fn num_rounds(&self) -> usize { + self.num_rounds_for_current_phase() + } + + fn normalize_opening_point(&self, challenges: &[F::Challenge]) -> OpeningPoint { + self.precommitted + .normalize_opening_point(self.is_cycle_phase(), challenges) + } + + #[cfg(feature = "zk")] + fn input_claim_constraint(&self) -> InputClaimConstraint { + match self.phase { + PrecommittedPhase::CycleVariables => { + let openings: Vec = (0..NUM_VAL_STAGES) + .map(|stage| { + OpeningId::virt( + VirtualPolynomial::BytecodeValStage(stage), + SumcheckId::BytecodeReadRafAddressPhase, + ) + }) + .collect(); + InputClaimConstraint::all_weighted_openings(&openings) + } + PrecommittedPhase::AddressVariables => InputClaimConstraint::direct(OpeningId::virt( + VirtualPolynomial::BytecodeClaimReductionIntermediate, + SumcheckId::BytecodeClaimReductionCyclePhase, + )), + } + } + + #[cfg(feature = "zk")] + fn input_constraint_challenge_values(&self, _: &dyn OpeningAccumulator) -> Vec { + match self.phase { + PrecommittedPhase::CycleVariables => self.eta_powers.to_vec(), + PrecommittedPhase::AddressVariables => Vec::new(), + } + } + + #[cfg(feature = "zk")] + fn output_claim_constraint(&self) -> Option { + match self.phase { + PrecommittedPhase::CycleVariables => { + Some(OutputClaimConstraint::direct(OpeningId::virt( + VirtualPolynomial::BytecodeClaimReductionIntermediate, + SumcheckId::BytecodeClaimReductionCyclePhase, + ))) + } + PrecommittedPhase::AddressVariables => { + let terms = (0..self.bytecode_chunk_count) + .map(|chunk_idx| { + let opening = OpeningId::committed( + CommittedPolynomial::BytecodeChunk(chunk_idx), + SumcheckId::BytecodeClaimReduction, + ); + ( + ValueSource::Challenge(chunk_idx), + ValueSource::Opening(opening), + ) + }) + .collect(); + Some(OutputClaimConstraint::linear(terms)) + } + } + } + + #[cfg(feature = "zk")] + fn output_constraint_challenge_values(&self, sumcheck_challenges: &[F::Challenge]) -> Vec { + match self.phase { + PrecommittedPhase::CycleVariables => vec![], + PrecommittedPhase::AddressVariables => { + let eq_combined = evaluate_bytecode_eq_combined(self, sumcheck_challenges); + let scale: F = precommitted_skip_round_scale(&self.precommitted); + self.chunk_rbc_weights + .iter() + .map(|w| *w * eq_combined * scale) + .collect() + } + } + } +} + +#[derive(Allocative)] +pub struct BytecodeClaimReductionProver { + params: BytecodeClaimReductionParams, + value_poly: MultilinearPolynomial, + eq_poly: MultilinearPolynomial, + scale: F, + chunk_value_polys: Vec>, +} + +impl BytecodeClaimReductionProver { + pub fn params(&self) -> &BytecodeClaimReductionParams { + &self.params + } + + pub fn transition_to_address_phase(&mut self) { + self.params.transition_to_address_phase(); + } + + pub fn initialize( + params: BytecodeClaimReductionParams, + raw_chunk_coeffs: &[Vec], + ) -> Self { + let eq_cycle = EqPolynomial::::evals(¶ms.r_bc.r); + let eq_coeffs_template: Vec = (0..raw_chunk_coeffs[0].len()) + .map(|idx| { + let (lane, cycle) = native_index_to_lane_cycle(¶ms, idx); + params.lane_weights[lane] * eq_cycle[cycle] + }) + .collect(); + + let raw_value_coeffs: Vec = (0..raw_chunk_coeffs[0].len()) + .into_par_iter() + .map(|idx| { + raw_chunk_coeffs + .iter() + .zip(params.chunk_rbc_weights.iter()) + .map(|(coeffs, weight)| coeffs[idx] * *weight) + .sum::() + }) + .collect(); + let mut coeffs_by_poly = Vec::with_capacity(2 + raw_chunk_coeffs.len()); + coeffs_by_poly.push(raw_value_coeffs); + coeffs_by_poly.push(eq_coeffs_template); + for coeffs in raw_chunk_coeffs.iter() { + coeffs_by_poly.push(coeffs.clone()); + } + let mut permuted_polys = + permute_precommitted_polys(coeffs_by_poly, ¶ms.precommitted).into_iter(); + let value_poly = permuted_polys + .next() + .expect("expected permuted bytecode value polynomial"); + let eq_poly = permuted_polys + .next() + .expect("expected permuted bytecode eq polynomial"); + let chunk_value_polys: Vec> = permuted_polys.collect(); + + Self { + params, + value_poly, + eq_poly, + scale: F::one(), + chunk_value_polys, + } + } + + fn bind_aux_polys(&mut self, r_j: F::Challenge) { + for poly in self.chunk_value_polys.iter_mut() { + poly.bind_parallel(r_j, BindingOrder::LowToHigh); + } + } + + fn compute_message_unscaled(&self, previous_claim_unscaled: F) -> UniPoly { + let half = self.value_poly.len() / 2; + let evals: [F; TWO_PHASE_DEGREE_BOUND] = (0..half) + .into_par_iter() + .map(|j| { + let value_evals = self + .value_poly + .sumcheck_evals_array::(j, BindingOrder::LowToHigh); + let eq_evals = self + .eq_poly + .sumcheck_evals_array::(j, BindingOrder::LowToHigh); + let mut out = [F::zero(); TWO_PHASE_DEGREE_BOUND]; + for i in 0..TWO_PHASE_DEGREE_BOUND { + out[i] = value_evals[i] * eq_evals[i]; + } + out + }) + .reduce( + || [F::zero(); TWO_PHASE_DEGREE_BOUND], + |mut acc, arr| { + acc.iter_mut().zip(arr.iter()).for_each(|(a, b)| *a += *b); + acc + }, + ); + UniPoly::from_evals_and_hint(previous_claim_unscaled, &evals) + } + + fn cycle_intermediate_claim(&self) -> F { + let len = self.value_poly.len(); + debug_assert_eq!(len, self.eq_poly.len()); + let mut sum = F::zero(); + for i in 0..len { + sum += self.value_poly.get_bound_coeff(i) * self.eq_poly.get_bound_coeff(i); + } + sum * self.scale + } + + fn final_claim_if_ready(&self) -> Option { + if self.value_poly.len() == 1 { + Some(self.value_poly.get_bound_coeff(0)) + } else { + None + } + } +} + +impl SumcheckInstanceProver for BytecodeClaimReductionProver { + fn get_params(&self) -> &dyn SumcheckInstanceParams { + &self.params + } + + fn round_offset(&self, max_num_rounds: usize) -> usize { + self.params.round_offset(max_num_rounds) + } + + fn compute_message(&mut self, round: usize, previous_claim: F) -> UniPoly { + let is_active_round = if self.params.is_cycle_phase() { + self.params.is_cycle_phase_round(round) + } else { + self.params.is_address_phase_round(round) + }; + if !is_active_round { + return UniPoly::from_coeff(vec![previous_claim * F::from_u64(2).inverse().unwrap()]); + } + + let trailing_cap = if self.params.is_cycle_phase() { + self.params.cycle_alignment_rounds() + } else { + self.params.address_alignment_rounds() + }; + let num_trailing_variables = + trailing_cap.saturating_sub(self.params.num_rounds_for_current_phase()); + let scaling_factor = self.scale * F::one().mul_pow_2(num_trailing_variables); + let prev_unscaled = previous_claim * scaling_factor.inverse().unwrap(); + let poly_unscaled = self.compute_message_unscaled(prev_unscaled); + poly_unscaled * scaling_factor + } + + fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize) { + let is_active_round = if self.params.is_cycle_phase() { + self.params.is_cycle_phase_round(round) + } else { + self.params.is_address_phase_round(round) + }; + if !is_active_round { + self.scale *= F::from_u64(2).inverse().unwrap(); + return; + } + + self.value_poly.bind_parallel(r_j, BindingOrder::LowToHigh); + self.eq_poly.bind_parallel(r_j, BindingOrder::LowToHigh); + self.bind_aux_polys(r_j); + if self.params.is_cycle_phase() { + self.params.precommitted.record_cycle_challenge(r_j); + } + } + + fn cache_openings( + &self, + accumulator: &mut ProverOpeningAccumulator, + sumcheck_challenges: &[F::Challenge], + ) { + let params = &self.params; + let opening_point = params.normalize_opening_point(sumcheck_challenges); + + if params.phase == PrecommittedPhase::CycleVariables { + accumulator.append_virtual( + VirtualPolynomial::BytecodeClaimReductionIntermediate, + SumcheckId::BytecodeClaimReductionCyclePhase, + opening_point.clone(), + self.cycle_intermediate_claim(), + ); + } + + if let Some(bytecode_claim) = self.final_claim_if_ready() { + let chunk_claims: Vec = self + .chunk_value_polys + .iter() + .map(|poly| poly.final_sumcheck_claim()) + .collect(); + let weighted_chunk_sum = chunk_claims + .iter() + .zip(params.chunk_rbc_weights.iter()) + .map(|(claim, weight)| *claim * *weight) + .sum::(); + debug_assert_eq!(weighted_chunk_sum, bytecode_claim); + for (chunk_idx, claim) in chunk_claims.into_iter().enumerate() { + accumulator.append_dense( + CommittedPolynomial::BytecodeChunk(chunk_idx), + SumcheckId::BytecodeClaimReduction, + opening_point.r.clone(), + claim, + ); + } + } + } + + #[cfg(feature = "allocative")] + fn update_flamegraph(&self, flamegraph: &mut allocative::FlameGraphBuilder) { + flamegraph.visit_root(self); + } +} + +pub struct BytecodeClaimReductionVerifier { + pub params: RefCell>, +} + +impl BytecodeClaimReductionVerifier { + pub fn new(params: BytecodeClaimReductionParams) -> Self { + Self { + params: RefCell::new(params), + } + } +} + +impl> + SumcheckInstanceVerifier for BytecodeClaimReductionVerifier +{ + fn get_params(&self) -> &dyn SumcheckInstanceParams { + unsafe { &*self.params.as_ptr() } + } + + fn round_offset(&self, max_num_rounds: usize) -> usize { + let params = self.params.borrow(); + params.round_offset(max_num_rounds) + } + + fn expected_output_claim(&self, accumulator: &A, sumcheck_challenges: &[F::Challenge]) -> F { + let params = self.params.borrow(); + match params.phase { + PrecommittedPhase::CycleVariables => { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeClaimReductionIntermediate, + SumcheckId::BytecodeClaimReductionCyclePhase, + ) + .1 + } + PrecommittedPhase::AddressVariables => { + let bytecode_opening: F = (0..params.bytecode_chunk_count) + .map(|chunk_idx| { + params.chunk_rbc_weights[chunk_idx] + * accumulator + .get_committed_polynomial_opening( + CommittedPolynomial::BytecodeChunk(chunk_idx), + SumcheckId::BytecodeClaimReduction, + ) + .1 + }) + .sum(); + let eq_combined = evaluate_bytecode_eq_combined(¶ms, sumcheck_challenges); + let scale: F = precommitted_skip_round_scale(¶ms.precommitted); + + bytecode_opening * eq_combined * scale + } + } + } + + fn cache_openings(&self, accumulator: &mut A, sumcheck_challenges: &[F::Challenge]) { + let mut params = self.params.borrow_mut(); + if params.phase == PrecommittedPhase::CycleVariables { + let opening_point = params.normalize_opening_point(sumcheck_challenges); + accumulator.append_virtual( + VirtualPolynomial::BytecodeClaimReductionIntermediate, + SumcheckId::BytecodeClaimReductionCyclePhase, + opening_point.clone(), + ); + let opening_point_le: OpeningPoint = opening_point.match_endianness(); + params + .precommitted + .set_cycle_var_challenges(opening_point_le.r); + } + + if params.num_address_phase_rounds() == 0 + || params.phase == PrecommittedPhase::AddressVariables + { + let opening_point = params.normalize_opening_point(sumcheck_challenges); + for chunk_idx in 0..params.bytecode_chunk_count { + accumulator.append_dense( + CommittedPolynomial::BytecodeChunk(chunk_idx), + SumcheckId::BytecodeClaimReduction, + opening_point.r.clone(), + ); + } + } + } +} + +fn evaluate_bytecode_eq_combined( + params: &BytecodeClaimReductionParams, + sumcheck_challenges: &[F::Challenge], +) -> F { + let opening_point = params.normalize_opening_point(sumcheck_challenges); + let lane_var_count = committed_lanes().log_2(); + + let (lane_challenges, cycle_challenges) = match DoryGlobals::get_layout() { + DoryLayout::CycleMajor => { + let (lane, cycle) = opening_point.r.split_at(lane_var_count); + (lane, cycle) + } + DoryLayout::AddressMajor => { + let (cycle, lane) = opening_point.r.split_at(params.bytecode_T); + (lane, cycle) + } + }; + + debug_assert_eq!(lane_challenges.len(), lane_var_count); + debug_assert_eq!(cycle_challenges.len(), params.r_bc.r.len()); + + let eq_cycle = EqPolynomial::mle(cycle_challenges, ¶ms.r_bc.r); + let eq_lane = EqPolynomial::::evals(lane_challenges); + let lane_weight_eval: F = params + .lane_weights + .iter() + .zip(eq_lane.iter()) + .map(|(w, eq)| *w * *eq) + .sum(); + + lane_weight_eval * eq_cycle +} + +#[inline(always)] +fn native_index_to_lane_cycle( + params: &BytecodeClaimReductionParams, + index: usize, +) -> (usize, usize) { + let bytecode_len = 1usize << params.bytecode_T; + match DoryGlobals::get_layout() { + DoryLayout::CycleMajor => (index / bytecode_len, index % bytecode_len), + DoryLayout::AddressMajor => (index % committed_lanes(), index / committed_lanes()), + } +} + +fn compute_lane_weights( + bytecode_read_raf_params: &BytecodeReadRafSumcheckParams, + accumulator: &dyn OpeningAccumulator, + eta_powers: &[F; NUM_VAL_STAGES], +) -> Vec { + let reg_count = REGISTER_COUNT as usize; + let total = crate::zkvm::bytecode::chunks::total_lanes(); + + let rs1_start = 0usize; + let rs2_start = rs1_start + reg_count; + let rd_start = rs2_start + reg_count; + let unexp_pc_idx = rd_start + reg_count; + let imm_idx = unexp_pc_idx + 1; + let circuit_start = imm_idx + 1; + let instr_start = circuit_start + NUM_CIRCUIT_FLAGS; + let lookup_start = instr_start + NUM_INSTRUCTION_FLAGS; + let raf_flag_idx = lookup_start + LookupTables::::COUNT; + debug_assert_eq!(raf_flag_idx + 1, total); + + let log_reg = reg_count.log_2(); + let r_register_4 = accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::RdWa, + SumcheckId::RegistersReadWriteChecking, + ) + .0 + .r; + let eq_r_register_4 = EqPolynomial::::evals(&r_register_4[..log_reg]); + + let r_register_5 = accumulator + .get_virtual_polynomial_opening(VirtualPolynomial::RdWa, SumcheckId::RegistersValEvaluation) + .0 + .r; + let eq_r_register_5 = EqPolynomial::::evals(&r_register_5[..log_reg]); + + let mut weights = vec![F::zero(); committed_lanes()]; + + { + let coeff = eta_powers[0]; + let g = &bytecode_read_raf_params.stage1_gammas; + weights[unexp_pc_idx] += coeff * g[0]; + weights[imm_idx] += coeff * g[1]; + for i in 0..NUM_CIRCUIT_FLAGS { + weights[circuit_start + i] += coeff * g[2 + i]; + } + } + { + let coeff = eta_powers[1]; + let g = &bytecode_read_raf_params.stage2_gammas; + weights[circuit_start + (CircuitFlags::Jump as usize)] += coeff * g[0]; + weights[instr_start + (InstructionFlags::Branch as usize)] += coeff * g[1]; + weights[circuit_start + (CircuitFlags::WriteLookupOutputToRD as usize)] += coeff * g[2]; + weights[circuit_start + (CircuitFlags::VirtualInstruction as usize)] += coeff * g[3]; + } + { + let coeff = eta_powers[2]; + let g = &bytecode_read_raf_params.stage3_gammas; + weights[imm_idx] += coeff * g[0]; + weights[unexp_pc_idx] += coeff * g[1]; + weights[instr_start + (InstructionFlags::LeftOperandIsRs1Value as usize)] += coeff * g[2]; + weights[instr_start + (InstructionFlags::LeftOperandIsPC as usize)] += coeff * g[3]; + weights[instr_start + (InstructionFlags::RightOperandIsRs2Value as usize)] += coeff * g[4]; + weights[instr_start + (InstructionFlags::RightOperandIsImm as usize)] += coeff * g[5]; + weights[instr_start + (InstructionFlags::IsNoop as usize)] += coeff * g[6]; + weights[circuit_start + (CircuitFlags::VirtualInstruction as usize)] += coeff * g[7]; + weights[circuit_start + (CircuitFlags::IsFirstInSequence as usize)] += coeff * g[8]; + } + { + let coeff = eta_powers[3]; + let g = &bytecode_read_raf_params.stage4_gammas; + for r in 0..reg_count { + weights[rd_start + r] += coeff * g[0] * eq_r_register_4[r]; + weights[rs1_start + r] += coeff * g[1] * eq_r_register_4[r]; + weights[rs2_start + r] += coeff * g[2] * eq_r_register_4[r]; + } + } + { + let coeff = eta_powers[4]; + let g = &bytecode_read_raf_params.stage5_gammas; + for r in 0..reg_count { + weights[rd_start + r] += coeff * g[0] * eq_r_register_5[r]; + } + weights[raf_flag_idx] += coeff * g[1]; + for i in 0..LookupTables::::COUNT { + weights[lookup_start + i] += coeff * g[2 + i]; + } + } + + weights +} diff --git a/jolt-core/src/zkvm/claim_reductions/hamming_weight.rs b/jolt-core/src/zkvm/claim_reductions/hamming_weight.rs index 7402827df..77c7aba30 100644 --- a/jolt-core/src/zkvm/claim_reductions/hamming_weight.rs +++ b/jolt-core/src/zkvm/claim_reductions/hamming_weight.rs @@ -81,7 +81,9 @@ use allocative::Allocative; use rayon::prelude::*; use tracer::instruction::Cycle; +use crate::curve::JoltCurve; use crate::field::JoltField; +use crate::poly::commitment::commitment_scheme::CommitmentScheme; #[cfg(feature = "zk")] use crate::poly::opening_proof::OpeningId; use crate::poly::{ @@ -105,7 +107,7 @@ use crate::subprotocols::{ use crate::transcripts::Transcript; use crate::zkvm::{ config::OneHotParams, - verifier::JoltSharedPreprocessing, + prover::JoltProverPreprocessing, witness::{CommittedPolynomial, VirtualPolynomial}, }; @@ -424,18 +426,22 @@ impl HammingWeightClaimReductionProver { /// Initialize the prover by computing all G_i polynomials. /// Returns (prover, ram_hw_claims) where ram_hw_claims contains the computed H_i for RAM polynomials. #[tracing::instrument(skip_all, name = "HammingWeightClaimReductionProver::initialize")] - pub fn initialize( + pub fn initialize( params: HammingWeightClaimReductionParams, trace: &[Cycle], - preprocessing: &JoltSharedPreprocessing, + preprocessing: &JoltProverPreprocessing, one_hot_params: &OneHotParams, - ) -> Self { + ) -> Self + where + C: JoltCurve, + PCS: CommitmentScheme, + { // Compute all G_i polynomials via streaming. // `params.r_cycle` is in BIG_ENDIAN (OpeningPoint) convention. let G_vecs = compute_all_G::( trace, - &preprocessing.bytecode, - &preprocessing.memory_layout, + &preprocessing.materialized_program().bytecode, + &preprocessing.shared.memory_layout, one_hot_params, ¶ms.r_cycle, ); diff --git a/jolt-core/src/zkvm/claim_reductions/mod.rs b/jolt-core/src/zkvm/claim_reductions/mod.rs index 3f91e7ca8..428271829 100644 --- a/jolt-core/src/zkvm/claim_reductions/mod.rs +++ b/jolt-core/src/zkvm/claim_reductions/mod.rs @@ -1,13 +1,19 @@ pub mod advice; +pub mod bytecode; pub mod hamming_weight; pub mod increments; pub mod instruction_lookups; +mod precommitted; +pub mod program_image; pub mod ram_ra; pub mod registers; pub use advice::{ AdviceClaimReductionParams, AdviceClaimReductionProver, AdviceClaimReductionVerifier, - AdviceKind, ReductionPhase, + AdviceKind, +}; +pub use bytecode::{ + BytecodeClaimReductionParams, BytecodeClaimReductionProver, BytecodeClaimReductionVerifier, }; pub use hamming_weight::{ HammingWeightClaimReductionParams, HammingWeightClaimReductionProver, @@ -21,6 +27,15 @@ pub use instruction_lookups::{ InstructionLookupsClaimReductionSumcheckParams, InstructionLookupsClaimReductionSumcheckProver, InstructionLookupsClaimReductionSumcheckVerifier, }; +pub use precommitted::{ + permute_precommitted_polys, precommitted_eq_evals_with_scaling, precommitted_skip_round_scale, + PrecomittedParams, PrecomittedProver, PrecommittedClaimReduction, PrecommittedPhase, + PrecommittedPolynomial, PrecommittedSchedulingReference, TWO_PHASE_DEGREE_BOUND, +}; +pub use program_image::{ + ProgramImageClaimReductionParams, ProgramImageClaimReductionProver, + ProgramImageClaimReductionVerifier, +}; pub use ram_ra::{ RaReductionParams, RamRaClaimReductionSumcheckProver, RamRaClaimReductionSumcheckVerifier, }; diff --git a/jolt-core/src/zkvm/claim_reductions/precommitted.rs b/jolt-core/src/zkvm/claim_reductions/precommitted.rs new file mode 100644 index 000000000..a3c401316 --- /dev/null +++ b/jolt-core/src/zkvm/claim_reductions/precommitted.rs @@ -0,0 +1,603 @@ +use allocative::Allocative; +use rayon::prelude::*; +use std::sync::Arc; + +use crate::field::JoltField; +use crate::poly::commitment::dory::{DoryGlobals, DoryLayout}; +use crate::poly::eq_poly::EqPolynomial; +use crate::poly::multilinear_polynomial::{BindingOrder, MultilinearPolynomial, PolynomialBinding}; +use crate::poly::opening_proof::{OpeningPoint, BIG_ENDIAN, LITTLE_ENDIAN}; +use crate::poly::unipoly::UniPoly; +use crate::subprotocols::sumcheck_verifier::SumcheckInstanceParams; +use crate::utils::math::Math; +use crate::zkvm::bytecode::chunks::committed_lanes; + +#[derive(Clone, Debug)] +pub enum PrecommittedPolynomial { + Dense(MultilinearPolynomial), + BytecodeChunk { + chunk_index: usize, + chunk_cycle_len: usize, + }, + ProgramImage { + words: Arc>, + padded_len: usize, + }, +} + +impl PrecommittedPolynomial { + pub(crate) fn original_len(&self) -> usize { + match self { + Self::Dense(poly) => poly.original_len(), + Self::BytecodeChunk { + chunk_cycle_len, .. + } => committed_lanes() * *chunk_cycle_len, + Self::ProgramImage { padded_len, .. } => *padded_len, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Allocative)] +pub enum PrecommittedPhase { + CycleVariables, + AddressVariables, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Allocative)] +pub struct PrecommittedSchedulingReference { + pub main_total_vars: usize, + pub reference_total_vars: usize, + pub cycle_alignment_rounds: usize, + pub address_rounds: usize, + pub joint_col_vars: usize, +} + +#[derive(Debug, Clone, Allocative)] +pub struct PrecommittedClaimReduction { + pub scheduling_reference: PrecommittedSchedulingReference, + pub cycle_var_challenges: Vec, + poly_opening_round_permutation_be: Vec, + cycle_phase_rounds: Vec, + cycle_phase_total_rounds: usize, + address_phase_rounds: Vec, + address_phase_total_rounds: usize, +} + +impl PrecommittedClaimReduction { + /// Compute shared scheduling dimensions from Main and precommitted candidates. + /// + /// `reference_total_vars` is the largest total var count across Main and candidates. + pub fn scheduling_reference( + main_total_vars: usize, + candidates: &[usize], + ) -> PrecommittedSchedulingReference { + let address_rounds = DoryGlobals::main_k().log_2(); + let max_precommitted = candidates.iter().copied().max().unwrap_or(0); + let reference_total_vars = std::cmp::max(main_total_vars, max_precommitted); + let cycle_alignment_rounds = reference_total_vars.saturating_sub(address_rounds); + let (reference_sigma, _) = DoryGlobals::balanced_sigma_nu(reference_total_vars); + let joint_col_vars = std::cmp::max( + DoryGlobals::configured_main_num_columns().log_2(), + reference_sigma, + ); + PrecommittedSchedulingReference { + main_total_vars, + reference_total_vars, + cycle_alignment_rounds, + address_rounds, + joint_col_vars, + } + } + + #[inline] + pub fn new( + poly_row_vars: usize, + poly_col_vars: usize, + scheduling_reference: PrecommittedSchedulingReference, + ) -> Self { + let has_precommitted_dominance = + scheduling_reference.reference_total_vars > scheduling_reference.main_total_vars; + let dory_opening_round_permutation_be = Self::reference_dory_opening_round_permutation_be( + &scheduling_reference, + has_precommitted_dominance, + DoryGlobals::main_t().log_2(), + ); + let poly_opening_round_permutation_be = Self::project_dory_round_permutation_for_poly( + &dory_opening_round_permutation_be, + &scheduling_reference, + poly_row_vars, + poly_col_vars, + ); + let (cycle_phase_rounds, address_phase_rounds) = Self::active_rounds_from_poly_permutation( + &poly_opening_round_permutation_be, + scheduling_reference.cycle_alignment_rounds, + ); + Self { + scheduling_reference, + cycle_var_challenges: vec![], + poly_opening_round_permutation_be, + cycle_phase_rounds, + cycle_phase_total_rounds: scheduling_reference.cycle_alignment_rounds, + address_phase_rounds, + address_phase_total_rounds: scheduling_reference.address_rounds, + } + } + + fn reference_dory_opening_round_permutation_be( + reference: &PrecommittedSchedulingReference, + has_precommitted_dominance: bool, + dense_cycle_prefix_rounds: usize, + ) -> Vec { + let cycle_rounds = reference.cycle_alignment_rounds; + let address_rounds = reference.address_rounds; + let total_rounds = cycle_rounds + address_rounds; + if has_precommitted_dominance { + let address_rev = (cycle_rounds..total_rounds).rev(); + match DoryGlobals::get_layout() { + DoryLayout::CycleMajor => { + let t = dense_cycle_prefix_rounds.min(cycle_rounds); + let prefix_rev = (0..cycle_rounds.saturating_sub(t)).rev(); + let dense_rev = (cycle_rounds.saturating_sub(t)..cycle_rounds).rev(); + return prefix_rev.chain(address_rev).chain(dense_rev).collect(); + } + DoryLayout::AddressMajor => { + let t = dense_cycle_prefix_rounds.min(cycle_rounds); + let prefix_rev = (0..cycle_rounds.saturating_sub(t)).rev(); + let dense_rev = (cycle_rounds.saturating_sub(t)..cycle_rounds).rev(); + return dense_rev.chain(address_rev).chain(prefix_rev).collect(); + } + } + } + + match DoryGlobals::get_layout() { + DoryLayout::CycleMajor => (0..total_rounds).rev().collect(), + DoryLayout::AddressMajor => { + let cycle_rev = (0..cycle_rounds).rev(); + let address_rev = (cycle_rounds..total_rounds).rev(); + cycle_rev.chain(address_rev).collect() + } + } + } + + fn project_dory_round_permutation_for_poly( + dory_opening_round_permutation_be: &[usize], + reference: &PrecommittedSchedulingReference, + poly_row_vars: usize, + poly_col_vars: usize, + ) -> Vec { + let total_full = reference.reference_total_vars; + let sigma_full = reference.joint_col_vars; + let nu_full = total_full.saturating_sub(sigma_full); + assert_eq!( + dory_opening_round_permutation_be.len(), + total_full, + "reference dory round permutation length mismatch", + ); + assert!( + poly_row_vars <= nu_full && poly_col_vars <= sigma_full, + "top-left projection requires poly dims <= full dims (poly row/col vars={poly_row_vars}/{poly_col_vars}, full row/col vars={nu_full}/{sigma_full})" + ); + let row_be = &dory_opening_round_permutation_be[..nu_full]; + let col_be = &dory_opening_round_permutation_be[nu_full..nu_full + sigma_full]; + let row_tail = &row_be[nu_full - poly_row_vars..]; + let col_tail = &col_be[sigma_full - poly_col_vars..]; + [row_tail, col_tail].concat() + } + + fn active_rounds_from_poly_permutation( + poly_opening_round_permutation_be: &[usize], + cycle_alignment_rounds: usize, + ) -> (Vec, Vec) { + let mut cycle_phase_rounds = Vec::new(); + let mut address_phase_rounds = Vec::new(); + for &global_round in poly_opening_round_permutation_be.iter() { + if global_round < cycle_alignment_rounds { + cycle_phase_rounds.push(global_round); + } else { + address_phase_rounds.push(global_round - cycle_alignment_rounds); + } + } + cycle_phase_rounds.sort_unstable(); + cycle_phase_rounds.dedup(); + address_phase_rounds.sort_unstable(); + address_phase_rounds.dedup(); + (cycle_phase_rounds, address_phase_rounds) + } + + #[inline] + pub fn num_address_phase_rounds(&self) -> usize { + self.address_phase_rounds.len() + } + + #[inline] + pub fn is_cycle_phase_round(&self, round: usize) -> bool { + self.cycle_phase_rounds.contains(&round) + } + + #[inline] + pub fn is_address_phase_round(&self, round: usize) -> bool { + self.address_phase_rounds.contains(&round) + } + + #[inline] + pub fn cycle_alignment_rounds(&self) -> usize { + self.scheduling_reference.cycle_alignment_rounds + } + + #[inline] + pub fn address_alignment_rounds(&self) -> usize { + self.scheduling_reference.address_rounds + } + + #[inline] + pub fn num_rounds_for_phase(&self, is_cycle_phase: bool) -> usize { + if is_cycle_phase { + self.cycle_phase_total_rounds + } else { + self.address_phase_total_rounds + } + } + + pub fn round_offset(&self, is_cycle_phase: bool, max_num_rounds: usize) -> usize { + let _ = (is_cycle_phase, max_num_rounds); + 0 + } + + fn cycle_challenge_for_round(&self, round: usize) -> F::Challenge { + let idx = self + .cycle_phase_rounds + .iter() + .position(|&scheduled_round| scheduled_round == round) + .unwrap_or_else(|| { + panic!( + "missing recorded cycle challenge for round={} (active rounds={:?})", + round, self.cycle_phase_rounds + ) + }); + assert!( + idx < self.cycle_var_challenges.len(), + "cycle challenge vector too short: idx={} len={}", + idx, + self.cycle_var_challenges.len() + ); + self.cycle_var_challenges[idx] + } + + pub fn normalize_opening_point( + &self, + is_cycle_phase: bool, + challenges: &[F::Challenge], + ) -> OpeningPoint { + if is_cycle_phase { + let local_cycle_challenges: Vec = self + .cycle_phase_rounds + .iter() + .map(|&round| { + assert!( + round < challenges.len(), + "cycle round index out of local bounds: round={} local_len={}", + round, + challenges.len() + ); + challenges[round] + }) + .collect(); + return OpeningPoint::::new(local_cycle_challenges) + .match_endianness(); + } + + let cycle_round_limit = self.cycle_alignment_rounds(); + let opening_rounds = &self.poly_opening_round_permutation_be; + let mut opening_point_be = Vec::with_capacity(opening_rounds.len()); + for &global_round in opening_rounds.iter() { + if global_round < cycle_round_limit { + opening_point_be.push(self.cycle_challenge_for_round(global_round)); + } else { + let address_round = global_round - cycle_round_limit; + assert!( + address_round < challenges.len(), + "address round index out of local bounds: round={} local_len={}", + address_round, + challenges.len() + ); + opening_point_be.push(challenges[address_round]); + } + } + OpeningPoint::::new(opening_point_be) + } + + #[inline] + pub fn record_cycle_challenge(&mut self, challenge: F::Challenge) { + self.cycle_var_challenges.push(challenge); + } + + #[inline] + pub fn set_cycle_var_challenges(&mut self, challenges: Vec) { + self.cycle_var_challenges = challenges; + } +} + +pub fn permute_precommitted_polys( + coeffs_by_poly: Vec>, + precommitted: &PrecommittedClaimReduction, +) -> Vec> +where + MultilinearPolynomial: From>, +{ + if coeffs_by_poly.is_empty() { + return Vec::new(); + } + let coeffs_len = coeffs_by_poly[0].len(); + assert!( + coeffs_by_poly + .iter() + .all(|coeffs| coeffs.len() == coeffs_len), + "all precommitted polynomials must have equal coefficient lengths", + ); + let inverse_permutation = precommitted_sumcheck_inverse_index_permutation( + coeffs_len, + &precommitted.poly_opening_round_permutation_be, + ); + let permuted_coeffs_by_poly: Vec> = + if let Some(inverse_permutation) = inverse_permutation { + coeffs_by_poly + .into_iter() + .map(|coeffs| { + (0..coeffs_len) + .into_par_iter() + .map(|new_idx| { + let old_idx = inverse_permutation[new_idx]; + coeffs[old_idx] + }) + .collect() + }) + .collect() + } else { + coeffs_by_poly + }; + permuted_coeffs_by_poly + .into_iter() + .map(Into::into) + .collect() +} + +pub fn precommitted_eq_evals_with_scaling( + challenges_be: &[C], + scaling_factor: Option, + precommitted: &PrecommittedClaimReduction, +) -> Vec +where + C: Copy + Send + Sync + Into, + F: JoltField + std::ops::Mul + std::ops::SubAssign, +{ + let permuted_challenges = precommitted_permute_eq_challenges( + challenges_be, + &precommitted.poly_opening_round_permutation_be, + ); + if let Some(permuted_challenges) = permuted_challenges { + EqPolynomial::evals_with_scaling(&permuted_challenges, scaling_factor) + } else { + EqPolynomial::evals_with_scaling(challenges_be, scaling_factor) + } +} + +fn precommitted_permute_eq_challenges( + challenges_be: &[C], + poly_opening_round_permutation_be: &[usize], +) -> Option> { + let old_lsb_to_new_lsb = + precommitted_sumcheck_lsb_permutation(poly_opening_round_permutation_be)?; + assert_eq!( + challenges_be.len(), + old_lsb_to_new_lsb.len(), + "challenge vector length mismatch for precommitted eq permutation", + ); + let num_vars = challenges_be.len(); + let mut permuted_challenges = challenges_be.to_vec(); + for old_be in 0..num_vars { + let old_lsb = num_vars - 1 - old_be; + let new_lsb = old_lsb_to_new_lsb[old_lsb]; + let new_be = num_vars - 1 - new_lsb; + permuted_challenges[new_be] = challenges_be[old_be]; + } + Some(permuted_challenges) +} + +fn precommitted_sumcheck_lsb_permutation( + poly_opening_round_permutation_be: &[usize], +) -> Option> { + let num_vars = poly_opening_round_permutation_be.len(); + let mut be_var_by_round: Vec = (0..num_vars).collect(); + be_var_by_round.sort_unstable_by_key(|&be_idx| poly_opening_round_permutation_be[be_idx]); + + let mut old_lsb_to_new_lsb = vec![0usize; num_vars]; + for (new_lsb, be_var_idx) in be_var_by_round.into_iter().enumerate() { + let old_lsb = num_vars - 1 - be_var_idx; + old_lsb_to_new_lsb[old_lsb] = new_lsb; + } + + if old_lsb_to_new_lsb + .iter() + .enumerate() + .all(|(old_lsb, &new_lsb)| old_lsb == new_lsb) + { + return None; + } + Some(old_lsb_to_new_lsb) +} + +fn precommitted_sumcheck_inverse_index_permutation( + coeffs_len: usize, + poly_opening_round_permutation_be: &[usize], +) -> Option> { + let num_vars = poly_opening_round_permutation_be.len(); + assert_eq!( + coeffs_len, + 1usize << num_vars, + "precommitted coeff vector length mismatch: len={} expected=2^{}", + coeffs_len, + num_vars + ); + let old_lsb_to_new_lsb = + precommitted_sumcheck_lsb_permutation(poly_opening_round_permutation_be)?; + + let mut new_lsb_to_old_lsb = vec![0usize; num_vars]; + for (old_lsb, &new_lsb) in old_lsb_to_new_lsb.iter().enumerate() { + new_lsb_to_old_lsb[new_lsb] = old_lsb; + } + + let inverse_permutation: Vec = (0..coeffs_len) + .into_par_iter() + .map(|new_idx| { + let mut old_idx = 0usize; + for new_lsb in 0..num_vars { + let bit = (new_idx >> new_lsb) & 1usize; + let old_lsb = new_lsb_to_old_lsb[new_lsb]; + old_idx |= bit << old_lsb; + } + old_idx + }) + .collect(); + Some(inverse_permutation) +} + +pub const TWO_PHASE_DEGREE_BOUND: usize = 2; + +pub trait PrecomittedParams: SumcheckInstanceParams { + fn is_cycle_phase(&self) -> bool; + fn is_cycle_phase_round(&self, round: usize) -> bool; + fn is_address_phase_round(&self, round: usize) -> bool; + fn cycle_alignment_rounds(&self) -> usize; + fn address_alignment_rounds(&self) -> usize; + fn record_cycle_challenge(&mut self, challenge: F::Challenge); +} + +#[derive(Allocative)] +pub struct PrecomittedProver> { + params: P, + value_poly: MultilinearPolynomial, + eq_poly: MultilinearPolynomial, + scale: F, +} + +impl> PrecomittedProver { + pub fn new( + params: P, + value_poly: MultilinearPolynomial, + eq_poly: MultilinearPolynomial, + ) -> Self { + Self { + params, + value_poly, + eq_poly, + scale: F::one(), + } + } + + pub fn params(&self) -> &P { + &self.params + } + + pub fn params_mut(&mut self) -> &mut P { + &mut self.params + } + + fn compute_message_unscaled(&self, previous_claim_unscaled: F) -> UniPoly { + let half = self.value_poly.len() / 2; + let value_poly = &self.value_poly; + let eq_poly = &self.eq_poly; + let evals: [F; TWO_PHASE_DEGREE_BOUND] = (0..half) + .into_par_iter() + .map(|j| { + let value_evals = value_poly + .sumcheck_evals_array::(j, BindingOrder::LowToHigh); + let eq_evals = eq_poly + .sumcheck_evals_array::(j, BindingOrder::LowToHigh); + + let mut out = [F::zero(); TWO_PHASE_DEGREE_BOUND]; + for i in 0..TWO_PHASE_DEGREE_BOUND { + out[i] = value_evals[i] * eq_evals[i]; + } + out + }) + .reduce( + || [F::zero(); TWO_PHASE_DEGREE_BOUND], + |mut acc, arr| { + acc.iter_mut().zip(arr.iter()).for_each(|(a, b)| *a += *b); + acc + }, + ); + UniPoly::from_evals_and_hint(previous_claim_unscaled, &evals) + } + + pub fn compute_message(&mut self, round: usize, previous_claim: F) -> UniPoly { + let is_active_round = if self.params.is_cycle_phase() { + self.params.is_cycle_phase_round(round) + } else { + self.params.is_address_phase_round(round) + }; + if !is_active_round { + return UniPoly::from_coeff(vec![previous_claim * F::from_u64(2).inverse().unwrap()]); + } + + let trailing_cap = if self.params.is_cycle_phase() { + self.params.cycle_alignment_rounds() + } else { + self.params.address_alignment_rounds() + }; + let num_trailing_variables = trailing_cap.saturating_sub(self.params.num_rounds()); + let scaling_factor = self.scale * F::one().mul_pow_2(num_trailing_variables); + let prev_unscaled = previous_claim * scaling_factor.inverse().unwrap(); + let poly_unscaled = self.compute_message_unscaled(prev_unscaled); + poly_unscaled * scaling_factor + } + + pub fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize) { + let is_active_round = if self.params.is_cycle_phase() { + self.params.is_cycle_phase_round(round) + } else { + self.params.is_address_phase_round(round) + }; + if !is_active_round { + self.scale *= F::from_u64(2).inverse().unwrap(); + return; + } + + self.value_poly.bind_parallel(r_j, BindingOrder::LowToHigh); + self.eq_poly.bind_parallel(r_j, BindingOrder::LowToHigh); + if self.params.is_cycle_phase() { + self.params.record_cycle_challenge(r_j); + } + } + + pub fn cycle_intermediate_claim(&self) -> F { + let len = self.value_poly.len(); + assert_eq!(len, self.eq_poly.len()); + + let mut sum = F::zero(); + for i in 0..len { + sum += self.value_poly.get_bound_coeff(i) * self.eq_poly.get_bound_coeff(i); + } + sum * self.scale + } + + pub fn final_claim_if_ready(&self) -> Option { + if self.value_poly.len() == 1 { + Some(self.value_poly.get_bound_coeff(0)) + } else { + None + } + } +} + +pub fn precommitted_skip_round_scale( + precommitted: &PrecommittedClaimReduction, +) -> F { + let cycle_gap_len = + precommitted.cycle_phase_total_rounds - precommitted.cycle_phase_rounds.len(); + let address_gap_len = + precommitted.address_phase_total_rounds - precommitted.address_phase_rounds.len(); + let gap_len = cycle_gap_len + address_gap_len; + let two_inv = F::from_u64(2).inverse().unwrap(); + (0..gap_len).fold(F::one(), |acc, _| acc * two_inv) +} diff --git a/jolt-core/src/zkvm/claim_reductions/program_image.rs b/jolt-core/src/zkvm/claim_reductions/program_image.rs new file mode 100644 index 000000000..9e2340b07 --- /dev/null +++ b/jolt-core/src/zkvm/claim_reductions/program_image.rs @@ -0,0 +1,541 @@ +//! Program-image (initial RAM) claim reduction. +//! +//! In committed bytecode mode, Stage 4 consumes prover-supplied scalar claims for the +//! program-image contribution to `Val_init(r_address)` without materializing the initial RAM. +//! This sumcheck binds those scalars to a trusted commitment to the program-image words polynomial. + +use allocative::Allocative; +use std::cell::RefCell; + +use crate::field::JoltField; +use crate::poly::commitment::dory::DoryGlobals; +use crate::poly::eq_poly::EqPolynomial; +use crate::poly::multilinear_polynomial::{MultilinearPolynomial, PolynomialEvaluation}; +#[cfg(feature = "zk")] +use crate::poly::opening_proof::OpeningId; +use crate::poly::opening_proof::{ + AbstractVerifierOpeningAccumulator, OpeningAccumulator, OpeningPoint, ProverOpeningAccumulator, + SumcheckId, BIG_ENDIAN, LITTLE_ENDIAN, +}; +use crate::poly::unipoly::UniPoly; +#[cfg(feature = "zk")] +use crate::subprotocols::blindfold::{InputClaimConstraint, OutputClaimConstraint, ValueSource}; +use crate::subprotocols::sumcheck_prover::SumcheckInstanceProver; +use crate::subprotocols::sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}; +use crate::transcripts::Transcript; +use crate::utils::math::Math; +use crate::zkvm::claim_reductions::{ + permute_precommitted_polys, precommitted_skip_round_scale, PrecomittedParams, + PrecomittedProver, PrecommittedClaimReduction, PrecommittedPhase, + PrecommittedSchedulingReference, TWO_PHASE_DEGREE_BOUND, +}; +use crate::zkvm::ram::remap_address; +use crate::zkvm::witness::{CommittedPolynomial, VirtualPolynomial}; +use tracer::JoltDevice; + +#[derive(Clone, Allocative)] +pub struct ProgramImageClaimReductionParams { + pub phase: PrecommittedPhase, + pub precommitted: PrecommittedClaimReduction, + pub prog_col_vars: usize, + pub prog_row_vars: usize, + pub ram_num_vars: usize, + pub start_index: usize, + pub padded_len_words: usize, + pub m: usize, + pub r_addr_rw: Vec, + pub shifted_eq_coeffs: Vec, +} + +impl ProgramImageClaimReductionParams { + pub fn num_address_phase_rounds(&self) -> usize { + self.precommitted.num_address_phase_rounds() + } + + #[allow(clippy::too_many_arguments)] + pub fn new( + program_io: &JoltDevice, + ram_min_bytecode_address: u64, + padded_len_words: usize, + ram_K: usize, + scheduling_reference: PrecommittedSchedulingReference, + accumulator: &dyn OpeningAccumulator, + _transcript: &mut impl Transcript, + ) -> Self { + let ram_num_vars = ram_K.log_2(); + let start_index = + remap_address(ram_min_bytecode_address, &program_io.memory_layout).unwrap() as usize; + let m = padded_len_words.log_2(); + debug_assert!(padded_len_words.is_power_of_two()); + debug_assert!(padded_len_words > 0); + let (prog_col_vars, prog_row_vars) = DoryGlobals::balanced_sigma_nu(m); + let precommitted = + PrecommittedClaimReduction::new(prog_row_vars, prog_col_vars, scheduling_reference); + + let (r_rw, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::RamVal, + SumcheckId::RamReadWriteChecking, + ); + let (r_addr_rw, _) = r_rw.split_at(ram_num_vars); + let shifted_eq_coeffs = + shifted_program_image_eq_slice::(&r_addr_rw.r, start_index, padded_len_words); + + Self { + phase: PrecommittedPhase::CycleVariables, + precommitted, + prog_col_vars, + prog_row_vars, + ram_num_vars, + start_index, + padded_len_words, + m, + r_addr_rw: r_addr_rw.r, + shifted_eq_coeffs, + } + } +} + +impl ProgramImageClaimReductionParams { + fn is_cycle_phase(&self) -> bool { + self.phase == PrecommittedPhase::CycleVariables + } + + pub fn transition_to_address_phase(&mut self) { + self.phase = PrecommittedPhase::AddressVariables; + } + + pub fn round_offset(&self, max_num_rounds: usize) -> usize { + self.precommitted + .round_offset(self.is_cycle_phase(), max_num_rounds) + } +} + +impl SumcheckInstanceParams for ProgramImageClaimReductionParams { + fn input_claim(&self, accumulator: &dyn OpeningAccumulator) -> F { + match self.phase { + PrecommittedPhase::CycleVariables => { + // Scalar claims were staged in Stage 4 as virtual openings. + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::ProgramImageInitContributionRw, + SumcheckId::RamValCheck, + ) + .1 + } + PrecommittedPhase::AddressVariables => { + accumulator + .get_committed_polynomial_opening( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReductionCyclePhase, + ) + .1 + } + } + } + + fn degree(&self) -> usize { + TWO_PHASE_DEGREE_BOUND + } + + fn num_rounds(&self) -> usize { + self.precommitted + .num_rounds_for_phase(self.is_cycle_phase()) + } + + fn normalize_opening_point(&self, challenges: &[F::Challenge]) -> OpeningPoint { + self.precommitted + .normalize_opening_point(self.is_cycle_phase(), challenges) + } + + #[cfg(feature = "zk")] + fn input_claim_constraint(&self) -> InputClaimConstraint { + match self.phase { + PrecommittedPhase::CycleVariables => InputClaimConstraint::direct(OpeningId::virt( + VirtualPolynomial::ProgramImageInitContributionRw, + SumcheckId::RamValCheck, + )), + PrecommittedPhase::AddressVariables => { + InputClaimConstraint::direct(OpeningId::committed( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReductionCyclePhase, + )) + } + } + } + + #[cfg(feature = "zk")] + fn input_constraint_challenge_values(&self, _: &dyn OpeningAccumulator) -> Vec { + Vec::new() + } + + #[cfg(feature = "zk")] + fn output_claim_constraint(&self) -> Option { + match self.phase { + PrecommittedPhase::CycleVariables => { + Some(OutputClaimConstraint::direct(OpeningId::committed( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReductionCyclePhase, + ))) + } + PrecommittedPhase::AddressVariables => Some(OutputClaimConstraint::linear(vec![( + ValueSource::Challenge(0), + ValueSource::Opening(OpeningId::committed( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + )), + )])), + } + } + + #[cfg(feature = "zk")] + fn output_constraint_challenge_values(&self, sumcheck_challenges: &[F::Challenge]) -> Vec { + match self.phase { + PrecommittedPhase::CycleVariables => vec![], + PrecommittedPhase::AddressVariables => { + let opening_point = self.normalize_opening_point(sumcheck_challenges); + let eq_combined = eval_shifted_eq_poly_at_opening_point::( + &self.r_addr_rw, + self.start_index, + &opening_point.r, + ); + debug_assert_eq!( + eq_combined, + evaluate_shifted_eq_poly::(&self.shifted_eq_coeffs, &opening_point.r), + "program_image eq_slice optimized evaluation mismatch" + ); + let scale: F = precommitted_skip_round_scale(&self.precommitted); + vec![eq_combined * scale] + } + } + } +} + +impl PrecomittedParams for ProgramImageClaimReductionParams { + fn is_cycle_phase(&self) -> bool { + self.phase == PrecommittedPhase::CycleVariables + } + + fn is_cycle_phase_round(&self, round: usize) -> bool { + self.precommitted.is_cycle_phase_round(round) + } + + fn is_address_phase_round(&self, round: usize) -> bool { + self.precommitted.is_address_phase_round(round) + } + + fn cycle_alignment_rounds(&self) -> usize { + self.precommitted.cycle_alignment_rounds() + } + + fn address_alignment_rounds(&self) -> usize { + self.precommitted.address_alignment_rounds() + } + + fn record_cycle_challenge(&mut self, challenge: F::Challenge) { + self.precommitted.record_cycle_challenge(challenge); + } +} + +#[derive(Allocative)] +pub struct ProgramImageClaimReductionProver { + core: PrecomittedProver>, +} + +fn shifted_program_image_eq_slice( + r_addr: &[F::Challenge], + start_index: usize, + padded_len_words: usize, +) -> Vec +where + F: JoltField + std::ops::Mul + std::ops::SubAssign, +{ + let mut eq_slice = Vec::with_capacity(padded_len_words); + let mut idx = start_index; + let mut remaining = padded_len_words; + + while remaining > 0 { + let (block_size, block_evals) = + EqPolynomial::::evals_for_max_aligned_block(r_addr, idx, remaining); + eq_slice.extend(block_evals); + idx += block_size; + remaining -= block_size; + } + + eq_slice +} + +fn evaluate_shifted_eq_poly(shifted_eq_coeffs: &[F], opening_point: &[C]) -> F +where + C: Copy + Send + Sync + Into + crate::field::ChallengeFieldOps, + F: JoltField + crate::field::FieldChallengeOps, +{ + MultilinearPolynomial::from(shifted_eq_coeffs.to_vec()).evaluate(opening_point) +} + +impl ProgramImageClaimReductionProver { + pub fn params(&self) -> &ProgramImageClaimReductionParams { + self.core.params() + } + + pub fn transition_to_address_phase(&mut self) { + self.core.params_mut().transition_to_address_phase(); + } + + #[tracing::instrument(skip_all, name = "ProgramImageClaimReductionProver::initialize")] + pub fn initialize( + params: ProgramImageClaimReductionParams, + program_image_words_padded: Vec, + ) -> Self { + debug_assert_eq!(program_image_words_padded.len(), params.padded_len_words); + debug_assert_eq!(params.padded_len_words, 1usize << params.m); + + let eq_slice = permute_precommitted_polys( + vec![params.shifted_eq_coeffs.clone()], + ¶ms.precommitted, + ) + .into_iter() + .next() + .expect("expected one permuted shifted eq polynomial"); + + // Permute ProgramWord and eq_slice so low-to-high binding follows the two-phase + // schedule while preserving top-left projection semantics against the joint point. + let program_word: MultilinearPolynomial = { + let mut permuted = + permute_precommitted_polys(vec![program_image_words_padded], ¶ms.precommitted) + .into_iter(); + permuted + .next() + .expect("expected one permuted program image polynomial") + }; + + Self { + core: PrecomittedProver::new(params, program_word, eq_slice), + } + } +} + +impl SumcheckInstanceProver + for ProgramImageClaimReductionProver +{ + fn get_params(&self) -> &dyn SumcheckInstanceParams { + self.core.params() + } + + fn round_offset(&self, max_num_rounds: usize) -> usize { + self.core.params().round_offset(max_num_rounds) + } + + fn compute_message(&mut self, round: usize, previous_claim: F) -> UniPoly { + self.core.compute_message(round, previous_claim) + } + + fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize) { + self.core.ingest_challenge(r_j, round); + } + + fn cache_openings( + &self, + accumulator: &mut ProverOpeningAccumulator, + sumcheck_challenges: &[F::Challenge], + ) { + let params = self.core.params(); + let opening_point = params.normalize_opening_point(sumcheck_challenges); + if params.phase == PrecommittedPhase::CycleVariables { + accumulator.append_dense( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReductionCyclePhase, + // This is a phase-boundary intermediate claim, not a real program-image opening. + // Keep a sentinel point so it cannot alias with the final opening claim. + vec![], + self.core.cycle_intermediate_claim(), + ); + } + + if let Some(claim) = self.core.final_claim_if_ready() { + accumulator.append_dense( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + opening_point.r, + claim, + ); + } + } + + #[cfg(feature = "allocative")] + fn update_flamegraph(&self, flamegraph: &mut allocative::FlameGraphBuilder) { + flamegraph.visit_root(self); + } +} + +pub struct ProgramImageClaimReductionVerifier { + pub params: RefCell>, +} + +impl ProgramImageClaimReductionVerifier { + pub fn new(params: ProgramImageClaimReductionParams) -> Self { + Self { + params: RefCell::new(params), + } + } +} + +impl> + SumcheckInstanceVerifier for ProgramImageClaimReductionVerifier +{ + fn get_params(&self) -> &dyn SumcheckInstanceParams { + unsafe { &*self.params.as_ptr() } + } + + fn round_offset(&self, max_num_rounds: usize) -> usize { + let params = self.params.borrow(); + params.round_offset(max_num_rounds) + } + + fn expected_output_claim(&self, accumulator: &A, sumcheck_challenges: &[F::Challenge]) -> F { + let params = self.params.borrow(); + match params.phase { + PrecommittedPhase::CycleVariables => { + accumulator + .get_committed_polynomial_opening( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReductionCyclePhase, + ) + .1 + } + PrecommittedPhase::AddressVariables => { + let opening_point = params.normalize_opening_point(sumcheck_challenges); + debug_assert_eq!(opening_point.len(), params.m); + let pw_eval = accumulator + .get_committed_polynomial_opening( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + ) + .1; + let eq_combined = eval_shifted_eq_poly_at_opening_point::( + ¶ms.r_addr_rw, + params.start_index, + &opening_point.r, + ); + debug_assert_eq!( + eq_combined, + evaluate_shifted_eq_poly::(¶ms.shifted_eq_coeffs, &opening_point.r), + "program_image eq_slice optimized evaluation mismatch" + ); + let scale: F = precommitted_skip_round_scale(¶ms.precommitted); + pw_eval * eq_combined * scale + } + } + } + + fn cache_openings(&self, accumulator: &mut A, sumcheck_challenges: &[F::Challenge]) { + let mut params = self.params.borrow_mut(); + let opening_point = params.normalize_opening_point(sumcheck_challenges); + if params.phase == PrecommittedPhase::CycleVariables { + accumulator.append_dense( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReductionCyclePhase, + // Match prover behavior: the cycle-phase intermediate claim is not a real opening. + vec![], + ); + let opening_point_le: OpeningPoint = opening_point.match_endianness(); + params + .precommitted + .set_cycle_var_challenges(opening_point_le.r); + } + + if params.phase == PrecommittedPhase::AddressVariables + || params.num_address_phase_rounds() == 0 + { + accumulator.append_dense( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + opening_point.r, + ); + } + } +} + +fn eval_shifted_eq_poly_at_opening_point( + r_addr_be: &[F::Challenge], + start_index: usize, + opening_point_be: &[F::Challenge], +) -> F +where + F: JoltField, +{ + let ell = r_addr_be.len(); + let m = opening_point_be.len(); + debug_assert!(m <= ell); + + let challenge_for_old_lsb = |old_lsb: usize| -> F { + debug_assert!(old_lsb < m); + opening_point_be[m - 1 - old_lsb].into() + }; + + // Match the current verifier path exactly: `opening_point_be` is already arranged in the + // variable order expected by `evaluate_shifted_eq_poly`. + let mut dp0 = F::one(); + let mut dp1 = F::zero(); + + for old_lsb in 0..ell { + let start_bit = ((start_index >> old_lsb) & 1) as u8; + let r_addr_bit: F = r_addr_be[ell - 1 - old_lsb].into(); + let k0 = F::one() - r_addr_bit; + let k1 = r_addr_bit; + let y_var = old_lsb < m; + let r_y = if y_var { + challenge_for_old_lsb(old_lsb) + } else { + F::zero() + }; + + let mut next_dp0 = F::zero(); + let mut next_dp1 = F::zero(); + + let update_state = |weight: F, carry: u8, next_dp0: &mut F, next_dp1: &mut F| { + if weight.is_zero() { + return; + } + + if y_var { + let sum0 = start_bit + carry; + let k_bit0 = sum0 & 1; + let carry0 = (sum0 >> 1) & 1; + let addr_factor0 = if k_bit0 == 1 { k1 } else { k0 }; + let y_factor0 = F::one() - r_y; + if carry0 == 0 { + *next_dp0 += weight * addr_factor0 * y_factor0; + } else { + *next_dp1 += weight * addr_factor0 * y_factor0; + } + + let sum1 = start_bit + carry + 1; + let k_bit1 = sum1 & 1; + let carry1 = (sum1 >> 1) & 1; + let addr_factor1 = if k_bit1 == 1 { k1 } else { k0 }; + if carry1 == 0 { + *next_dp0 += weight * addr_factor1 * r_y; + } else { + *next_dp1 += weight * addr_factor1 * r_y; + } + } else { + let sum0 = start_bit + carry; + let k_bit0 = sum0 & 1; + let carry0 = (sum0 >> 1) & 1; + let addr_factor0 = if k_bit0 == 1 { k1 } else { k0 }; + if carry0 == 0 { + *next_dp0 += weight * addr_factor0; + } else { + *next_dp1 += weight * addr_factor0; + } + } + }; + + update_state(dp0, 0, &mut next_dp0, &mut next_dp1); + update_state(dp1, 1, &mut next_dp0, &mut next_dp1); + dp0 = next_dp0; + dp1 = next_dp1; + } + + dp0 + dp1 +} diff --git a/jolt-core/src/zkvm/config.rs b/jolt-core/src/zkvm/config.rs index 59d6b29d2..02e2c55e6 100644 --- a/jolt-core/src/zkvm/config.rs +++ b/jolt-core/src/zkvm/config.rs @@ -1,5 +1,8 @@ use allocative::Allocative; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate, +}; +use std::io::{Read, Write}; use crate::field::JoltField; use crate::utils::math::Math; @@ -20,6 +23,52 @@ pub fn get_instruction_sumcheck_phases(log_t: usize) -> usize { } } +/// Controls how bytecode and program-image data are handled by the verifier. +#[repr(u8)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Allocative, Default)] +pub enum ProgramMode { + /// Verifier has full bytecode and program image available. + #[default] + Full = 0, + /// Verifier uses commitments for bytecode/program-image openings in Stage 8. + Committed = 1, +} + +impl CanonicalSerialize for ProgramMode { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (*self as u8).serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + (*self as u8).serialized_size(compress) + } +} + +impl Valid for ProgramMode { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl CanonicalDeserialize for ProgramMode { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let value = u8::deserialize_with_mode(reader, compress, validate)?; + match value { + 0 => Ok(Self::Full), + 1 => Ok(Self::Committed), + _ => Err(SerializationError::InvalidData), + } + } +} + /// Configuration for read-write checking sumchecks. /// /// Contains parameters that control phase structure for RAM and register @@ -91,15 +140,6 @@ impl ReadWriteConfig { } Ok(()) } - - /// Returns true if all cycle variables are bound in phase 1. - /// - /// When this returns true, the advice opening points for `RamValCheck` and - /// `RamValCheck` are identical, so we only need one advice opening. - #[inline] - pub fn needs_single_advice_opening(&self, log_T: usize) -> bool { - self.ram_rw_phase1_num_rounds as usize == log_T - } } /// Minimal configuration for one-hot encoding that gets serialized in the proof. diff --git a/jolt-core/src/zkvm/mod.rs b/jolt-core/src/zkvm/mod.rs index 42564eb9c..a8bd46700 100644 --- a/jolt-core/src/zkvm/mod.rs +++ b/jolt-core/src/zkvm/mod.rs @@ -1,6 +1,6 @@ use std::fs::File; -use crate::zkvm::config::{OneHotConfig, OneHotParams, ReadWriteConfig}; +use crate::zkvm::config::{OneHotConfig, OneHotParams, ProgramMode, ReadWriteConfig}; use crate::zkvm::witness::CommittedPolynomial; use crate::{ curve::Bn254Curve, @@ -52,6 +52,7 @@ pub mod config; pub mod instruction; pub mod instruction_lookups; pub mod lookup_table; +pub mod program; pub mod proof_serialization; #[cfg(feature = "prover")] pub mod prover; @@ -67,6 +68,8 @@ pub(crate) fn stage8_opening_ids( one_hot_params: &OneHotParams, include_trusted_advice: bool, include_untrusted_advice: bool, + program_mode: ProgramMode, + bytecode_chunk_count: usize, ) -> Vec { let mut opening_ids = Vec::new(); @@ -104,6 +107,20 @@ pub(crate) fn stage8_opening_ids( if include_untrusted_advice { opening_ids.push(OpeningId::UntrustedAdvice(SumcheckId::AdviceClaimReduction)); } + if program_mode == ProgramMode::Committed { + for i in 0..bytecode_chunk_count { + opening_ids.push(OpeningId::committed( + CommittedPolynomial::BytecodeChunk(i), + SumcheckId::BytecodeClaimReduction, + )); + } + } + if program_mode == ProgramMode::Committed { + opening_ids.push(OpeningId::committed( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + )); + } opening_ids } diff --git a/jolt-core/src/zkvm/program.rs b/jolt-core/src/zkvm/program.rs new file mode 100644 index 000000000..225277b89 --- /dev/null +++ b/jolt-core/src/zkvm/program.rs @@ -0,0 +1,495 @@ +use std::io::{Read, Write}; + +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate, +}; + +use crate::poly::commitment::commitment_scheme::CommitmentScheme; +use crate::poly::commitment::dory::{DoryContext, DoryGlobals}; +use crate::poly::multilinear_polynomial::MultilinearPolynomial; +use crate::utils::errors::ProofVerifyError; +use crate::utils::math::Math; +use crate::zkvm::bytecode::{ + BytecodePreprocessing, PreprocessingError, TrustedBytecodeCommitments, TrustedBytecodeHints, +}; +use crate::zkvm::ram::RAMPreprocessing; +use common::jolt_device::MemoryLayout; +use tracer::instruction::{Cycle, Instruction}; + +#[derive(Debug, Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct FullProgramPreprocessing { + pub bytecode: BytecodePreprocessing, + pub ram: RAMPreprocessing, +} + +impl Default for FullProgramPreprocessing { + fn default() -> Self { + Self { + bytecode: BytecodePreprocessing::default(), + ram: RAMPreprocessing { + min_bytecode_address: 0, + bytecode_words: Vec::new(), + }, + } + } +} + +impl FullProgramPreprocessing { + #[tracing::instrument(skip_all, name = "ProgramPreprocessing::preprocess")] + pub fn preprocess( + instructions: Vec, + memory_init: Vec<(u64, u8)>, + ) -> Result { + let entry_address = instructions + .first() + .map(|instr| instr.normalize().address as u64) + .unwrap_or(0); + Ok(Self { + bytecode: BytecodePreprocessing::preprocess(instructions, entry_address)?, + ram: RAMPreprocessing::preprocess(memory_init), + }) + } + + pub fn bytecode_len(&self) -> usize { + self.bytecode.code_size + } + + pub fn program_image_len_words(&self) -> usize { + self.ram.bytecode_words.len() + } + + pub fn program_image_len_words_padded(&self) -> usize { + self.program_image_len_words().next_power_of_two().max(2) + } + + pub fn committed_program_image_num_words(&self, memory_layout: &MemoryLayout) -> usize { + self.meta().committed_program_image_num_words(memory_layout) + } + + pub fn meta(&self) -> ProgramMetadata { + ProgramMetadata { + entry_address: self.bytecode.entry_address, + min_bytecode_address: self.ram.min_bytecode_address, + program_image_len_words: self.program_image_len_words(), + bytecode_len: self.bytecode_len(), + } + } + + #[inline(always)] + pub fn get_pc(&self, cycle: &Cycle) -> usize { + self.bytecode.get_pc(cycle) + } + + #[inline(always)] + pub fn entry_bytecode_index(&self) -> usize { + self.bytecode.entry_bytecode_index() + } +} + +#[derive(Debug, Clone)] +pub struct CommittedProgramPreprocessing { + pub meta: ProgramMetadata, + pub bytecode_commitments: TrustedBytecodeCommitments, + pub program_commitments: TrustedProgramCommitments, + #[cfg(feature = "prover")] + pub prover_data: Option>, +} + +#[cfg(feature = "prover")] +#[derive(Debug, Clone)] +pub struct CommittedProgramProverData { + pub full: FullProgramPreprocessing, + pub bytecode_hints: TrustedBytecodeHints, + pub program_hints: TrustedProgramHints, +} + +#[derive(Debug, Clone)] +pub enum ProgramPreprocessing< + PCS: CommitmentScheme = crate::poly::commitment::dory::DoryCommitmentScheme, +> { + Full(FullProgramPreprocessing), + Committed(CommittedProgramPreprocessing), +} + +impl CanonicalSerialize for ProgramPreprocessing +where + PCS::Commitment: CanonicalSerialize, +{ + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + match self { + Self::Full(full) => { + 0u8.serialize_with_mode(&mut writer, compress)?; + full.serialize_with_mode(&mut writer, compress)?; + } + Self::Committed(committed) => { + 1u8.serialize_with_mode(&mut writer, compress)?; + committed.meta.serialize_with_mode(&mut writer, compress)?; + committed + .bytecode_commitments + .serialize_with_mode(&mut writer, compress)?; + committed + .program_commitments + .serialize_with_mode(&mut writer, compress)?; + } + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + 1 + match self { + Self::Full(full) => full.serialized_size(compress), + Self::Committed(committed) => { + committed.meta.serialized_size(compress) + + committed.bytecode_commitments.serialized_size(compress) + + committed.program_commitments.serialized_size(compress) + } + } + } +} + +impl Valid for ProgramPreprocessing +where + PCS::Commitment: Valid, +{ + fn check(&self) -> Result<(), SerializationError> { + match self { + Self::Full(full) => full.check(), + Self::Committed(committed) => { + committed.meta.check()?; + committed.bytecode_commitments.check()?; + committed.program_commitments.check()?; + Ok(()) + } + } + } +} + +impl CanonicalDeserialize for ProgramPreprocessing +where + PCS::Commitment: CanonicalDeserialize, +{ + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let tag = u8::deserialize_with_mode(&mut reader, compress, validate)?; + match tag { + 0 => Ok(Self::Full(FullProgramPreprocessing::deserialize_with_mode( + &mut reader, + compress, + validate, + )?)), + 1 => Ok(Self::Committed(CommittedProgramPreprocessing { + meta: ProgramMetadata::deserialize_with_mode(&mut reader, compress, validate)?, + bytecode_commitments: TrustedBytecodeCommitments::::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + program_commitments: TrustedProgramCommitments::::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + #[cfg(feature = "prover")] + prover_data: None, + })), + _ => Err(SerializationError::InvalidData), + } + } +} + +impl Default for ProgramPreprocessing { + fn default() -> Self { + Self::Full(FullProgramPreprocessing::default()) + } +} + +impl ProgramPreprocessing { + #[tracing::instrument(skip_all, name = "ProgramPreprocessing::preprocess")] + pub fn preprocess( + instructions: Vec, + memory_init: Vec<(u64, u8)>, + ) -> Result { + Ok(Self::Full(FullProgramPreprocessing::preprocess( + instructions, + memory_init, + )?)) + } + + pub fn commit( + self, + memory_layout: &MemoryLayout, + generators: &PCS::ProverSetup, + bytecode_chunk_count: usize, + max_log_k_chunk: usize, + ) -> Self { + let full = match self { + Self::Full(full) => full, + Self::Committed(_committed) => { + #[cfg(feature = "prover")] + { + _committed + .prover_data + .expect("committed prover data missing during recommit") + .full + } + #[cfg(not(feature = "prover"))] + { + panic!("cannot commit already-committed verifier preprocessing") + } + } + }; + let meta = full.meta(); + #[cfg(feature = "prover")] + let (bytecode_commitments, bytecode_hints) = TrustedBytecodeCommitments::derive( + &full.bytecode, + generators, + max_log_k_chunk, + bytecode_chunk_count, + ); + #[cfg(not(feature = "prover"))] + let (bytecode_commitments, _bytecode_hints) = TrustedBytecodeCommitments::derive( + &full.bytecode, + generators, + max_log_k_chunk, + bytecode_chunk_count, + ); + #[cfg(feature = "prover")] + let (program_commitments, program_hints) = + TrustedProgramCommitments::derive(&full, memory_layout, generators); + #[cfg(not(feature = "prover"))] + let (program_commitments, _program_hints) = + TrustedProgramCommitments::derive(&full, memory_layout, generators); + Self::Committed(CommittedProgramPreprocessing { + meta, + bytecode_commitments, + program_commitments, + #[cfg(feature = "prover")] + prover_data: Some(CommittedProgramProverData { + full, + bytecode_hints, + program_hints, + }), + }) + } + + pub fn full(&self) -> Option<&FullProgramPreprocessing> { + match self { + Self::Full(full) => Some(full), + Self::Committed(_committed) => { + #[cfg(feature = "prover")] + { + _committed.prover_data.as_ref().map(|data| &data.full) + } + #[cfg(not(feature = "prover"))] + { + None + } + } + } + } + + pub fn as_full(&self) -> Result<&FullProgramPreprocessing, ProofVerifyError> { + self.full().ok_or_else(|| { + ProofVerifyError::BytecodeTypeMismatch( + "full program preprocessing unavailable in committed mode".to_string(), + ) + }) + } + + pub fn is_full(&self) -> bool { + matches!(self, Self::Full(_)) + } + + pub fn is_committed(&self) -> bool { + matches!(self, Self::Committed(_)) + } + + pub fn bytecode_commitments(&self) -> Option<&TrustedBytecodeCommitments> { + match self { + Self::Committed(committed) => Some(&committed.bytecode_commitments), + Self::Full(_) => None, + } + } + + pub fn bytecode_hints(&self) -> Option<&TrustedBytecodeHints> { + match self { + #[cfg(feature = "prover")] + Self::Committed(committed) => committed + .prover_data + .as_ref() + .map(|data| &data.bytecode_hints), + #[cfg(not(feature = "prover"))] + Self::Committed(_) => None, + Self::Full(_) => None, + } + } + + pub fn program_commitments(&self) -> Option<&TrustedProgramCommitments> { + match self { + Self::Committed(committed) => Some(&committed.program_commitments), + Self::Full(_) => None, + } + } + + pub fn program_hints(&self) -> Option<&TrustedProgramHints> { + match self { + #[cfg(feature = "prover")] + Self::Committed(committed) => committed + .prover_data + .as_ref() + .map(|data| &data.program_hints), + #[cfg(not(feature = "prover"))] + Self::Committed(_) => None, + Self::Full(_) => None, + } + } + + pub fn as_committed(&self) -> Result<&TrustedProgramCommitments, ProofVerifyError> { + self.program_commitments().ok_or_else(|| { + ProofVerifyError::BytecodeTypeMismatch("expected Committed, got Full".to_string()) + }) + } + + pub fn bytecode_len(&self) -> usize { + match self { + Self::Full(full) => full.bytecode_len(), + Self::Committed(committed) => committed.meta.bytecode_len, + } + } + + pub fn program_image_len_words(&self) -> usize { + match self { + Self::Full(full) => full.program_image_len_words(), + Self::Committed(committed) => committed.meta.program_image_len_words, + } + } + + pub fn program_image_len_words_padded(&self) -> usize { + self.program_image_len_words().next_power_of_two().max(2) + } + + pub fn committed_program_image_num_words(&self, memory_layout: &MemoryLayout) -> usize { + self.meta().committed_program_image_num_words(memory_layout) + } + + pub fn meta(&self) -> ProgramMetadata { + match self { + Self::Full(full) => full.meta(), + Self::Committed(committed) => committed.meta.clone(), + } + } + + #[inline(always)] + pub fn get_pc(&self, cycle: &Cycle) -> usize { + self.as_full() + .expect("full program preprocessing required to compute PC") + .get_pc(cycle) + } + + #[inline(always)] + pub fn entry_bytecode_index(&self) -> usize { + self.as_full() + .expect("full program preprocessing required to compute entry bytecode index") + .entry_bytecode_index() + } + + pub fn to_verifier_program(&self) -> Self { + match self { + Self::Full(full) => Self::Full(full.clone()), + Self::Committed(committed) => Self::Committed(CommittedProgramPreprocessing { + meta: committed.meta.clone(), + bytecode_commitments: committed.bytecode_commitments.clone(), + program_commitments: committed.program_commitments.clone(), + #[cfg(feature = "prover")] + prover_data: None, + }), + } + } +} + +#[derive(Debug, Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct ProgramMetadata { + pub entry_address: u64, + pub min_bytecode_address: u64, + pub program_image_len_words: usize, + pub bytecode_len: usize, +} + +impl ProgramMetadata { + pub fn program_image_len_words_padded(&self) -> usize { + self.program_image_len_words.next_power_of_two().max(2) + } + + pub fn committed_program_image_num_words(&self, _memory_layout: &MemoryLayout) -> usize { + self.program_image_len_words_padded() + } +} + +#[derive(Clone, Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)] +pub struct TrustedProgramCommitments { + pub program_image_commitment: PCS::Commitment, + pub program_image_num_columns: usize, + pub program_image_num_words: usize, +} + +#[derive(Clone, Debug)] +pub struct TrustedProgramHints { + pub program_image_hint: PCS::OpeningProofHint, +} + +impl TrustedProgramCommitments { + #[tracing::instrument(skip_all, name = "TrustedProgramCommitments::derive")] + pub fn derive( + program: &FullProgramPreprocessing, + memory_layout: &MemoryLayout, + generators: &PCS::ProverSetup, + ) -> (Self, TrustedProgramHints) { + let program_image_num_words = program.committed_program_image_num_words(memory_layout); + let (program_image_sigma, _) = + crate::poly::commitment::dory::DoryGlobals::balanced_sigma_nu( + program_image_num_words.log_2(), + ); + let program_image_num_columns = 1usize << program_image_sigma; + let program_image_poly = MultilinearPolynomial::from(build_program_image_words_padded( + program, + program_image_num_words, + )); + let _program_image_guard = DoryGlobals::initialize_context( + 1, + program_image_num_words, + DoryContext::UntrustedAdvice, + None, + ); + let (program_image_commitment, program_image_hint) = { + let _ctx = DoryGlobals::with_context(DoryContext::UntrustedAdvice); + PCS::commit(&program_image_poly, generators) + }; + + ( + Self { + program_image_commitment, + program_image_num_columns, + program_image_num_words, + }, + TrustedProgramHints { program_image_hint }, + ) + } +} + +pub(crate) fn build_program_image_words_padded( + program: &FullProgramPreprocessing, + padded_len: usize, +) -> Vec { + debug_assert!(padded_len.is_power_of_two()); + debug_assert!(padded_len >= program.ram.bytecode_words.len().max(1)); + let mut coeffs = vec![0u64; padded_len]; + coeffs[..program.ram.bytecode_words.len()].copy_from_slice(&program.ram.bytecode_words); + coeffs +} diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index 2dea06bd7..cf519e813 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -48,7 +48,8 @@ pub struct JoltProof< pub stage3_sumcheck_proof: SumcheckInstanceProof, pub stage4_sumcheck_proof: SumcheckInstanceProof, pub stage5_sumcheck_proof: SumcheckInstanceProof, - pub stage6_sumcheck_proof: SumcheckInstanceProof, + pub stage6a_sumcheck_proof: SumcheckInstanceProof, + pub stage6b_sumcheck_proof: SumcheckInstanceProof, pub stage7_sumcheck_proof: SumcheckInstanceProof, #[cfg(feature = "zk")] pub blindfold_proof: BlindFoldProof, @@ -77,7 +78,8 @@ impl, PCS: CommitmentScheme, FS: Tr && self.stage3_sumcheck_proof.is_zk() == zk_mode && self.stage4_sumcheck_proof.is_zk() == zk_mode && self.stage5_sumcheck_proof.is_zk() == zk_mode - && self.stage6_sumcheck_proof.is_zk() == zk_mode + && self.stage6a_sumcheck_proof.is_zk() == zk_mode + && self.stage6b_sumcheck_proof.is_zk() == zk_mode && self.stage7_sumcheck_proof.is_zk() == zk_mode; if !consistent { @@ -298,13 +300,25 @@ impl CanonicalSerialize for CommittedPolynomial { } Self::TrustedAdvice => 5u8.serialize_with_mode(writer, compress), Self::UntrustedAdvice => 6u8.serialize_with_mode(writer, compress), + Self::BytecodeChunk(i) => { + 7u8.serialize_with_mode(&mut writer, compress)?; + (u8::try_from(*i).unwrap()).serialize_with_mode(writer, compress) + } + Self::ProgramImageInit => 8u8.serialize_with_mode(writer, compress), } } fn serialized_size(&self, _compress: Compress) -> usize { match self { - Self::RdInc | Self::RamInc | Self::TrustedAdvice | Self::UntrustedAdvice => 1, - Self::InstructionRa(_) | Self::BytecodeRa(_) | Self::RamRa(_) => 2, + Self::RdInc + | Self::RamInc + | Self::TrustedAdvice + | Self::UntrustedAdvice + | Self::ProgramImageInit => 1, + Self::InstructionRa(_) + | Self::BytecodeRa(_) + | Self::RamRa(_) + | Self::BytecodeChunk(_) => 2, } } } @@ -339,6 +353,11 @@ impl CanonicalDeserialize for CommittedPolynomial { } 5 => Self::TrustedAdvice, 6 => Self::UntrustedAdvice, + 7 => { + let i = u8::deserialize_with_mode(reader, compress, validate)?; + Self::BytecodeChunk(i as usize) + } + 8 => Self::ProgramImageInit, _ => return Err(SerializationError::InvalidData), }, ) @@ -403,6 +422,19 @@ impl CanonicalSerialize for VirtualPolynomial { 38u8.serialize_with_mode(&mut writer, compress)?; (u8::try_from(*flag).unwrap()).serialize_with_mode(&mut writer, compress) } + Self::BytecodeReadRafAddrClaim => 39u8.serialize_with_mode(&mut writer, compress), + Self::BooleanityAddrClaim => 40u8.serialize_with_mode(&mut writer, compress), + Self::BytecodeValStage(i) => { + 41u8.serialize_with_mode(&mut writer, compress)?; + (u8::try_from(*i).unwrap()).serialize_with_mode(&mut writer, compress) + } + Self::BytecodeClaimReductionIntermediate => { + 42u8.serialize_with_mode(&mut writer, compress) + } + Self::ProgramImageInitContributionRw => 43u8.serialize_with_mode(&mut writer, compress), + Self::ProgramImageInitContributionRaf => { + 44u8.serialize_with_mode(&mut writer, compress) + } } } @@ -442,11 +474,17 @@ impl CanonicalSerialize for VirtualPolynomial { | Self::RamValInit | Self::RamValFinal | Self::RamHammingWeight - | Self::UnivariateSkip => 1, + | Self::UnivariateSkip + | Self::BytecodeReadRafAddrClaim + | Self::BooleanityAddrClaim + | Self::BytecodeClaimReductionIntermediate + | Self::ProgramImageInitContributionRw + | Self::ProgramImageInitContributionRaf => 1, Self::InstructionRa(_) | Self::OpFlags(_) | Self::InstructionFlags(_) - | Self::LookupTableFlag(_) => 2, + | Self::LookupTableFlag(_) + | Self::BytecodeValStage(_) => 2, } } } @@ -520,6 +558,15 @@ impl CanonicalDeserialize for VirtualPolynomial { let flag = u8::deserialize_with_mode(&mut reader, compress, validate)?; Self::LookupTableFlag(flag as usize) } + 39 => Self::BytecodeReadRafAddrClaim, + 40 => Self::BooleanityAddrClaim, + 41 => { + let i = u8::deserialize_with_mode(&mut reader, compress, validate)?; + Self::BytecodeValStage(i as usize) + } + 42 => Self::BytecodeClaimReductionIntermediate, + 43 => Self::ProgramImageInitContributionRw, + 44 => Self::ProgramImageInitContributionRaf, _ => return Err(SerializationError::InvalidData), }, ) diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index b97db2a3e..24f0f0b62 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -2,7 +2,6 @@ use crate::poly::opening_proof::OpeningId; #[cfg(feature = "zk")] use crate::zkvm::stage8_opening_ids; -use crate::zkvm::{claim_reductions::advice::ReductionPhase, config::OneHotConfig}; #[cfg(not(target_arch = "wasm32"))] use std::time::Instant; use std::{ @@ -20,7 +19,10 @@ use crate::poly::commitment::dory::bind_opening_inputs_zk; use crate::poly::commitment::dory::DoryContext; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +#[cfg(feature = "zk")] +use crate::zkvm::config::ProgramMode; use crate::zkvm::config::ReadWriteConfig; +use crate::zkvm::program::{build_program_image_words_padded, FullProgramPreprocessing}; use crate::zkvm::ram::remap_address; use crate::zkvm::verifier::JoltSharedPreprocessing; use crate::zkvm::Serializable; @@ -37,17 +39,19 @@ use crate::{ commitment_scheme::{StreamingCommitmentScheme, ZkEvalCommitment}, dory::{DoryGlobals, DoryLayout}, }, - eq_poly::EqPolynomial, multilinear_polynomial::MultilinearPolynomial, opening_proof::{ - compute_advice_lagrange_factor, DoryOpeningState, OpeningAccumulator, - ProverOpeningAccumulator, SumcheckId, + compute_lagrange_factor, DoryOpeningState, OpeningAccumulator, OpeningPoint, + ProverOpeningAccumulator, SumcheckId, BIG_ENDIAN, }, rlc_polynomial::{RLCStreamingData, TraceSource}, }, pprof_scope, subprotocols::{ - booleanity::{BooleanitySumcheckParams, BooleanitySumcheckProver}, + booleanity::{ + BooleanityAddressSumcheckProver, BooleanityCycleSumcheckProver, + BooleanitySumcheckParams, + }, streaming_schedule::LinearOnlySchedule, sumcheck::{BatchedSumcheck, SumcheckInstanceProof}, sumcheck_prover::SumcheckInstanceProver, @@ -56,15 +60,20 @@ use crate::{ transcripts::Transcript, utils::{math::Math, thread::drop_in_background_thread}, zkvm::{ - bytecode::read_raf_checking::BytecodeReadRafSumcheckParams, + bytecode::{ + chunks::{build_committed_bytecode_chunk_coeffs, committed_bytecode_chunk_cycle_len}, + read_raf_checking::BytecodeReadRafSumcheckParams, + }, claim_reductions::{ AdviceClaimReductionParams, AdviceClaimReductionProver, AdviceKind, + BytecodeClaimReductionParams, BytecodeClaimReductionProver, HammingWeightClaimReductionParams, HammingWeightClaimReductionProver, IncClaimReductionSumcheckParams, IncClaimReductionSumcheckProver, InstructionLookupsClaimReductionSumcheckParams, - InstructionLookupsClaimReductionSumcheckProver, RaReductionParams, - RamRaClaimReductionSumcheckProver, RegistersClaimReductionSumcheckParams, - RegistersClaimReductionSumcheckProver, + InstructionLookupsClaimReductionSumcheckProver, PrecommittedClaimReduction, + PrecommittedPolynomial, ProgramImageClaimReductionParams, + ProgramImageClaimReductionProver, RaReductionParams, RamRaClaimReductionSumcheckProver, + RegistersClaimReductionSumcheckParams, RegistersClaimReductionSumcheckProver, }, config::OneHotParams, instruction_lookups::{ @@ -74,7 +83,7 @@ use crate::{ ram::{ hamming_booleanity::HammingBooleanitySumcheckParams, output_check::OutputSumcheckParams, - populate_memory_states, + populate_memory_states, prover_accumulate_program_image, ra_virtual::RamRaVirtualParams, raf_evaluation::RafEvaluationSumcheckParams, read_write_checking::RamReadWriteCheckingParams, @@ -99,7 +108,9 @@ use crate::{ use crate::{ poly::commitment::commitment_scheme::CommitmentScheme, zkvm::{ - bytecode::read_raf_checking::BytecodeReadRafSumcheckProver, + bytecode::read_raf_checking::{ + BytecodeReadRafAddressSumcheckProver, BytecodeReadRafCycleSumcheckProver, + }, fiat_shamir_preamble, instruction_lookups::{ ra_virtual::InstructionRaSumcheckProver as LookupsRaSumcheckProver, @@ -182,6 +193,10 @@ pub struct JoltCpuProver< /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). /// Cache the prover state here between stages. advice_reduction_prover_untrusted: Option>, + /// Bytecode claim reduction spans stages 6b and 7 in committed mode. + bytecode_reduction_prover: Option>, + /// Program-image claim reduction spans stages 6b and 7 in committed mode. + program_image_reduction_prover: Option>, pub unpadded_trace_len: usize, pub padded_trace_len: usize, pub transcript: ProofTranscript, @@ -274,81 +289,115 @@ impl< ) } - /// Adjusts the padded trace length to ensure the main Dory matrix is large enough - /// to embed advice polynomials as the top-left block. - /// - /// Returns the adjusted padded_trace_len that satisfies: - /// - `sigma_main >= max_sigma_a` - /// - `nu_main >= max_nu_a` - /// - /// Panics if `max_padded_trace_length` is too small for the configured advice sizes. - fn adjust_trace_length_for_advice( - mut padded_trace_len: usize, - max_padded_trace_length: usize, - max_trusted_advice_size: u64, - max_untrusted_advice_size: u64, - has_trusted_advice: bool, - has_untrusted_advice: bool, - ) -> usize { - // Canonical advice shape policy (balanced): - // - advice_vars = log2(advice_len) - // - sigma_a = ceil(advice_vars/2) - // - nu_a = advice_vars - sigma_a - let mut max_sigma_a = 0usize; - let mut max_nu_a = 0usize; - - if has_trusted_advice { - let (sigma_a, nu_a) = - DoryGlobals::advice_sigma_nu_from_max_bytes(max_trusted_advice_size as usize); - max_sigma_a = max_sigma_a.max(sigma_a); - max_nu_a = max_nu_a.max(nu_a); + #[inline] + fn main_total_vars(&self) -> usize { + let trace_log_t = self.trace.len().log_2(); + let log_k_chunk = self.one_hot_params.log_k_chunk; + JoltSharedPreprocessing::::max_total_vars_from_candidates( + trace_log_t + log_k_chunk, + self.preprocessing.shared.precommitted_candidate_total_vars( + self.preprocessing.is_committed_mode(), + !self.program_io.trusted_advice.is_empty(), + !self.program_io.untrusted_advice.is_empty(), + ), + ) + } + + fn stage8_opening_point(&self) -> OpeningPoint { + let native_main_vars = self.trace.len().log_2() + self.one_hot_params.log_k_chunk; + let mut opening_candidates: Vec<(String, OpeningPoint)> = Vec::new(); + if let Some((point, _)) = self + .opening_accumulator + .get_advice_opening(AdviceKind::Trusted, SumcheckId::AdviceClaimReduction) + { + opening_candidates.push(("trusted_advice".to_string(), point)); } - if has_untrusted_advice { - let (sigma_a, nu_a) = - DoryGlobals::advice_sigma_nu_from_max_bytes(max_untrusted_advice_size as usize); - max_sigma_a = max_sigma_a.max(sigma_a); - max_nu_a = max_nu_a.max(nu_a); + if let Some((point, _)) = self + .opening_accumulator + .get_advice_opening(AdviceKind::Untrusted, SumcheckId::AdviceClaimReduction) + { + opening_candidates.push(("untrusted_advice".to_string(), point)); } - - if max_sigma_a == 0 && max_nu_a == 0 { - return padded_trace_len; + if self.preprocessing.is_committed_mode() { + for chunk_idx in 0..self.preprocessing.shared.bytecode_chunk_count { + let (point, _) = self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::BytecodeChunk(chunk_idx), + SumcheckId::BytecodeClaimReduction, + ); + opening_candidates.push((format!("bytecode_chunk[{chunk_idx}]"), point)); + } + } + if self.preprocessing.is_committed_mode() { + let (program_image_point, _) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + ); + opening_candidates.push(("program_image".to_string(), program_image_point)); } - // Require main matrix dimensions to be large enough to embed advice as the top-left - // block: sigma_main >= sigma_a and nu_main >= nu_a. - // - // This loop doubles padded_trace_len until the main Dory matrix is large enough. - // Each doubling increases log_t by 1, which increases total_vars by 1 (since - // log_k_chunk stays constant for a given log_t range), increasing both sigma_main - // and nu_main by roughly 0.5 each iteration. - while { - let log_t = padded_trace_len.log_2(); - let log_k_chunk = OneHotConfig::new(log_t).log_k_chunk as usize; - let (sigma_main, nu_main) = DoryGlobals::main_sigma_nu(log_k_chunk, log_t); - sigma_main < max_sigma_a || nu_main < max_nu_a - } { - if padded_trace_len >= max_padded_trace_length { - // This is a configuration error: the preprocessing was set up with - // max_padded_trace_length too small for the configured advice sizes. - // Cannot recover at runtime - user must fix their configuration. - let log_t = padded_trace_len.log_2(); - let log_k_chunk = OneHotConfig::new(log_t).log_k_chunk as usize; - let total_vars = log_k_chunk + log_t; - let (sigma_main, nu_main) = DoryGlobals::main_sigma_nu(log_k_chunk, log_t); - panic!( - "Configuration error: trace too small to embed advice into Dory batch opening.\n\ - Current: (sigma_main={sigma_main}, nu_main={nu_main}) from total_vars={total_vars} (log_t={log_t}, log_k_chunk={log_k_chunk})\n\ - Required: (sigma_a={max_sigma_a}, nu_a={max_nu_a}) for advice embedding\n\ - Solutions:\n\ - 1. Increase max_trace_length in preprocessing (currently {max_padded_trace_length})\n\ - 2. Reduce max_trusted_advice_size or max_untrusted_advice_size\n\ - 3. Run a program with more cycles" + let max_len = opening_candidates + .iter() + .map(|(_, p)| p.r.len()) + .max() + .unwrap_or(0); + let final_point = if max_len > native_main_vars { + let dominant = opening_candidates + .iter() + .find(|(_, p)| p.r.len() == max_len) + .expect("at least one dominant precommitted candidate expected"); + for (name, point) in opening_candidates + .iter() + .filter(|(_, p)| p.r.len() == max_len) + { + assert_eq!( + point.r, dominant.1.r, + "incompatible dominant precommitted anchors: {} and {} have equal dimensionality {} but different opening points", + dominant.0, name, max_len ); } - padded_trace_len = (padded_trace_len * 2).min(max_padded_trace_length); - } + OpeningPoint::::new(dominant.1.r.clone()) + } else { + let (hamming_point, _) = self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::InstructionRa(0), + SumcheckId::HammingWeightClaimReduction, + ); + let r_address_stage7 = hamming_point.r[..self.one_hot_params.log_k_chunk].to_vec(); + let r_cycle_stage6 = self + .opening_accumulator + .get_committed_polynomial_opening( + CommittedPolynomial::RamInc, + SumcheckId::IncClaimReduction, + ) + .0 + .r; + + match DoryGlobals::get_layout() { + DoryLayout::AddressMajor => OpeningPoint::::new( + [r_cycle_stage6.as_slice(), r_address_stage7.as_slice()].concat(), + ), + DoryLayout::CycleMajor => { + let native_cycle = &hamming_point.r[self.one_hot_params.log_k_chunk..]; + assert!( + r_cycle_stage6.len() >= native_cycle.len(), + "stage6 cycle challenges shorter than native cycle vars" + ); + assert!( + r_cycle_stage6[..native_cycle.len()] == *native_cycle, + "cycle-major Stage-8 expects stage6 cycle prefix to equal native cycle vars \ + (cycle_full_len={}, native_len={})", + r_cycle_stage6.len(), + native_cycle.len() + ); + let cycle_extra = &r_cycle_stage6[native_cycle.len()..]; + let cycle_extra_and_anchor = + [cycle_extra, r_address_stage7.as_slice(), native_cycle].concat(); + OpeningPoint::::new(cycle_extra_and_anchor) + } + } + }; - padded_trace_len + final_point } pub fn gen_from_trace( @@ -386,20 +435,6 @@ impl< ); } - // We may need extra padding so the main Dory matrix has enough (row, col) variables - // to embed advice commitments committed in their own preprocessing-only contexts. - let has_trusted_advice = !program_io.trusted_advice.is_empty(); - let has_untrusted_advice = !program_io.untrusted_advice.is_empty(); - - let padded_trace_len = Self::adjust_trace_length_for_advice( - padded_trace_len, - preprocessing.shared.max_padded_trace_length, - preprocessing.shared.memory_layout.max_trusted_advice_size, - preprocessing.shared.memory_layout.max_untrusted_advice_size, - has_trusted_advice, - has_untrusted_advice, - ); - trace.resize(padded_trace_len, Cycle::NoOp); // Calculate K for DoryGlobals initialization @@ -415,11 +450,11 @@ impl< .unwrap_or(0) .max( remap_address( - preprocessing.shared.ram.min_bytecode_address, + preprocessing.shared.program_meta.min_bytecode_address, &preprocessing.shared.memory_layout, ) .unwrap_or(0) - + preprocessing.shared.ram.bytecode_words.len() as u64 + + preprocessing.shared.program_meta.program_image_len_words as u64 + 1, ) .next_power_of_two() as usize; @@ -431,7 +466,7 @@ impl< let (initial_ram_state, final_ram_state) = gen_ram_memory_states::( ram_K, - &preprocessing.shared.ram, + &preprocessing.materialized_program().ram, &program_io, &final_memory_state, ); @@ -439,8 +474,7 @@ impl< let log_T = trace.len().log_2(); let ram_log_K = ram_K.log_2(); let rw_config = ReadWriteConfig::new(log_T, ram_log_K); - let one_hot_params = - OneHotParams::new(log_T, preprocessing.shared.bytecode.code_size, ram_K); + let one_hot_params = OneHotParams::new(log_T, preprocessing.shared.bytecode_size(), ram_K); #[cfg(feature = "zk")] let pedersen_generators = { @@ -462,6 +496,8 @@ impl< }, advice_reduction_prover_trusted: None, advice_reduction_prover_untrusted: None, + bytecode_reduction_prover: None, + program_image_reduction_prover: None, unpadded_trace_len, padded_trace_len, transcript, @@ -497,7 +533,7 @@ impl< &self.program_io, self.one_hot_params.ram_k, self.trace.len(), - self.preprocessing.shared.bytecode.entry_address, + self.preprocessing.shared.program_meta.entry_address, &self.rw_config, &self.one_hot_params.to_config(), DoryGlobals::get_layout(), @@ -507,7 +543,7 @@ impl< tracing::info!( "bytecode size: {}", - self.preprocessing.shared.bytecode.code_size + self.preprocessing.shared.bytecode_size() ); let (commitments, mut opening_proof_hints) = self.generate_and_commit_witness_polynomials(); @@ -521,6 +557,30 @@ impl< if let Some(hint) = self.advice.untrusted_advice_hint.take() { opening_proof_hints.insert(CommittedPolynomial::UntrustedAdvice, hint); } + if let Some(bytecode_hints) = self.preprocessing.shared.program.bytecode_hints() { + for (idx, hint) in bytecode_hints.hints.iter().cloned().enumerate() { + opening_proof_hints.insert(CommittedPolynomial::BytecodeChunk(idx), hint); + } + } + if let Some(program_hints) = self.preprocessing.shared.program.program_hints() { + opening_proof_hints.insert( + CommittedPolynomial::ProgramImageInit, + program_hints.program_image_hint.clone(), + ); + } + if let Some(bytecode_commitments) = self.preprocessing.shared.program.bytecode_commitments() + { + for commitment in &bytecode_commitments.commitments { + self.transcript + .append_serializable(b"bytecode_chunk_commit", commitment); + } + } + if let Some(program_commitments) = self.preprocessing.shared.program.program_commitments() { + self.transcript.append_serializable( + b"program_image_commitment", + &program_commitments.program_image_commitment, + ); + } let (stage1_uni_skip_first_round_proof, stage1_sumcheck_proof, r_stage1) = self.prove_stage1(); @@ -529,7 +589,10 @@ impl< let (stage3_sumcheck_proof, r_stage3) = self.prove_stage3(); let (stage4_sumcheck_proof, r_stage4) = self.prove_stage4(); let (stage5_sumcheck_proof, r_stage5) = self.prove_stage5(); - let (stage6_sumcheck_proof, r_stage6) = self.prove_stage6(); + let (stage6a_sumcheck_proof, bytecode_read_raf_params, booleanity_params) = + self.prove_stage6a(); + let (stage6b_sumcheck_proof, r_stage6) = + self.prove_stage6b(bytecode_read_raf_params, booleanity_params); let (stage7_sumcheck_proof, r_stage7) = self.prove_stage7(); let _sumcheck_challenges = [ @@ -576,7 +639,8 @@ impl< stage3_sumcheck_proof, stage4_sumcheck_proof, stage5_sumcheck_proof, - stage6_sumcheck_proof, + stage6a_sumcheck_proof, + stage6b_sumcheck_proof, stage7_sumcheck_proof, #[cfg(feature = "zk")] blindfold_proof, @@ -670,36 +734,34 @@ impl< Vec, HashMap, ) { - let _guard = DoryGlobals::initialize_context( + let main_total_vars = self.main_total_vars(); + let trace = Arc::clone(&self.trace); + let _guard = DoryGlobals::initialize_main_with_log_embedding( 1 << self.one_hot_params.log_k_chunk, - self.padded_trace_len, - DoryContext::Main, + trace.len(), + main_total_vars, Some(DoryGlobals::get_layout()), ); let polys = all_committed_polynomials(&self.one_hot_params); - let T = DoryGlobals::get_T(); + let T = DoryGlobals::get_embedded_t(); - // For AddressMajor, use non-streaming commit path since streaming assumes CycleMajor layout - let (commitments, hint_map) = if DoryGlobals::get_layout() == DoryLayout::AddressMajor { + // AddressMajor uses non-streaming commit path, and we also use non-streaming when + // Stage 6/8 embedding domain exceeds the trace domain. + let use_materialized_commit = + DoryGlobals::get_layout() == DoryLayout::AddressMajor || self.trace.len() != T; + let (commitments, hint_map) = if use_materialized_commit { tracing::debug!( - "Using non-streaming commit path for AddressMajor layout with {} polynomials", + "Using non-streaming commit path with {} polynomials", polys.len() ); - // Materialize the trace for non-streaming commit - let trace: Vec = self - .lazy_trace - .clone() - .pad_using(T, |_| Cycle::NoOp) - .collect(); - // Generate witnesses and commit using the regular (non-streaming) path let (commitments, hints): (Vec<_>, Vec<_>) = polys .par_iter() .map(|poly_id| { let witness: MultilinearPolynomial = poly_id.generate_witness( - &self.preprocessing.shared.bytecode, + &self.preprocessing.materialized_program().bytecode, &self.preprocessing.shared.memory_layout, &trace, Some(&self.one_hot_params), @@ -736,9 +798,8 @@ impl< let res: Vec<_> = polys .par_iter() .map(|poly| { - poly.stream_witness_and_commit_rows::<_, PCS>( - &self.preprocessing.generators, - &self.preprocessing.shared, + poly.stream_witness_and_commit_rows::<_, _, PCS>( + self.preprocessing, &chunk, &self.one_hot_params, ) @@ -860,14 +921,14 @@ impl< let mut uni_skip = OuterUniSkipProver::initialize( uni_skip_params.clone(), &self.trace, - &self.preprocessing.shared.bytecode, + &self.preprocessing.materialized_program().bytecode, ); let first_round_proof = self.prove_uniskip(&mut uni_skip); let schedule = LinearOnlySchedule::new(uni_skip_params.tau.len() - 1); let shared = OuterSharedState::new( Arc::clone(&self.trace), - &self.preprocessing.shared.bytecode, + &self.preprocessing.materialized_program().bytecode, &uni_skip_params, &self.opening_accumulator, ); @@ -935,7 +996,7 @@ impl< let ram_read_write_checking = RamReadWriteCheckingProver::initialize( ram_read_write_checking_params, &self.trace, - &self.preprocessing.shared.bytecode, + &self.preprocessing.materialized_program().bytecode, &self.program_io.memory_layout, &self.initial_ram_state, ); @@ -1024,7 +1085,7 @@ impl< let spartan_shift = ShiftSumcheckProver::initialize( spartan_shift_params, Arc::clone(&self.trace), - &self.preprocessing.shared.bytecode, + &self.preprocessing.materialized_program().bytecode, ); let spartan_instruction_input = InstructionInputSumcheckProver::initialize( spartan_instruction_input_params, @@ -1090,6 +1151,14 @@ impl< &self.one_hot_params, &mut self.opening_accumulator, ); + if self.preprocessing.is_committed_mode() { + prover_accumulate_program_image( + self.one_hot_params.ram_k, + &self.preprocessing.materialized_program().ram, + &self.program_io, + &mut self.opening_accumulator, + ); + } // Domain-separate the batching challenge. self.transcript.append_bytes(b"ram_val_check_gamma", &[]); let ram_val_check_gamma: F = self.transcript.challenge_scalar::(); @@ -1099,20 +1168,22 @@ impl< &self.initial_ram_state, self.trace.len(), ram_val_check_gamma, - &self.preprocessing.shared.ram, + &self.preprocessing.materialized_program().ram, &self.program_io, + &self.rw_config, + self.preprocessing.is_committed_mode(), ); let registers_read_write_checking = RegistersReadWriteCheckingProver::initialize( registers_read_write_checking_params, self.trace.clone(), - &self.preprocessing.shared.bytecode, + &self.preprocessing.materialized_program().bytecode, &self.program_io.memory_layout, ); let ram_val_check = RamValCheckSumcheckProver::initialize( ram_val_check_params, &self.trace, - &self.preprocessing.shared.bytecode, + &self.preprocessing.materialized_program().bytecode, &self.program_io.memory_layout, ); @@ -1181,7 +1252,7 @@ impl< let registers_val_evaluation = RegistersValEvaluationSumcheckProver::initialize( registers_val_evaluation_params, &self.trace, - &self.preprocessing.shared.bytecode, + &self.preprocessing.materialized_program().bytecode, &self.program_io.memory_layout, ); @@ -1215,32 +1286,83 @@ impl< } #[tracing::instrument(skip_all)] - fn prove_stage6( + fn prove_stage6a( &mut self, ) -> ( SumcheckInstanceProof, - Vec, + BytecodeReadRafSumcheckParams, + BooleanitySumcheckParams, ) { #[cfg(not(target_arch = "wasm32"))] - print_current_memory_usage("Stage 6 baseline"); + print_current_memory_usage("Stage 6a baseline"); let bytecode_read_raf_params = BytecodeReadRafSumcheckParams::gen( - &self.preprocessing.shared.bytecode, + Some(&self.preprocessing.shared.program), self.trace.len().log_2(), &self.one_hot_params, + self.preprocessing.is_committed_mode(), + None, &self.opening_accumulator, &mut self.transcript, ); - let ram_hamming_booleanity_params = - HammingBooleanitySumcheckParams::new(&self.opening_accumulator); - let booleanity_params = BooleanitySumcheckParams::new( self.trace.len().log_2(), &self.one_hot_params, &self.opening_accumulator, &mut self.transcript, ); + let mut bytecode_read_raf = BytecodeReadRafAddressSumcheckProver::initialize( + bytecode_read_raf_params.clone(), + Arc::clone(&self.trace), + Arc::new(self.preprocessing.materialized_program().bytecode.clone()), + ); + let mut booleanity = BooleanityAddressSumcheckProver::initialize( + booleanity_params.clone(), + &self.trace, + &self.preprocessing.materialized_program().bytecode, + &self.program_io.memory_layout, + ); + + #[cfg(feature = "allocative")] + { + print_data_structure_heap_usage( + "BytecodeReadRafAddressSumcheckProver", + &bytecode_read_raf, + ); + print_data_structure_heap_usage("BooleanityAddressSumcheckProver", &booleanity); + } + + let mut instances: Vec<&mut dyn SumcheckInstanceProver<_, _>> = + vec![&mut bytecode_read_raf, &mut booleanity]; + + #[cfg(feature = "allocative")] + write_instance_flamegraph_svg(&instances, "stage6a_start_flamechart.svg"); + tracing::info!("Stage 6a proving"); + + let (sumcheck_proof, _r_stage6a, _initial_claim) = + self.prove_batched_sumcheck(instances.iter_mut().map(|v| &mut **v as _).collect()); + + #[cfg(feature = "allocative")] + write_instance_flamegraph_svg(&instances, "stage6a_end_flamechart.svg"); + + (sumcheck_proof, bytecode_read_raf_params, booleanity_params) + } + + #[tracing::instrument(skip_all)] + fn prove_stage6b( + &mut self, + bytecode_read_raf_params: BytecodeReadRafSumcheckParams, + booleanity_params: BooleanitySumcheckParams, + ) -> ( + SumcheckInstanceProof, + Vec, + ) { + #[cfg(not(target_arch = "wasm32"))] + print_current_memory_usage("Stage 6b baseline"); + + let ram_hamming_booleanity_params = + HammingBooleanitySumcheckParams::new(&self.opening_accumulator); let ram_ra_virtual_params = RamRaVirtualParams::new( self.trace.len(), @@ -1258,12 +1380,24 @@ impl< &mut self.transcript, ); - // Advice claim reduction (Phase 1 in Stage 6): trusted and untrusted are separate instances. + let main_total_vars = self.trace.len().log_2() + self.one_hot_params.log_k_chunk; + let precommitted_candidates = self.preprocessing.shared.precommitted_candidate_total_vars( + self.preprocessing.is_committed_mode(), + self.advice.trusted_advice_polynomial.is_some(), + self.advice.untrusted_advice_polynomial.is_some(), + ); + let precommitted_scheduling_reference = + PrecommittedClaimReduction::::scheduling_reference( + main_total_vars, + &precommitted_candidates, + ); + + // Advice claim reduction (Phase 1 in Stage 6b): trusted and untrusted are separate instances. if self.advice.trusted_advice_polynomial.is_some() { let trusted_advice_params = AdviceClaimReductionParams::new( AdviceKind::Trusted, - &self.program_io.memory_layout, - self.trace.len(), + self.program_io.memory_layout.max_trusted_advice_size as usize, + precommitted_scheduling_reference, &self.opening_accumulator, ); // Note: We clone the advice polynomial here because Stage 8 needs the original polynomial @@ -1284,8 +1418,8 @@ impl< if self.advice.untrusted_advice_polynomial.is_some() { let untrusted_advice_params = AdviceClaimReductionParams::new( AdviceKind::Untrusted, - &self.program_io.memory_layout, - self.trace.len(), + self.program_io.memory_layout.max_untrusted_advice_size as usize, + precommitted_scheduling_reference, &self.opening_accumulator, ); // Note: We clone the advice polynomial here because Stage 8 needs the original polynomial @@ -1303,20 +1437,65 @@ impl< }; } - let mut bytecode_read_raf = BytecodeReadRafSumcheckProver::initialize( + if self.preprocessing.is_committed_mode() { + let bytecode_chunk_count = self.preprocessing.shared.bytecode_chunk_count; + let bytecode_reduction_params = BytecodeClaimReductionParams::new( + &bytecode_read_raf_params, + self.preprocessing.shared.bytecode_size(), + bytecode_chunk_count, + precommitted_scheduling_reference, + &self.opening_accumulator, + &mut self.transcript, + ); + let bytecode_chunk_coeffs = build_committed_bytecode_chunk_coeffs( + &self.preprocessing.materialized_program().bytecode.bytecode, + bytecode_chunk_count, + ); + self.bytecode_reduction_prover = Some(BytecodeClaimReductionProver::initialize( + bytecode_reduction_params, + &bytecode_chunk_coeffs, + )); + + let padded_len_words = self + .preprocessing + .shared + .program + .committed_program_image_num_words(&self.program_io.memory_layout); + let program_image_words = build_program_image_words_padded( + self.preprocessing.materialized_program(), + padded_len_words, + ); + let program_image_reduction_params = ProgramImageClaimReductionParams::new( + &self.program_io, + self.preprocessing.shared.program_meta.min_bytecode_address, + padded_len_words, + self.one_hot_params.ram_k, + precommitted_scheduling_reference, + &self.opening_accumulator, + &mut self.transcript, + ); + self.program_image_reduction_prover = + Some(ProgramImageClaimReductionProver::initialize( + program_image_reduction_params, + program_image_words, + )); + } + + let mut bytecode_read_raf = BytecodeReadRafCycleSumcheckProver::initialize( bytecode_read_raf_params, Arc::clone(&self.trace), - Arc::clone(&self.preprocessing.shared.bytecode), + Arc::new(self.preprocessing.materialized_program().bytecode.clone()), + &self.opening_accumulator, ); - let mut ram_hamming_booleanity = - HammingBooleanitySumcheckProver::initialize(ram_hamming_booleanity_params, &self.trace); - - let mut booleanity = BooleanitySumcheckProver::initialize( + let mut booleanity = BooleanityCycleSumcheckProver::initialize( booleanity_params, &self.trace, - &self.preprocessing.shared.bytecode, + &self.preprocessing.materialized_program().bytecode, &self.program_io.memory_layout, + &self.opening_accumulator, ); + let mut ram_hamming_booleanity = + HammingBooleanitySumcheckProver::initialize(ram_hamming_booleanity_params, &self.trace); let mut ram_ra_virtual = RamRaVirtualSumcheckProver::initialize( ram_ra_virtual_params, @@ -1331,8 +1510,11 @@ impl< #[cfg(feature = "allocative")] { - print_data_structure_heap_usage("BytecodeReadRafSumcheckProver", &bytecode_read_raf); - print_data_structure_heap_usage("BooleanitySumcheckProver", &booleanity); + print_data_structure_heap_usage( + "BytecodeReadRafCycleSumcheckProver", + &bytecode_read_raf, + ); + print_data_structure_heap_usage("BooleanityCycleSumcheckProver", &booleanity); print_data_structure_heap_usage( "ram HammingBooleanitySumcheckProver", &ram_hamming_booleanity, @@ -1350,6 +1532,8 @@ impl< let mut advice_trusted = self.advice_reduction_prover_trusted.take(); let mut advice_untrusted = self.advice_reduction_prover_untrusted.take(); + let mut bytecode_reduction = self.bytecode_reduction_prover.take(); + let mut program_image_reduction = self.program_image_reduction_prover.take(); let mut instances: Vec<&mut dyn SumcheckInstanceProver<_, _>> = vec![ &mut bytecode_read_raf, @@ -1365,15 +1549,21 @@ impl< if let Some(ref mut advice) = advice_untrusted { instances.push(advice); } + if let Some(ref mut reduction) = bytecode_reduction { + instances.push(reduction); + } + if let Some(ref mut reduction) = program_image_reduction { + instances.push(reduction); + } #[cfg(feature = "allocative")] - write_instance_flamegraph_svg(&instances, "stage6_start_flamechart.svg"); - tracing::info!("Stage 6 proving"); + write_instance_flamegraph_svg(&instances, "stage6b_start_flamechart.svg"); + tracing::info!("Stage 6b proving"); - let (sumcheck_proof, r_stage6, _initial_claim) = + let (sumcheck_proof, r_stage6b, _initial_claim) = self.prove_batched_sumcheck(instances.iter_mut().map(|v| &mut **v as _).collect()); #[cfg(feature = "allocative")] - write_instance_flamegraph_svg(&instances, "stage6_end_flamechart.svg"); + write_instance_flamegraph_svg(&instances, "stage6b_end_flamechart.svg"); drop_in_background_thread(bytecode_read_raf); drop_in_background_thread(booleanity); drop_in_background_thread(ram_hamming_booleanity); @@ -1383,8 +1573,10 @@ impl< self.advice_reduction_prover_trusted = advice_trusted; self.advice_reduction_prover_untrusted = advice_untrusted; + self.bytecode_reduction_prover = bytecode_reduction; + self.program_image_reduction_prover = program_image_reduction; - (sumcheck_proof, r_stage6) + (sumcheck_proof, r_stage6b) } #[tracing::instrument(skip_all)] @@ -1409,8 +1601,8 @@ impl< let zk_stages = self.blindfold_accumulator.take_stage_data(); assert_eq!( zk_stages.len(), - 7, - "Expected 7 ZK stages, got {}", + 8, + "Expected 8 ZK stages, got {}", zk_stages.len() ); @@ -1852,7 +2044,7 @@ impl< let hw_prover = HammingWeightClaimReductionProver::initialize( hw_params, &self.trace, - &self.preprocessing.shared, + self.preprocessing, &self.one_hot_params, ); @@ -1868,12 +2060,12 @@ impl< self.advice_reduction_prover_trusted.take() { if advice_reduction_prover_trusted - .params + .params() .num_address_phase_rounds() > 0 { // Transition phase - advice_reduction_prover_trusted.params.phase = ReductionPhase::AddressVariables; + advice_reduction_prover_trusted.transition_to_address_phase(); instances.push(Box::new(advice_reduction_prover_trusted)); } } @@ -1881,15 +2073,36 @@ impl< self.advice_reduction_prover_untrusted.take() { if advice_reduction_prover_untrusted - .params + .params() .num_address_phase_rounds() > 0 { // Transition phase - advice_reduction_prover_untrusted.params.phase = ReductionPhase::AddressVariables; + advice_reduction_prover_untrusted.transition_to_address_phase(); instances.push(Box::new(advice_reduction_prover_untrusted)); } } + if let Some(mut bytecode_reduction_prover) = self.bytecode_reduction_prover.take() { + if bytecode_reduction_prover + .params() + .num_address_phase_rounds() + > 0 + { + bytecode_reduction_prover.transition_to_address_phase(); + instances.push(Box::new(bytecode_reduction_prover)); + } + } + if let Some(mut program_image_reduction_prover) = self.program_image_reduction_prover.take() + { + if program_image_reduction_prover + .params() + .num_address_phase_rounds() + > 0 + { + program_image_reduction_prover.transition_to_address_phase(); + instances.push(Box::new(program_image_reduction_prover)); + } + } #[cfg(feature = "allocative")] write_boxed_instance_flamegraph_svg(&instances, "stage7_start_flamechart.svg"); @@ -1913,70 +2126,62 @@ impl< ) -> PCS::Proof { tracing::info!("Stage 8 proving (Dory batch opening)"); - let _guard = DoryGlobals::initialize_context( - self.one_hot_params.k_chunk, - self.padded_trace_len, - DoryContext::Main, - Some(DoryGlobals::get_layout()), - ); - - // Get the unified opening point from HammingWeightClaimReduction - // This contains (r_address_stage7 || r_cycle_stage6) in big-endian - let (opening_point, _) = self.opening_accumulator.get_committed_polynomial_opening( - CommittedPolynomial::InstructionRa(0), - SumcheckId::HammingWeightClaimReduction, - ); - - let log_k_chunk = self.one_hot_params.log_k_chunk; - let r_address_stage7 = &opening_point.r[..log_k_chunk]; + let opening_point = self.stage8_opening_point(); let mut polynomial_claims = Vec::new(); let mut scaling_factors = Vec::new(); + let mut precommitted_polys: HashMap> = + HashMap::new(); - // Dense polynomials: RamInc and RdInc (from IncClaimReduction in Stage 6) - // at r_cycle_stage6 only (length log_T) - let (_, ram_inc_claim) = self.opening_accumulator.get_committed_polynomial_opening( - CommittedPolynomial::RamInc, - SumcheckId::IncClaimReduction, - ); - let (_, rd_inc_claim) = self.opening_accumulator.get_committed_polynomial_opening( - CommittedPolynomial::RdInc, - SumcheckId::IncClaimReduction, - ); + let (ram_inc_point, ram_inc_claim) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::RamInc, + SumcheckId::IncClaimReduction, + ); + let (rd_inc_point, rd_inc_claim) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::RdInc, + SumcheckId::IncClaimReduction, + ); - // Dense polynomials are zero-padded in the Dory matrix, so their evaluation - // includes a factor eq(r_addr, 0) = ∏(1 − r_addr_i). - let lagrange_factor: F = EqPolynomial::zero_selector(r_address_stage7); - polynomial_claims.push((CommittedPolynomial::RamInc, ram_inc_claim * lagrange_factor)); - scaling_factors.push(lagrange_factor); - polynomial_claims.push((CommittedPolynomial::RdInc, rd_inc_claim * lagrange_factor)); - scaling_factors.push(lagrange_factor); + let ram_inc_lagrange = compute_lagrange_factor::(&opening_point.r, &ram_inc_point.r); + let rd_inc_lagrange = compute_lagrange_factor::(&opening_point.r, &rd_inc_point.r); + polynomial_claims.push(( + CommittedPolynomial::RamInc, + ram_inc_claim * ram_inc_lagrange, + )); + scaling_factors.push(ram_inc_lagrange); + polynomial_claims.push((CommittedPolynomial::RdInc, rd_inc_claim * rd_inc_lagrange)); + scaling_factors.push(rd_inc_lagrange); // Sparse polynomials: all RA polys (from HammingWeightClaimReduction) // These are at (r_address_stage7, r_cycle_stage6) for i in 0..self.one_hot_params.instruction_d { - let (_, claim) = self.opening_accumulator.get_committed_polynomial_opening( + let (ra_point, claim) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::InstructionRa(i), SumcheckId::HammingWeightClaimReduction, ); - polynomial_claims.push((CommittedPolynomial::InstructionRa(i), claim)); - scaling_factors.push(F::one()); + let lagrange = compute_lagrange_factor::(&opening_point.r, &ra_point.r); + polynomial_claims.push((CommittedPolynomial::InstructionRa(i), claim * lagrange)); + scaling_factors.push(lagrange); } for i in 0..self.one_hot_params.bytecode_d { - let (_, claim) = self.opening_accumulator.get_committed_polynomial_opening( + let (ra_point, claim) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::BytecodeRa(i), SumcheckId::HammingWeightClaimReduction, ); - polynomial_claims.push((CommittedPolynomial::BytecodeRa(i), claim)); - scaling_factors.push(F::one()); + let lagrange = compute_lagrange_factor::(&opening_point.r, &ra_point.r); + polynomial_claims.push((CommittedPolynomial::BytecodeRa(i), claim * lagrange)); + scaling_factors.push(lagrange); } for i in 0..self.one_hot_params.ram_d { - let (_, claim) = self.opening_accumulator.get_committed_polynomial_opening( + let (ra_point, claim) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::RamRa(i), SumcheckId::HammingWeightClaimReduction, ); - polynomial_claims.push((CommittedPolynomial::RamRa(i), claim)); - scaling_factors.push(F::one()); + let lagrange = compute_lagrange_factor::(&opening_point.r, &ra_point.r); + polynomial_claims.push((CommittedPolynomial::RamRa(i), claim * lagrange)); + scaling_factors.push(lagrange); } // Advice polynomials: TrustedAdvice and UntrustedAdvice (from AdviceClaimReduction in Stage 6) @@ -1991,8 +2196,7 @@ impl< .opening_accumulator .get_advice_opening(AdviceKind::Trusted, SumcheckId::AdviceClaimReduction) { - let lagrange_factor = - compute_advice_lagrange_factor::(&opening_point.r, &advice_point.r); + let lagrange_factor = compute_lagrange_factor::(&opening_point.r, &advice_point.r); polynomial_claims.push(( CommittedPolynomial::TrustedAdvice, advice_claim * lagrange_factor, @@ -2008,8 +2212,7 @@ impl< .opening_accumulator .get_advice_opening(AdviceKind::Untrusted, SumcheckId::AdviceClaimReduction) { - let lagrange_factor = - compute_advice_lagrange_factor::(&opening_point.r, &advice_point.r); + let lagrange_factor = compute_lagrange_factor::(&opening_point.r, &advice_point.r); polynomial_claims.push(( CommittedPolynomial::UntrustedAdvice, advice_claim * lagrange_factor, @@ -2021,6 +2224,71 @@ impl< } } + if self.preprocessing.is_committed_mode() { + let chunk_count = self.preprocessing.shared.bytecode_chunk_count; + let chunk_cycle_len = committed_bytecode_chunk_cycle_len( + self.preprocessing + .materialized_program() + .bytecode + .bytecode + .len(), + chunk_count, + ); + for chunk_idx in 0..chunk_count { + let (chunk_point, chunk_claim) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::BytecodeChunk(chunk_idx), + SumcheckId::BytecodeClaimReduction, + ); + let lagrange_factor = + compute_lagrange_factor::(&opening_point.r, &chunk_point.r); + polynomial_claims.push(( + CommittedPolynomial::BytecodeChunk(chunk_idx), + chunk_claim * lagrange_factor, + )); + scaling_factors.push(lagrange_factor); + precommitted_polys.insert( + CommittedPolynomial::BytecodeChunk(chunk_idx), + PrecommittedPolynomial::BytecodeChunk { + chunk_index: chunk_idx, + chunk_cycle_len, + }, + ); + } + } + + if self.preprocessing.is_committed_mode() { + let padded_len = self + .preprocessing + .shared + .program + .committed_program_image_num_words(&self.program_io.memory_layout); + let (program_point, program_claim) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + ); + let lagrange_factor = compute_lagrange_factor::(&opening_point.r, &program_point.r); + polynomial_claims.push(( + CommittedPolynomial::ProgramImageInit, + program_claim * lagrange_factor, + )); + scaling_factors.push(lagrange_factor); + precommitted_polys.insert( + CommittedPolynomial::ProgramImageInit, + PrecommittedPolynomial::ProgramImage { + words: Arc::new( + self.preprocessing + .materialized_program() + .ram + .bytecode_words + .clone(), + ), + padded_len, + }, + ); + } + // 2. Sample gamma and compute powers for RLC let claims: Vec = polynomial_claims.iter().map(|(_, c)| *c).collect(); // In non-ZK mode, absorb claims before sampling gamma for Fiat-Shamir binding. @@ -2045,6 +2313,12 @@ impl< &self.one_hot_params, include_trusted_advice, include_untrusted_advice, + if self.preprocessing.is_committed_mode() { + ProgramMode::Committed + } else { + ProgramMode::Full + }, + self.preprocessing.shared.bytecode_chunk_count, ); // Build DoryOpeningState @@ -2055,17 +2329,22 @@ impl< }; let streaming_data = Arc::new(RLCStreamingData { - bytecode: Arc::clone(&self.preprocessing.shared.bytecode), + bytecode: Arc::new(self.preprocessing.materialized_program().bytecode.clone()), memory_layout: self.preprocessing.shared.memory_layout.clone(), }); - // Build advice polynomials map for RLC - let mut advice_polys = HashMap::new(); + // Add advice polynomials to precommitted polynomials map for RLC. if let Some(poly) = self.advice.trusted_advice_polynomial.take() { - advice_polys.insert(CommittedPolynomial::TrustedAdvice, poly); + precommitted_polys.insert( + CommittedPolynomial::TrustedAdvice, + PrecommittedPolynomial::Dense(poly), + ); } if let Some(poly) = self.advice.untrusted_advice_polynomial.take() { - advice_polys.insert(CommittedPolynomial::UntrustedAdvice, poly); + precommitted_polys.insert( + CommittedPolynomial::UntrustedAdvice, + PrecommittedPolynomial::Dense(poly), + ); } // Build streaming RLC polynomial directly (no witness poly regeneration!) @@ -2075,9 +2354,8 @@ impl< TraceSource::Materialized(Arc::clone(&self.trace)), streaming_data, opening_proof_hints, - advice_polys, + precommitted_polys, ); - let (proof, _y_blinding) = PCS::prove( &self.preprocessing.generators, &joint_poly, @@ -2142,42 +2420,100 @@ fn write_instance_flamegraph_svg( write_flamegraph_svg(flamegraph, path); } -#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Clone)] pub struct JoltProverPreprocessing< F: JoltField, C: JoltCurve, PCS: CommitmentScheme, > { pub generators: PCS::ProverSetup, - pub shared: JoltSharedPreprocessing, + pub shared: JoltSharedPreprocessing, _curve: std::marker::PhantomData, } -impl JoltProverPreprocessing +impl CanonicalSerialize for JoltProverPreprocessing where F: JoltField, C: JoltCurve, PCS: CommitmentScheme, { - #[tracing::instrument(skip_all, name = "JoltProverPreprocessing::gen")] - pub fn new(shared: JoltSharedPreprocessing) -> Self { - use common::constants::ONEHOT_CHUNK_THRESHOLD_LOG_T; - let max_T: usize = shared.max_padded_trace_length.next_power_of_two(); - let max_log_T = max_T.log_2(); - let max_log_k_chunk = if max_log_T < ONEHOT_CHUNK_THRESHOLD_LOG_T { - 4 - } else { - 8 - }; - let generators = PCS::setup_prover(max_log_k_chunk + max_log_T); + fn serialize_with_mode( + &self, + mut writer: W, + compress: ark_serialize::Compress, + ) -> Result<(), ark_serialize::SerializationError> { + self.generators.serialize_with_mode(&mut writer, compress)?; + self.shared.serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: ark_serialize::Compress) -> usize { + self.generators.serialized_size(compress) + self.shared.serialized_size(compress) + } +} - JoltProverPreprocessing { +impl ark_serialize::Valid for JoltProverPreprocessing +where + F: JoltField, + C: JoltCurve, + PCS: CommitmentScheme, +{ + fn check(&self) -> Result<(), ark_serialize::SerializationError> { + self.generators.check()?; + self.shared.check()?; + Ok(()) + } +} + +impl CanonicalDeserialize for JoltProverPreprocessing +where + F: JoltField, + C: JoltCurve, + PCS: CommitmentScheme, +{ + fn deserialize_with_mode( + mut reader: R, + compress: ark_serialize::Compress, + validate: ark_serialize::Validate, + ) -> Result { + Ok(Self { + generators: PCS::ProverSetup::deserialize_with_mode(&mut reader, compress, validate)?, + shared: JoltSharedPreprocessing::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + _curve: std::marker::PhantomData, + }) + } +} + +impl JoltProverPreprocessing +where + F: JoltField, + C: JoltCurve, + PCS: CommitmentScheme, +{ + pub fn new_with_generators( + shared: JoltSharedPreprocessing, + generators: PCS::ProverSetup, + ) -> Self { + Self { generators, shared, _curve: std::marker::PhantomData, } } + #[tracing::instrument(skip_all, name = "JoltProverPreprocessing::new")] + pub fn new(shared: JoltSharedPreprocessing) -> Self { + let committed_mode = shared.program.is_committed(); + let (max_total_vars, _) = shared.compute_max_total_vars(committed_mode); + let generators = PCS::setup_prover(max_total_vars); + + Self::new_with_generators(shared, generators) + } + #[cfg(feature = "zk")] pub fn blindfold_setup(&self) -> BlindfoldSetup where @@ -2202,6 +2538,17 @@ where ) } + pub fn is_committed_mode(&self) -> bool { + self.shared.program.is_committed() + } + + pub fn materialized_program(&self) -> &FullProgramPreprocessing { + self.shared + .program + .as_full() + .expect("prover requires materialized program preprocessing") + } + pub fn save_to_target_dir(&self, target_dir: &str) -> std::io::Result<()> { let filename = Path::new(target_dir).join("jolt_prover_preprocessing.dat"); let mut file = File::create(filename.as_path())?; @@ -2249,7 +2596,9 @@ mod tests { multilinear_polynomial::MultilinearPolynomial, opening_proof::{OpeningAccumulator, SumcheckId}, }; + use crate::zkvm::bytecode::PreprocessingError; use crate::zkvm::claim_reductions::AdviceKind; + use crate::zkvm::program::ProgramPreprocessing; use crate::zkvm::verifier::JoltSharedPreprocessing; use crate::zkvm::witness::CommittedPolynomial; use crate::zkvm::{ @@ -2310,23 +2659,51 @@ mod tests { (commitment, hint) } + fn test_shared_preprocessing( + bytecode: Vec, + init_memory_state: Vec<(u64, u8)>, + memory_layout: common::jolt_device::MemoryLayout, + max_trace_len: usize, + ) -> Result<(JoltSharedPreprocessing, Arc), PreprocessingError> { + let program = ProgramPreprocessing::preprocess(bytecode, init_memory_state)?; + let shared = JoltSharedPreprocessing::new(program.clone(), memory_layout, max_trace_len); + let program = Arc::new(program); + Ok((shared, program)) + } + + fn test_shared_preprocessing_committed( + bytecode: Vec, + init_memory_state: Vec<(u64, u8)>, + memory_layout: common::jolt_device::MemoryLayout, + max_trace_len: usize, + bytecode_chunk_count: usize, + ) -> Result<(JoltSharedPreprocessing, Arc), PreprocessingError> { + let program = ProgramPreprocessing::preprocess(bytecode, init_memory_state)?; + let shared = JoltSharedPreprocessing::new_committed( + program.clone(), + memory_layout, + max_trace_len, + bytecode_chunk_count, + ); + let program = Arc::new(program); + Ok((shared, program)) + } + #[test] #[serial] fn fib_e2e_dory() { DoryGlobals::reset(); let mut program = host::Program::new("fibonacci-guest"); let inputs = postcard::to_stdvec(&100u32).unwrap(); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing); let elf_contents_opt = program.get_elf_contents(); let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); @@ -2361,18 +2738,15 @@ mod tests { DoryGlobals::reset(); let mut program = host::Program::new("fibonacci-guest"); let inputs = postcard::to_stdvec(&5u32).unwrap(); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 8192, - e_entry, ) .unwrap(); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); let elf_contents_opt = program.get_elf_contents(); let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); @@ -2416,19 +2790,17 @@ mod tests { DoryGlobals::reset(); let mut program = host::Program::new("sha3-guest"); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let inputs = postcard::to_stdvec(&[5u8; 32]).unwrap(); let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); let elf_contents_opt = program.get_elf_contents(); let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); @@ -2474,19 +2846,17 @@ mod tests { DoryGlobals::reset(); let mut program = host::Program::new("sha2-guest"); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let inputs = postcard::to_stdvec(&[5u8; 32]).unwrap(); let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); let elf_contents_opt = program.get_elf_contents(); let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); @@ -2535,19 +2905,18 @@ mod tests { // - Trusted: commit in preprocessing-only context, reduce in Stage 6, batch in Stage 8 // - Untrusted: commit at prove time, reduce in Stage 6, batch in Stage 8 let mut program = host::Program::new("sha2-guest"); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let inputs = postcard::to_stdvec(&[5u8; 32]).unwrap(); let trusted_advice = postcard::to_stdvec(&[7u8; 32]).unwrap(); let untrusted_advice = postcard::to_stdvec(&[9u8; 32]).unwrap(); let (_, _, _, io_device) = program.trace(&inputs, &untrusted_advice, &trusted_advice); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); @@ -2602,16 +2971,15 @@ mod tests { let trusted_advice = vec![7u8; 4096]; let untrusted_advice = vec![9u8; 4096]; - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let (lazy_trace, trace, final_memory_state, io_device) = program.trace(&inputs, &untrusted_advice, &trusted_advice); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 4096, - e_entry, ) .unwrap(); let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); @@ -2660,7 +3028,7 @@ mod tests { // Tests a guest (merkle-tree) that actually consumes both trusted and untrusted advice. let mut program = host::Program::new("merkle-tree-guest"); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); // Merkle tree with 4 leaves: input=leaf1, trusted=[leaf2, leaf3], untrusted=leaf4 let inputs = postcard::to_stdvec(&[5u8; 32].as_slice()).unwrap(); @@ -2669,12 +3037,11 @@ mod tests { trusted_advice.extend(postcard::to_stdvec(&[7u8; 32]).unwrap()); let (_, _, _, io_device) = program.trace(&inputs, &untrusted_advice, &trusted_advice); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); @@ -2731,16 +3098,15 @@ mod tests { let trusted_advice = postcard::to_stdvec(&[7u8; 32]).unwrap(); let untrusted_advice = postcard::to_stdvec(&[9u8; 32]).unwrap(); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let (lazy_trace, trace, final_memory_state, io_device) = program.trace(&inputs, &untrusted_advice, &trusted_advice); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); @@ -2823,18 +3189,16 @@ mod tests { fn memory_ops_e2e_dory() { DoryGlobals::reset(); let mut program = host::Program::new("memory-ops-guest"); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let (_, _, _, io_device) = program.trace(&[], &[], &[]); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); let elf_contents_opt = program.get_elf_contents(); let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); @@ -2868,19 +3232,17 @@ mod tests { fn btreemap_e2e_dory() { DoryGlobals::reset(); let mut program = host::Program::new("btreemap-guest"); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let inputs = postcard::to_stdvec(&50u32).unwrap(); let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); let elf_contents_opt = program.get_elf_contents(); let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); @@ -2914,19 +3276,61 @@ mod tests { fn muldiv_e2e_dory() { DoryGlobals::reset(); let mut program = host::Program::new("muldiv-guest"); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let inputs = postcard::to_stdvec(&[9u32, 5u32, 3u32]).unwrap(); let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); + let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); + let elf_contents_opt = program.get_elf_contents(); + let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); + let prover = RV64IMACProver::gen_from_elf( + &prover_preprocessing, + elf_contents, + &inputs, + &[], + &[], + None, + None, + None, + ); + let io_device = prover.program_io.clone(); + let (jolt_proof, debug_info) = prover.prove(); + let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); + let verifier = RV64IMACVerifier::new( + &verifier_preprocessing, + jolt_proof, + io_device, + None, + debug_info, + ) + .expect("Failed to create verifier"); + verifier.verify().expect("Failed to verify proof"); + } + + #[test] + #[serial] + fn muldiv_e2e_dory_committed_program_commitments() { + DoryGlobals::reset(); + let mut program = host::Program::new("muldiv-guest"); + let (bytecode, init_memory_state, _, _) = program.decode(); + let inputs = postcard::to_stdvec(&[9u32, 5u32, 3u32]).unwrap(); + let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); + let (shared_preprocessing, _program_data) = test_shared_preprocessing_committed( + bytecode, + init_memory_state, + io_device.memory_layout.clone(), + 1 << 16, + 1, + ) + .unwrap(); let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); let elf_contents_opt = program.get_elf_contents(); let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); @@ -2965,18 +3369,16 @@ mod tests { program.set_std(true); program.set_func("int_to_string"); let inputs = postcard::to_stdvec(&81i32).unwrap(); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); let elf_contents_opt = program.get_elf_contents(); let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); @@ -3025,7 +3427,6 @@ mod tests { }; use crate::subprotocols::sumcheck::SumcheckInstanceProof; use crate::transcripts::{KeccakTranscript, Transcript}; - use crate::zkvm::verifier::JoltSharedPreprocessing; /// Helper to process a single stage's sumcheck proof. /// Returns a list of (RoundWitness, degree) for each round. /// For ZK proofs, creates synthetic witnesses with correct degrees to test R1CS structure. @@ -3119,16 +3520,14 @@ mod tests { // Run muldiv prover to get a real proof let mut program = host::Program::new("muldiv-guest"); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let inputs = postcard::to_stdvec(&[9u32, 5u32, 3u32]).unwrap(); let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); let preprocessing = JoltProverPreprocessing::new(shared_preprocessing); @@ -3160,7 +3559,7 @@ mod tests { ("Stage 5 (Value+Lookup)", &jolt_proof.stage5_sumcheck_proof), ( "Stage 6 (OneHot+Hamming)", - &jolt_proof.stage6_sumcheck_proof, + &jolt_proof.stage6b_sumcheck_proof, ), ( "Stage 7 (HammingWeight+ClaimReduction)", @@ -3240,22 +3639,20 @@ mod tests { #[should_panic] fn truncated_trace() { let mut program = host::Program::new("fibonacci-guest"); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let inputs = postcard::to_stdvec(&9u8).unwrap(); let (lazy_trace, mut trace, final_memory_state, mut program_io) = program.trace(&inputs, &[], &[]); trace.truncate(100); program_io.outputs[0] = 0; // change the output to 0 - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - program_io.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + program_io.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); let prover = RV64IMACProver::gen_from_trace( @@ -3282,17 +3679,16 @@ mod tests { fn malicious_trace() { let mut program = host::Program::new("fibonacci-guest"); let inputs = postcard::to_stdvec(&1u8).unwrap(); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let (lazy_trace, trace, final_memory_state, mut program_io) = program.trace(&inputs, &[], &[]); // Since the preprocessing is done with the original memory layout, the verifier should fail - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - program_io.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + program_io.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); @@ -3334,15 +3730,13 @@ mod tests { let inputs = postcard::to_stdvec(&9u8).unwrap(); let (bytecode, init_memory_state, _, e_entry) = program.decode(); let (lazy_trace, trace, final_memory_state, program_io) = program.trace(&inputs, &[], &[]); - + let original_program = + Arc::new(ProgramPreprocessing::preprocess(bytecode, init_memory_state).unwrap()); let shared = JoltSharedPreprocessing::new( - bytecode.clone(), + (*original_program).clone(), program_io.memory_layout.clone(), - init_memory_state, 1 << 16, - e_entry, - ) - .unwrap(); + ); let prover_preprocessing = JoltProverPreprocessing::new(shared.clone()); let prover = RV64IMACProver::gen_from_trace( &prover_preprocessing, @@ -3355,14 +3749,13 @@ mod tests { ); let (proof, _) = prover.prove(); - let original_entry_index = shared.bytecode.entry_bytecode_index(); + let original_entry_index = original_program.entry_bytecode_index(); // Tamper: give verifier a wrong entry_address so it computes a different // entry_bytecode_index and thus a different input_claim expectation. let mut tampered_shared = shared.clone(); - let mut tampered_bytecode = (*tampered_shared.bytecode).clone(); - tampered_bytecode.entry_address = e_entry.wrapping_add(4); - tampered_shared.bytecode = Arc::new(tampered_bytecode); - let tampered_entry_index = tampered_shared.bytecode.entry_bytecode_index(); + tampered_shared.program_meta.entry_address = e_entry.wrapping_add(4); + let tampered_entry_index = tampered_shared.program_meta.entry_address as usize + / common::constants::BYTES_PER_INSTRUCTION; assert_ne!( original_entry_index, tampered_entry_index, "tamper did not change entry_bytecode_index — test scenario is invalid" @@ -3506,15 +3899,14 @@ mod tests { let mut program = host::Program::new("fibonacci-guest"); let inputs = postcard::to_stdvec(&50u32).unwrap(); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing); @@ -3549,7 +3941,7 @@ mod tests { // Tests a guest (merkle-tree) that actually consumes both trusted and untrusted advice. let mut program = host::Program::new("merkle-tree-guest"); - let (bytecode, init_memory_state, _, e_entry) = program.decode(); + let (bytecode, init_memory_state, _, _) = program.decode(); // Merkle tree with 4 leaves: input=leaf1, trusted=[leaf2, leaf3], untrusted=leaf4 let inputs = postcard::to_stdvec(&[5u8; 32].as_slice()).unwrap(); @@ -3558,12 +3950,11 @@ mod tests { trusted_advice.extend(postcard::to_stdvec(&[7u8; 32]).unwrap()); let (_, _, _, io_device) = program.trace(&inputs, &untrusted_advice, &trusted_advice); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), + let (shared_preprocessing, _program_data) = test_shared_preprocessing( + bytecode, init_memory_state, + io_device.memory_layout.clone(), 1 << 16, - e_entry, ) .unwrap(); let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); diff --git a/jolt-core/src/zkvm/ram/mod.rs b/jolt-core/src/zkvm/ram/mod.rs index d67a59312..e1664988c 100644 --- a/jolt-core/src/zkvm/ram/mod.rs +++ b/jolt-core/src/zkvm/ram/mod.rs @@ -329,6 +329,61 @@ pub fn verifier_accumulate_advice( + ram_K: usize, + ram_preprocessing: &RAMPreprocessing, + program_io: &JoltDevice, + opening_accumulator: &mut ProverOpeningAccumulator, +) { + let total_vars = ram_K.log_2(); + let bytecode_start = remap_address( + ram_preprocessing.min_bytecode_address, + &program_io.memory_layout, + ) + .unwrap() as usize; + + let (r_rw, _) = opening_accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::RamVal, + SumcheckId::RamReadWriteChecking, + ); + let (r_address_rw, _) = r_rw.split_at(total_vars); + let c_rw = sparse_eval_block::( + bytecode_start, + &ram_preprocessing.bytecode_words, + &r_address_rw.r, + ); + opening_accumulator.append_virtual( + VirtualPolynomial::ProgramImageInitContributionRw, + SumcheckId::RamValCheck, + r_address_rw, + c_rw, + ); +} + +/// Mirrors [`prover_accumulate_program_image`] on verifier side by caching opening points. +pub fn verifier_accumulate_program_image( + ram_K: usize, + opening_accumulator: &mut impl AbstractVerifierOpeningAccumulator, +) { + let total_vars = ram_K.log_2(); + let (r_rw, _) = opening_accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::RamVal, + SumcheckId::RamReadWriteChecking, + ); + let (r_address_rw, _) = r_rw.split_at(total_vars); + opening_accumulator.append_virtual( + VirtualPolynomial::ProgramImageInitContributionRw, + SumcheckId::RamValCheck, + r_address_rw, + ); +} + /// Calculates how advice inputs contribute to the evaluation of initial_ram_state at a given random point. /// /// ## Example with Two Commitments: @@ -485,6 +540,34 @@ pub fn reconstruct_full_eval( eval } +/// Evaluate just the public input words at a random RAM address point. +/// +/// Inputs are packed into little-endian `u64` words and placed at +/// `memory_layout.input_start`. +pub fn eval_inputs_mle(program_io: &JoltDevice, r_address: &[F::Challenge]) -> F { + if program_io.inputs.is_empty() { + return F::zero(); + } + + let input_start = remap_address( + program_io.memory_layout.input_start, + &program_io.memory_layout, + ) + .unwrap() as usize; + let input_words: Vec = program_io + .inputs + .chunks(8) + .map(|chunk| { + let mut word = [0u8; 8]; + for (i, byte) in chunk.iter().enumerate() { + word[i] = *byte; + } + u64::from_le_bytes(word) + }) + .collect(); + sparse_eval_block::(input_start, &input_words, r_address) +} + /// Trait for coefficient types usable in sparse MLE evaluation. /// /// `u64` uses Barrett-reduced accumulation (`MedAccumU`) for performance. @@ -595,25 +678,7 @@ pub fn eval_initial_ram_mle( let mut acc = sparse_eval_block::(bytecode_start, &ram_preprocessing.bytecode_words, r_address); - if !program_io.inputs.is_empty() { - let input_start = remap_address( - program_io.memory_layout.input_start, - &program_io.memory_layout, - ) - .unwrap() as usize; - let input_words: Vec = program_io - .inputs - .chunks(8) - .map(|chunk| { - let mut word = [0u8; 8]; - for (i, byte) in chunk.iter().enumerate() { - word[i] = *byte; - } - u64::from_le_bytes(word) - }) - .collect(); - acc += sparse_eval_block::(input_start, &input_words, r_address); - } + acc += eval_inputs_mle::(program_io, r_address); acc } diff --git a/jolt-core/src/zkvm/ram/val_check.rs b/jolt-core/src/zkvm/ram/val_check.rs index fec20a587..6e326cde8 100644 --- a/jolt-core/src/zkvm/ram/val_check.rs +++ b/jolt-core/src/zkvm/ram/val_check.rs @@ -81,12 +81,14 @@ pub struct RamValCheckSumcheckParams { /// Val_init(r_address) evaluation to subtract on both LHS terms. pub init_eval: F, - /// Public-only portion of init_eval (bytecode + inputs), used by BlindFold constraint. + /// Public constant portion of init_eval used by BlindFold constraints. + /// In committed-program mode this is inputs-only; program image is an opening. #[cfg(feature = "zk")] pub init_eval_public: F, /// Advice contributions decomposed for BlindFold: each is (-selector, opening_id). #[cfg(feature = "zk")] pub advice_contributions: Vec<(F, OpeningId)>, + pub include_program_image_claims: bool, } impl RamValCheckSumcheckParams { @@ -98,6 +100,8 @@ impl RamValCheckSumcheckParams { gamma: F, ram_preprocessing: &super::RAMPreprocessing, program_io: &JoltDevice, + _rw_config: &ReadWriteConfig, + include_program_image_claims: bool, ) -> Self { let K = one_hot_params.ram_k; @@ -124,8 +128,11 @@ impl RamValCheckSumcheckParams { let init_eval = val_init.evaluate(&r_address.r); #[cfg(feature = "zk")] - let init_eval_public = - super::eval_initial_ram_mle::(ram_preprocessing, program_io, &r_address.r); + let init_eval_public = if include_program_image_claims { + super::eval_inputs_mle::(program_io, &r_address.r) + } else { + super::eval_initial_ram_mle::(ram_preprocessing, program_io, &r_address.r) + }; #[cfg(feature = "zk")] let advice_contributions = super::compute_advice_init_contributions( @@ -151,6 +158,7 @@ impl RamValCheckSumcheckParams { init_eval_public, #[cfg(feature = "zk")] advice_contributions, + include_program_image_claims, } } @@ -163,6 +171,7 @@ impl RamValCheckSumcheckParams { _rw_config: &ReadWriteConfig, gamma: F, opening_accumulator: &dyn OpeningAccumulator, + include_program_image_claims: bool, ) -> Self { // (r_address, r_cycle) from RamVal/RamReadWriteChecking. let (r, _) = opening_accumulator.get_virtual_polynomial_opening( @@ -183,8 +192,22 @@ impl RamValCheckSumcheckParams { } let n_memory_vars = ram_K.log_2(); - let init_eval_public = - super::eval_initial_ram_mle::(ram_preprocessing, program_io, &r_address.r); + let init_eval_public_base = if include_program_image_claims { + super::eval_inputs_mle::(program_io, &r_address.r) + } else { + super::eval_initial_ram_mle::(ram_preprocessing, program_io, &r_address.r) + }; + let program_image_contribution = if include_program_image_claims { + opening_accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::ProgramImageInitContributionRw, + SumcheckId::RamValCheck, + ) + .1 + } else { + F::zero() + }; + let init_eval_public = init_eval_public_base + program_image_contribution; let advice_contributions = super::compute_advice_init_contributions( opening_accumulator, &program_io.memory_layout, @@ -199,7 +222,9 @@ impl RamValCheckSumcheckParams { ); #[cfg(not(feature = "zk"))] - let _ = (init_eval_public, advice_contributions); + let _ = (init_eval_public_base, advice_contributions); + #[cfg(feature = "zk")] + let init_eval_public = init_eval_public_base; Self { T: trace_len, @@ -212,6 +237,7 @@ impl RamValCheckSumcheckParams { init_eval_public, #[cfg(feature = "zk")] advice_contributions, + include_program_image_claims, } } } @@ -248,12 +274,17 @@ impl SumcheckInstanceParams for RamValCheckSumcheckParams { fn input_claim_constraint(&self) -> InputClaimConstraint { // input_claim = (val_rw - init_eval) + γ*(val_final - init_eval) // = val_rw + γ*val_final - (1+γ)*init_eval - // = val_rw + γ*val_final - (1+γ)*(init_eval_public + Σ(sel_i * advice_i)) + // where: + // - in full-program mode: + // init_eval = init_eval_public + Σ(sel_i * advice_i) + // - in committed-program mode: + // init_eval = init_eval_public + program_image_claim + Σ(sel_i * advice_i) // // Challenge layout: // Challenge(0) = γ // Challenge(1) = -(1+γ)*init_eval_public - // Challenge(2..) = -(1+γ)*selector_i (one per advice contribution) + // Challenge(2) = -(1+γ) (program-image claim; committed mode only) + // Challenge(next..) = -(1+γ)*selector_i (one per advice contribution) let val_rw = OpeningId::virt(VirtualPolynomial::RamVal, SumcheckId::RamReadWriteChecking); let val_final = OpeningId::virt(VirtualPolynomial::RamValFinal, SumcheckId::RamOutputCheck); @@ -265,11 +296,23 @@ impl SumcheckInstanceParams for RamValCheckSumcheckParams { ), ProductTerm::single(ValueSource::Challenge(1)), ]; - for (i, (_, advice_opening_id)) in self.advice_contributions.iter().enumerate() { + let mut challenge_idx = 2; + if self.include_program_image_claims { terms.push(ProductTerm::product(vec![ - ValueSource::Challenge(i + 2), + ValueSource::Challenge(challenge_idx), + ValueSource::Opening(OpeningId::virt( + VirtualPolynomial::ProgramImageInitContributionRw, + SumcheckId::RamValCheck, + )), + ])); + challenge_idx += 1; + } + for (_, advice_opening_id) in self.advice_contributions.iter() { + terms.push(ProductTerm::product(vec![ + ValueSource::Challenge(challenge_idx), ValueSource::Opening(*advice_opening_id), ])); + challenge_idx += 1; } InputClaimConstraint::sum_of_products(terms) } @@ -278,6 +321,9 @@ impl SumcheckInstanceParams for RamValCheckSumcheckParams { fn input_constraint_challenge_values(&self, _: &dyn OpeningAccumulator) -> Vec { let one_plus_gamma = F::one() + self.gamma; let mut values = vec![self.gamma, -one_plus_gamma * self.init_eval_public]; + if self.include_program_image_claims { + values.push(-one_plus_gamma); + } for (neg_selector, _) in &self.advice_contributions { // neg_selector is already negative (-selector_i), scale by (1+γ) values.push(one_plus_gamma * *neg_selector); @@ -474,6 +520,7 @@ impl RamValCheckSumcheckVerifier { rw_config: &ReadWriteConfig, gamma: F, opening_accumulator: &dyn OpeningAccumulator, + include_program_image_claims: bool, ) -> Self { let params = RamValCheckSumcheckParams::new_from_verifier( initial_ram_state, @@ -484,6 +531,7 @@ impl RamValCheckSumcheckVerifier { rw_config, gamma, opening_accumulator, + include_program_image_claims, ); Self { params } } diff --git a/jolt-core/src/zkvm/transpilable_verifier.rs b/jolt-core/src/zkvm/transpilable_verifier.rs index 201a75cee..fc3a91c41 100644 --- a/jolt-core/src/zkvm/transpilable_verifier.rs +++ b/jolt-core/src/zkvm/transpilable_verifier.rs @@ -41,20 +41,27 @@ use crate::curve::JoltCurve; use crate::poly::commitment::commitment_scheme::CommitmentScheme; +use crate::poly::commitment::dory::DoryGlobals; #[cfg(not(feature = "zk"))] use crate::poly::opening_proof::{OpeningPoint, BIG_ENDIAN}; use crate::subprotocols::sumcheck::{BatchedSumcheck, ClearSumcheckProof, SumcheckInstanceProof}; use crate::zkvm::claim_reductions::{ - AdviceClaimReductionVerifier, AdviceKind, HammingWeightClaimReductionVerifier, ReductionPhase, + AdviceClaimReductionVerifier, AdviceKind, BytecodeClaimReductionParams, + BytecodeClaimReductionVerifier, HammingWeightClaimReductionVerifier, PrecommittedPhase, + ProgramImageClaimReductionParams, ProgramImageClaimReductionVerifier, RegistersClaimReductionSumcheckVerifier, }; use crate::zkvm::config::OneHotParams; use crate::zkvm::{ - bytecode::read_raf_checking::BytecodeReadRafSumcheckVerifier, + bytecode::read_raf_checking::{ + BytecodeReadRafAddressSumcheckVerifier, BytecodeReadRafCycleSumcheckVerifier, + BytecodeReadRafSumcheckParams, + }, claim_reductions::{ IncClaimReductionSumcheckVerifier, InstructionLookupsClaimReductionSumcheckVerifier, RamRaClaimReductionSumcheckVerifier, }, + config::ProgramMode, fiat_shamir_preamble, instruction_lookups::{ ra_virtual::RaSumcheckVerifier as LookupsRaSumcheckVerifier, @@ -67,7 +74,7 @@ use crate::zkvm::{ output_check::OutputSumcheckVerifier, ra_virtual::RamRaVirtualSumcheckVerifier, raf_evaluation::RafEvaluationSumcheckVerifier as RamRafEvaluationSumcheckVerifier, read_write_checking::RamReadWriteCheckingVerifier, val_check::RamValCheckSumcheckVerifier, - verifier_accumulate_advice, + verifier_accumulate_advice, verifier_accumulate_program_image, }, registers::{ read_write_checking::RegistersReadWriteCheckingVerifier, @@ -78,7 +85,7 @@ use crate::zkvm::{ product::ProductVirtualRemainderVerifier, shift::ShiftSumcheckVerifier, verify_stage1_uni_skip, verify_stage2_uni_skip, }, - verifier::JoltVerifierPreprocessing, + verifier::{JoltSharedPreprocessing, JoltVerifierPreprocessing}, ProverDebugInfo, }; use crate::{ @@ -86,7 +93,10 @@ use crate::{ poly::opening_proof::{AbstractVerifierOpeningAccumulator, VerifierOpeningAccumulator}, pprof_scope, subprotocols::{ - booleanity::{BooleanitySumcheckParams, BooleanitySumcheckVerifier}, + booleanity::{ + BooleanityAddressSumcheckVerifier, BooleanityCycleSumcheckVerifier, + BooleanitySumcheckParams, + }, sumcheck_verifier::SumcheckInstanceVerifier, }, transcripts::Transcript, @@ -128,12 +138,10 @@ pub struct TranspilableVerifier< pub opening_accumulator: A, pub spartan_key: UniformSpartanKey, pub one_hot_params: OneHotParams, - /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). - /// Cache the verifier state here between stages. advice_reduction_verifier_trusted: Option>, - /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). - /// Cache the verifier state here between stages. advice_reduction_verifier_untrusted: Option>, + bytecode_reduction_verifier: Option>, + program_image_reduction_verifier: Option>, } impl< @@ -216,10 +224,11 @@ impl< .validate() .map_err(ProofVerifyError::InvalidOneHotConfig)?; - let min_ram_K = compute_min_ram_K( - &preprocessing.shared.ram, - &preprocessing.shared.memory_layout, - ); + let ram_preprocessing = crate::zkvm::ram::RAMPreprocessing { + min_bytecode_address: preprocessing.shared.program_meta.min_bytecode_address, + bytecode_words: vec![0; preprocessing.shared.program_meta.program_image_len_words], + }; + let min_ram_K = compute_min_ram_K(&ram_preprocessing, &preprocessing.shared.memory_layout); if !proof.ram_K.is_power_of_two() || proof.ram_K < min_ram_K { return Err(ProofVerifyError::InvalidRamK(proof.ram_K, min_ram_K)); } @@ -230,7 +239,7 @@ impl< .map_err(ProofVerifyError::InvalidReadWriteConfig)?; // Construct full params from the validated config - let bytecode_K = preprocessing.shared.bytecode.code_size; + let bytecode_K = preprocessing.shared.bytecode_size(); let one_hot_params = OneHotParams::from_config(&proof.one_hot_config, bytecode_K, proof.ram_K); @@ -245,6 +254,8 @@ impl< one_hot_params, advice_reduction_verifier_trusted: None, advice_reduction_verifier_untrusted: None, + bytecode_reduction_verifier: None, + program_image_reduction_verifier: None, }) } @@ -261,7 +272,7 @@ impl< opening_accumulator: A, ) -> Self { let spartan_key = UniformSpartanKey::new(proof.trace_length.next_power_of_two()); - let bytecode_K = preprocessing.shared.bytecode.code_size; + let bytecode_K = preprocessing.shared.bytecode_size(); let one_hot_params = OneHotParams::from_config(&proof.one_hot_config, bytecode_K, proof.ram_K); @@ -276,9 +287,25 @@ impl< one_hot_params, advice_reduction_verifier_trusted: None, advice_reduction_verifier_untrusted: None, + bytecode_reduction_verifier: None, + program_image_reduction_verifier: None, } } + #[inline] + fn main_total_vars(&self) -> usize { + let trace_log_t = self.proof.trace_length.log_2(); + let log_k_chunk = self.one_hot_params.log_k_chunk; + JoltSharedPreprocessing::::max_total_vars_from_candidates( + trace_log_t + log_k_chunk, + self.preprocessing.shared.precommitted_candidate_total_vars( + self.preprocessing.shared.program.is_committed(), + self.trusted_advice_commitment.is_some(), + self.proof.untrusted_advice_commitment.is_some(), + ), + ) + } + /// Verify the Jolt proof (stages 1-7). /// /// Note: Stage 8 (PCS verification) is not included because it uses @@ -293,7 +320,7 @@ impl< &self.program_io, self.proof.ram_K, self.proof.trace_length, - self.preprocessing.shared.bytecode.entry_address, + self.preprocessing.shared.program_meta.entry_address, &self.proof.rw_config, &self.proof.one_hot_config, self.proof.dory_layout, @@ -471,23 +498,51 @@ impl< self.trusted_advice_commitment.is_some(), &mut self.opening_accumulator, ); + if self.preprocessing.shared.program.is_committed() { + verifier_accumulate_program_image::(self.proof.ram_K, &mut self.opening_accumulator); + } // Domain-separate the batching challenge. self.transcript.append_bytes(b"ram_val_check_gamma", &[]); let ram_val_check_gamma: F = self.transcript.challenge_scalar::(); - let initial_ram_state = crate::zkvm::ram::gen_ram_initial_memory_state::( - self.proof.ram_K, - &self.preprocessing.shared.ram, - &self.program_io, - ); + let initial_ram_state = if self.preprocessing.shared.program.is_full() { + crate::zkvm::ram::gen_ram_initial_memory_state::( + self.proof.ram_K, + &self.preprocessing.shared.program.as_full().unwrap().ram, + &self.program_io, + ) + } else { + vec![0u64; self.proof.ram_K] + }; + let ram_preprocessing = if self.preprocessing.shared.program.is_full() { + self.preprocessing + .shared + .program + .as_full() + .unwrap() + .ram + .clone() + } else { + crate::zkvm::ram::RAMPreprocessing { + min_bytecode_address: self.preprocessing.shared.program_meta.min_bytecode_address, + bytecode_words: vec![ + 0; + self.preprocessing + .shared + .program_meta + .program_image_len_words + ], + } + }; let ram_val_check = RamValCheckSumcheckVerifier::new( &initial_ram_state, &self.program_io, - &self.preprocessing.shared.ram, + &ram_preprocessing, self.proof.trace_length, self.proof.ram_K, &self.proof.rw_config, ram_val_check_gamma, &self.opening_accumulator, + self.preprocessing.shared.program.is_committed(), ); let instances: Vec<&dyn SumcheckInstanceVerifier> = @@ -538,25 +593,85 @@ impl< } fn verify_stage6(&mut self) -> Result<(), ProofVerifyError> { + let _ = DoryGlobals::initialize_main_with_log_embedding( + self.one_hot_params.k_chunk, + self.proof.trace_length, + self.main_total_vars(), + Some(self.proof.dory_layout), + ); + let (bytecode_read_raf_params, booleanity_params) = self.verify_stage6a()?; + self.verify_stage6b(bytecode_read_raf_params, booleanity_params)?; + Ok(()) + } + + fn verify_stage6a( + &mut self, + ) -> Result< + ( + BytecodeReadRafSumcheckParams, + BooleanitySumcheckParams, + ), + ProofVerifyError, + > { let n_cycle_vars = self.proof.trace_length.log_2(); - let bytecode_read_raf = BytecodeReadRafSumcheckVerifier::gen( - &self.preprocessing.shared.bytecode, + let program_mode = if self.preprocessing.shared.program.is_committed() { + ProgramMode::Committed + } else { + ProgramMode::Full + }; + let entry_bytecode_index = self + .preprocessing + .shared + .program_meta + .entry_address + .saturating_sub(self.preprocessing.shared.program_meta.min_bytecode_address) + as usize + / common::constants::BYTES_PER_INSTRUCTION + + 1; + let bytecode_read_raf = BytecodeReadRafAddressSumcheckVerifier::new::( + Some(&self.preprocessing.shared.program), n_cycle_vars, &self.one_hot_params, &self.opening_accumulator, &mut self.transcript, - ); - - let ram_hamming_booleanity = - HammingBooleanitySumcheckVerifier::new(&self.opening_accumulator); + program_mode, + entry_bytecode_index, + )?; let booleanity_params = BooleanitySumcheckParams::new( n_cycle_vars, &self.one_hot_params, &self.opening_accumulator, &mut self.transcript, ); + let booleanity = BooleanityAddressSumcheckVerifier::new(booleanity_params.clone()); + + let instances: Vec<&dyn SumcheckInstanceVerifier> = + vec![&bytecode_read_raf, &booleanity]; + + let _r_stage6a = BatchedSumcheck::verify_standard::( + extract_clear_proof(&self.proof.stage6a_sumcheck_proof), + instances, + &mut self.opening_accumulator, + &mut self.transcript, + )?; - let booleanity = BooleanitySumcheckVerifier::new(booleanity_params); + Ok((bytecode_read_raf.into_params(), booleanity_params)) + } + + fn verify_stage6b( + &mut self, + bytecode_read_raf_params: BytecodeReadRafSumcheckParams, + booleanity_params: BooleanitySumcheckParams, + ) -> Result<(), ProofVerifyError> { + let bytecode_reduction_seed_params = bytecode_read_raf_params.clone(); + let bytecode_read_raf = BytecodeReadRafCycleSumcheckVerifier::new( + bytecode_read_raf_params, + &self.opening_accumulator, + ); + let ram_hamming_booleanity = + HammingBooleanitySumcheckVerifier::new(&self.opening_accumulator); + let booleanity = + BooleanityCycleSumcheckVerifier::new(booleanity_params, &self.opening_accumulator); let ram_ra_virtual = RamRaVirtualSumcheckVerifier::new( self.proof.trace_length, &self.one_hot_params, @@ -574,21 +689,65 @@ impl< &mut self.transcript, ); - // Advice claim reduction (Phase 1 in Stage 6): trusted and untrusted are separate instances. + let main_total_vars = self.proof.trace_length.log_2() + self.one_hot_params.log_k_chunk; + let precommitted_candidates = self.preprocessing.shared.precommitted_candidate_total_vars( + self.preprocessing.shared.program.is_committed(), + self.trusted_advice_commitment.is_some(), + self.proof.untrusted_advice_commitment.is_some(), + ); + let precommitted_scheduling_reference = + crate::zkvm::claim_reductions::PrecommittedClaimReduction::::scheduling_reference( + main_total_vars, + &precommitted_candidates, + ); + if self.trusted_advice_commitment.is_some() { self.advice_reduction_verifier_trusted = Some(AdviceClaimReductionVerifier::new( AdviceKind::Trusted, - &self.program_io.memory_layout, - self.proof.trace_length, + self.program_io.memory_layout.max_trusted_advice_size as usize, + precommitted_scheduling_reference, &self.opening_accumulator, )); } if self.proof.untrusted_advice_commitment.is_some() { self.advice_reduction_verifier_untrusted = Some(AdviceClaimReductionVerifier::new( AdviceKind::Untrusted, - &self.program_io.memory_layout, - self.proof.trace_length, + self.program_io.memory_layout.max_untrusted_advice_size as usize, + precommitted_scheduling_reference, + &self.opening_accumulator, + )); + } + + if self.preprocessing.shared.program.is_committed() { + let bytecode_chunk_count = self.preprocessing.shared.bytecode_chunk_count; + let bytecode_reduction_params = BytecodeClaimReductionParams::new( + &bytecode_reduction_seed_params, + self.preprocessing.shared.bytecode_size(), + bytecode_chunk_count, + precommitted_scheduling_reference, + &self.opening_accumulator, + &mut self.transcript, + ); + self.bytecode_reduction_verifier = Some(BytecodeClaimReductionVerifier::new( + bytecode_reduction_params, + )); + + let padded_len_words = self + .preprocessing + .shared + .program_meta + .committed_program_image_num_words(&self.program_io.memory_layout); + let program_image_reduction_params = ProgramImageClaimReductionParams::new( + &self.program_io, + self.preprocessing.shared.program_meta.min_bytecode_address, + padded_len_words, + self.proof.ram_K, + precommitted_scheduling_reference, &self.opening_accumulator, + &mut self.transcript, + ); + self.program_image_reduction_verifier = Some(ProgramImageClaimReductionVerifier::new( + program_image_reduction_params, )); } @@ -606,9 +765,15 @@ impl< if let Some(ref advice) = self.advice_reduction_verifier_untrusted { instances.push(advice); } + if let Some(ref reduction) = self.bytecode_reduction_verifier { + instances.push(reduction); + } + if let Some(ref reduction) = self.program_image_reduction_verifier { + instances.push(reduction); + } - let _r_stage6 = BatchedSumcheck::verify_standard::( - extract_clear_proof(&self.proof.stage6_sumcheck_proof), + let _r_stage6b = BatchedSumcheck::verify_standard::( + extract_clear_proof(&self.proof.stage6b_sumcheck_proof), instances, &mut self.opening_accumulator, &mut self.transcript, @@ -638,7 +803,7 @@ impl< { let mut params = advice_reduction_verifier_trusted.params.borrow_mut(); if params.num_address_phase_rounds() > 0 { - params.phase = ReductionPhase::AddressVariables; + params.phase = PrecommittedPhase::AddressVariables; instances.push(advice_reduction_verifier_trusted); } } @@ -647,7 +812,7 @@ impl< { let mut params = advice_reduction_verifier_untrusted.params.borrow_mut(); if params.num_address_phase_rounds() > 0 { - params.phase = ReductionPhase::AddressVariables; + params.phase = PrecommittedPhase::AddressVariables; instances.push(advice_reduction_verifier_untrusted); } } diff --git a/jolt-core/src/zkvm/verifier.rs b/jolt-core/src/zkvm/verifier.rs index 1f95abda6..5a4e67518 100644 --- a/jolt-core/src/zkvm/verifier.rs +++ b/jolt-core/src/zkvm/verifier.rs @@ -1,14 +1,8 @@ -use std::collections::HashMap; -use std::fs::File; -use std::io::{Read, Write}; -use std::path::Path; -use std::sync::Arc; - use crate::curve::JoltCurve; use crate::poly::commitment::commitment_scheme::{CommitmentScheme, ZkEvalCommitment}; #[cfg(feature = "zk")] use crate::poly::commitment::dory::bind_opening_inputs_zk; -use crate::poly::commitment::dory::{bind_opening_inputs, DoryContext, DoryGlobals}; +use crate::poly::commitment::dory::{bind_opening_inputs, DoryGlobals, DoryLayout}; use crate::poly::commitment::pedersen::PedersenGenerators; #[cfg(feature = "zk")] use crate::poly::lagrange_poly::LagrangeHelper; @@ -27,10 +21,11 @@ use crate::subprotocols::sumcheck::SumcheckInstanceProof; use crate::subprotocols::sumcheck_verifier::SumcheckInstanceParams; #[cfg(feature = "zk")] use crate::subprotocols::univariate_skip::UniSkipFirstRoundProofVariant; -use crate::zkvm::bytecode::{BytecodePreprocessing, PreprocessingError}; -use crate::zkvm::claim_reductions::advice::ReductionPhase; +use crate::zkvm::bytecode::chunks::DEFAULT_COMMITTED_BYTECODE_CHUNK_COUNT; +use crate::zkvm::bytecode::chunks::{committed_lanes, validate_committed_bytecode_chunk_count}; use crate::zkvm::claim_reductions::RegistersClaimReductionSumcheckVerifier; -use crate::zkvm::config::OneHotParams; +use crate::zkvm::config::{OneHotParams, ProgramMode}; +use crate::zkvm::program::{ProgramMetadata, ProgramPreprocessing}; #[cfg(feature = "prover")] use crate::zkvm::prover::JoltProverPreprocessing; #[cfg(feature = "zk")] @@ -38,15 +33,19 @@ use crate::zkvm::r1cs::constraints::{ OUTER_FIRST_ROUND_POLY_NUM_COEFFS, OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE, PRODUCT_VIRTUAL_FIRST_ROUND_POLY_NUM_COEFFS, PRODUCT_VIRTUAL_UNIVARIATE_SKIP_DOMAIN_SIZE, }; -use crate::zkvm::ram::RAMPreprocessing; use crate::zkvm::witness::all_committed_polynomials; use crate::zkvm::Serializable; use crate::zkvm::{ - bytecode::read_raf_checking::BytecodeReadRafSumcheckVerifier, + bytecode::read_raf_checking::{ + BytecodeReadRafAddressSumcheckVerifier, BytecodeReadRafCycleSumcheckVerifier, + BytecodeReadRafSumcheckParams, + }, claim_reductions::{ - AdviceClaimReductionVerifier, AdviceKind, HammingWeightClaimReductionVerifier, + AdviceClaimReductionVerifier, AdviceKind, BytecodeClaimReductionParams, + BytecodeClaimReductionVerifier, HammingWeightClaimReductionVerifier, IncClaimReductionSumcheckVerifier, InstructionLookupsClaimReductionSumcheckVerifier, - RamRaClaimReductionSumcheckVerifier, + PrecommittedClaimReduction, ProgramImageClaimReductionParams, + ProgramImageClaimReductionVerifier, RamRaClaimReductionSumcheckVerifier, }, fiat_shamir_preamble, instruction_lookups::{ @@ -60,7 +59,7 @@ use crate::zkvm::{ output_check::OutputSumcheckVerifier, ra_virtual::RamRaVirtualSumcheckVerifier, raf_evaluation::RafEvaluationSumcheckVerifier as RamRafEvaluationSumcheckVerifier, read_write_checking::RamReadWriteCheckingVerifier, val_check::RamValCheckSumcheckVerifier, - verifier_accumulate_advice, + verifier_accumulate_advice, verifier_accumulate_program_image, }, registers::{ read_write_checking::RegistersReadWriteCheckingVerifier, @@ -75,22 +74,27 @@ use crate::zkvm::{ }; use crate::{ field::JoltField, - poly::{ - eq_poly::EqPolynomial, - opening_proof::{ - compute_advice_lagrange_factor, DoryOpeningState, OpeningAccumulator, OpeningId, - SumcheckId, VerifierOpeningAccumulator, - }, + poly::opening_proof::{ + compute_lagrange_factor, DoryOpeningState, OpeningAccumulator, OpeningId, OpeningPoint, + SumcheckId, VerifierOpeningAccumulator, BIG_ENDIAN, }, pprof_scope, subprotocols::{ - booleanity::{BooleanitySumcheckParams, BooleanitySumcheckVerifier}, + booleanity::{ + BooleanityAddressSumcheckVerifier, BooleanityCycleSumcheckVerifier, + BooleanitySumcheckParams, + }, sumcheck_verifier::SumcheckInstanceVerifier, }, transcripts::Transcript, utils::{errors::ProofVerifyError, math::Math}, zkvm::witness::CommittedPolynomial, }; +use common::constants::BYTES_PER_INSTRUCTION; +use std::collections::HashMap; +use std::fs::File; +use std::io::{Read, Write}; +use std::path::Path; #[cfg(feature = "zk")] struct StageVerifyResult { @@ -110,6 +114,12 @@ struct StageVerifyResult { challenges: Vec, } +type Stage6aVerifyResult = ( + BytecodeReadRafSumcheckParams, + BooleanitySumcheckParams, + StageVerifyResult, +); + #[cfg(feature = "zk")] impl StageVerifyResult { fn new( @@ -207,7 +217,6 @@ fn scale_batching_coefficients< } use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use common::jolt_device::MemoryLayout; -use tracer::instruction::Instruction; use tracer::JoltDevice; pub struct JoltVerifier< @@ -229,6 +238,10 @@ pub struct JoltVerifier< /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). /// Cache the verifier state here between stages. advice_reduction_verifier_untrusted: Option>, + /// Bytecode claim reduction spans stages 6b and 7 in committed mode. + bytecode_reduction_verifier: Option>, + /// Program-image claim reduction spans stages 6b and 7 in committed mode. + program_image_reduction_verifier: Option>, pub spartan_key: UniformSpartanKey, pub one_hot_params: OneHotParams, } @@ -248,6 +261,118 @@ impl< ProofTranscript: Transcript, > JoltVerifier<'a, F, C, PCS, ProofTranscript> { + #[inline] + fn main_total_vars(&self) -> usize { + let trace_log_t = self.proof.trace_length.log_2(); + let log_k_chunk = self.one_hot_params.log_k_chunk; + JoltSharedPreprocessing::::max_total_vars_from_candidates( + trace_log_t + log_k_chunk, + self.preprocessing.shared.precommitted_candidate_total_vars( + self.preprocessing.shared.program.is_committed(), + self.trusted_advice_commitment.is_some(), + self.proof.untrusted_advice_commitment.is_some(), + ), + ) + } + + fn stage8_opening_point(&self) -> Result, ProofVerifyError> { + let native_main_vars = self.proof.trace_length.log_2() + self.one_hot_params.log_k_chunk; + let mut opening_candidates: Vec<(&str, OpeningPoint)> = Vec::new(); + if let Some((point, _)) = self + .opening_accumulator + .get_advice_opening(AdviceKind::Trusted, SumcheckId::AdviceClaimReduction) + { + opening_candidates.push(("trusted_advice", point)); + } + if let Some((point, _)) = self + .opening_accumulator + .get_advice_opening(AdviceKind::Untrusted, SumcheckId::AdviceClaimReduction) + { + opening_candidates.push(("untrusted_advice", point)); + } + if self.preprocessing.shared.program.is_committed() { + for chunk_idx in 0..self.preprocessing.shared.bytecode_chunk_count { + let (point, _) = self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::BytecodeChunk(chunk_idx), + SumcheckId::BytecodeClaimReduction, + ); + opening_candidates.push(("bytecode_chunk", point)); + } + } + if self.preprocessing.shared.program.is_committed() { + let (program_image_point, _) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + ); + opening_candidates.push(("program_image", program_image_point)); + } + + let max_len = opening_candidates + .iter() + .map(|(_, p)| p.r.len()) + .max() + .unwrap_or(0); + if max_len > native_main_vars { + let dominant = opening_candidates + .iter() + .find(|(_, p)| p.r.len() == max_len) + .expect("at least one dominant precommitted candidate expected"); + for (name, point) in opening_candidates + .iter() + .filter(|(_, p)| p.r.len() == max_len) + { + if point.r != dominant.1.r { + return Err(ProofVerifyError::DoryError(format!( + "incompatible dominant precommitted anchors: {} and {} have equal dimensionality {} but different opening points", + dominant.0, name, max_len + ))); + } + } + Ok(OpeningPoint::::new(dominant.1.r.clone())) + } else { + let (hamming_point, _) = self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::InstructionRa(0), + SumcheckId::HammingWeightClaimReduction, + ); + let r_address_stage7 = hamming_point.r[..self.one_hot_params.log_k_chunk].to_vec(); + let r_cycle_stage6 = self + .opening_accumulator + .get_committed_polynomial_opening( + CommittedPolynomial::RamInc, + SumcheckId::IncClaimReduction, + ) + .0 + .r; + + match self.proof.dory_layout { + DoryLayout::AddressMajor => Ok(OpeningPoint::::new( + [r_cycle_stage6.as_slice(), r_address_stage7.as_slice()].concat(), + )), + DoryLayout::CycleMajor => { + let native_cycle = &hamming_point.r[self.one_hot_params.log_k_chunk..]; + if r_cycle_stage6.len() < native_cycle.len() { + return Err(ProofVerifyError::DoryError( + "stage6 cycle challenges shorter than native cycle vars".to_string(), + )); + } + if r_cycle_stage6[..native_cycle.len()] != *native_cycle { + return Err(ProofVerifyError::DoryError(format!( + "cycle-major Stage-8 expects stage6 cycle prefix to equal native cycle vars \ + (cycle_full_len={}, native_len={})", + r_cycle_stage6.len(), + native_cycle.len() + ))); + } + let cycle_extra = &r_cycle_stage6[native_cycle.len()..]; + let cycle_extra_and_anchor = + [cycle_extra, r_address_stage7.as_slice(), native_cycle].concat(); + Ok(OpeningPoint::::new(cycle_extra_and_anchor)) + } + } + } + } + pub fn new( preprocessing: &'a JoltVerifierPreprocessing, proof: JoltProof, @@ -331,10 +456,11 @@ impl< .validate() .map_err(ProofVerifyError::InvalidOneHotConfig)?; - let min_ram_K = compute_min_ram_K( - &preprocessing.shared.ram, - &preprocessing.shared.memory_layout, - ); + let ram_preprocessing = crate::zkvm::ram::RAMPreprocessing { + min_bytecode_address: preprocessing.shared.program_meta.min_bytecode_address, + bytecode_words: vec![0; preprocessing.shared.program_meta.program_image_len_words], + }; + let min_ram_K = compute_min_ram_K(&ram_preprocessing, &preprocessing.shared.memory_layout); if !proof.ram_K.is_power_of_two() || proof.ram_K < min_ram_K { return Err(ProofVerifyError::InvalidRamK(proof.ram_K, min_ram_K)); } @@ -345,7 +471,7 @@ impl< .map_err(ProofVerifyError::InvalidReadWriteConfig)?; // Construct full params from the validated config. - let bytecode_K = preprocessing.shared.bytecode.code_size; + let bytecode_K = preprocessing.shared.bytecode_size(); let one_hot_params = OneHotParams::from_config(&proof.one_hot_config, bytecode_K, proof.ram_K); @@ -358,6 +484,8 @@ impl< opening_accumulator, advice_reduction_verifier_trusted: None, advice_reduction_verifier_untrusted: None, + bytecode_reduction_verifier: None, + program_image_reduction_verifier: None, spartan_key, one_hot_params, }) @@ -397,7 +525,7 @@ impl< &self.program_io, self.proof.ram_K, self.proof.trace_length, - self.preprocessing.shared.bytecode.entry_address, + self.preprocessing.shared.program_meta.entry_address, &self.proof.rw_config, &self.proof.one_hot_config, self.proof.dory_layout, @@ -420,6 +548,19 @@ impl< self.transcript .append_serializable(b"trusted_advice", trusted_advice_commitment); } + if let Some(trusted_bytecode) = self.preprocessing.shared.program.bytecode_commitments() { + for commitment in &trusted_bytecode.commitments { + self.transcript + .append_serializable(b"bytecode_chunk_commit", commitment); + } + } + if self.preprocessing.shared.program.is_committed() { + let trusted = self.preprocessing.shared.program.as_committed()?; + self.transcript.append_serializable( + b"program_image_commitment", + &trusted.program_image_commitment, + ); + } let (stage1_result, uniskip_challenge1) = self .verify_stage1() @@ -436,7 +577,7 @@ impl< let stage5_result = self .verify_stage5() .inspect_err(|e| tracing::error!("Stage 5: {e}"))?; - let stage6_result = self + let (stage6a_result, stage6b_result) = self .verify_stage6() .inspect_err(|e| tracing::error!("Stage 6: {e}"))?; let stage7_result = self @@ -455,7 +596,8 @@ impl< stage3_result.challenges.clone(), stage4_result.challenges.clone(), stage5_result.challenges.clone(), - stage6_result.challenges.clone(), + stage6a_result.challenges.clone(), + stage6b_result.challenges.clone(), stage7_result.challenges.clone(), ]; let uniskip_challenges = [uniskip_challenge1, uniskip_challenge2]; @@ -466,7 +608,8 @@ impl< stage3_result.batched_output_constraint, stage4_result.batched_output_constraint, stage5_result.batched_output_constraint, - stage6_result.batched_output_constraint, + stage6a_result.batched_output_constraint, + stage6b_result.batched_output_constraint, stage7_result.batched_output_constraint, ]; @@ -476,7 +619,8 @@ impl< stage3_result.batched_input_constraint.clone(), stage4_result.batched_input_constraint.clone(), stage5_result.batched_input_constraint.clone(), - stage6_result.batched_input_constraint.clone(), + stage6a_result.batched_input_constraint.clone(), + stage6b_result.batched_input_constraint.clone(), stage7_result.batched_input_constraint.clone(), ]; @@ -490,17 +634,19 @@ impl< stage3_result.input_constraint_challenge_values.clone(), stage4_result.input_constraint_challenge_values.clone(), stage5_result.input_constraint_challenge_values.clone(), - stage6_result.input_constraint_challenge_values.clone(), + stage6a_result.input_constraint_challenge_values.clone(), + stage6b_result.input_constraint_challenge_values.clone(), stage7_result.input_constraint_challenge_values.clone(), ]; - let output_constraint_challenge_values: [Vec; 7] = [ + let output_constraint_challenge_values: [Vec; 8] = [ stage1_result.output_constraint_challenge_values.clone(), stage2_result.output_constraint_challenge_values.clone(), stage3_result.output_constraint_challenge_values.clone(), stage4_result.output_constraint_challenge_values.clone(), stage5_result.output_constraint_challenge_values.clone(), - stage6_result.output_constraint_challenge_values.clone(), + stage6a_result.output_constraint_challenge_values.clone(), + stage6b_result.output_constraint_challenge_values.clone(), stage7_result.output_constraint_challenge_values.clone(), ]; @@ -510,7 +656,8 @@ impl< oc_blocks.extend(stage3_result.oc_block_ids); oc_blocks.extend(stage4_result.oc_block_ids); oc_blocks.extend(stage5_result.oc_block_ids); - oc_blocks.extend(stage6_result.oc_block_ids); + oc_blocks.extend(stage6a_result.oc_block_ids); + oc_blocks.extend(stage6b_result.oc_block_ids); oc_blocks.extend(stage7_result.oc_block_ids); self.verify_blindfold( @@ -829,23 +976,45 @@ impl< self.trusted_advice_commitment.is_some(), &mut self.opening_accumulator, ); + if self.preprocessing.shared.program.is_committed() { + verifier_accumulate_program_image::(self.proof.ram_K, &mut self.opening_accumulator); + } // Domain-separate the batching challenge. self.transcript.append_bytes(b"ram_val_check_gamma", &[]); let ram_val_check_gamma: F = self.transcript.challenge_scalar::(); - let initial_ram_state = crate::zkvm::ram::gen_ram_initial_memory_state::( - self.proof.ram_K, - &self.preprocessing.shared.ram, - &self.program_io, - ); + let initial_ram_state = if self.preprocessing.shared.program.is_full() { + crate::zkvm::ram::gen_ram_initial_memory_state::( + self.proof.ram_K, + &self.preprocessing.shared.program.as_full()?.ram, + &self.program_io, + ) + } else { + vec![0u64; self.proof.ram_K] + }; + let ram_preprocessing = if self.preprocessing.shared.program.is_full() { + self.preprocessing.shared.program.as_full()?.ram.clone() + } else { + crate::zkvm::ram::RAMPreprocessing { + min_bytecode_address: self.preprocessing.shared.program_meta.min_bytecode_address, + bytecode_words: vec![ + 0; + self.preprocessing + .shared + .program_meta + .program_image_len_words + ], + } + }; let ram_val_check = RamValCheckSumcheckVerifier::new( &initial_ram_state, &self.program_io, - &self.preprocessing.shared.ram, + &ram_preprocessing, self.proof.trace_length, self.proof.ram_K, &self.proof.rw_config, ram_val_check_gamma, &self.opening_accumulator, + self.preprocessing.shared.program.is_committed(), ); let instances: Vec< @@ -972,26 +1141,125 @@ impl< } #[cfg_attr(not(feature = "zk"), allow(unused_variables))] - fn verify_stage6(&mut self) -> Result, ProofVerifyError> { + fn verify_stage6( + &mut self, + ) -> Result<(StageVerifyResult, StageVerifyResult), ProofVerifyError> { + let _ = DoryGlobals::initialize_main_with_log_embedding( + self.one_hot_params.k_chunk, + self.proof.trace_length, + self.main_total_vars(), + Some(self.proof.dory_layout), + ); + let (bytecode_read_raf_params, booleanity_params, stage6a_result) = + self.verify_stage6a()?; + let stage6b_result = self.verify_stage6b(bytecode_read_raf_params, booleanity_params)?; + Ok((stage6a_result, stage6b_result)) + } + + fn verify_stage6a(&mut self) -> Result, ProofVerifyError> { let n_cycle_vars = self.proof.trace_length.log_2(); - let bytecode_read_raf = BytecodeReadRafSumcheckVerifier::gen( - &self.preprocessing.shared.bytecode, + let program_preprocessing = Some(&self.preprocessing.shared.program); + let entry_bytecode_index = self + .preprocessing + .shared + .program_meta + .entry_address + .saturating_sub(self.preprocessing.shared.program_meta.min_bytecode_address) + as usize + / BYTES_PER_INSTRUCTION + + 1; + let bytecode_read_raf = BytecodeReadRafAddressSumcheckVerifier::new( + program_preprocessing, n_cycle_vars, &self.one_hot_params, &self.opening_accumulator, &mut self.transcript, - ); - - let ram_hamming_booleanity = - HammingBooleanitySumcheckVerifier::new(&self.opening_accumulator); - let booleanity_params = BooleanitySumcheckParams::new( + if self.preprocessing.shared.program.is_committed() { + ProgramMode::Committed + } else { + ProgramMode::Full + }, + entry_bytecode_index, + )?; + let booleanity = BooleanityAddressSumcheckVerifier::new(BooleanitySumcheckParams::new( n_cycle_vars, &self.one_hot_params, &self.opening_accumulator, &mut self.transcript, - ); + )); + let instances: Vec< + &dyn SumcheckInstanceVerifier>, + > = vec![&bytecode_read_raf, &booleanity]; + let (_batching_coefficients, r_stage6a) = BatchedSumcheck::verify( + &self.proof.stage6a_sumcheck_proof, + instances.clone(), + &mut self.opening_accumulator, + &mut self.transcript, + ) + .inspect_err(|err| tracing::error!("Stage 6a: {err}"))?; + #[cfg(feature = "zk")] + { + let regular_oc_ids = self.opening_accumulator.take_pending_claim_ids(); + let batched_output_constraint = batch_output_constraints(&instances); + let batched_input_constraint = batch_input_constraints(&instances); + let max_num_rounds = instances.iter().map(|i| i.num_rounds()).max().unwrap(); + let mut output_constraint_challenge_values: Vec = _batching_coefficients.clone(); + let mut input_constraint_challenge_values: Vec = + scale_batching_coefficients(&_batching_coefficients, &instances); + for instance in &instances { + let num_rounds = instance.num_rounds(); + let offset = instance.round_offset(max_num_rounds); + let r_slice = &r_stage6a[offset..offset + num_rounds]; + output_constraint_challenge_values.extend( + instance + .get_params() + .output_constraint_challenge_values(r_slice), + ); + input_constraint_challenge_values.extend( + instance + .get_params() + .input_constraint_challenge_values(&self.opening_accumulator), + ); + } + let stage_result = StageVerifyResult::new( + r_stage6a, + batched_output_constraint, + output_constraint_challenge_values, + batched_input_constraint, + input_constraint_challenge_values, + vec![regular_oc_ids], + ); + Ok(( + bytecode_read_raf.into_params(), + booleanity.into_params(), + stage_result, + )) + } + #[cfg(not(feature = "zk"))] + Ok(( + bytecode_read_raf.into_params(), + booleanity.into_params(), + StageVerifyResult { + challenges: r_stage6a, + }, + )) + } - let booleanity = BooleanitySumcheckVerifier::new(booleanity_params); + #[cfg_attr(not(feature = "zk"), allow(unused_variables))] + fn verify_stage6b( + &mut self, + bytecode_read_raf_params: BytecodeReadRafSumcheckParams, + booleanity_params: BooleanitySumcheckParams, + ) -> Result, ProofVerifyError> { + let bytecode_reduction_seed_params = bytecode_read_raf_params.clone(); + let bytecode_read_raf = BytecodeReadRafCycleSumcheckVerifier::new( + bytecode_read_raf_params, + &self.opening_accumulator, + ); + let ram_hamming_booleanity = + HammingBooleanitySumcheckVerifier::new(&self.opening_accumulator); + let booleanity = + BooleanityCycleSumcheckVerifier::new(booleanity_params, &self.opening_accumulator); let ram_ra_virtual = RamRaVirtualSumcheckVerifier::new( self.proof.trace_length, &self.one_hot_params, @@ -1009,23 +1277,67 @@ impl< &mut self.transcript, ); - // Advice claim reduction (Phase 1 in Stage 6): trusted and untrusted are separate instances. + let main_total_vars = self.proof.trace_length.log_2() + self.one_hot_params.log_k_chunk; + let precommitted_candidates = self.preprocessing.shared.precommitted_candidate_total_vars( + self.preprocessing.shared.program.is_committed(), + self.trusted_advice_commitment.is_some(), + self.proof.untrusted_advice_commitment.is_some(), + ); + let precommitted_scheduling_reference = + PrecommittedClaimReduction::::scheduling_reference( + main_total_vars, + &precommitted_candidates, + ); + + // Advice claim reduction (Phase 1 in Stage 6b): trusted and untrusted are separate instances. if self.trusted_advice_commitment.is_some() { self.advice_reduction_verifier_trusted = Some(AdviceClaimReductionVerifier::new( AdviceKind::Trusted, - &self.program_io.memory_layout, - self.proof.trace_length, + self.program_io.memory_layout.max_trusted_advice_size as usize, + precommitted_scheduling_reference, &self.opening_accumulator, )); } if self.proof.untrusted_advice_commitment.is_some() { self.advice_reduction_verifier_untrusted = Some(AdviceClaimReductionVerifier::new( AdviceKind::Untrusted, - &self.program_io.memory_layout, - self.proof.trace_length, + self.program_io.memory_layout.max_untrusted_advice_size as usize, + precommitted_scheduling_reference, &self.opening_accumulator, )); } + if self.preprocessing.shared.program.is_committed() { + let bytecode_chunk_count = self.preprocessing.shared.bytecode_chunk_count; + let bytecode_reduction_params = BytecodeClaimReductionParams::new( + &bytecode_reduction_seed_params, + self.preprocessing.shared.bytecode_size(), + bytecode_chunk_count, + precommitted_scheduling_reference, + &self.opening_accumulator, + &mut self.transcript, + ); + self.bytecode_reduction_verifier = Some(BytecodeClaimReductionVerifier::new( + bytecode_reduction_params, + )); + + let padded_len_words = self + .preprocessing + .shared + .program_meta + .committed_program_image_num_words(&self.program_io.memory_layout); + let program_image_reduction_params = ProgramImageClaimReductionParams::new( + &self.program_io, + self.preprocessing.shared.program_meta.min_bytecode_address, + padded_len_words, + self.proof.ram_K, + precommitted_scheduling_reference, + &self.opening_accumulator, + &mut self.transcript, + ); + self.program_image_reduction_verifier = Some(ProgramImageClaimReductionVerifier::new( + program_image_reduction_params, + )); + } let mut instances: Vec< &dyn SumcheckInstanceVerifier>, @@ -1043,13 +1355,20 @@ impl< if let Some(ref advice) = self.advice_reduction_verifier_untrusted { instances.push(advice); } + if let Some(ref reduction) = self.bytecode_reduction_verifier { + instances.push(reduction); + } + if let Some(ref reduction) = self.program_image_reduction_verifier { + instances.push(reduction); + } - let (batching_coefficients, r_stage6) = BatchedSumcheck::verify( - &self.proof.stage6_sumcheck_proof, + let (batching_coefficients, r_stage6b) = BatchedSumcheck::verify( + &self.proof.stage6b_sumcheck_proof, instances.clone(), &mut self.opening_accumulator, &mut self.transcript, - )?; + ) + .inspect_err(|err| tracing::error!("Stage 6b: {err}"))?; #[cfg(feature = "zk")] { @@ -1063,7 +1382,7 @@ impl< for instance in &instances { let num_rounds = instance.num_rounds(); let offset = instance.round_offset(max_num_rounds); - let r_slice = &r_stage6[offset..offset + num_rounds]; + let r_slice = &r_stage6b[offset..offset + num_rounds]; output_constraint_challenge_values.extend( instance .get_params() @@ -1076,7 +1395,7 @@ impl< ); } Ok(StageVerifyResult::new( - r_stage6, + r_stage6b, batched_output_constraint, output_constraint_challenge_values, batched_input_constraint, @@ -1086,7 +1405,7 @@ impl< } #[cfg(not(feature = "zk"))] Ok(StageVerifyResult { - challenges: r_stage6, + challenges: r_stage6b, }) } @@ -1094,12 +1413,12 @@ impl< #[allow(clippy::too_many_arguments)] fn verify_blindfold( &mut self, - sumcheck_challenges: &[Vec; 7], + sumcheck_challenges: &[Vec; 8], uniskip_challenges: [F::Challenge; 2], - stage_output_constraints: &[Option; 7], - output_constraint_challenge_values: &[Vec; 7], - stage_input_constraints: &[InputClaimConstraint; 7], - input_constraint_challenge_values: &[Vec; 7], + stage_output_constraints: &[Option; 8], + output_constraint_challenge_values: &[Vec; 8], + stage_input_constraints: &[InputClaimConstraint; 8], + input_constraint_challenge_values: &[Vec; 8], // For stages 0-1: batched input constraint for regular rounds (different from uni-skip) stage1_batched_input: &InputClaimConstraint, stage2_batched_input: &InputClaimConstraint, @@ -1116,7 +1435,8 @@ impl< &self.proof.stage3_sumcheck_proof, &self.proof.stage4_sumcheck_proof, &self.proof.stage5_sumcheck_proof, - &self.proof.stage6_sumcheck_proof, + &self.proof.stage6a_sumcheck_proof, + &self.proof.stage6b_sumcheck_proof, &self.proof.stage7_sumcheck_proof, ]; @@ -1133,7 +1453,7 @@ impl< let mut stage_configs = Vec::new(); // Track which stage_config index corresponds to uni-skip and regular first rounds let mut uniskip_indices: Vec = Vec::new(); // Only 2 elements for stages 0-1 - let mut regular_first_round_indices: Vec = Vec::new(); // 7 elements for all stages + let mut regular_first_round_indices: Vec = Vec::new(); // 8 elements for all stages let mut last_round_indices: Vec = Vec::new(); for (stage_idx, proof) in stage_proofs.iter().enumerate() { @@ -1215,7 +1535,7 @@ impl< Some(ClaimBindingConfig::with_constraint(constraint.clone())); } - // Add initial_input configurations for regular first rounds (all 7 stages) + // Add initial_input configurations for regular first rounds (all 8 stages) // These use the batched input constraints from the stage results let regular_constraints = [ stage1_batched_input.clone(), // Stage 0 regular @@ -1223,8 +1543,9 @@ impl< stage_input_constraints[2].clone(), // Stage 2 stage_input_constraints[3].clone(), // Stage 3 stage_input_constraints[4].clone(), // Stage 4 - stage_input_constraints[5].clone(), // Stage 5 - stage_input_constraints[6].clone(), // Stage 6 + stage_input_constraints[5].clone(), // Stage 5 (6a) + stage_input_constraints[6].clone(), // Stage 6 (6b) + stage_input_constraints[7].clone(), // Stage 7 ]; for (i, constraint) in regular_constraints.iter().enumerate() { let idx = regular_first_round_indices[i]; @@ -1252,7 +1573,7 @@ impl< } } - let all_input_challenge_values: [&[F]; 9] = [ + let all_input_challenge_values: [&[F]; 10] = [ &input_constraint_challenge_values[0], stage1_batched_input_values, &input_constraint_challenge_values[1], @@ -1262,6 +1583,7 @@ impl< &input_constraint_challenge_values[4], &input_constraint_challenge_values[5], &input_constraint_challenge_values[6], + &input_constraint_challenge_values[7], ]; let mut baked_input_challenges: Vec = Vec::new(); for expected_values in all_input_challenge_values.iter() { @@ -1365,7 +1687,7 @@ impl< let mut params = advice_reduction_verifier_trusted.params.borrow_mut(); if params.num_address_phase_rounds() > 0 { // Transition phase - params.phase = ReductionPhase::AddressVariables; + params.transition_to_address_phase(); instances.push(advice_reduction_verifier_trusted); } } @@ -1375,10 +1697,26 @@ impl< let mut params = advice_reduction_verifier_untrusted.params.borrow_mut(); if params.num_address_phase_rounds() > 0 { // Transition phase - params.phase = ReductionPhase::AddressVariables; + params.transition_to_address_phase(); instances.push(advice_reduction_verifier_untrusted); } } + if let Some(bytecode_reduction_verifier) = self.bytecode_reduction_verifier.as_mut() { + let mut params = bytecode_reduction_verifier.params.borrow_mut(); + if params.num_address_phase_rounds() > 0 { + params.transition_to_address_phase(); + instances.push(bytecode_reduction_verifier); + } + } + if let Some(program_image_reduction_verifier) = + self.program_image_reduction_verifier.as_mut() + { + let mut params = program_image_reduction_verifier.params.borrow_mut(); + if params.num_address_phase_rounds() > 0 { + params.transition_to_address_phase(); + instances.push(program_image_reduction_verifier); + } + } let (batching_coefficients, r_stage7) = BatchedSumcheck::verify( &self.proof.stage7_sumcheck_proof, @@ -1426,72 +1764,61 @@ impl< }) } - /// Stage 8: Dory batch opening verification. fn verify_stage8(&mut self) -> Result, ProofVerifyError> { - // Initialize DoryGlobals with the layout from the proof - // This ensures the verifier uses the same layout as the prover - let _guard = DoryGlobals::initialize_context( - 1 << self.one_hot_params.log_k_chunk, - self.proof.trace_length.next_power_of_two(), - DoryContext::Main, - Some(self.proof.dory_layout), - ); - - // Get the unified opening point from HammingWeightClaimReduction - // This contains (r_address_stage7 || r_cycle_stage6) in big-endian - let (opening_point, _) = self.opening_accumulator.get_committed_polynomial_opening( - CommittedPolynomial::InstructionRa(0), - SumcheckId::HammingWeightClaimReduction, - ); - let log_k_chunk = self.one_hot_params.log_k_chunk; - let r_address_stage7 = &opening_point.r[..log_k_chunk]; + let opening_point = self.stage8_opening_point()?; // 1. Collect all (polynomial, claim) pairs let mut polynomial_claims = Vec::new(); let mut scaling_factors = Vec::new(); // Dense polynomials: RamInc and RdInc (from IncClaimReduction in Stage 6) - let (_, ram_inc_claim) = self.opening_accumulator.get_committed_polynomial_opening( + let (ram_inc_point, ram_inc_claim) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::RamInc, + SumcheckId::IncClaimReduction, + ); + let (rd_inc_point, rd_inc_claim) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::RdInc, + SumcheckId::IncClaimReduction, + ); + let ram_inc_lagrange = compute_lagrange_factor::(&opening_point.r, &ram_inc_point.r); + let rd_inc_lagrange = compute_lagrange_factor::(&opening_point.r, &rd_inc_point.r); + polynomial_claims.push(( CommittedPolynomial::RamInc, - SumcheckId::IncClaimReduction, - ); - let (_, rd_inc_claim) = self.opening_accumulator.get_committed_polynomial_opening( - CommittedPolynomial::RdInc, - SumcheckId::IncClaimReduction, - ); - - // Dense polynomials are zero-padded in the Dory matrix, so their evaluation - // includes a factor eq(r_addr, 0) = ∏(1 − r_addr_i). - let lagrange_factor: F = EqPolynomial::zero_selector(r_address_stage7); - polynomial_claims.push((CommittedPolynomial::RamInc, ram_inc_claim * lagrange_factor)); - scaling_factors.push(lagrange_factor); - polynomial_claims.push((CommittedPolynomial::RdInc, rd_inc_claim * lagrange_factor)); - scaling_factors.push(lagrange_factor); + ram_inc_claim * ram_inc_lagrange, + )); + scaling_factors.push(ram_inc_lagrange); + polynomial_claims.push((CommittedPolynomial::RdInc, rd_inc_claim * rd_inc_lagrange)); + scaling_factors.push(rd_inc_lagrange); // Sparse polynomials: all RA polys (from HammingWeightClaimReduction) for i in 0..self.one_hot_params.instruction_d { - let (_, claim) = self.opening_accumulator.get_committed_polynomial_opening( + let (ra_point, claim) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::InstructionRa(i), SumcheckId::HammingWeightClaimReduction, ); - polynomial_claims.push((CommittedPolynomial::InstructionRa(i), claim)); - scaling_factors.push(F::one()); + let lagrange = compute_lagrange_factor::(&opening_point.r, &ra_point.r); + polynomial_claims.push((CommittedPolynomial::InstructionRa(i), claim * lagrange)); + scaling_factors.push(lagrange); } for i in 0..self.one_hot_params.bytecode_d { - let (_, claim) = self.opening_accumulator.get_committed_polynomial_opening( + let (ra_point, claim) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::BytecodeRa(i), SumcheckId::HammingWeightClaimReduction, ); - polynomial_claims.push((CommittedPolynomial::BytecodeRa(i), claim)); - scaling_factors.push(F::one()); + let lagrange = compute_lagrange_factor::(&opening_point.r, &ra_point.r); + polynomial_claims.push((CommittedPolynomial::BytecodeRa(i), claim * lagrange)); + scaling_factors.push(lagrange); } for i in 0..self.one_hot_params.ram_d { - let (_, claim) = self.opening_accumulator.get_committed_polynomial_opening( + let (ra_point, claim) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::RamRa(i), SumcheckId::HammingWeightClaimReduction, ); - polynomial_claims.push((CommittedPolynomial::RamRa(i), claim)); - scaling_factors.push(F::one()); + let lagrange = compute_lagrange_factor::(&opening_point.r, &ra_point.r); + polynomial_claims.push((CommittedPolynomial::RamRa(i), claim * lagrange)); + scaling_factors.push(lagrange); } // Advice polynomials: TrustedAdvice and UntrustedAdvice (from AdviceClaimReduction in Stage 6) @@ -1504,8 +1831,7 @@ impl< .opening_accumulator .get_advice_opening(AdviceKind::Trusted, SumcheckId::AdviceClaimReduction) { - let lagrange_factor = - compute_advice_lagrange_factor::(&opening_point.r, &advice_point.r); + let lagrange_factor = compute_lagrange_factor::(&opening_point.r, &advice_point.r); polynomial_claims.push(( CommittedPolynomial::TrustedAdvice, advice_claim * lagrange_factor, @@ -1518,8 +1844,7 @@ impl< .opening_accumulator .get_advice_opening(AdviceKind::Untrusted, SumcheckId::AdviceClaimReduction) { - let lagrange_factor = - compute_advice_lagrange_factor::(&opening_point.r, &advice_point.r); + let lagrange_factor = compute_lagrange_factor::(&opening_point.r, &advice_point.r); polynomial_claims.push(( CommittedPolynomial::UntrustedAdvice, advice_claim * lagrange_factor, @@ -1528,6 +1853,37 @@ impl< include_untrusted_advice = true; } + if self.preprocessing.shared.program.is_committed() { + let chunk_count = self.preprocessing.shared.bytecode_chunk_count; + for chunk_idx in 0..chunk_count { + let (chunk_point, chunk_claim) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::BytecodeChunk(chunk_idx), + SumcheckId::BytecodeClaimReduction, + ); + let lagrange_factor = + compute_lagrange_factor::(&opening_point.r, &chunk_point.r); + polynomial_claims.push(( + CommittedPolynomial::BytecodeChunk(chunk_idx), + chunk_claim * lagrange_factor, + )); + scaling_factors.push(lagrange_factor); + } + } + if self.preprocessing.shared.program.is_committed() { + let (program_point, program_claim) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + ); + let lagrange_factor = compute_lagrange_factor::(&opening_point.r, &program_point.r); + polynomial_claims.push(( + CommittedPolynomial::ProgramImageInit, + program_claim * lagrange_factor, + )); + scaling_factors.push(lagrange_factor); + } + // 2. Sample gamma and compute powers for RLC let claims: Vec = polynomial_claims.iter().map(|(_, c)| *c).collect(); // In non-ZK mode, absorb claims before sampling gamma for Fiat-Shamir binding. @@ -1547,6 +1903,12 @@ impl< &self.one_hot_params, include_trusted_advice, include_untrusted_advice, + if self.preprocessing.shared.program.is_committed() { + ProgramMode::Committed + } else { + ProgramMode::Full + }, + self.preprocessing.shared.bytecode_chunk_count, ); let joint_claim: F = gamma_powers .iter() @@ -1596,6 +1958,32 @@ impl< commitments_map.insert(CommittedPolynomial::UntrustedAdvice, commitment.clone()); } } + if let Some(trusted_bytecode) = self.preprocessing.shared.program.bytecode_commitments() { + for (chunk_idx, commitment) in trusted_bytecode.commitments.iter().enumerate() { + if state + .polynomial_claims + .iter() + .any(|(p, _)| *p == CommittedPolynomial::BytecodeChunk(chunk_idx)) + { + commitments_map.insert( + CommittedPolynomial::BytecodeChunk(chunk_idx), + commitment.clone(), + ); + } + } + } + if let Ok(trusted_program) = self.preprocessing.shared.program.as_committed() { + if state + .polynomial_claims + .iter() + .any(|(p, _)| *p == CommittedPolynomial::ProgramImageInit) + { + commitments_map.insert( + CommittedPolynomial::ProgramImageInit, + trusted_program.program_image_commitment.clone(), + ); + } + } let joint_commitment = self.compute_joint_commitment(&mut commitments_map, &state)?; @@ -1671,14 +2059,20 @@ impl< } #[derive(Debug, Clone)] -pub struct JoltSharedPreprocessing { - pub bytecode: Arc, - pub ram: RAMPreprocessing, +pub struct JoltSharedPreprocessing< + PCS: CommitmentScheme = crate::poly::commitment::dory::DoryCommitmentScheme, +> { + pub program: ProgramPreprocessing, + pub program_meta: ProgramMetadata, pub memory_layout: MemoryLayout, pub max_padded_trace_length: usize, + pub bytecode_chunk_count: usize, } -impl JoltSharedPreprocessing { +impl JoltSharedPreprocessing +where + PCS::Commitment: CanonicalSerialize, +{ /// Blake2b-256 digest of the serialized preprocessing, used to bind /// the program identity to the Fiat-Shamir transcript. pub fn digest(&self) -> [u8; 32] { @@ -1691,78 +2085,202 @@ impl JoltSharedPreprocessing { } } -impl CanonicalSerialize for JoltSharedPreprocessing { +impl CanonicalSerialize for JoltSharedPreprocessing +where + PCS::Commitment: CanonicalSerialize, +{ fn serialize_with_mode( &self, mut writer: W, compress: ark_serialize::Compress, ) -> Result<(), ark_serialize::SerializationError> { - self.bytecode - .as_ref() + self.program.serialize_with_mode(&mut writer, compress)?; + self.program_meta .serialize_with_mode(&mut writer, compress)?; - self.ram.serialize_with_mode(&mut writer, compress)?; self.memory_layout .serialize_with_mode(&mut writer, compress)?; self.max_padded_trace_length .serialize_with_mode(&mut writer, compress)?; + self.bytecode_chunk_count + .serialize_with_mode(&mut writer, compress)?; Ok(()) } fn serialized_size(&self, compress: ark_serialize::Compress) -> usize { - self.bytecode.serialized_size(compress) - + self.ram.serialized_size(compress) + self.program.serialized_size(compress) + + self.program_meta.serialized_size(compress) + self.memory_layout.serialized_size(compress) + self.max_padded_trace_length.serialized_size(compress) + + self.bytecode_chunk_count.serialized_size(compress) } } -impl CanonicalDeserialize for JoltSharedPreprocessing { +impl CanonicalDeserialize for JoltSharedPreprocessing +where + PCS::Commitment: CanonicalDeserialize, +{ fn deserialize_with_mode( mut reader: R, compress: ark_serialize::Compress, validate: ark_serialize::Validate, ) -> Result { - let bytecode = - BytecodePreprocessing::deserialize_with_mode(&mut reader, compress, validate)?; - let ram = RAMPreprocessing::deserialize_with_mode(&mut reader, compress, validate)?; + let program = ProgramPreprocessing::deserialize_with_mode(&mut reader, compress, validate)?; + let program_meta = ProgramMetadata::deserialize_with_mode(&mut reader, compress, validate)?; let memory_layout = MemoryLayout::deserialize_with_mode(&mut reader, compress, validate)?; let max_padded_trace_length = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let bytecode_chunk_count = usize::deserialize_with_mode(&mut reader, compress, validate)?; Ok(Self { - bytecode: Arc::new(bytecode), - ram, + program, + program_meta, memory_layout, max_padded_trace_length, + bytecode_chunk_count, }) } } -impl ark_serialize::Valid for JoltSharedPreprocessing { +impl ark_serialize::Valid for JoltSharedPreprocessing +where + PCS::Commitment: ark_serialize::Valid, +{ fn check(&self) -> Result<(), ark_serialize::SerializationError> { - self.bytecode.check()?; - self.ram.check()?; + self.program.check()?; + self.program_meta.check()?; self.memory_layout.check()?; Ok(()) } } -impl JoltSharedPreprocessing { +impl JoltSharedPreprocessing { #[tracing::instrument(skip_all, name = "JoltSharedPreprocessing::new")] pub fn new( - bytecode: Vec, + program: ProgramPreprocessing, memory_layout: MemoryLayout, - memory_init: Vec<(u64, u8)>, max_padded_trace_length: usize, - entry_address: u64, - ) -> Result { - let bytecode = Arc::new(BytecodePreprocessing::preprocess(bytecode, entry_address)?); - let ram = RAMPreprocessing::preprocess(memory_init); - Ok(Self { - bytecode, - ram, + ) -> JoltSharedPreprocessing { + Self { + program_meta: program.meta(), + program, memory_layout, max_padded_trace_length, - }) + bytecode_chunk_count: DEFAULT_COMMITTED_BYTECODE_CHUNK_COUNT, + } + } + + #[tracing::instrument(skip_all, name = "JoltSharedPreprocessing::new_committed")] + pub fn new_committed( + program: ProgramPreprocessing, + memory_layout: MemoryLayout, + max_padded_trace_length: usize, + bytecode_chunk_count: usize, + ) -> JoltSharedPreprocessing { + validate_committed_bytecode_chunk_count(bytecode_chunk_count); + assert!( + program.bytecode_len().is_multiple_of(bytecode_chunk_count), + "bytecode chunk count ({bytecode_chunk_count}) must divide bytecode size ({})", + program.bytecode_len() + ); + let mut shared = Self { + program_meta: program.meta(), + program, + memory_layout, + max_padded_trace_length, + bytecode_chunk_count, + }; + let (max_total_vars, max_log_k_chunk) = shared.compute_max_total_vars(true); + let generators = PCS::setup_prover(max_total_vars); + shared.program = shared.program.commit( + &shared.memory_layout, + &generators, + shared.bytecode_chunk_count, + max_log_k_chunk, + ); + shared.program_meta = shared.program.meta(); + shared + } + + pub fn is_committed_mode(&self) -> bool { + self.program.is_committed() + } + + pub fn bytecode_size(&self) -> usize { + self.program_meta.bytecode_len + } + + #[inline] + pub fn committed_program_image_num_words(&self) -> usize { + self.program_meta + .committed_program_image_num_words(&self.memory_layout) + } + + #[inline] + pub(crate) fn precommitted_candidate_total_vars( + &self, + include_committed: bool, + include_trusted_advice: bool, + include_untrusted_advice: bool, + ) -> Vec { + let mut candidates = Vec::with_capacity( + include_committed as usize * 2 + + include_trusted_advice as usize + + include_untrusted_advice as usize, + ); + + if include_trusted_advice { + let (trusted_sigma, trusted_nu) = DoryGlobals::advice_sigma_nu_from_max_bytes( + self.memory_layout.max_trusted_advice_size as usize, + ); + candidates.push(trusted_sigma + trusted_nu); + } + + if include_untrusted_advice { + let (untrusted_sigma, untrusted_nu) = DoryGlobals::advice_sigma_nu_from_max_bytes( + self.memory_layout.max_untrusted_advice_size as usize, + ); + candidates.push(untrusted_sigma + untrusted_nu); + } + + if include_committed { + let chunk_cycle_log_t = (self.bytecode_size() / self.bytecode_chunk_count) + .next_power_of_two() + .log_2(); + candidates.push(committed_lanes().log_2() + chunk_cycle_log_t); + candidates.push(self.committed_program_image_num_words().log_2()); + } + + candidates + } + + #[inline] + pub(crate) fn max_total_vars_from_candidates( + main_total_vars: usize, + candidates: impl IntoIterator, + ) -> usize { + let mut max_total_vars = main_total_vars; + for total_vars in candidates { + max_total_vars = max_total_vars.max(total_vars); + } + max_total_vars + } + + #[inline] + pub(crate) fn compute_max_total_vars(&self, include_committed: bool) -> (usize, usize) { + use common::constants::ONEHOT_CHUNK_THRESHOLD_LOG_T; + let max_t_any = self.max_padded_trace_length.next_power_of_two(); + let max_log_t = max_t_any.log_2(); + let max_log_k_chunk = if max_log_t < ONEHOT_CHUNK_THRESHOLD_LOG_T { + 4 + } else { + 8 + }; + + let max_total_vars = Self::max_total_vars_from_candidates( + max_log_k_chunk + max_log_t, + self.precommitted_candidate_total_vars(include_committed, true, true), + ); + + (max_total_vars, max_log_k_chunk) } } @@ -1783,23 +2301,99 @@ impl From> for PedersenGenerators { } } -#[derive(Debug, Clone, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Debug, Clone)] pub struct JoltVerifierPreprocessing where F: JoltField, C: JoltCurve, PCS: CommitmentScheme, { + _curve: std::marker::PhantomData, pub generators: PCS::VerifierSetup, - pub shared: JoltSharedPreprocessing, + pub shared: JoltSharedPreprocessing, pub blindfold_setup: Option>, } +impl CanonicalSerialize for JoltVerifierPreprocessing +where + F: JoltField, + C: JoltCurve, + PCS: CommitmentScheme, + PCS::VerifierSetup: CanonicalSerialize, + PCS::Commitment: CanonicalSerialize, +{ + fn serialize_with_mode( + &self, + mut writer: W, + compress: ark_serialize::Compress, + ) -> Result<(), ark_serialize::SerializationError> { + self.generators.serialize_with_mode(&mut writer, compress)?; + self.shared.serialize_with_mode(&mut writer, compress)?; + self.blindfold_setup + .serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: ark_serialize::Compress) -> usize { + self.generators.serialized_size(compress) + + self.shared.serialized_size(compress) + + self.blindfold_setup.serialized_size(compress) + } +} + +impl CanonicalDeserialize for JoltVerifierPreprocessing +where + F: JoltField, + C: JoltCurve, + PCS: CommitmentScheme, + PCS::VerifierSetup: CanonicalDeserialize, + PCS::Commitment: CanonicalDeserialize, +{ + fn deserialize_with_mode( + mut reader: R, + compress: ark_serialize::Compress, + validate: ark_serialize::Validate, + ) -> Result { + Ok(Self { + _curve: std::marker::PhantomData, + generators: PCS::VerifierSetup::deserialize_with_mode(&mut reader, compress, validate)?, + shared: JoltSharedPreprocessing::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + blindfold_setup: Option::>::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + }) + } +} + +impl ark_serialize::Valid for JoltVerifierPreprocessing +where + F: JoltField, + C: JoltCurve, + PCS: CommitmentScheme, + PCS::VerifierSetup: ark_serialize::Valid, + PCS::Commitment: ark_serialize::Valid, +{ + fn check(&self) -> Result<(), ark_serialize::SerializationError> { + self.generators.check()?; + self.shared.check()?; + self.blindfold_setup.check()?; + Ok(()) + } +} + impl Serializable for JoltVerifierPreprocessing where F: JoltField, C: JoltCurve, PCS: CommitmentScheme, + PCS::VerifierSetup: CanonicalSerialize + CanonicalDeserialize, + PCS::Commitment: CanonicalSerialize + CanonicalDeserialize, { } @@ -1808,6 +2402,8 @@ where F: JoltField, C: JoltCurve, PCS: CommitmentScheme, + PCS::VerifierSetup: CanonicalSerialize + CanonicalDeserialize, + PCS::Commitment: CanonicalSerialize + CanonicalDeserialize, { pub fn save_to_target_dir(&self, target_dir: &str) -> std::io::Result<()> { let filename = Path::new(target_dir).join("jolt_verifier_preprocessing.dat"); @@ -1832,11 +2428,13 @@ impl, PCS: CommitmentScheme> { #[tracing::instrument(skip_all, name = "JoltVerifierPreprocessing::new")] pub fn new( - shared: JoltSharedPreprocessing, + mut shared: JoltSharedPreprocessing, generators: PCS::VerifierSetup, blindfold_setup: Option>, ) -> Self { + shared.program = shared.program.to_verifier_program(); Self { + _curve: std::marker::PhantomData, generators, shared, blindfold_setup, diff --git a/jolt-core/src/zkvm/witness.rs b/jolt-core/src/zkvm/witness.rs index 3746454b1..e1de122c0 100644 --- a/jolt-core/src/zkvm/witness.rs +++ b/jolt-core/src/zkvm/witness.rs @@ -6,11 +6,11 @@ use common::jolt_device::MemoryLayout; use rayon::prelude::*; use tracer::instruction::Cycle; +use crate::curve::JoltCurve; use crate::poly::commitment::commitment_scheme::StreamingCommitmentScheme; -use crate::zkvm::bytecode::BytecodePreprocessing; -use crate::zkvm::config::OneHotParams; +use crate::zkvm::config::{OneHotParams, ProgramMode}; use crate::zkvm::instruction::InstructionFlags; -use crate::zkvm::verifier::JoltSharedPreprocessing; +use crate::zkvm::prover::JoltProverPreprocessing; use crate::{ field::JoltField, poly::{multilinear_polynomial::MultilinearPolynomial, one_hot_polynomial::OneHotPolynomial}, @@ -31,6 +31,8 @@ pub enum CommittedPolynomial { InstructionRa(usize), /// One-hot ra polynomial for the bytecode instance of Shout BytecodeRa(usize), + /// Dense committed bytecode chunk polynomial for committed program mode. + BytecodeChunk(usize), /// One-hot ra/wa polynomial for the RAM instance of Twist /// Note that for RAM, ra and wa are the same polynomial because /// there is at most one load or store per cycle. @@ -41,6 +43,8 @@ pub enum CommittedPolynomial { /// Untrusted advice polynomial - committed during proving, commitment in proof. /// Length cannot exceed max_trace_length. UntrustedAdvice, + /// Program image (initial RAM image) polynomial for committed program mode. + ProgramImageInit, } /// Returns a list of symbols representing all committed polynomials. @@ -58,17 +62,33 @@ pub fn all_committed_polynomials(one_hot_params: &OneHotParams) -> Vec Vec { + let mut polynomials = all_committed_polynomials(one_hot_params); + if program_mode == ProgramMode::Committed { + for i in 0..bytecode_chunk_count { + polynomials.push(CommittedPolynomial::BytecodeChunk(i)); + } + polynomials.push(CommittedPolynomial::ProgramImageInit); + } + polynomials +} + impl CommittedPolynomial { /// Generate witness data and compute tier 1 commitment for a single row - pub fn stream_witness_and_commit_rows( + pub fn stream_witness_and_commit_rows( &self, - setup: &PCS::ProverSetup, - preprocessing: &JoltSharedPreprocessing, + preprocessing: &JoltProverPreprocessing, row_cycles: &[tracer::instruction::Cycle], one_hot_params: &OneHotParams, ) -> ::ChunkState where F: JoltField, + C: JoltCurve, PCS: StreamingCommitmentScheme, { match self { @@ -80,7 +100,7 @@ impl CommittedPolynomial { post_value as i128 - pre_value as i128 }) .collect(); - PCS::process_chunk(setup, &row) + PCS::process_chunk(&preprocessing.generators, &row) } CommittedPolynomial::RamInc => { let row: Vec = row_cycles @@ -92,7 +112,7 @@ impl CommittedPolynomial { _ => 0, }) .collect(); - PCS::process_chunk(setup, &row) + PCS::process_chunk(&preprocessing.generators, &row) } CommittedPolynomial::InstructionRa(idx) => { let row: Vec> = row_cycles @@ -102,17 +122,17 @@ impl CommittedPolynomial { Some(one_hot_params.lookup_index_chunk(lookup_index, *idx) as usize) }) .collect(); - PCS::process_chunk_onehot(setup, one_hot_params.k_chunk, &row) + PCS::process_chunk_onehot(&preprocessing.generators, one_hot_params.k_chunk, &row) } CommittedPolynomial::BytecodeRa(idx) => { let row: Vec> = row_cycles .iter() .map(|cycle| { - let pc = preprocessing.bytecode.get_pc(cycle); + let pc = preprocessing.materialized_program().get_pc(cycle); Some(one_hot_params.bytecode_pc_chunk(pc, *idx) as usize) }) .collect(); - PCS::process_chunk_onehot(setup, one_hot_params.k_chunk, &row) + PCS::process_chunk_onehot(&preprocessing.generators, one_hot_params.k_chunk, &row) } CommittedPolynomial::RamRa(idx) => { let row: Vec> = row_cycles @@ -120,15 +140,18 @@ impl CommittedPolynomial { .map(|cycle| { remap_address( cycle.ram_access().address() as u64, - &preprocessing.memory_layout, + &preprocessing.shared.memory_layout, ) .map(|address| one_hot_params.ram_address_chunk(address, *idx) as usize) }) .collect(); - PCS::process_chunk_onehot(setup, one_hot_params.k_chunk, &row) + PCS::process_chunk_onehot(&preprocessing.generators, one_hot_params.k_chunk, &row) } - CommittedPolynomial::TrustedAdvice | CommittedPolynomial::UntrustedAdvice => { - panic!("Advice polynomials should not use streaming witness generation") + CommittedPolynomial::TrustedAdvice + | CommittedPolynomial::UntrustedAdvice + | CommittedPolynomial::ProgramImageInit + | CommittedPolynomial::BytecodeChunk(_) => { + panic!("Precommitted polynomials should not use streaming witness generation") } } } @@ -136,7 +159,7 @@ impl CommittedPolynomial { #[tracing::instrument(skip_all, name = "CommittedPolynomial::generate_witness")] pub fn generate_witness( &self, - bytecode_preprocessing: &BytecodePreprocessing, + bytecode_preprocessing: &crate::zkvm::bytecode::BytecodePreprocessing, memory_layout: &MemoryLayout, trace: &[Cycle], one_hot_params: Option<&OneHotParams>, @@ -212,8 +235,11 @@ impl CommittedPolynomial { one_hot_params.k_chunk, )) } - CommittedPolynomial::TrustedAdvice | CommittedPolynomial::UntrustedAdvice => { - panic!("Advice polynomials should not use generate_witness") + CommittedPolynomial::TrustedAdvice + | CommittedPolynomial::UntrustedAdvice + | CommittedPolynomial::ProgramImageInit + | CommittedPolynomial::BytecodeChunk(_) => { + panic!("Precommitted polynomials should not use generate_witness") } } } @@ -269,4 +295,10 @@ pub enum VirtualPolynomial { OpFlags(CircuitFlags), InstructionFlags(InstructionFlags), LookupTableFlag(usize), + BytecodeValStage(usize), + BytecodeReadRafAddrClaim, + BooleanityAddrClaim, + BytecodeClaimReductionIntermediate, + ProgramImageInitContributionRw, + ProgramImageInitContributionRaf, } diff --git a/jolt-sdk/macros/src/lib.rs b/jolt-sdk/macros/src/lib.rs index 4d5281604..1f3383781 100644 --- a/jolt-sdk/macros/src/lib.rs +++ b/jolt-sdk/macros/src/lib.rs @@ -77,8 +77,11 @@ impl MacroBuilder { let trace_to_file_fn = self.make_trace_to_file_func(); let compile_fn = self.make_compile_func(); let preprocess_shared_fn = self.make_preprocess_shared_func(); + let preprocess_shared_committed_fn = self.make_preprocess_shared_committed_func(); let preprocess_prover_fn = self.make_preprocess_prover_func(); + let preprocess_committed_prover_fn = self.make_preprocess_committed_prover_func(); let preprocess_verifier_fn = self.make_preprocess_verifier_func(); + let preprocess_committed_verifier_fn = self.make_preprocess_committed_verifier_func(); let verifier_preprocess_from_prover_fn = self.make_preprocess_from_prover_func(); let commit_trusted_advice_fn = self.make_commit_trusted_advice_func(); let prove_fn = self.make_prove_func(); @@ -111,8 +114,11 @@ impl MacroBuilder { #trace_to_file_fn #compile_fn #preprocess_shared_fn + #preprocess_shared_committed_fn #preprocess_prover_fn + #preprocess_committed_prover_fn #preprocess_verifier_fn + #preprocess_committed_verifier_fn #verifier_preprocess_from_prover_fn #commit_trusted_advice_fn #prove_fn @@ -469,7 +475,7 @@ impl MacroBuilder { { #imports - let (bytecode, memory_init, program_size, e_entry) = program.decode(); + let (bytecode, memory_init, program_size, _e_entry) = program.decode(); let memory_config = MemoryConfig { max_input_size: #max_input_size, max_output_size: #max_output_size, @@ -481,15 +487,64 @@ impl MacroBuilder { }; let memory_layout = MemoryLayout::new(&memory_config); - let preprocessing = JoltSharedPreprocessing::new( - bytecode, + let program_data = jolt::ProgramPreprocessing::preprocess(bytecode, memory_init)?; + Ok(JoltSharedPreprocessing::new( + program_data, memory_layout, - memory_init, #max_trace_length, - e_entry, - )?; + )) + } + } + } + + fn make_preprocess_shared_committed_func(&self) -> TokenStream2 { + let imports = self.make_imports(); + let attributes = parse_attributes(&self.attr); + let max_input_size = proc_macro2::Literal::u64_unsuffixed(attributes.max_input_size); + let max_output_size = proc_macro2::Literal::u64_unsuffixed(attributes.max_output_size); + let max_untrusted_advice_size = + proc_macro2::Literal::u64_unsuffixed(attributes.max_untrusted_advice_size); + let max_trusted_advice_size = + proc_macro2::Literal::u64_unsuffixed(attributes.max_trusted_advice_size); + let stack_size = proc_macro2::Literal::u64_unsuffixed(attributes.stack_size); + let heap_size = proc_macro2::Literal::u64_unsuffixed(attributes.heap_size); + let max_trace_length = proc_macro2::Literal::u64_unsuffixed(attributes.max_trace_length); + + let fn_name = self.get_func_name(); + let preprocess_shared_committed_fn_name = Ident::new( + &format!("preprocess_shared_committed_{fn_name}"), + fn_name.span(), + ); + quote! { + #[cfg(all(not(target_arch = "wasm32"), not(feature = "guest")))] + pub fn #preprocess_shared_committed_fn_name( + program: &mut jolt::host::Program, + bytecode_chunk_count: usize, + ) -> Result + { + #imports + + let (bytecode, memory_init, program_size, _e_entry) = program.decode(); + let memory_config = MemoryConfig { + max_input_size: #max_input_size, + max_output_size: #max_output_size, + max_untrusted_advice_size: #max_untrusted_advice_size, + max_trusted_advice_size: #max_trusted_advice_size, + stack_size: #stack_size, + heap_size: #heap_size, + program_size: Some(program_size), + }; + let memory_layout = MemoryLayout::new(&memory_config); - Ok(preprocessing) + let program_data = jolt::ProgramPreprocessing::preprocess(bytecode, memory_init)?; + let shared_preprocessing = JoltSharedPreprocessing::new_committed( + program_data, + memory_layout, + #max_trace_length, + bytecode_chunk_count, + ); + + Ok(shared_preprocessing) } } } @@ -502,19 +557,48 @@ impl MacroBuilder { Ident::new(&format!("preprocess_prover_{fn_name}"), fn_name.span()); quote! { #[cfg(all(not(target_arch = "wasm32"), not(feature = "guest")))] - pub fn #preprocess_prover_fn_name(shared_preprocessing: jolt::JoltSharedPreprocessing) + pub fn #preprocess_prover_fn_name( + shared_preprocessing: jolt::JoltSharedPreprocessing + ) -> jolt::JoltProverPreprocessing { #imports - let prover_preprocessing = JoltProverPreprocessing::new( - shared_preprocessing, - ); + let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing); prover_preprocessing } } } + fn make_preprocess_committed_prover_func(&self) -> TokenStream2 { + let imports = self.make_imports(); + + let fn_name = self.get_func_name(); + let preprocess_committed_fn_name = + Ident::new(&format!("preprocess_committed_{fn_name}"), fn_name.span()); + let preprocess_shared_committed_fn_name = Ident::new( + &format!("preprocess_shared_committed_{fn_name}"), + fn_name.span(), + ); + quote! { + #[cfg(all(not(target_arch = "wasm32"), not(feature = "guest")))] + pub fn #preprocess_committed_fn_name( + program: &mut jolt::host::Program, + bytecode_chunk_count: usize, + ) + -> Result< + jolt::JoltProverPreprocessing, + jolt::PreprocessingError, + > + { + #imports + let shared_preprocessing = + #preprocess_shared_committed_fn_name(program, bytecode_chunk_count)?; + Ok(JoltProverPreprocessing::new(shared_preprocessing)) + } + } + } + fn make_preprocess_verifier_func(&self) -> TokenStream2 { let fn_name = self.get_func_name(); let preprocess_verifier_fn_name = @@ -527,6 +611,30 @@ impl MacroBuilder { generators: ::VerifierSetup, blindfold_setup: Option>, ) -> jolt::JoltVerifierPreprocessing + { + jolt::JoltVerifierPreprocessing::new( + shared_preprocess, + generators, + blindfold_setup, + ) + } + } + } + + fn make_preprocess_committed_verifier_func(&self) -> TokenStream2 { + let fn_name = self.get_func_name(); + let preprocess_committed_verifier_fn_name = Ident::new( + &format!("preprocess_committed_verifier_{fn_name}"), + fn_name.span(), + ); + + quote! { + #[cfg(all(not(target_arch = "wasm32"), not(feature = "guest")))] + pub fn #preprocess_committed_verifier_fn_name( + shared_preprocess: jolt::JoltSharedPreprocessing, + generators: ::VerifierSetup, + blindfold_setup: Option>, + ) -> jolt::JoltVerifierPreprocessing { jolt::JoltVerifierPreprocessing::new(shared_preprocess, generators, blindfold_setup) } diff --git a/jolt-sdk/src/host_utils.rs b/jolt-sdk/src/host_utils.rs index f7d09ba2c..cf12822a5 100644 --- a/jolt-sdk/src/host_utils.rs +++ b/jolt-sdk/src/host_utils.rs @@ -13,15 +13,19 @@ pub use jolt_core::field::JoltField; pub use jolt_core::guest; pub use jolt_core::poly::commitment::dory::DoryCommitmentScheme as PCS; pub use jolt_core::zkvm::{ - bytecode::PreprocessingError, proof_serialization::JoltProof, - verifier::JoltSharedPreprocessing, verifier::JoltVerifierPreprocessing, RV64IMACProof, - RV64IMACVerifier, Serializable, + bytecode::{PreprocessingError, TrustedBytecodeCommitments}, + program::ProgramPreprocessing, + program::TrustedProgramCommitments, + proof_serialization::JoltProof, + verifier::JoltSharedPreprocessing, + verifier::JoltVerifierPreprocessing, + RV64IMACProof, RV64IMACVerifier, Serializable, }; pub use jolt_core::AdviceTape; // Re-exports needed by the provable macro pub use jolt_core::poly::commitment::commitment_scheme::CommitmentScheme; -pub use jolt_core::poly::commitment::dory::{DoryContext, DoryGlobals}; +pub use jolt_core::poly::commitment::dory::{DoryContext, DoryGlobals, DoryLayout}; pub use jolt_core::poly::multilinear_polynomial::MultilinearPolynomial; pub use jolt_core::zkvm::ram::populate_memory_states; pub use jolt_core::zkvm::verifier::BlindfoldSetup; diff --git a/src/build_wasm.rs b/src/build_wasm.rs index 4ad5d9e7c..2a84cb604 100644 --- a/src/build_wasm.rs +++ b/src/build_wasm.rs @@ -14,6 +14,7 @@ use jolt_core::{ host::Program, poly::commitment::dory::DoryCommitmentScheme, zkvm::{ + program::ProgramPreprocessing, prover::JoltProverPreprocessing, verifier::{JoltSharedPreprocessing, JoltVerifierPreprocessing}, Serializable, @@ -28,16 +29,16 @@ struct FunctionAttributes { } fn preprocess_and_save(func_name: &str, attributes: &Attributes, is_std: bool) -> Result<()> { - let mut program = Program::new("guest"); + let mut host_program = Program::new("guest"); - program.set_func(func_name); - program.set_std(is_std); - program.set_heap_size(attributes.heap_size); - program.set_stack_size(attributes.stack_size); - program.set_max_input_size(attributes.max_input_size); - program.set_max_output_size(attributes.max_output_size); + host_program.set_func(func_name); + host_program.set_std(is_std); + host_program.set_heap_size(attributes.heap_size); + host_program.set_stack_size(attributes.stack_size); + host_program.set_max_input_size(attributes.max_input_size); + host_program.set_max_output_size(attributes.max_output_size); - let (bytecode, memory_init, program_size, e_entry) = program.decode(); + let (bytecode, memory_init, program_size, _e_entry) = host_program.decode(); let memory_config = MemoryConfig { max_input_size: attributes.max_input_size, @@ -50,13 +51,12 @@ fn preprocess_and_save(func_name: &str, attributes: &Attributes, is_std: bool) - }; let memory_layout = MemoryLayout::new(&memory_config); + let preprocessed_program = ProgramPreprocessing::preprocess(bytecode, memory_init)?; let shared = JoltSharedPreprocessing::new( - bytecode, + preprocessed_program, memory_layout, - memory_init, attributes.max_trace_length as usize, - e_entry, - )?; + ); let prover_preprocessing = JoltProverPreprocessing::::new(shared); @@ -71,7 +71,7 @@ fn preprocess_and_save(func_name: &str, attributes: &Attributes, is_std: bool) - let mut file = File::create(verifier_path)?; file.write_all(&verifier_bytes)?; - let elf_bytes = program + let elf_bytes = host_program .get_elf_contents() .expect("ELF not found after decode"); let elf_path = target_dir.join(format!("{func_name}.elf")); diff --git a/tracer/src/utils/virtual_registers.rs b/tracer/src/utils/virtual_registers.rs index e93bb6b9f..05392c09b 100644 --- a/tracer/src/utils/virtual_registers.rs +++ b/tracer/src/utils/virtual_registers.rs @@ -131,7 +131,7 @@ impl VirtualRegisterAllocator { /// Allocate virtual register that can be used in the inline sequence of /// an instruction. Skips reserved registers (32-39) and uses registers 40-47. - pub(crate) fn allocate(&self) -> VirtualRegisterGuard { + pub fn allocate(&self) -> VirtualRegisterGuard { for (i, allocated) in self .allocated .lock() diff --git a/transpiler/src/main.rs b/transpiler/src/main.rs index 8cd17fe82..fe3dcebce 100644 --- a/transpiler/src/main.rs +++ b/transpiler/src/main.rs @@ -155,17 +155,6 @@ fn main() { real_preprocessing.shared.memory_layout ); - // Convert to symbolic preprocessing: replace Dory generators with AstVerifierSetup stub. - // The `shared` field (memory layout, bytecode info) is reused as-is. - // AstCommitmentScheme satisfies the CommitmentScheme trait but performs no cryptographic - // operations. PCS verification is skipped in stages 1-6. - let symbolic_preprocessing: JoltVerifierPreprocessing = - JoltVerifierPreprocessing { - generators: transpiler::symbolic_traits::ast_commitment_scheme::AstVerifierSetup, - shared: real_preprocessing.shared.clone(), - blindfold_setup: None, - }; - // ========================================================================= // Step 2: Convert proof to symbolic representation // ========================================================================= @@ -208,7 +197,7 @@ fn main() { let transcript: SelectedAstTranscript = Transcript::new(b"Jolt"); // ========================================================================= - // Step 2b: Symbolize IO device + // Step 2b: Symbolize IO device and preprocessing // ========================================================================= // Make inputs/outputs/panic into witness variables instead of constants. // This sets up two override mechanisms: @@ -220,16 +209,25 @@ fn main() { transpiler::symbolize::symbolize_io_device(&io_device, &mut var_alloc); println!(" IO input words: {}", eval_input_words.len()); - // Set PENDING_INITIAL_RAM: bytecode as constants, inputs as symbolic + // Set PENDING_INITIAL_RAM: bytecode as constants, inputs as symbolic. + // In committed mode the verifier gets the bytecode contribution from a + // claim-reduction sumcheck, so PENDING_INITIAL_RAM only needs input words. { use jolt_core::zkvm::ram::{set_pending_initial_ram, PendingInitialRamValues}; - let bytecode_words: Vec = real_preprocessing - .shared - .ram - .bytecode_words - .iter() - .map(|&w| MleAst::from_u64(w)) - .collect(); + let bytecode_words: Vec = if real_preprocessing.shared.program.is_full() { + real_preprocessing + .shared + .program + .as_full() + .unwrap() + .ram + .bytecode_words + .iter() + .map(|&w| MleAst::from_u64(w)) + .collect() + } else { + Vec::new() + }; set_pending_initial_ram(PendingInitialRamValues { bytecode_words, input_words: eval_input_words, @@ -240,6 +238,88 @@ fn main() { var_alloc.next_idx() ); + // Convert to symbolic preprocessing: replace Dory generators with AstVerifierSetup stub. + // AstCommitmentScheme satisfies the CommitmentScheme trait but performs no cryptographic + // operations. PCS verification is skipped in stages 1-7. + // + // For Full mode: ProgramPreprocessing::Full is PCS-independent, just wrap it. + // For Committed mode: symbolize the trusted bytecode/program commitments so the + // transpiled circuit can include them as witness inputs. + println!("\n=== Converting Preprocessing ==="); + let symbolic_shared = { + use jolt_core::zkvm::bytecode::TrustedBytecodeCommitments; + use jolt_core::zkvm::program::{ + CommittedProgramPreprocessing, ProgramPreprocessing, TrustedProgramCommitments, + }; + + let symbolic_program: ProgramPreprocessing = match &real_preprocessing + .shared + .program + { + ProgramPreprocessing::Full(full) => { + println!(" Mode: Full"); + ProgramPreprocessing::Full(full.clone()) + } + ProgramPreprocessing::Committed(committed) => { + println!(" Mode: Committed (symbolizing trusted commitments)"); + let bytecode_commitments = TrustedBytecodeCommitments { + commitments: committed + .bytecode_commitments + .commitments + .iter() + .enumerate() + .map(|(i, c)| { + let chunks = + var_alloc.alloc_commitment(c, &format!("trusted_bytecode_{i}")); + AstCommitment::new(chunks) + }) + .collect(), + num_columns: committed.bytecode_commitments.num_columns, + log_k_chunk: committed.bytecode_commitments.log_k_chunk, + bytecode_chunk_count: committed.bytecode_commitments.bytecode_chunk_count, + bytecode_len: committed.bytecode_commitments.bytecode_len, + bytecode_T: committed.bytecode_commitments.bytecode_T, + }; + let program_commitments = TrustedProgramCommitments { + program_image_commitment: { + let chunks = var_alloc.alloc_commitment( + &committed.program_commitments.program_image_commitment, + "trusted_program_image", + ); + AstCommitment::new(chunks) + }, + program_image_num_columns: committed + .program_commitments + .program_image_num_columns, + program_image_num_words: committed.program_commitments.program_image_num_words, + }; + ProgramPreprocessing::Committed(CommittedProgramPreprocessing { + meta: committed.meta.clone(), + bytecode_commitments, + program_commitments, + prover_data: None, + }) + } + }; + + // Construct directly — don't use new_committed() because it calls + // PCS::setup_prover + program.commit, which panics for AstCommitmentScheme. + // The symbolic program already has the commitments symbolized above. + jolt_core::zkvm::verifier::JoltSharedPreprocessing:: { + program_meta: symbolic_program.meta(), + program: symbolic_program, + memory_layout: real_preprocessing.shared.memory_layout.clone(), + max_padded_trace_length: real_preprocessing.shared.max_padded_trace_length, + bytecode_chunk_count: real_preprocessing.shared.bytecode_chunk_count, + } + }; + let symbolic_preprocessing: JoltVerifierPreprocessing = + JoltVerifierPreprocessing::new( + symbolic_shared, + transpiler::symbolic_traits::ast_commitment_scheme::AstVerifierSetup, + None, + ); + // ========================================================================= // Step 3: Set up symbolic verifier // ========================================================================= diff --git a/transpiler/src/symbolic_proof.rs b/transpiler/src/symbolic_proof.rs index 75311d49a..54798fd79 100644 --- a/transpiler/src/symbolic_proof.rs +++ b/transpiler/src/symbolic_proof.rs @@ -355,11 +355,18 @@ pub fn symbolize_proof( "stage5_sumcheck", ); - // === Symbolize stage 6 sumcheck proof === - let stage6_sumcheck = symbolize_sumcheck_variant::( - &real_proof.stage6_sumcheck_proof, + // === Symbolize stage 6a sumcheck proof === + let stage6a_sumcheck = symbolize_sumcheck_variant::( + &real_proof.stage6a_sumcheck_proof, &mut alloc, - "stage6_sumcheck", + "stage6a_sumcheck", + ); + + // === Symbolize stage 6b sumcheck proof === + let stage6b_sumcheck = symbolize_sumcheck_variant::( + &real_proof.stage6b_sumcheck_proof, + &mut alloc, + "stage6b_sumcheck", ); // === Symbolize stage 7 sumcheck proof === @@ -390,7 +397,8 @@ pub fn symbolize_proof( stage3_sumcheck_proof: stage3_sumcheck, stage4_sumcheck_proof: stage4_sumcheck, stage5_sumcheck_proof: stage5_sumcheck, - stage6_sumcheck_proof: stage6_sumcheck, + stage6a_sumcheck_proof: stage6a_sumcheck, + stage6b_sumcheck_proof: stage6b_sumcheck, stage7_sumcheck_proof: stage7_sumcheck, joint_opening_proof: AstProof::default(), untrusted_advice_commitment,