diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index ac1395be7b..982524904b 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -69,6 +69,8 @@ jobs: path: ~/.jolt - name: Install Jolt RISC-V Rust toolchain run: cargo run install-toolchain + - name: Clear Dory URS cache + run: rm -rf ~/.cache/dory - name: Install nextest uses: taiki-e/install-action@nextest - name: Run jolt-core tests diff --git a/.gitignore b/.gitignore index 6c88a867c6..fc6d03d695 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,4 @@ jolt-sdk/tests/fib_io_device_bytes.rs jolt-sdk/tests/fib_proof_bytes.rs jolt-sdk/tests/jolt_verifier_preprocessing_bytes.rs +bytecode-commitment-progress.md diff --git a/book/src/usage/guests_hosts/hosts.md b/book/src/usage/guests_hosts/hosts.md index 5c05bb9dda..5c1f2fae1f 100644 --- a/book/src/usage/guests_hosts/hosts.md +++ b/book/src/usage/guests_hosts/hosts.md @@ -5,7 +5,7 @@ Hosts are where we can invoke the Jolt prover to prove functions defined within The host imports the guest package, and will have automatically generated functions to build each of the Jolt functions. For the SHA3 example we looked at in the [guest](./guests.md) section, the `jolt::provable` procedural macro generates several functions that can be invoked from the host (shown below): - `compile_sha3(target_dir)` to compile the SHA3 guest to RISC-V -- `preprocess_prover_sha3` and `verifier_preprocessing_from_prover_sha3` to generate the prover and verifier preprocessing. Note that the preprocessing only needs to be generated once for a given guest program, and can subsequently be reused to prove multiple invocations of the guest. +- `preprocess_sha3` and `verifier_preprocessing_from_prover_sha3` to generate the prover and verifier preprocessing. Note that the preprocessing only needs to be generated once for a given guest program, and can subsequently be reused to prove multiple invocations of the guest. - `build_prover_sha3` returns a closure for the prover, which takes in the same input types as the original function and modifies the output to additionally include a proof. - `build_verifier_sha3` returns a closure for the verifier, which verifies the proof. The verifier closure's parameters comprise of the program input, the claimed output, a `bool` value claiming whether the guest panicked, and the proof. @@ -14,7 +14,7 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_sha3(target_dir); - let prover_preprocessing = guest::preprocess_prover_sha3(&mut program); + let prover_preprocessing = guest::preprocess_sha3(&mut program); let verifier_preprocessing = guest::verifier_preprocessing_from_prover_sha3(&prover_preprocessing); diff --git a/examples/alloc/src/main.rs b/examples/alloc/src/main.rs index 1afd790d20..8845e61aaf 100644 --- a/examples/alloc/src/main.rs +++ b/examples/alloc/src/main.rs @@ -7,12 +7,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_alloc(target_dir); - let shared_preprocessing = guest::preprocess_shared_alloc(&mut program); - let prover_preprocessing = guest::preprocess_prover_alloc(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_alloc( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_alloc(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_alloc(&prover_preprocessing); let prove_alloc = guest::build_prover_alloc(program, prover_preprocessing); let verify_alloc = guest::build_verifier_alloc(verifier_preprocessing); diff --git a/examples/btreemap/host/src/main.rs b/examples/btreemap/host/src/main.rs index 011f502489..5bfb3ef5b5 100644 --- a/examples/btreemap/host/src/main.rs +++ b/examples/btreemap/host/src/main.rs @@ -17,19 +17,12 @@ pub fn btreemap() { guest::compile_btreemap(target_dir) }); - let shared_preprocessing = step!("Preprocessing shared", { - guest::preprocess_shared_btreemap(&mut program) - }); - let prover_preprocessing = step!("Preprocessing prover", { - guest::preprocess_prover_btreemap(shared_preprocessing.clone()) + guest::preprocess_btreemap(&mut program) }); let verifier_preprocessing = step!("Preprocessing verifier", { - guest::preprocess_verifier_btreemap( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ) + guest::verifier_preprocessing_from_prover_btreemap(&prover_preprocessing) }); let prove = step!("Building prover", { diff --git a/examples/collatz/src/main.rs b/examples/collatz/src/main.rs index c91450547d..1ea0415512 100644 --- a/examples/collatz/src/main.rs +++ b/examples/collatz/src/main.rs @@ -8,12 +8,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_collatz_convergence(target_dir); - let shared_preprocessing = guest::preprocess_shared_collatz_convergence(&mut program); - let prover_preprocessing = - guest::preprocess_prover_collatz_convergence(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let prover_preprocessing = guest::preprocess_collatz_convergence(&mut program); let verifier_preprocessing = - guest::preprocess_verifier_collatz_convergence(shared_preprocessing, verifier_setup); + guest::verifier_preprocessing_from_prover_collatz_convergence(&prover_preprocessing); let prove_collatz_single = guest::build_prover_collatz_convergence(program, prover_preprocessing); @@ -31,12 +28,9 @@ pub fn main() { // Prove/verify convergence for a range of numbers: let mut program = guest::compile_collatz_convergence_range(target_dir); - let shared_preprocessing = guest::preprocess_shared_collatz_convergence_range(&mut program); - let prover_preprocessing = - guest::preprocess_prover_collatz_convergence_range(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let prover_preprocessing = guest::preprocess_collatz_convergence_range(&mut program); let verifier_preprocessing = - guest::preprocess_verifier_collatz_convergence_range(shared_preprocessing, verifier_setup); + guest::verifier_preprocessing_from_prover_collatz_convergence_range(&prover_preprocessing); let prove_collatz_convergence = guest::build_prover_collatz_convergence_range(program, prover_preprocessing); diff --git a/examples/fibonacci/src/main.rs b/examples/fibonacci/src/main.rs index ac2b755cad..58bfd5e05f 100644 --- a/examples/fibonacci/src/main.rs +++ b/examples/fibonacci/src/main.rs @@ -6,16 +6,18 @@ pub fn main() { tracing_subscriber::fmt::init(); let save_to_disk = std::env::args().any(|arg| arg == "--save"); + let committed_bytecode = std::env::args().any(|arg| arg == "--committed-bytecode"); let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_fib(target_dir); - let shared_preprocessing = guest::preprocess_shared_fib(&mut program); - - let prover_preprocessing = guest::preprocess_prover_fib(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let prover_preprocessing = if committed_bytecode { + guest::preprocess_committed_fib(&mut program) + } else { + guest::preprocess_fib(&mut program) + }; let verifier_preprocessing = - guest::preprocess_verifier_fib(shared_preprocessing, verifier_setup); + guest::verifier_preprocessing_from_prover_fib(&prover_preprocessing); if save_to_disk { serialize_and_print_size( @@ -26,7 +28,6 @@ pub fn main() { .expect("Could not serialize preprocessing."); } - let prove_fib = guest::build_prover_fib(program, prover_preprocessing); let verify_fib = guest::build_verifier_fib(verifier_preprocessing); let program_summary = guest::analyze_fib(10); @@ -39,8 +40,22 @@ pub fn main() { info!("Trace file written to: {trace_file}."); let now = Instant::now(); - let (output, proof, io_device) = prove_fib(50); + let (output, proof, io_device) = if committed_bytecode { + let prove_fib = guest::build_prover_committed_fib(program, prover_preprocessing); + prove_fib(50) + } else { + let prove_fib = guest::build_prover_fib(program, prover_preprocessing); + prove_fib(50) + }; info!("Prover runtime: {} s", now.elapsed().as_secs_f64()); + info!( + "bytecode mode: {}", + if committed_bytecode { + "Committed" + } else { + "Full" + } + ); if save_to_disk { serialize_and_print_size("Proof", "/tmp/fib_proof.bin", &proof) diff --git a/examples/hash-bench/src/main.rs b/examples/hash-bench/src/main.rs index 181ec912c9..8c498ab3f2 100644 --- a/examples/hash-bench/src/main.rs +++ b/examples/hash-bench/src/main.rs @@ -6,11 +6,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_hashbench(target_dir); - let shared_preprocessing = guest::preprocess_shared_hashbench(&mut program); - let prover_preprocessing = guest::preprocess_prover_hashbench(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let prover_preprocessing = guest::preprocess_hashbench(&mut program); let verifier_preprocessing = - guest::preprocess_verifier_hashbench(shared_preprocessing, verifier_setup); + guest::verifier_preprocessing_from_prover_hashbench(&prover_preprocessing); let prove_hashbench = guest::build_prover_hashbench(program, prover_preprocessing); let verify_hashbench = guest::build_verifier_hashbench(verifier_preprocessing); diff --git a/examples/malloc/src/main.rs b/examples/malloc/src/main.rs index d28e99d067..39b3b955d4 100644 --- a/examples/malloc/src/main.rs +++ b/examples/malloc/src/main.rs @@ -4,12 +4,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_alloc(target_dir); - let shared_preprocessing = guest::preprocess_shared_alloc(&mut program); - let prover_preprocessing = guest::preprocess_prover_alloc(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_alloc( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_alloc(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_alloc(&prover_preprocessing); let prove = guest::build_prover_alloc(program, prover_preprocessing); let verify = guest::build_verifier_alloc(verifier_preprocessing); diff --git a/examples/memory-ops/src/main.rs b/examples/memory-ops/src/main.rs index a95af60aa0..3516b6144c 100644 --- a/examples/memory-ops/src/main.rs +++ b/examples/memory-ops/src/main.rs @@ -7,12 +7,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_memory_ops(target_dir); - let shared_preprocessing = guest::preprocess_shared_memory_ops(&mut program); - let prover_preprocessing = guest::preprocess_prover_memory_ops(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_memory_ops( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_memory_ops(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_memory_ops(&prover_preprocessing); let prove = guest::build_prover_memory_ops(program, prover_preprocessing); let verify = guest::build_verifier_memory_ops(verifier_preprocessing); diff --git a/examples/merkle-tree/src/main.rs b/examples/merkle-tree/src/main.rs index c31353402c..4a89261071 100644 --- a/examples/merkle-tree/src/main.rs +++ b/examples/merkle-tree/src/main.rs @@ -8,12 +8,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_merkle_tree(target_dir); - let shared_preprocessing = guest::preprocess_shared_merkle_tree(&mut program); - let prover_preprocessing = guest::preprocess_prover_merkle_tree(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_merkle_tree( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_merkle_tree(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_merkle_tree(&prover_preprocessing); let leaf1: &[u8] = &[5u8; 32]; let leaf2 = [6u8; 32]; diff --git a/examples/muldiv/src/main.rs b/examples/muldiv/src/main.rs index 7a3680e5dc..5cc95530db 100644 --- a/examples/muldiv/src/main.rs +++ b/examples/muldiv/src/main.rs @@ -7,12 +7,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_muldiv(target_dir); - let shared_preprocessing = guest::preprocess_shared_muldiv(&mut program); - let prover_preprocessing = guest::preprocess_prover_muldiv(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_muldiv( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_muldiv(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_muldiv(&prover_preprocessing); let prove = guest::build_prover_muldiv(program, prover_preprocessing); let verify = guest::build_verifier_muldiv(verifier_preprocessing); diff --git a/examples/multi-function/src/main.rs b/examples/multi-function/src/main.rs index 6d9f9da9f8..c12c081bbd 100644 --- a/examples/multi-function/src/main.rs +++ b/examples/multi-function/src/main.rs @@ -8,11 +8,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_add(target_dir); - let shared_preprocessing = guest::preprocess_shared_add(&mut program); - let prover_preprocessing = guest::preprocess_prover_add(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let prover_preprocessing = guest::preprocess_add(&mut program); let verifier_preprocessing = - guest::preprocess_verifier_add(shared_preprocessing, verifier_setup); + guest::verifier_preprocessing_from_prover_add(&prover_preprocessing); let prove_add = guest::build_prover_add(program, prover_preprocessing); let verify_add = guest::build_verifier_add(verifier_preprocessing); @@ -21,12 +19,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_mul(target_dir); - let shared_preprocessing = guest::preprocess_shared_mul(&mut program); - let prover_preprocessing = guest::preprocess_prover_mul(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_mul( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_mul(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_mul(&prover_preprocessing); let prove_mul = guest::build_prover_mul(program, prover_preprocessing); let verify_mul = guest::build_verifier_mul(verifier_preprocessing); diff --git a/examples/overflow/src/main.rs b/examples/overflow/src/main.rs index 4a17575e70..a677dc4537 100644 --- a/examples/overflow/src/main.rs +++ b/examples/overflow/src/main.rs @@ -9,9 +9,7 @@ pub fn main() { // An overflowing stack should fail to prove. let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_overflow_stack(target_dir); - let shared_preprocessing = guest::preprocess_shared_overflow_stack(&mut program); - let prover_preprocessing = - guest::preprocess_prover_overflow_stack(shared_preprocessing.clone()); + let prover_preprocessing = guest::preprocess_overflow_stack(&mut program); let prove_overflow_stack = guest::build_prover_overflow_stack(program, prover_preprocessing); let res = panic::catch_unwind(|| { @@ -23,8 +21,7 @@ pub fn main() { // now lets try to overflow the heap, should also panic let mut program = guest::compile_overflow_heap(target_dir); - let shared_preprocessing = guest::preprocess_shared_overflow_heap(&mut program); - let prover_preprocessing = guest::preprocess_prover_overflow_heap(shared_preprocessing.clone()); + let prover_preprocessing = guest::preprocess_overflow_heap(&mut program); let prove_overflow_heap = guest::build_prover_overflow_heap(program, prover_preprocessing); let res = panic::catch_unwind(|| { @@ -35,15 +32,11 @@ pub fn main() { // valid case for stack allocation, calls overflow_stack() under the hood // but with stack_size=8192 let mut program = guest::compile_allocate_stack_with_increased_size(target_dir); - - let shared_preprocessing = - guest::preprocess_shared_allocate_stack_with_increased_size(&mut program); - let prover_preprocessing = - guest::preprocess_prover_allocate_stack_with_increased_size(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_allocate_stack_with_increased_size( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_allocate_stack_with_increased_size(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_allocate_stack_with_increased_size( + &prover_preprocessing, + ); let prove_allocate_stack_with_increased_size = guest::build_prover_allocate_stack_with_increased_size(program, prover_preprocessing); diff --git a/examples/random/src/main.rs b/examples/random/src/main.rs index e4456db259..0379c49bd0 100644 --- a/examples/random/src/main.rs +++ b/examples/random/src/main.rs @@ -7,12 +7,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_rand(target_dir); - let shared_preprocessing = guest::preprocess_shared_rand(&mut program); - let prover_preprocessing = guest::preprocess_prover_rand(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_rand( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_rand(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_rand(&prover_preprocessing); let prove = guest::build_prover_rand(program, prover_preprocessing); let verify = guest::build_verifier_rand(verifier_preprocessing); diff --git a/examples/recover-ecdsa/src/main.rs b/examples/recover-ecdsa/src/main.rs index 038a5c1fa7..512a59ca22 100644 --- a/examples/recover-ecdsa/src/main.rs +++ b/examples/recover-ecdsa/src/main.rs @@ -31,12 +31,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_recover(target_dir); - let shared_preprocessing = guest::preprocess_shared_recover(&mut program); - let prover_preprocessing = guest::preprocess_prover_recover(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_recover( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_recover(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_recover(&prover_preprocessing); if save_to_disk { serialize_and_print_size( diff --git a/examples/secp256k1-ecdsa-verify/src/main.rs b/examples/secp256k1-ecdsa-verify/src/main.rs index dfe38f6da8..4ebc61bcec 100644 --- a/examples/secp256k1-ecdsa-verify/src/main.rs +++ b/examples/secp256k1-ecdsa-verify/src/main.rs @@ -7,12 +7,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_secp256k1_ecdsa_verify(target_dir); - let shared_preprocessing = guest::preprocess_shared_secp256k1_ecdsa_verify(&mut program); - let prover_preprocessing = - guest::preprocess_prover_secp256k1_ecdsa_verify(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let prover_preprocessing = guest::preprocess_secp256k1_ecdsa_verify(&mut program); let verifier_preprocessing = - guest::preprocess_verifier_secp256k1_ecdsa_verify(shared_preprocessing, verifier_setup); + guest::verifier_preprocessing_from_prover_secp256k1_ecdsa_verify(&prover_preprocessing); let prove_secp256k1_ecdsa_verify = guest::build_prover_secp256k1_ecdsa_verify(program, prover_preprocessing); diff --git a/examples/sha2-chain/src/main.rs b/examples/sha2-chain/src/main.rs index 94114c0414..f7f1ccbd60 100644 --- a/examples/sha2-chain/src/main.rs +++ b/examples/sha2-chain/src/main.rs @@ -7,12 +7,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_sha2_chain(target_dir); - let shared_preprocessing = guest::preprocess_shared_sha2_chain(&mut program); - let prover_preprocessing = guest::preprocess_prover_sha2_chain(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_sha2_chain( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_sha2_chain(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_sha2_chain(&prover_preprocessing); let prove_sha2_chain = guest::build_prover_sha2_chain(program, prover_preprocessing); let verify_sha2_chain = guest::build_verifier_sha2_chain(verifier_preprocessing); diff --git a/examples/sha2-ex/src/main.rs b/examples/sha2-ex/src/main.rs index 4bce837fb8..2d86050f25 100644 --- a/examples/sha2-ex/src/main.rs +++ b/examples/sha2-ex/src/main.rs @@ -7,12 +7,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_sha2(target_dir); - let shared_preprocessing = guest::preprocess_shared_sha2(&mut program); - let prover_preprocessing = guest::preprocess_prover_sha2(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_sha2( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_sha2(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_sha2(&prover_preprocessing); let prove_sha2 = guest::build_prover_sha2(program, prover_preprocessing); let verify_sha2 = guest::build_verifier_sha2(verifier_preprocessing); diff --git a/examples/sha3-chain/src/main.rs b/examples/sha3-chain/src/main.rs index 97e223467b..cae32b0148 100644 --- a/examples/sha3-chain/src/main.rs +++ b/examples/sha3-chain/src/main.rs @@ -6,12 +6,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_sha3_chain(target_dir); - let shared_preprocessing = guest::preprocess_shared_sha3_chain(&mut program); - let prover_preprocessing = guest::preprocess_prover_sha3_chain(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_sha3_chain( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_sha3_chain(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_sha3_chain(&prover_preprocessing); let prove_sha3_chain = guest::build_prover_sha3_chain(program, prover_preprocessing); let verify_sha3_chain = guest::build_verifier_sha3_chain(verifier_preprocessing); diff --git a/examples/sha3-ex/src/main.rs b/examples/sha3-ex/src/main.rs index 1b49530258..69467d6f4e 100644 --- a/examples/sha3-ex/src/main.rs +++ b/examples/sha3-ex/src/main.rs @@ -6,12 +6,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_sha3(target_dir); - let shared_preprocessing = guest::preprocess_shared_sha3(&mut program); - let prover_preprocessing = guest::preprocess_prover_sha3(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_sha3( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_sha3(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_sha3(&prover_preprocessing); let prove_sha3 = guest::build_prover_sha3(program, prover_preprocessing); let verify_sha3 = guest::build_verifier_sha3(verifier_preprocessing); diff --git a/examples/stdlib/src/main.rs b/examples/stdlib/src/main.rs index 8edd0fed21..8b84b31743 100644 --- a/examples/stdlib/src/main.rs +++ b/examples/stdlib/src/main.rs @@ -7,12 +7,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_int_to_string(target_dir); - let shared_preprocessing = guest::preprocess_shared_int_to_string(&mut program); - let prover_preprocessing = guest::preprocess_prover_int_to_string(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_int_to_string( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_int_to_string(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_int_to_string(&prover_preprocessing); let prove = guest::build_prover_int_to_string(program, prover_preprocessing); let verify = guest::build_verifier_int_to_string(verifier_preprocessing); @@ -24,12 +21,9 @@ pub fn main() { let mut program = guest::compile_string_concat(target_dir); - let shared_preprocessing = guest::preprocess_shared_string_concat(&mut program); - let prover_preprocessing = guest::preprocess_prover_string_concat(shared_preprocessing.clone()); - let verifier_preprocessing = guest::preprocess_verifier_string_concat( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); + let prover_preprocessing = guest::preprocess_string_concat(&mut program); + let verifier_preprocessing = + guest::verifier_preprocessing_from_prover_string_concat(&prover_preprocessing); let prove = guest::build_prover_string_concat(program, prover_preprocessing); let verify = guest::build_verifier_string_concat(verifier_preprocessing); diff --git a/jolt-core/benches/e2e_profiling.rs b/jolt-core/benches/e2e_profiling.rs index cf5cb3b65d..876cf2e434 100644 --- a/jolt-core/benches/e2e_profiling.rs +++ b/jolt-core/benches/e2e_profiling.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use ark_serialize::CanonicalSerialize; use jolt_core::host; use jolt_core::zkvm::prover::JoltProverPreprocessing; @@ -201,19 +203,24 @@ fn prove_example( ) -> Vec<(tracing::Span, Box)> { let mut tasks = Vec::new(); let mut program = host::Program::new(example_name); - let (bytecode, init_memory_state, _) = program.decode(); + let (instructions, init_memory_state, _) = program.decode(); let (_lazy_trace, trace, _, program_io) = program.trace(&serialized_input, &[], &[]); let padded_trace_len = (trace.len() + 1).next_power_of_two(); drop(trace); let task = move || { + use jolt_core::zkvm::program::ProgramPreprocessing; + let program_data = Arc::new(ProgramPreprocessing::preprocess( + instructions, + init_memory_state, + )); let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode, + program_data.meta(), program_io.memory_layout.clone(), - init_memory_state, padded_trace_len, ); - let preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); + let preprocessing = + JoltProverPreprocessing::new(shared_preprocessing.clone(), Arc::clone(&program_data)); let elf_contents_opt = program.get_elf_contents(); let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); @@ -229,9 +236,10 @@ fn prove_example( let program_io = prover.program_io.clone(); let (jolt_proof, _) = prover.prove(); - let verifier_preprocessing = JoltVerifierPreprocessing::new( + let verifier_preprocessing = JoltVerifierPreprocessing::new_full( shared_preprocessing, preprocessing.generators.to_verifier_setup(), + Arc::clone(&preprocessing.program), ); let verifier = RV64IMACVerifier::new(&verifier_preprocessing, jolt_proof, program_io, None, None) @@ -255,7 +263,7 @@ fn prove_example_with_trace( _scale: usize, ) -> (std::time::Duration, usize, usize, usize) { let mut program = host::Program::new(example_name); - let (bytecode, init_memory_state, _) = program.decode(); + let (instructions, init_memory_state, _) = program.decode(); let (_, trace, _, program_io) = program.trace(&serialized_input, &[], &[]); assert!( @@ -263,13 +271,18 @@ fn prove_example_with_trace( "Trace is longer than expected" ); + use jolt_core::zkvm::program::ProgramPreprocessing; + let program_data = Arc::new(ProgramPreprocessing::preprocess( + instructions, + init_memory_state, + )); let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), + program_data.meta(), program_io.memory_layout.clone(), - init_memory_state, trace.len().next_power_of_two(), ); - let preprocessing = JoltProverPreprocessing::new(shared_preprocessing); + let preprocessing = + JoltProverPreprocessing::new(shared_preprocessing, Arc::clone(&program_data)); let elf_contents_opt = program.get_elf_contents(); let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); diff --git a/jolt-core/src/field/ark.rs b/jolt-core/src/field/ark.rs index c4bb4066a2..70bf7ca691 100644 --- a/jolt-core/src/field/ark.rs +++ b/jolt-core/src/field/ark.rs @@ -1,6 +1,7 @@ use super::{FieldOps, JoltField, MulU64WithCarry}; #[cfg(feature = "challenge-254-bit")] use crate::field::challenge::Mont254BitChallenge; +#[cfg(not(feature = "challenge-254-bit"))] use crate::field::challenge::MontU128Challenge; use crate::field::MulTrunc; use crate::utils::thread::unsafe_allocate_zero_vec; diff --git a/jolt-core/src/field/tracked_ark.rs b/jolt-core/src/field/tracked_ark.rs index e52a288e1f..f634513707 100644 --- a/jolt-core/src/field/tracked_ark.rs +++ b/jolt-core/src/field/tracked_ark.rs @@ -1,6 +1,7 @@ use super::{FieldOps, JoltField}; #[cfg(feature = "challenge-254-bit")] use crate::field::challenge::Mont254BitChallenge; +#[cfg(not(feature = "challenge-254-bit"))] use crate::field::challenge::MontU128Challenge; use crate::utils::counters::{ @@ -462,12 +463,13 @@ impl TrackedFr { #[cfg(test)] mod tests { #![allow(clippy::op_ref)] + use std::ops::MulAssign; + use crate::field::tracked_ark::TrackedFr as Fr; use crate::field::{JoltField, OptimizedMul}; use crate::utils::counters::{ get_inverse_count, get_mult_count, reset_inverse_count, reset_mult_count, }; - use std::ops::MulAssign; #[test] fn test_if_trackers_are_working() { diff --git a/jolt-core/src/guest/prover.rs b/jolt-core/src/guest/prover.rs index a20023fed7..1e0fde75a3 100644 --- a/jolt-core/src/guest/prover.rs +++ b/jolt-core/src/guest/prover.rs @@ -4,10 +4,14 @@ use crate::poly::commitment::commitment_scheme::CommitmentScheme; use crate::poly::commitment::commitment_scheme::StreamingCommitmentScheme; use crate::poly::commitment::dory::DoryCommitmentScheme; use crate::transcripts::Transcript; +use crate::zkvm::program::ProgramPreprocessing; use crate::zkvm::proof_serialization::JoltProof; +use crate::zkvm::prover::JoltCpuProver; use crate::zkvm::prover::JoltProverPreprocessing; +use crate::zkvm::verifier::JoltSharedPreprocessing; use crate::zkvm::ProverDebugInfo; use common::jolt_device::MemoryLayout; +use std::sync::Arc; use tracer::JoltDevice; #[allow(clippy::type_complexity)] @@ -16,16 +20,15 @@ pub fn preprocess( guest: &Program, max_trace_length: usize, ) -> JoltProverPreprocessing { - use crate::zkvm::verifier::JoltSharedPreprocessing; - - let (bytecode, memory_init, program_size) = guest.decode(); + let (instructions, memory_init, program_size) = guest.decode(); let mut memory_config = guest.memory_config; memory_config.program_size = Some(program_size); let memory_layout = MemoryLayout::new(&memory_config); - let shared_preprocessing = - JoltSharedPreprocessing::new(bytecode, memory_layout, memory_init, max_trace_length); - JoltProverPreprocessing::new(shared_preprocessing) + + let program = Arc::new(ProgramPreprocessing::preprocess(instructions, memory_init)); + let shared = JoltSharedPreprocessing::new(program.meta(), memory_layout, max_trace_length); + JoltProverPreprocessing::new(shared, program) } #[allow(clippy::type_complexity, clippy::too_many_arguments)] @@ -44,8 +47,6 @@ pub fn prove, FS: Transc JoltDevice, Option>, ) { - use crate::zkvm::prover::JoltCpuProver; - let prover = JoltCpuProver::gen_from_elf( preprocessing, &guest.elf_contents, diff --git a/jolt-core/src/guest/verifier.rs b/jolt-core/src/guest/verifier.rs index 5c2a92904d..50b1867351 100644 --- a/jolt-core/src/guest/verifier.rs +++ b/jolt-core/src/guest/verifier.rs @@ -1,31 +1,33 @@ +use std::sync::Arc; + use crate::field::JoltField; +use crate::guest::program::Program; use crate::poly::commitment::commitment_scheme::CommitmentScheme; use crate::poly::commitment::commitment_scheme::StreamingCommitmentScheme; - -use crate::guest::program::Program; use crate::poly::commitment::dory::DoryCommitmentScheme; use crate::transcripts::Transcript; use crate::utils::errors::ProofVerifyError; +use crate::zkvm::program::ProgramPreprocessing; use crate::zkvm::proof_serialization::JoltProof; use crate::zkvm::verifier::JoltSharedPreprocessing; use crate::zkvm::verifier::JoltVerifier; use crate::zkvm::verifier::JoltVerifierPreprocessing; -use common::jolt_device::MemoryConfig; -use common::jolt_device::MemoryLayout; +use common::jolt_device::{JoltDevice, MemoryConfig, MemoryLayout}; pub fn preprocess( guest: &Program, max_trace_length: usize, verifier_setup: ::VerifierSetup, ) -> JoltVerifierPreprocessing { - let (bytecode, memory_init, program_size) = guest.decode(); + let (instructions, memory_init, program_size) = guest.decode(); let mut memory_config = guest.memory_config; memory_config.program_size = Some(program_size); let memory_layout = MemoryLayout::new(&memory_config); - let shared = - JoltSharedPreprocessing::new(bytecode, memory_layout, memory_init, max_trace_length); - JoltVerifierPreprocessing::new(shared, verifier_setup) + + let program = Arc::new(ProgramPreprocessing::preprocess(instructions, memory_init)); + let shared = JoltSharedPreprocessing::new(program.meta(), memory_layout, max_trace_length); + JoltVerifierPreprocessing::new_full(shared, verifier_setup, program) } pub fn verify, FS: Transcript>( @@ -35,7 +37,6 @@ pub fn verify, FS: Trans proof: JoltProof, preprocessing: &JoltVerifierPreprocessing, ) -> Result<(), ProofVerifyError> { - use common::jolt_device::JoltDevice; let memory_layout = &preprocessing.shared.memory_layout; let memory_config = MemoryConfig { max_untrusted_advice_size: memory_layout.max_untrusted_advice_size, diff --git a/jolt-core/src/poly/commitment/commitment_scheme.rs b/jolt-core/src/poly/commitment/commitment_scheme.rs index 6debe3b519..7e1a2faa43 100644 --- a/jolt-core/src/poly/commitment/commitment_scheme.rs +++ b/jolt-core/src/poly/commitment/commitment_scheme.rs @@ -27,7 +27,13 @@ pub trait CommitmentScheme: Clone + Sync + Send + 'static { /// A hint that helps the prover compute an opening proof. Typically some byproduct of /// the commitment computation, e.g. for Dory the Pedersen commitments to the rows can be /// used as a hint for the opening proof. - type OpeningProofHint: Sync + Send + Clone + Debug + PartialEq; + type OpeningProofHint: Sync + + Send + + Clone + + Debug + + PartialEq + + CanonicalSerialize + + CanonicalDeserialize; /// Generates the prover setup for this PCS. `max_num_vars` is the maximum number of /// variables of any polynomial that will be committed using this setup. diff --git a/jolt-core/src/poly/commitment/dory/commitment_scheme.rs b/jolt-core/src/poly/commitment/dory/commitment_scheme.rs index d9b890c6ef..9e1efaccd6 100644 --- a/jolt-core/src/poly/commitment/dory/commitment_scheme.rs +++ b/jolt-core/src/poly/commitment/dory/commitment_scheme.rs @@ -197,31 +197,41 @@ impl CommitmentScheme for DoryCommitmentScheme { coeffs: &[Self::Field], ) -> Self::OpeningProofHint { let num_rows = DoryGlobals::get_max_num_rows(); - let mut rlc_hint = vec![ArkG1(G1Projective::zero()); num_rows]; - for (coeff, mut hint) in coeffs.iter().zip(hints.into_iter()) { - hint.resize(num_rows, ArkG1(G1Projective::zero())); - let row_commitments: &mut [G1Projective] = unsafe { - std::slice::from_raw_parts_mut(hint.as_mut_ptr() as *mut G1Projective, hint.len()) - }; + // SAFETY: ArkG1 is repr(transparent) over G1Projective. + let rlc_row_commitments: &mut [G1Projective] = unsafe { + std::slice::from_raw_parts_mut( + rlc_hint.as_mut_ptr() as *mut G1Projective, + rlc_hint.len(), + ) + }; + + // Combine each hint into the accumulator without forcing all hints to length `num_rows`. + // This avoids O(num_rows) work per polynomial when the hint is much smaller (e.g. program image). + for (coeff, hint) in coeffs.iter().zip(hints.into_iter()) { + if coeff.is_zero() { + continue; + } + let len = hint.len().min(num_rows); + if len == 0 { + continue; + } - let rlc_row_commitments: &[G1Projective] = unsafe { - std::slice::from_raw_parts(rlc_hint.as_ptr() as *const G1Projective, rlc_hint.len()) + // SAFETY: ArkG1 is repr(transparent) over G1Projective. + let hint_rows: &[G1Projective] = unsafe { + std::slice::from_raw_parts(hint.as_ptr() as *const G1Projective, hint.len()) }; - let _span = trace_span!("vector_scalar_mul_add_gamma_g1_online"); + let _span = trace_span!("vector_add_scalar_mul_g1_online"); let _enter = _span.enter(); - // Scales the row commitments for the current polynomial by - // its coefficient - jolt_optimizations::vector_scalar_mul_add_gamma_g1_online( - row_commitments, + // Accumulate: rlc[i] += coeff * hint[i] for i in [0..len) + jolt_optimizations::vector_add_scalar_mul_g1_online( + &mut rlc_row_commitments[..len], + &hint_rows[..len], *coeff, - rlc_row_commitments, ); - - let _ = std::mem::replace(&mut rlc_hint, hint); } rlc_hint diff --git a/jolt-core/src/poly/commitment/dory/dory_globals.rs b/jolt-core/src/poly/commitment/dory/dory_globals.rs index c4c2ebe421..8edef2e567 100644 --- a/jolt-core/src/poly/commitment/dory/dory_globals.rs +++ b/jolt-core/src/poly/commitment/dory/dory_globals.rs @@ -151,7 +151,17 @@ static mut UNTRUSTED_ADVICE_T: OnceLock = OnceLock::new(); static mut UNTRUSTED_ADVICE_MAX_NUM_ROWS: OnceLock = OnceLock::new(); static mut UNTRUSTED_ADVICE_NUM_COLUMNS: OnceLock = OnceLock::new(); -// Context tracking: 0=Main, 1=TrustedAdvice, 2=UntrustedAdvice +// Bytecode globals +static mut BYTECODE_T: OnceLock = OnceLock::new(); +static mut BYTECODE_MAX_NUM_ROWS: OnceLock = OnceLock::new(); +static mut BYTECODE_NUM_COLUMNS: OnceLock = OnceLock::new(); + +// Program image globals (committed initial RAM image) +static mut PROGRAM_IMAGE_T: OnceLock = OnceLock::new(); +static mut PROGRAM_IMAGE_MAX_NUM_ROWS: OnceLock = OnceLock::new(); +static mut PROGRAM_IMAGE_NUM_COLUMNS: OnceLock = OnceLock::new(); + +// Context tracking: 0=Main, 1=TrustedAdvice, 2=UntrustedAdvice, 3=Bytecode, 4=ProgramImage static CURRENT_CONTEXT: AtomicU8 = AtomicU8::new(0); // Layout tracking: 0=CycleMajor, 1=AddressMajor @@ -163,6 +173,8 @@ pub enum DoryContext { Main = 0, TrustedAdvice = 1, UntrustedAdvice = 2, + Bytecode = 3, + ProgramImage = 4, } impl From for DoryContext { @@ -171,6 +183,8 @@ impl From for DoryContext { 0 => DoryContext::Main, 1 => DoryContext::TrustedAdvice, 2 => DoryContext::UntrustedAdvice, + 3 => DoryContext::Bytecode, + 4 => DoryContext::ProgramImage, _ => panic!("Invalid DoryContext value: {value}"), } } @@ -190,6 +204,260 @@ impl Drop for DoryContextGuard { pub struct DoryGlobals; impl DoryGlobals { + /// Initialize Bytecode context so its `num_columns` matches Main's `sigma_main`. + /// + /// This is required for committed-bytecode Stage 8 folding when `sigma_main > sigma_bytecode`: + /// we commit bytecode chunk polynomials using the Main matrix width (more columns, fewer rows), + /// so they embed as a top block of rows in the Main matrix when extra cycle variables are fixed to 0. + pub fn initialize_bytecode_context_for_main_sigma( + k_chunk: usize, + bytecode_len: usize, + log_k_chunk: usize, + log_t: usize, + ) -> Option<()> { + let (sigma_main, _) = Self::main_sigma_nu(log_k_chunk, log_t); + let num_columns = 1usize << sigma_main; + let total_size = k_chunk * bytecode_len; + + assert!( + total_size % num_columns == 0, + "bytecode matrix width {num_columns} must divide total_size {total_size}" + ); + let num_rows = total_size / num_columns; + + // If already initialized, ensure it matches (avoid silently ignoring OnceCell::set failures). + #[allow(static_mut_refs)] + unsafe { + if let (Some(existing_cols), Some(existing_rows), Some(existing_t)) = ( + BYTECODE_NUM_COLUMNS.get(), + BYTECODE_MAX_NUM_ROWS.get(), + BYTECODE_T.get(), + ) { + assert_eq!(*existing_cols, num_columns); + assert_eq!(*existing_rows, num_rows); + assert_eq!(*existing_t, bytecode_len); + return Some(()); + } + } + + Self::set_num_columns_for_context(num_columns, DoryContext::Bytecode); + Self::set_T_for_context(bytecode_len, DoryContext::Bytecode); + Self::set_max_num_rows_for_context(num_rows, DoryContext::Bytecode); + Some(()) + } + + /// Initialize Bytecode context with MAIN-matrix dimensions for CycleMajor Stage 8 embedding. + /// + /// This is used when committing bytecode for CycleMajor layout with T > bytecode_len. + /// The bytecode polynomial is padded to `k_chunk * max_trace_len` coefficients so that + /// its row-commitment hints match the main matrix structure exactly. + /// + /// **Key difference from `initialize_bytecode_context_for_main_sigma`:** + /// - Uses `max_trace_len` (main T) for total size, not `bytecode_len` + /// - This ensures bytecode row indices match main row indices for CycleMajor + pub fn initialize_bytecode_context_with_main_dimensions( + k_chunk: usize, + max_trace_len: usize, + log_k_chunk: usize, + ) -> Option<()> { + let log_t = max_trace_len.log_2(); + let (sigma_main, _) = Self::main_sigma_nu(log_k_chunk, log_t); + let num_columns = 1usize << sigma_main; + let total_size = k_chunk * max_trace_len; + + assert!( + total_size % num_columns == 0, + "bytecode matrix width {num_columns} must divide total_size {total_size}" + ); + let num_rows = total_size / num_columns; + + // If already initialized, ensure it matches (avoid silently ignoring OnceCell::set failures). + #[allow(static_mut_refs)] + unsafe { + if let (Some(existing_cols), Some(existing_rows), Some(existing_t)) = ( + BYTECODE_NUM_COLUMNS.get(), + BYTECODE_MAX_NUM_ROWS.get(), + BYTECODE_T.get(), + ) { + assert_eq!(*existing_cols, num_columns); + assert_eq!(*existing_rows, num_rows); + assert_eq!(*existing_t, max_trace_len); + return Some(()); + } + } + + Self::set_num_columns_for_context(num_columns, DoryContext::Bytecode); + Self::set_T_for_context(max_trace_len, DoryContext::Bytecode); + Self::set_max_num_rows_for_context(num_rows, DoryContext::Bytecode); + Some(()) + } + + /// Initialize ProgramImage context so its `num_columns` matches Main's `sigma_main`. + /// + /// This is used so that tier-1 row-commitment hints can be combined into the Main-context + /// batch opening hint in Stage 8 (mirrors the committed-bytecode strategy). + pub fn initialize_program_image_context_for_main_sigma( + padded_len_words: usize, + max_log_k_chunk: usize, + max_log_t_any: usize, + ) -> Option<()> { + let (sigma_main, _) = Self::main_sigma_nu(max_log_k_chunk, max_log_t_any); + let num_columns = 1usize << sigma_main; + let k_chunk = 1usize << max_log_k_chunk; + + if num_columns <= padded_len_words { + assert!( + padded_len_words % num_columns == 0, + "program-image matrix width {num_columns} must divide padded_len_words {padded_len_words}" + ); + // Match the Main-context K so AddressMajor trace-dense embedding (stride-by-K columns) + // uses the correct `cycles_per_row`. + let total_size = k_chunk * padded_len_words; + debug_assert!( + total_size.is_power_of_two(), + "expected K*T to be power-of-two" + ); + let num_rows = total_size / num_columns; + + // If already initialized, ensure it matches (avoid silently ignoring OnceCell::set failures). + #[allow(static_mut_refs)] + unsafe { + if let (Some(existing_cols), Some(existing_rows), Some(existing_t)) = ( + PROGRAM_IMAGE_NUM_COLUMNS.get(), + PROGRAM_IMAGE_MAX_NUM_ROWS.get(), + PROGRAM_IMAGE_T.get(), + ) { + assert_eq!(*existing_cols, num_columns); + assert_eq!(*existing_rows, num_rows); + assert_eq!(*existing_t, padded_len_words); + return Some(()); + } + } + + Self::set_num_columns_for_context(num_columns, DoryContext::ProgramImage); + Self::set_T_for_context(padded_len_words, DoryContext::ProgramImage); + Self::set_max_num_rows_for_context(num_rows, DoryContext::ProgramImage); + } else { + // Fallback: balanced dimensions for the program image itself. + Self::initialize_context(1, padded_len_words, DoryContext::ProgramImage, None); + } + Some(()) + } + + /// Initialize the **ProgramImage** context using an explicit `num_columns` (i.e. fixed sigma) + /// and an explicit `k_chunk` (Main's lane/address chunk size). + /// + /// This is used so program-image tier-1 row-commitment hints can be combined into the + /// Main-context batch opening hint in Stage 8. + /// + /// **Important**: We intentionally size the ProgramImage context so that + /// `k_from_matrix_shape() == k_chunk`. This makes the AddressMajor "trace-dense" embedding + /// (which occupies evenly-spaced columns with stride K) consistent between ProgramImage and + /// Main contexts. + /// + /// Requirements: + /// - `k_chunk` must be a power of two + /// - `num_columns` must be a power of two + /// - `padded_len_words` must be a power of two + /// - `k_chunk * padded_len_words >= num_columns` (so `num_rows >= 1`) + pub fn initialize_program_image_context_with_num_columns( + k_chunk: usize, + padded_len_words: usize, + num_columns: usize, + ) -> Option<()> { + assert!(padded_len_words.is_power_of_two()); + assert!(padded_len_words > 0); + assert!(k_chunk.is_power_of_two()); + assert!(k_chunk > 0); + assert!(num_columns.is_power_of_two()); + let total_size = k_chunk * padded_len_words; + assert!( + total_size >= num_columns, + "program-image K*T ({total_size}) must be >= num_columns ({num_columns})" + ); + assert!( + total_size % num_columns == 0, + "program-image K*T ({total_size}) must be divisible by num_columns ({num_columns})" + ); + let num_rows = total_size / num_columns; + + // If already initialized, ensure it matches (avoid silently ignoring OnceCell::set failures). + #[allow(static_mut_refs)] + unsafe { + if let (Some(existing_cols), Some(existing_rows), Some(existing_t)) = ( + PROGRAM_IMAGE_NUM_COLUMNS.get(), + PROGRAM_IMAGE_MAX_NUM_ROWS.get(), + PROGRAM_IMAGE_T.get(), + ) { + assert_eq!(*existing_cols, num_columns); + assert_eq!(*existing_rows, num_rows); + assert_eq!(*existing_t, padded_len_words); + return Some(()); + } + } + + Self::set_num_columns_for_context(num_columns, DoryContext::ProgramImage); + Self::set_T_for_context(padded_len_words, DoryContext::ProgramImage); + Self::set_max_num_rows_for_context(num_rows, DoryContext::ProgramImage); + Some(()) + } + + /// Initialize the **Main** context using an explicit `num_columns` (i.e. fixed sigma). + /// + /// This is used in `ProgramMode::Committed` so that the Main context uses the same column + /// dimension as trusted bytecode commitments, which were derived under a sigma computed from a + /// "max trace length" bound (to support batching/folding). + /// + /// # Safety / correctness notes + /// - Requires `num_columns` to be a power of two. + /// - Requires `(K * T) % num_columns == 0` so `num_rows` is integral. + /// - If the Main context was already initialized, this asserts the dimensions match to avoid + /// silently ignoring OnceLock::set failures. + pub fn initialize_main_context_with_num_columns( + K: usize, + T: usize, + num_columns: usize, + layout: Option, + ) -> Option<()> { + assert!( + num_columns.is_power_of_two(), + "num_columns must be a power of two" + ); + let total_size = K * T; + assert!( + total_size % num_columns == 0, + "main matrix width {num_columns} must divide total_size {total_size}" + ); + let num_rows = total_size / num_columns; + + // If already initialized, ensure it matches (avoid silently ignoring OnceCell::set failures). + #[allow(static_mut_refs)] + unsafe { + if let (Some(existing_cols), Some(existing_rows), Some(existing_t)) = + (NUM_COLUMNS.get(), MAX_NUM_ROWS.get(), GLOBAL_T.get()) + { + assert_eq!(*existing_cols, num_columns); + assert_eq!(*existing_rows, num_rows); + assert_eq!(*existing_t, T); + if let Some(l) = layout { + CURRENT_LAYOUT.store(l as u8, Ordering::SeqCst); + } + CURRENT_CONTEXT.store(DoryContext::Main as u8, Ordering::SeqCst); + return Some(()); + } + } + + Self::set_num_columns_for_context(num_columns, DoryContext::Main); + Self::set_T_for_context(T, DoryContext::Main); + Self::set_max_num_rows_for_context(num_rows, DoryContext::Main); + + if let Some(l) = layout { + CURRENT_LAYOUT.store(l as u8, Ordering::SeqCst); + } + CURRENT_CONTEXT.store(DoryContext::Main as u8, Ordering::SeqCst); + Some(()) + } + /// Split `total_vars` into a *balanced* pair `(sigma, nu)` where: /// - **sigma** is the number of **column** variables /// - **nu** is the number of **row** variables @@ -209,6 +477,20 @@ impl DoryGlobals { Self::balanced_sigma_nu(log_k_chunk + log_t) } + /// Returns the (sigma, nu) for the **initialized** Main context, if available. + /// + /// This is useful in committed mode where the Main context may be initialized with + /// an explicit `num_columns` override, making `(sigma, nu)` differ from the balanced + /// split implied by `log_k_chunk + log_t`. + pub fn try_get_main_sigma_nu() -> Option<(usize, usize)> { + #[allow(static_mut_refs)] + unsafe { + let num_columns = NUM_COLUMNS.get()?; + let num_rows = MAX_NUM_ROWS.get()?; + Some((num_columns.log_2(), num_rows.log_2())) + } + } + /// Computes balanced `(sigma, nu)` dimensions directly from a max advice byte budget. /// /// - `max_advice_size_bytes` is interpreted as bytes of 64-bit words. @@ -251,7 +533,6 @@ impl DoryGlobals { /// Set the Dory matrix layout directly (test-only). /// /// In production code, prefer passing the layout to `initialize_context` instead. - #[cfg(test)] pub fn set_layout(layout: DoryLayout) { CURRENT_LAYOUT.store(layout as u8, Ordering::SeqCst); } @@ -305,6 +586,12 @@ impl DoryGlobals { DoryContext::UntrustedAdvice => { let _ = UNTRUSTED_ADVICE_MAX_NUM_ROWS.set(max_num_rows); } + DoryContext::Bytecode => { + let _ = BYTECODE_MAX_NUM_ROWS.set(max_num_rows); + } + DoryContext::ProgramImage => { + let _ = PROGRAM_IMAGE_MAX_NUM_ROWS.set(max_num_rows); + } } } } @@ -321,6 +608,12 @@ impl DoryGlobals { DoryContext::UntrustedAdvice => *UNTRUSTED_ADVICE_MAX_NUM_ROWS .get() .expect("untrusted_advice max_num_rows not initialized"), + DoryContext::Bytecode => *BYTECODE_MAX_NUM_ROWS + .get() + .expect("bytecode max_num_rows not initialized"), + DoryContext::ProgramImage => *PROGRAM_IMAGE_MAX_NUM_ROWS + .get() + .expect("program_image max_num_rows not initialized"), } } } @@ -338,6 +631,12 @@ impl DoryGlobals { DoryContext::UntrustedAdvice => { let _ = UNTRUSTED_ADVICE_NUM_COLUMNS.set(num_columns); } + DoryContext::Bytecode => { + let _ = BYTECODE_NUM_COLUMNS.set(num_columns); + } + DoryContext::ProgramImage => { + let _ = PROGRAM_IMAGE_NUM_COLUMNS.set(num_columns); + } } } } @@ -354,6 +653,12 @@ impl DoryGlobals { DoryContext::UntrustedAdvice => *UNTRUSTED_ADVICE_NUM_COLUMNS .get() .expect("untrusted_advice num_columns not initialized"), + DoryContext::Bytecode => *BYTECODE_NUM_COLUMNS + .get() + .expect("bytecode num_columns not initialized"), + DoryContext::ProgramImage => *PROGRAM_IMAGE_NUM_COLUMNS + .get() + .expect("program_image num_columns not initialized"), } } } @@ -371,6 +676,12 @@ impl DoryGlobals { DoryContext::UntrustedAdvice => { let _ = UNTRUSTED_ADVICE_T.set(t); } + DoryContext::Bytecode => { + let _ = BYTECODE_T.set(t); + } + DoryContext::ProgramImage => { + let _ = PROGRAM_IMAGE_T.set(t); + } } } } @@ -387,6 +698,10 @@ impl DoryGlobals { DoryContext::UntrustedAdvice => *UNTRUSTED_ADVICE_T .get() .expect("untrusted_advice t not initialized"), + DoryContext::Bytecode => *BYTECODE_T.get().expect("bytecode t not initialized"), + DoryContext::ProgramImage => *PROGRAM_IMAGE_T + .get() + .expect("program_image t not initialized"), } } } @@ -414,7 +729,7 @@ impl DoryGlobals { /// # Arguments /// * `K` - Maximum address space size (K in OneHot polynomials) /// * `T` - Maximum trace length (cycle count) - /// * `context` - The Dory context to initialize (Main, TrustedAdvice, or UntrustedAdvice) + /// * `context` - The Dory context to initialize (Main, TrustedAdvice, UntrustedAdvice, Bytecode, ProgramImage) /// * `layout` - Optional layout for the Dory matrix. Only applies to Main context. /// If `Some(layout)`, sets the layout. If `None`, leaves the existing layout /// unchanged (defaults to `CycleMajor` after `reset()`). Ignored for advice contexts. @@ -466,6 +781,16 @@ impl DoryGlobals { let _ = UNTRUSTED_ADVICE_T.take(); let _ = UNTRUSTED_ADVICE_MAX_NUM_ROWS.take(); let _ = UNTRUSTED_ADVICE_NUM_COLUMNS.take(); + + // Reset bytecode globals + let _ = BYTECODE_T.take(); + let _ = BYTECODE_MAX_NUM_ROWS.take(); + let _ = BYTECODE_NUM_COLUMNS.take(); + + // Reset program image globals + let _ = PROGRAM_IMAGE_T.take(); + let _ = PROGRAM_IMAGE_MAX_NUM_ROWS.take(); + let _ = PROGRAM_IMAGE_NUM_COLUMNS.take(); } // Reset context to Main diff --git a/jolt-core/src/poly/commitment/dory/wrappers.rs b/jolt-core/src/poly/commitment/dory/wrappers.rs index 431387d7c2..ba784da898 100644 --- a/jolt-core/src/poly/commitment/dory/wrappers.rs +++ b/jolt-core/src/poly/commitment/dory/wrappers.rs @@ -8,10 +8,11 @@ use crate::{ multilinear_polynomial::{MultilinearPolynomial, PolynomialEvaluation}, }, transcripts::{AppendToTranscript, Transcript}, + utils::small_scalar::SmallScalar, }; use ark_bn254::Fr; use ark_ec::CurveGroup; -use ark_ff::Zero; +use ark_ff::{One, Zero}; use dory::{ error::DoryError, primitives::{ @@ -108,9 +109,6 @@ impl DoryPolynomial for MultilinearPolynomial { impl MultilinearLagrange for MultilinearPolynomial { fn vector_matrix_product(&self, left_vec: &[ArkFr], nu: usize, sigma: usize) -> Vec { - use crate::utils::small_scalar::SmallScalar; - use ark_ff::One; - let num_cols = 1usize << sigma; let num_rows = 1usize << nu; @@ -227,28 +225,53 @@ where let dory_layout = DoryGlobals::get_layout(); // Dense polynomials (all scalar variants except OneHot/RLC) are committed row-wise. - // Under AddressMajor, dense coefficients occupy evenly-spaced columns, so each row - // commitment uses `cycles_per_row` bases (one per occupied column). - let (dense_affine_bases, dense_chunk_size): (Vec<_>, usize) = match (dory_context, dory_layout) - { - (DoryContext::Main, DoryLayout::AddressMajor) => { - let cycles_per_row = DoryGlobals::address_major_cycles_per_row(); - let bases: Vec<_> = g1_slice - .par_iter() - .take(row_len) - .step_by(row_len / cycles_per_row) - .map(|g| g.0.into_affine()) - .collect(); - (bases, cycles_per_row) - } - _ => ( + // + // In `Main` + `AddressMajor`, we have two *representations* in this repo: + // - **Trace-dense**: length == T (e.g., `RdInc`, `RamInc`). These are embedded into the + // main matrix by occupying evenly-spaced columns, so each row commitment uses + // `cycles_per_row` bases (one per occupied column). + // - **Matrix-dense**: length == K*T (e.g., bytecode chunk polynomials). These occupy the + // full matrix and must use the full `row_len` bases. + let is_trace_dense = match poly { + MultilinearPolynomial::LargeScalars(p) => p.Z.len() == DoryGlobals::get_T(), + MultilinearPolynomial::BoolScalars(p) => p.coeffs.len() == DoryGlobals::get_T(), + MultilinearPolynomial::U8Scalars(p) => p.coeffs.len() == DoryGlobals::get_T(), + MultilinearPolynomial::U16Scalars(p) => p.coeffs.len() == DoryGlobals::get_T(), + MultilinearPolynomial::U32Scalars(p) => p.coeffs.len() == DoryGlobals::get_T(), + MultilinearPolynomial::U64Scalars(p) => p.coeffs.len() == DoryGlobals::get_T(), + MultilinearPolynomial::U128Scalars(p) => p.coeffs.len() == DoryGlobals::get_T(), + MultilinearPolynomial::I64Scalars(p) => p.coeffs.len() == DoryGlobals::get_T(), + MultilinearPolynomial::I128Scalars(p) => p.coeffs.len() == DoryGlobals::get_T(), + MultilinearPolynomial::S128Scalars(p) => p.coeffs.len() == DoryGlobals::get_T(), + MultilinearPolynomial::OneHot(_) | MultilinearPolynomial::RLC(_) => false, + }; + + // Treat ProgramImage like Main here when its context is sized to match Main's K. + // This enables AddressMajor "trace-dense" embedding (stride-by-K columns) for the + // committed program-image polynomial. + let is_trace_dense_addr_major = + matches!(dory_context, DoryContext::Main | DoryContext::ProgramImage) + && dory_layout == DoryLayout::AddressMajor + && is_trace_dense; + + let (dense_affine_bases, dense_chunk_size): (Vec<_>, usize) = if is_trace_dense_addr_major { + let cycles_per_row = DoryGlobals::address_major_cycles_per_row(); + let bases: Vec<_> = g1_slice + .par_iter() + .take(row_len) + .step_by(row_len / cycles_per_row) + .map(|g| g.0.into_affine()) + .collect(); + (bases, cycles_per_row) + } else { + ( g1_slice .par_iter() .take(row_len) .map(|g| g.0.into_affine()) .collect(), row_len, - ), + ) }; let result: Vec = match poly { diff --git a/jolt-core/src/poly/opening_proof.rs b/jolt-core/src/poly/opening_proof.rs index 78ad812de6..c5d89b49c5 100644 --- a/jolt-core/src/poly/opening_proof.rs +++ b/jolt-core/src/poly/opening_proof.rs @@ -152,12 +152,19 @@ pub enum SumcheckId { RegistersClaimReduction, RegistersReadWriteChecking, RegistersValEvaluation, + BytecodeReadRafAddressPhase, BytecodeReadRaf, + BooleanityAddressPhase, Booleanity, AdviceClaimReductionCyclePhase, AdviceClaimReduction, + BytecodeClaimReductionCyclePhase, + BytecodeClaimReduction, IncClaimReduction, HammingWeightClaimReduction, + /// Claim reduction binding the staged program-image (initial RAM) scalar contribution(s) + /// to the committed `CommittedPolynomial::ProgramImageInit` polynomial. + ProgramImageClaimReduction, } #[derive(Hash, PartialEq, Eq, Copy, Clone, Debug, PartialOrd, Ord, Allocative)] @@ -255,6 +262,7 @@ impl DoryOpeningState { rlc_streaming_data: Arc, mut opening_hints: HashMap, advice_polys: HashMap>, + bytecode_T: usize, ) -> (MultilinearPolynomial, PCS::OpeningProofHint) { // Accumulate gamma coefficients per polynomial let mut rlc_map = BTreeMap::new(); @@ -272,6 +280,7 @@ impl DoryOpeningState { poly_ids.clone(), &coeffs, advice_polys, + bytecode_T, )); let hints: Vec = rlc_map diff --git a/jolt-core/src/poly/rlc_polynomial.rs b/jolt-core/src/poly/rlc_polynomial.rs index 47a68c231e..a6de9ea83d 100644 --- a/jolt-core/src/poly/rlc_polynomial.rs +++ b/jolt-core/src/poly/rlc_polynomial.rs @@ -4,10 +4,12 @@ use crate::poly::multilinear_polynomial::MultilinearPolynomial; use crate::utils::accumulation::Acc6S; use crate::utils::math::{s64_from_diff_u64s, Math}; use crate::utils::thread::unsafe_allocate_zero_vec; +use crate::zkvm::bytecode::chunks::{for_each_active_lane_value, total_lanes, ActiveLaneValue}; use crate::zkvm::config::OneHotParams; use crate::zkvm::instruction::LookupQuery; +use crate::zkvm::program::ProgramPreprocessing; use crate::zkvm::ram::remap_address; -use crate::zkvm::{bytecode::BytecodePreprocessing, witness::CommittedPolynomial}; +use crate::zkvm::witness::CommittedPolynomial; use allocative::Allocative; use common::constants::XLEN; use common::jolt_device::MemoryLayout; @@ -20,10 +22,136 @@ use tracer::{instruction::Cycle, LazyTraceIterator}; #[derive(Clone, Debug)] pub struct RLCStreamingData { - pub bytecode: Arc, + pub program: Arc, pub memory_layout: MemoryLayout, } +/// Computes the bytecode chunk polynomial contribution to a vector-matrix product. +/// +/// This is a standalone version of the bytecode VMP computation that can be used +/// by external callers (e.g., GPU prover) without needing a full `StreamingRLCContext`. +/// +/// # Arguments +/// * `result` - Output buffer to accumulate contributions into +/// * `left_vec` - Left vector for the vector-matrix product (length >= num_rows) +/// * `num_columns` - Number of columns in the Dory matrix +/// * `bytecode_polys` - List of (chunk_index, coefficient) pairs for the RLC +/// * `program` - Program preprocessing data +/// * `one_hot_params` - One-hot parameters (contains k_chunk) +/// * `bytecode_T` - The T value used for bytecode coefficient indexing (from TrustedProgramCommitments) +pub fn compute_bytecode_vmp_contribution( + result: &mut [F], + left_vec: &[F], + num_columns: usize, + bytecode_polys: &[(usize, F)], + program: &ProgramPreprocessing, + one_hot_params: &OneHotParams, + bytecode_T: usize, +) { + if bytecode_polys.is_empty() { + return; + } + + let layout = DoryGlobals::get_layout(); + let k_chunk = one_hot_params.k_chunk; + let bytecode_len = program.bytecode_len(); + let bytecode_cols = num_columns; + let total = total_lanes(); + let num_chunks = total.div_ceil(k_chunk); + debug_assert!( + bytecode_cols.is_power_of_two(), + "Dory num_columns must be power-of-two (got {bytecode_cols})" + ); + let col_shift = bytecode_cols.trailing_zeros(); + let col_mask = bytecode_cols - 1; + + // Use the passed bytecode_T for coefficient indexing. + // This is the T value used when the bytecode was committed: + // - CycleMajor: max_trace_len (main-matrix dimensions) + // - AddressMajor: bytecode_len (bytecode dimensions) + let index_T = bytecode_T; + + debug_assert!( + k_chunk * bytecode_len >= bytecode_cols, + "bytecode_len*k_chunk must cover at least one full row: (k_chunk*bytecode_len)={} < num_columns={}", + k_chunk * bytecode_len, + bytecode_cols + ); + + // Build a dense coefficient table per chunk so we can invert the loops: + // iterate cycles once and only touch lanes that are nonzero for that instruction. + let mut coeff_by_chunk: Vec = unsafe_allocate_zero_vec(num_chunks); + let mut any_nonzero = false; + for (chunk_idx, coeff) in bytecode_polys.iter() { + if *chunk_idx < num_chunks && !coeff.is_zero() { + coeff_by_chunk[*chunk_idx] += *coeff; + any_nonzero = true; + } + } + if !any_nonzero { + return; + } + + // Parallelize over cycles with thread-local accumulation. + let bytecode_contrib: Vec = program.instructions[..bytecode_len] + .par_iter() + .enumerate() + .fold( + || unsafe_allocate_zero_vec(bytecode_cols), + |mut acc, (cycle, instr)| { + for_each_active_lane_value::(instr, |global_lane, lane_val| { + let chunk_idx = global_lane / k_chunk; + if chunk_idx >= num_chunks { + return; + } + let coeff = coeff_by_chunk[chunk_idx]; + if coeff.is_zero() { + return; + } + let lane = global_lane % k_chunk; + + // Use layout-conditional indexing. + let global_index = match layout { + DoryLayout::CycleMajor => lane * index_T + cycle, + DoryLayout::AddressMajor => cycle * k_chunk + lane, + }; + let row_index = global_index >> col_shift; + if row_index >= left_vec.len() { + return; + } + let left = left_vec[row_index]; + if left.is_zero() { + return; + } + let col_index = global_index & col_mask; + + let base = left * coeff; + match lane_val { + ActiveLaneValue::One => { + acc[col_index] += base; + } + ActiveLaneValue::Scalar(v) => { + acc[col_index] += base * v; + } + } + }); + acc + }, + ) + .reduce( + || unsafe_allocate_zero_vec(bytecode_cols), + |mut a, b| { + a.iter_mut().zip(b.iter()).for_each(|(x, y)| *x += *y); + a + }, + ); + + result + .par_iter_mut() + .zip(bytecode_contrib.par_iter()) + .for_each(|(r, c)| *r += *c); +} + /// Source of trace data for streaming VMV computation. #[derive(Clone, Debug)] pub enum TraceSource { @@ -56,9 +184,17 @@ impl TraceSource { pub struct StreamingRLCContext { pub dense_polys: Vec<(CommittedPolynomial, F)>, pub onehot_polys: Vec<(CommittedPolynomial, F)>, - /// Advice polynomials with their RLC coefficients. + /// Bytecode chunk polynomials with their RLC coefficients. + pub bytecode_polys: Vec<(usize, F)>, + /// The T value used for bytecode coefficient indexing (from TrustedProgramCommitments). + /// For CycleMajor: max_trace_len (main-matrix dimensions). + /// For AddressMajor: bytecode_len (bytecode dimensions). + pub bytecode_T: usize, + /// Advice polynomials with their RLC coefficients and IDs. /// These are NOT streamed from trace - they're passed in directly. - pub advice_polys: Vec<(F, MultilinearPolynomial)>, + /// Format: (poly_id, coeff, polynomial) - ID is needed to determine + /// commitment dimensions (ProgramImageInit uses Main's sigma). + pub advice_polys: Vec<(CommittedPolynomial, F, MultilinearPolynomial)>, pub trace_source: TraceSource, pub preprocessing: Arc, pub one_hot_params: OneHotParams, @@ -166,6 +302,7 @@ impl RLCPolynomial { /// * `poly_ids` - List of polynomial identifiers /// * `coefficients` - RLC coefficients for each polynomial /// * `advice_poly_map` - Map of advice polynomial IDs to their actual polynomials + /// * `bytecode_T` - The T value used for bytecode coefficient indexing (from TrustedProgramCommitments) #[tracing::instrument(skip_all)] pub fn new_streaming( one_hot_params: OneHotParams, @@ -174,11 +311,13 @@ impl RLCPolynomial { poly_ids: Vec, coefficients: &[F], mut advice_poly_map: HashMap>, + bytecode_T: usize, ) -> Self { debug_assert_eq!(poly_ids.len(), coefficients.len()); let mut dense_polys = Vec::new(); let mut onehot_polys = Vec::new(); + let mut bytecode_polys = Vec::new(); let mut advice_polys = Vec::new(); for (poly_id, coeff) in poly_ids.iter().zip(coefficients.iter()) { @@ -191,10 +330,22 @@ impl RLCPolynomial { | CommittedPolynomial::RamRa(_) => { onehot_polys.push((*poly_id, *coeff)); } - CommittedPolynomial::TrustedAdvice | CommittedPolynomial::UntrustedAdvice => { - // Advice polynomials are passed in directly (not streamed from trace) + CommittedPolynomial::BytecodeChunk(_) => { + if let CommittedPolynomial::BytecodeChunk(idx) = poly_id { + bytecode_polys.push((*idx, *coeff)); + } + } + CommittedPolynomial::TrustedAdvice + | CommittedPolynomial::UntrustedAdvice + | CommittedPolynomial::ProgramImageInit => { + // "Extra" polynomials are passed in directly (not streamed from trace). + // Today this includes advice polynomials and (in committed mode) the program-image polynomial. if advice_poly_map.contains_key(poly_id) { - advice_polys.push((*coeff, advice_poly_map.remove(poly_id).unwrap())); + advice_polys.push(( + *poly_id, + *coeff, + advice_poly_map.remove(poly_id).unwrap(), + )); } } } @@ -206,6 +357,8 @@ impl RLCPolynomial { streaming_context: Some(Arc::new(StreamingRLCContext { dense_polys, onehot_polys, + bytecode_polys, + bytecode_T, advice_polys, trace_source, preprocessing, @@ -353,9 +506,130 @@ impl RLCPolynomial { // For each advice polynomial, compute its contribution to the result ctx.advice_polys .iter() - .filter(|(_, advice_poly)| advice_poly.original_len() > 0) - .for_each(|(coeff, advice_poly)| { + .filter(|(_, _, advice_poly)| advice_poly.original_len() > 0) + .for_each(|(poly_id, coeff, advice_poly)| { let advice_len = advice_poly.original_len(); + if *poly_id == CommittedPolynomial::ProgramImageInit { + // ProgramImageInit is embedded like a trace-dense polynomial (missing lane variables). + // In AddressMajor this occupies evenly-spaced columns (stride-by-K), not a contiguous block. + match DoryGlobals::get_layout() { + DoryLayout::CycleMajor => { + // Contiguous prefix block: full columns, limited rows. + debug_assert!( + advice_len % num_columns == 0, + "ProgramImageInit len ({advice_len}) must be divisible by num_columns ({num_columns})" + ); + // Avoid O(num_columns) work when the program image is much smaller than the + // main matrix width. We only need to visit the actual program-image words; + // the padded tail is identically zero. + // + // For CycleMajor, coefficient index maps as: + // idx = row * num_columns + col + let advice_cols = num_columns; + let max_nonzero_prefix = ctx.preprocessing.program.program_image_words.len(); + let len = max_nonzero_prefix.min(advice_len); + + // Fast path for u64-backed program image (Committed mode). + if let MultilinearPolynomial::U64Scalars(poly) = advice_poly { + for (idx, &word) in poly.coeffs[..len].iter().enumerate() { + if word == 0 { + continue; + } + let row_idx = idx / advice_cols; + if row_idx >= left_vec.len() { + continue; + } + let left = left_vec[row_idx]; + if left.is_zero() { + continue; + } + let col_idx = idx % advice_cols; + result[col_idx] += left * *coeff * F::from_u64(word); + } + } else { + // Fallback: generic coefficient access (should be rare). + for idx in 0..len { + let row_idx = idx / advice_cols; + if row_idx >= left_vec.len() { + continue; + } + let left = left_vec[row_idx]; + if left.is_zero() { + continue; + } + let advice_val = advice_poly.get_coeff(idx); + if advice_val.is_zero() { + continue; + } + let col_idx = idx % advice_cols; + result[col_idx] += left * *coeff * advice_val; + } + } + } + DoryLayout::AddressMajor => { + // Strided columns: lane variables are the low bits, so selecting lane=0 + // hits columns {0, K, 2K, ...}. + let k_chunk = DoryGlobals::k_from_matrix_shape(); + let cycles_per_row = DoryGlobals::address_major_cycles_per_row(); // == num_columns / K + debug_assert_eq!( + num_columns, + k_chunk * cycles_per_row, + "Expected num_columns == K * cycles_per_row in AddressMajor" + ); + debug_assert!( + advice_len % cycles_per_row == 0, + "ProgramImageInit len ({advice_len}) must be divisible by cycles_per_row ({cycles_per_row})" + ); + // Avoid O(cycles_per_row) work when the program image is small. + // For AddressMajor trace-dense embedding, coefficient index maps as: + // idx = row * cycles_per_row + offset + // and it contributes to main column: + // col = offset * K + let max_nonzero_prefix = ctx.preprocessing.program.program_image_words.len(); + let len = max_nonzero_prefix.min(advice_len); + + if let MultilinearPolynomial::U64Scalars(poly) = advice_poly { + for (idx, &word) in poly.coeffs[..len].iter().enumerate() { + if word == 0 { + continue; + } + let row_idx = idx / cycles_per_row; + if row_idx >= left_vec.len() { + continue; + } + let left = left_vec[row_idx]; + if left.is_zero() { + continue; + } + let offset = idx % cycles_per_row; + let col_idx = offset * k_chunk; + result[col_idx] += left * *coeff * F::from_u64(word); + } + } else { + for idx in 0..len { + let row_idx = idx / cycles_per_row; + if row_idx >= left_vec.len() { + continue; + } + let left = left_vec[row_idx]; + if left.is_zero() { + continue; + } + let advice_val = advice_poly.get_coeff(idx); + if advice_val.is_zero() { + continue; + } + let offset = idx % cycles_per_row; + let col_idx = offset * k_chunk; + result[col_idx] += left * *coeff * advice_val; + } + } + } + } + return; + } + + // Other advice polynomials use balanced dimensions and embed as a top-left block. let advice_vars = advice_len.log_2(); let (sigma_a, nu_a) = DoryGlobals::balanced_sigma_nu(advice_vars); let advice_cols = 1usize << sigma_a; @@ -363,19 +637,14 @@ impl RLCPolynomial { debug_assert!( advice_cols <= num_columns, - "Advice columns (2^{{sigma_a}}={advice_cols}) must fit in main num_columns={num_columns}; \ + "Advice columns ({advice_cols}) must fit in main num_columns={num_columns}; \ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." ); - // Only the top-left block contributes: rows [0..advice_rows), cols [0..advice_cols) let effective_rows = advice_rows.min(left_vec.len()); - - // Compute column contributions: for each column, sum contributions from all rows - // Note: advice_len is always advice_cols * advice_rows (advice size must be power of 2) let column_contributions: Vec = (0..advice_cols) .into_par_iter() .map(|col_idx| { - // For this column, sum contributions from all non-zero rows left_vec[..effective_rows] .iter() .enumerate() @@ -389,7 +658,6 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." }) .collect(); - // Add column contributions to result in parallel result[..advice_cols] .par_iter_mut() .zip(column_contributions.par_iter()) @@ -399,6 +667,27 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." }); } + /// Adds the bytecode chunk polynomial contribution to the vector-matrix-vector product result. + /// + /// Bytecode chunk polynomials are embedded in the top-left block by fixing the extra cycle + /// variables to 0, so we only iterate cycles in `[0, bytecode_len)`. + fn vmp_bytecode_contribution( + result: &mut [F], + left_vec: &[F], + num_columns: usize, + ctx: &StreamingRLCContext, + ) { + compute_bytecode_vmp_contribution( + result, + left_vec, + num_columns, + &ctx.bytecode_polys, + &ctx.preprocessing.program, + &ctx.one_hot_params, + ctx.bytecode_T, + ); + } + /// Streaming VMP implementation that generates rows on-demand from trace. /// Achieves O(sqrt(n)) space complexity by lazily generating the witness. /// Single pass through trace for both dense and one-hot polynomials. @@ -450,6 +739,7 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." let mut result = materialized.vector_matrix_product(left_vec); Self::vmp_advice_contribution(&mut result, left_vec, num_columns, ctx); + Self::vmp_bytecode_contribution(&mut result, left_vec, num_columns, ctx); result } @@ -467,7 +757,7 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." // Materialize dense polynomials (RdInc, RamInc) into dense_rlc for (poly_id, coeff) in ctx.dense_polys.iter() { let poly: MultilinearPolynomial = poly_id.generate_witness( - &ctx.preprocessing.bytecode, + &ctx.preprocessing.program, &ctx.preprocessing.memory_layout, trace, Some(&ctx.one_hot_params), @@ -488,7 +778,7 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." let mut one_hot_rlc = Vec::new(); for (poly_id, coeff) in ctx.onehot_polys.iter() { let poly = poly_id.generate_witness( - &ctx.preprocessing.bytecode, + &ctx.preprocessing.program, &ctx.preprocessing.memory_layout, trace, Some(&ctx.one_hot_params), @@ -573,6 +863,7 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." // Advice contribution is small and independent of the trace; add it after the streamed pass. Self::vmp_advice_contribution(&mut result, left_vec, num_columns, ctx); + Self::vmp_bytecode_contribution(&mut result, left_vec, num_columns, ctx); result } @@ -627,6 +918,7 @@ guardrail in gen_from_trace should ensure sigma_main >= sigma_a." // Advice contribution is small and independent of the trace; add it after the streamed pass. Self::vmp_advice_contribution(&mut result, left_vec, num_columns, ctx); + Self::vmp_bytecode_contribution(&mut result, left_vec, num_columns, ctx); result } } @@ -652,8 +944,8 @@ struct VmvSetup<'a, F: JoltField> { row_factors: Vec, /// Folded one-hot tables (coeff * eq_k pre-multiplied) folded_tables: FoldedOneHotTables, - /// Reference to preprocessing data - bytecode: &'a BytecodePreprocessing, + /// Reference to program preprocessing data + program: &'a ProgramPreprocessing, memory_layout: &'a MemoryLayout, /// Reference to one-hot parameters one_hot_params: &'a OneHotParams, @@ -694,7 +986,7 @@ impl<'a, F: JoltField> VmvSetup<'a, F> { ram_inc_coeff, row_factors, folded_tables, - bytecode: &ctx.preprocessing.bytecode, + program: &ctx.preprocessing.program, memory_layout: &ctx.preprocessing.memory_layout, one_hot_params, } @@ -810,7 +1102,7 @@ impl<'a, F: JoltField> VmvSetup<'a, F> { } // Bytecode RA chunks - let pc = self.bytecode.get_pc(cycle); + let pc = self.program.get_pc(cycle); for (i, table) in self.folded_tables.bytecode.iter().enumerate() { let k = self.one_hot_params.bytecode_pc_chunk(pc, i) as usize; inner_sum += *table[k].as_unreduced_ref(); diff --git a/jolt-core/src/poly/shared_ra_polys.rs b/jolt-core/src/poly/shared_ra_polys.rs index 282ed07317..808cccdc88 100644 --- a/jolt-core/src/poly/shared_ra_polys.rs +++ b/jolt-core/src/poly/shared_ra_polys.rs @@ -1,7 +1,7 @@ //! Shared utilities for RA (read-address) polynomials across all families. //! //! This module provides efficient computation of RA indices and G evaluations -//! that are shared across instruction, bytecode, and RAM polynomial families. +//! that are shared across instruction, program, and RAM polynomial families. //! //! ## Design Goals //! @@ -32,9 +32,9 @@ use crate::poly::eq_poly::EqPolynomial; use crate::poly::multilinear_polynomial::{BindingOrder, MultilinearPolynomial, PolynomialBinding}; use crate::utils::thread::drop_in_background_thread; use crate::utils::thread::unsafe_allocate_zero_vec; -use crate::zkvm::bytecode::BytecodePreprocessing; use crate::zkvm::config::OneHotParams; use crate::zkvm::instruction::LookupQuery; +use crate::zkvm::program::ProgramPreprocessing; use crate::zkvm::ram::remap_address; use common::constants::XLEN; use common::jolt_device::MemoryLayout; @@ -43,7 +43,7 @@ use tracer::instruction::Cycle; /// Maximum number of instruction RA chunks (lookup index splits into at most 32 chunks) pub const MAX_INSTRUCTION_D: usize = 32; -/// Maximum number of bytecode RA chunks (PC splits into at most 6 chunks) +/// Maximum number of program RA chunks (PC splits into at most 6 chunks) pub const MAX_BYTECODE_D: usize = 6; /// Maximum number of RAM RA chunks (address splits into at most 8 chunks) pub const MAX_RAM_D: usize = 8; @@ -79,7 +79,7 @@ pub struct RaIndices { /// Instruction RA chunk indices (always present) pub instruction: [u8; MAX_INSTRUCTION_D], /// Bytecode RA chunk indices (always present) - pub bytecode: [u8; MAX_BYTECODE_D], + pub program: [u8; MAX_BYTECODE_D], /// RAM RA chunk indices (None for non-memory cycles) pub ram: [Option; MAX_RAM_D], } @@ -108,7 +108,7 @@ impl Zero for RaIndices { fn is_zero(&self) -> bool { self.instruction.iter().all(|&x| x == 0) - && self.bytecode.iter().all(|&x| x == 0) + && self.program.iter().all(|&x| x == 0) && self.ram.iter().all(|x| x.is_none()) } } @@ -118,7 +118,7 @@ impl RaIndices { #[inline] pub fn from_cycle( cycle: &Cycle, - bytecode: &BytecodePreprocessing, + program: &ProgramPreprocessing, memory_layout: &MemoryLayout, one_hot_params: &OneHotParams, ) -> Self { @@ -150,10 +150,10 @@ impl RaIndices { } // Bytecode indices from PC - let pc = bytecode.get_pc(cycle); - let mut bytecode_arr = [0u8; MAX_BYTECODE_D]; + let pc = program.get_pc(cycle); + let mut program_arr = [0u8; MAX_BYTECODE_D]; for i in 0..one_hot_params.bytecode_d { - bytecode_arr[i] = one_hot_params.bytecode_pc_chunk(pc, i); + program_arr[i] = one_hot_params.bytecode_pc_chunk(pc, i); } // RAM indices from remapped address (None for non-memory cycles) @@ -166,13 +166,13 @@ impl RaIndices { Self { instruction, - bytecode: bytecode_arr, + program: program_arr, ram, } } /// Extract the index for polynomial `poly_idx` in the unified ordering: - /// [instruction_0..d, bytecode_0..d, ram_0..d] + /// [instruction_0..d, program_0..d, ram_0..d] #[inline] pub fn get_index(&self, poly_idx: usize, one_hot_params: &OneHotParams) -> Option { let instruction_d = one_hot_params.instruction_d; @@ -181,7 +181,7 @@ impl RaIndices { if poly_idx < instruction_d { Some(self.instruction[poly_idx]) } else if poly_idx < instruction_d + bytecode_d { - Some(self.bytecode[poly_idx - instruction_d]) + Some(self.program[poly_idx - instruction_d]) } else { self.ram[poly_idx - instruction_d - bytecode_d] } @@ -198,24 +198,17 @@ impl RaIndices { /// Uses a two-table split-eq: split `r_cycle` into MSB/LSB halves, compute `E_hi` and `E_lo`, /// then `eq(r_cycle, c) = E_hi[c_hi] * E_lo[c_lo]` where `c = (c_hi << lo_bits) | c_lo`. /// -/// Returns G in order: [instruction_0..d, bytecode_0..d, ram_0..d] +/// Returns G in order: [instruction_0..d, program_0..d, ram_0..d] /// Each inner Vec has length k_chunk. #[tracing::instrument(skip_all, name = "shared_ra_polys::compute_all_G")] pub fn compute_all_G( trace: &[Cycle], - bytecode: &BytecodePreprocessing, + program: &ProgramPreprocessing, memory_layout: &MemoryLayout, one_hot_params: &OneHotParams, r_cycle: &[F::Challenge], ) -> Vec> { - compute_all_G_impl::( - trace, - bytecode, - memory_layout, - one_hot_params, - r_cycle, - None, - ) + compute_all_G_impl::(trace, program, memory_layout, one_hot_params, r_cycle, None) } /// Compute all G evaluations AND RA indices in a single pass over the trace. @@ -228,7 +221,7 @@ pub fn compute_all_G( #[tracing::instrument(skip_all, name = "shared_ra_polys::compute_all_G_and_ra_indices")] pub fn compute_all_G_and_ra_indices( trace: &[Cycle], - bytecode: &BytecodePreprocessing, + program: &ProgramPreprocessing, memory_layout: &MemoryLayout, one_hot_params: &OneHotParams, r_cycle: &[F::Challenge], @@ -239,7 +232,7 @@ pub fn compute_all_G_and_ra_indices( let G = compute_all_G_impl::( trace, - bytecode, + program, memory_layout, one_hot_params, r_cycle, @@ -256,7 +249,7 @@ pub fn compute_all_G_and_ra_indices( #[inline(always)] fn compute_all_G_impl( trace: &[Cycle], - bytecode: &BytecodePreprocessing, + program: &ProgramPreprocessing, memory_layout: &MemoryLayout, one_hot_params: &OneHotParams, r_cycle: &[F::Challenge], @@ -320,7 +313,7 @@ fn compute_all_G_impl( (0..ram_d).map(|_| unsafe_allocate_zero_vec(K)).collect(); let mut touched_instruction: Vec = vec![FixedBitSet::with_capacity(K); instruction_d]; - let mut touched_bytecode: Vec = + let mut touched_program: Vec = vec![FixedBitSet::with_capacity(K); bytecode_d]; let mut touched_ram: Vec = vec![FixedBitSet::with_capacity(K); ram_d]; @@ -337,10 +330,10 @@ fn compute_all_G_impl( touched_instruction[i].clear(); } for i in 0..bytecode_d { - for k in touched_bytecode[i].ones() { + for k in touched_program[i].ones() { local_bytecode[i][k] = Default::default(); } - touched_bytecode[i].clear(); + touched_program[i].clear(); } for i in 0..ram_d { for k in touched_ram[i].ones() { @@ -360,7 +353,7 @@ fn compute_all_G_impl( let add = *E_lo[c_lo].as_unreduced_ref(); let ra_idx = - RaIndices::from_cycle(&trace[j], bytecode, memory_layout, one_hot_params); + RaIndices::from_cycle(&trace[j], program, memory_layout, one_hot_params); // Write ra_indices if collecting (disjoint write, each j visited once) if ra_ptr_usize != 0 { @@ -383,9 +376,9 @@ fn compute_all_G_impl( // BytecodeRa contributions (unreduced accumulation) for i in 0..bytecode_d { - let k = ra_idx.bytecode[i] as usize; - if !touched_bytecode[i].contains(k) { - touched_bytecode[i].insert(k); + let k = ra_idx.program[i] as usize; + if !touched_program[i].contains(k) { + touched_program[i].insert(k); } local_bytecode[i][k] += add; } @@ -410,7 +403,7 @@ fn compute_all_G_impl( } } for i in 0..bytecode_d { - for k in touched_bytecode[i].ones() { + for k in touched_program[i].ones() { let reduced = F::from_barrett_reduce::<5>(local_bytecode[i][k]); partial_bytecode[i][k] += e_hi * reduced; } @@ -423,7 +416,7 @@ fn compute_all_G_impl( } } - // Combine into single Vec> in order: instruction, bytecode, ram + // Combine into single Vec> in order: instruction, program, ram let mut result: Vec> = Vec::with_capacity(N); result.extend(partial_instruction); result.extend(partial_bytecode); @@ -872,12 +865,12 @@ impl SharedRaRound3 { #[tracing::instrument(skip_all, name = "shared_ra_polys::compute_ra_indices")] pub fn compute_ra_indices( trace: &[Cycle], - bytecode: &BytecodePreprocessing, + program: &ProgramPreprocessing, memory_layout: &MemoryLayout, one_hot_params: &OneHotParams, ) -> Vec { trace .par_iter() - .map(|cycle| RaIndices::from_cycle(cycle, bytecode, memory_layout, one_hot_params)) + .map(|cycle| RaIndices::from_cycle(cycle, program, memory_layout, one_hot_params)) .collect() } diff --git a/jolt-core/src/poly/split_eq_poly.rs b/jolt-core/src/poly/split_eq_poly.rs index fb3d22af71..688d2d24d0 100644 --- a/jolt-core/src/poly/split_eq_poly.rs +++ b/jolt-core/src/poly/split_eq_poly.rs @@ -500,6 +500,38 @@ impl GruenSplitEqPolynomial { UniPoly::from_coeff(s_coeffs) } + /// Compute the round polynomial `s(X) = l(X) · q(X)` given: + /// - `q_evals`: evaluations `[q(1), q(2), ..., q(deg(q)-1), q(∞)]` (length = deg(q)) + /// - `q_at_0`: evaluation `q(0)` + /// + /// This avoids requiring `s(0)+s(1)` as an input, and avoids recovering `q(0)` via division. + pub fn gruen_poly_from_evals_with_q0(&self, q_evals: &[F], q_at_0: F) -> UniPoly { + let r_round = match self.binding_order { + BindingOrder::LowToHigh => self.w[self.current_index - 1], + BindingOrder::HighToLow => self.w[self.current_index], + }; + + // Compute l(0) and l(1) for the current linear eq polynomial. + let l_at_0 = self.current_scalar * EqPolynomial::mle(&[F::zero()], &[r_round]); + let l_at_1 = self.current_scalar * EqPolynomial::mle(&[F::one()], &[r_round]); + + // Interpolate q from [q(0), q(1), ..., q(deg-1), q(∞)]. + let mut full_q_evals = q_evals.to_vec(); + full_q_evals.insert(0, q_at_0); + let q = UniPoly::from_evals_toom(&full_q_evals); + + // Multiply q(X) by l(X) = l_c0 + l_c1·X. + let l_c0 = l_at_0; + let l_c1 = l_at_1 - l_at_0; + let mut s_coeffs = vec![F::zero(); q.coeffs.len() + 1]; + for (i, q_ci) in q.coeffs.into_iter().enumerate() { + s_coeffs[i] += q_ci * l_c0; + s_coeffs[i + 1] += q_ci * l_c1; + } + + UniPoly::from_coeff(s_coeffs) + } + pub fn merge(&self) -> DensePolynomial { let evals = match self.binding_order { BindingOrder::LowToHigh => { @@ -795,8 +827,6 @@ mod tests { /// Verify that evals_cached returns [1] at index 0 (eq over 0 vars). #[test] fn evals_cached_starts_with_one() { - use crate::poly::eq_poly::EqPolynomial; - let mut rng = test_rng(); for num_vars in 1..=10 { let w: Vec<::Challenge> = diff --git a/jolt-core/src/subprotocols/booleanity.rs b/jolt-core/src/subprotocols/booleanity.rs index ed6d58a0a0..37600b27e8 100644 --- a/jolt-core/src/subprotocols/booleanity.rs +++ b/jolt-core/src/subprotocols/booleanity.rs @@ -36,7 +36,10 @@ use crate::{ OpeningAccumulator, OpeningPoint, ProverOpeningAccumulator, SumcheckId, VerifierOpeningAccumulator, BIG_ENDIAN, }, - shared_ra_polys::{compute_all_G_and_ra_indices, RaIndices, SharedRaPolynomials}, + shared_ra_polys::{ + compute_all_G, compute_all_G_and_ra_indices, compute_ra_indices, RaIndices, + SharedRaPolynomials, + }, split_eq_poly::GruenSplitEqPolynomial, unipoly::UniPoly, }, @@ -47,8 +50,8 @@ use crate::{ transcripts::Transcript, utils::{expanding_table::ExpandingTable, thread::drop_in_background_thread}, zkvm::{ - bytecode::BytecodePreprocessing, config::OneHotParams, + program::ProgramPreprocessing, witness::{CommittedPolynomial, VirtualPolynomial}, }, }; @@ -242,13 +245,13 @@ impl BooleanitySumcheckProver { pub fn initialize( params: BooleanitySumcheckParams, trace: &[Cycle], - bytecode: &BytecodePreprocessing, + program: &ProgramPreprocessing, memory_layout: &MemoryLayout, ) -> Self { // Compute G and RA indices in a single pass over the trace let (G, ra_indices) = compute_all_G_and_ra_indices::( trace, - bytecode, + program, memory_layout, ¶ms.one_hot_params, ¶ms.r_cycle, @@ -388,6 +391,53 @@ impl BooleanitySumcheckProver { gruen_poly * self.eq_r_r } + + fn ingest_address_challenge(&mut self, r_j: F::Challenge, round: usize) { + // Phase 1: Bind B and update F + self.B.bind(r_j); + self.F.update(r_j); + + // Transition to phase 2 + if round == self.params.log_k_chunk - 1 { + self.eq_r_r = self.B.get_current_scalar(); + + // Initialize SharedRaPolynomials with per-poly pre-scaled eq tables (by rho_i) + let F_table = std::mem::take(&mut self.F); + let ra_indices = std::mem::take(&mut self.ra_indices); + let base_eq = F_table.clone_values(); + let num_polys = self.params.polynomial_types.len(); + debug_assert!( + num_polys == self.gamma_powers.len(), + "gamma_powers length mismatch: got {}, expected {}", + self.gamma_powers.len(), + num_polys + ); + let tables: Vec> = (0..num_polys) + .into_par_iter() + .map(|i| { + let rho = self.gamma_powers[i]; + base_eq.iter().map(|v| rho * *v).collect() + }) + .collect(); + self.H = Some(SharedRaPolynomials::new( + tables, + ra_indices, + self.params.one_hot_params.clone(), + )); + + // Drop G arrays + let g = std::mem::take(&mut self.G); + drop_in_background_thread(g); + } + } + + fn ingest_cycle_challenge(&mut self, r_j: F::Challenge) { + // Phase 2: Bind D and H + self.D.bind(r_j); + if let Some(ref mut h) = self.H { + h.bind_in_place(r_j, BindingOrder::LowToHigh); + } + } } impl SumcheckInstanceProver for BooleanitySumcheckProver { @@ -407,48 +457,9 @@ impl SumcheckInstanceProver for BooleanitySum #[tracing::instrument(skip_all, name = "BooleanitySumcheckProver::ingest_challenge")] fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize) { if round < self.params.log_k_chunk { - // Phase 1: Bind B and update F - self.B.bind(r_j); - self.F.update(r_j); - - // Transition to phase 2 - if round == self.params.log_k_chunk - 1 { - self.eq_r_r = self.B.get_current_scalar(); - - // Initialize SharedRaPolynomials with per-poly pre-scaled eq tables (by rho_i) - let F_table = std::mem::take(&mut self.F); - let ra_indices = std::mem::take(&mut self.ra_indices); - let base_eq = F_table.clone_values(); - let num_polys = self.params.polynomial_types.len(); - debug_assert!( - num_polys == self.gamma_powers.len(), - "gamma_powers length mismatch: got {}, expected {}", - self.gamma_powers.len(), - num_polys - ); - let tables: Vec> = (0..num_polys) - .into_par_iter() - .map(|i| { - let rho = self.gamma_powers[i]; - base_eq.iter().map(|v| rho * *v).collect() - }) - .collect(); - self.H = Some(SharedRaPolynomials::new( - tables, - ra_indices, - self.params.one_hot_params.clone(), - )); - - // Drop G arrays - let g = std::mem::take(&mut self.G); - drop_in_background_thread(g); - } + self.ingest_address_challenge(r_j, round); } else { - // Phase 2: Bind D and H - self.D.bind(r_j); - if let Some(ref mut h) = self.H { - h.bind_in_place(r_j, BindingOrder::LowToHigh); - } + self.ingest_cycle_challenge(r_j); } } @@ -483,6 +494,393 @@ impl SumcheckInstanceProver for BooleanitySum } } +/// Booleanity Address-Phase Sumcheck Prover. +/// +/// This prover handles only the first `log_k_chunk` rounds (address variables). +/// The cycle-phase prover is constructed separately from witness + accumulator (Option B). +#[derive(Allocative)] +pub struct BooleanityAddressSumcheckProver { + /// B: split-eq over address-chunk variables (LowToHigh). + B: GruenSplitEqPolynomial, + /// G[i][k] = Σ_j eq(r_cycle, j) · ra_i(k, j) for all RA polynomials + G: Vec>, + /// F: Expanding table for address phase + F: ExpandingTable, + /// Last round polynomial for claim computation + last_round_poly: Option>, + /// Final claim after binding all address variables + address_claim: Option, + /// Parameters (shared with cycle prover) + pub params: BooleanitySumcheckParams, +} + +impl BooleanityAddressSumcheckProver { + /// Initialize a BooleanityAddressSumcheckProver. + /// + /// Computes G polynomials and RA indices in a single pass over the trace. + #[tracing::instrument(skip_all, name = "BooleanityAddressSumcheckProver::initialize")] + pub fn initialize( + params: BooleanitySumcheckParams, + trace: &[Cycle], + program: &ProgramPreprocessing, + memory_layout: &MemoryLayout, + ) -> Self { + // Compute G in a single pass over the trace (witness-dependent). + let G = compute_all_G::( + trace, + program, + memory_layout, + ¶ms.one_hot_params, + ¶ms.r_cycle, + ); + + // Initialize split-eq polynomial for address variables + let B = GruenSplitEqPolynomial::new(¶ms.r_address, BindingOrder::LowToHigh); + + // Initialize expanding table for address phase + let k_chunk = 1 << params.log_k_chunk; + let mut F_table = ExpandingTable::new(k_chunk, BindingOrder::LowToHigh); + F_table.reset(F::one()); + + Self { + B, + G, + F: F_table, + last_round_poly: None, + address_claim: None, + params, + } + } + + fn compute_message_impl(&self, round: usize, previous_claim: F) -> UniPoly { + let m = round + 1; + let B = &self.B; + let N = self.params.polynomial_types.len(); + + // Compute quadratic coefficients via generic split-eq fold + let quadratic_coeffs: [F; DEGREE_BOUND - 1] = B + .par_fold_out_in_unreduced::<9, { DEGREE_BOUND - 1 }>(&|k_prime| { + let coeffs = (0..N) + .into_par_iter() + .map(|i| { + let G_i = &self.G[i]; + let inner_sum = G_i[k_prime << m..(k_prime + 1) << m] + .par_iter() + .enumerate() + .map(|(k, &G_k)| { + let k_m = k >> (m - 1); + let F_k = self.F[k & ((1 << (m - 1)) - 1)]; + let G_times_F = G_k * F_k; + + let eval_infty = G_times_F * F_k; + let eval_0 = if k_m == 0 { + eval_infty - G_times_F + } else { + F::zero() + }; + [eval_0, eval_infty] + }) + .fold_with( + [F::Unreduced::<5>::zero(); DEGREE_BOUND - 1], + |running, new| { + [ + running[0] + new[0].as_unreduced_ref(), + running[1] + new[1].as_unreduced_ref(), + ] + }, + ) + .reduce( + || [F::Unreduced::zero(); DEGREE_BOUND - 1], + |running, new| [running[0] + new[0], running[1] + new[1]], + ); + + let gamma_2i = self.params.gamma_powers_square[i]; + [ + gamma_2i * F::from_barrett_reduce(inner_sum[0]), + gamma_2i * F::from_barrett_reduce(inner_sum[1]), + ] + }) + .reduce( + || [F::zero(); DEGREE_BOUND - 1], + |running, new| [running[0] + new[0], running[1] + new[1]], + ); + coeffs + }); + + B.gruen_poly_deg_3(quadratic_coeffs[0], quadratic_coeffs[1], previous_claim) + } + + fn ingest_challenge_impl(&mut self, r_j: F::Challenge) { + self.B.bind(r_j); + self.F.update(r_j); + } +} + +impl SumcheckInstanceProver + for BooleanityAddressSumcheckProver +{ + fn degree(&self) -> usize { + self.params.degree() + } + + fn num_rounds(&self) -> usize { + self.params.log_k_chunk + } + + fn input_claim(&self, _accumulator: &ProverOpeningAccumulator) -> F { + self.params.input_claim(_accumulator) + } + + fn compute_message(&mut self, round: usize, previous_claim: F) -> UniPoly { + let poly = self.compute_message_impl(round, previous_claim); + self.last_round_poly = Some(poly.clone()); + poly + } + + fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize) { + if let Some(poly) = self.last_round_poly.take() { + let claim = poly.evaluate(&r_j); + if round == self.params.log_k_chunk - 1 { + self.address_claim = Some(claim); + } + } + self.ingest_challenge_impl(r_j); + } + + fn cache_openings( + &self, + accumulator: &mut ProverOpeningAccumulator, + transcript: &mut T, + sumcheck_challenges: &[F::Challenge], + ) { + let mut r_address = sumcheck_challenges.to_vec(); + r_address.reverse(); + let opening_point = OpeningPoint::::new(r_address); + let address_claim = self + .address_claim + .expect("Booleanity address-phase claim missing"); + accumulator.append_virtual( + transcript, + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + opening_point, + address_claim, + ); + } + + #[cfg(feature = "allocative")] + fn update_flamegraph(&self, flamegraph: &mut FlameGraphBuilder) { + flamegraph.visit_root(self); + } +} + +/// Booleanity Cycle-Phase Sumcheck Prover. +/// +/// This prover handles the remaining `log_t` rounds (cycle variables). +/// It is constructed from scratch via [`BooleanityCycleSumcheckProver::initialize`]. +#[derive(Allocative)] +pub struct BooleanityCycleSumcheckProver { + /// D: split-eq over time/cycle variables (LowToHigh). + D: GruenSplitEqPolynomial, + /// Shared H polynomials (RA polys bound over address, pre-scaled by gamma) + H: SharedRaPolynomials, + /// eq(r_address, r_address) from address phase + eq_r_r: F, + /// Per-polynomial powers γ^i (in the base field). + gamma_powers: Vec, + /// Per-polynomial inverse powers γ^{-i} (in the base field). + gamma_powers_inv: Vec, + /// Parameters + pub params: BooleanitySumcheckParams, +} + +impl BooleanityCycleSumcheckProver { + /// Initialize the cycle-phase prover from the Stage 6a address opening point. + /// + /// The only witness-dependent work performed here should be collecting `ra_indices` + /// (needed to materialize `SharedRaPolynomials` for the cycle phase). + #[tracing::instrument(skip_all, name = "BooleanityCycleSumcheckProver::initialize")] + pub fn initialize( + params: BooleanitySumcheckParams, + trace: &[Cycle], + program: &ProgramPreprocessing, + memory_layout: &MemoryLayout, + accumulator: &ProverOpeningAccumulator, + ) -> Self { + // Recover Stage 6a address challenges from the accumulator. + // These were stored as BIG_ENDIAN (MSB-first) by the address-phase cache_openings. + let (r_address_point, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + ); + let mut r_address_low_to_high = r_address_point.r; + r_address_low_to_high.reverse(); + + // Derive eq_r_r = eq(params.r_address, r_address_challenges) via the same binding + // progression as the address prover. + let mut B = GruenSplitEqPolynomial::new(¶ms.r_address, BindingOrder::LowToHigh); + for r_j in r_address_low_to_high.iter().cloned() { + B.bind(r_j); + } + let eq_r_r = B.get_current_scalar(); + + // Derive base eq table over k_chunk addresses from the address challenges. + let k_chunk = 1 << params.log_k_chunk; + let mut F_table = ExpandingTable::new(k_chunk, BindingOrder::LowToHigh); + F_table.reset(F::one()); + for r_j in r_address_low_to_high.iter().cloned() { + F_table.update(r_j); + } + let base_eq = F_table.clone_values(); + + // Compute RA indices from witness (unfused with G computation). + let ra_indices = compute_ra_indices(trace, program, memory_layout, ¶ms.one_hot_params); + + // Compute prover-only batching coefficients rho_i = gamma^i and inverses. + let num_polys = params.polynomial_types.len(); + let gamma_f: F = params.gamma.into(); + let mut gamma_powers = Vec::with_capacity(num_polys); + let mut gamma_powers_inv = Vec::with_capacity(num_polys); + let mut rho_i = F::one(); + for _ in 0..num_polys { + gamma_powers.push(rho_i); + gamma_powers_inv.push( + rho_i + .inverse() + .expect("gamma is nonzero, so rho_i is invertible"), + ); + rho_i *= gamma_f; + } + + // Initialize SharedRaPolynomials with per-poly pre-scaled eq tables (by rho_i). + let tables: Vec> = (0..num_polys) + .into_par_iter() + .map(|i| { + let rho = gamma_powers[i]; + base_eq.iter().map(|v| rho * *v).collect() + }) + .collect(); + let H = SharedRaPolynomials::new(tables, ra_indices, params.one_hot_params.clone()); + + // Cycle split-eq polynomial over r_cycle. + let D = GruenSplitEqPolynomial::new(¶ms.r_cycle, BindingOrder::LowToHigh); + + Self { + D, + H, + eq_r_r, + gamma_powers, + gamma_powers_inv, + params, + } + } + + fn compute_message_impl(&self, previous_claim: F) -> UniPoly { + let D = &self.D; + let H = &self.H; + let num_polys = H.num_polys(); + + // Compute quadratic coefficients via generic split-eq fold + let quadratic_coeffs: [F; DEGREE_BOUND - 1] = D + .par_fold_out_in_unreduced::<9, { DEGREE_BOUND - 1 }>(&|j_prime| { + let mut acc_c = F::Unreduced::<9>::zero(); + let mut acc_e = F::Unreduced::<9>::zero(); + for i in 0..num_polys { + let h_0 = H.get_bound_coeff(i, 2 * j_prime); + let h_1 = H.get_bound_coeff(i, 2 * j_prime + 1); + let b = h_1 - h_0; + + let rho = self.gamma_powers[i]; + acc_c += h_0.mul_unreduced::<9>(h_0 - rho); + acc_e += b.mul_unreduced::<9>(b); + } + [ + F::from_montgomery_reduce::<9>(acc_c), + F::from_montgomery_reduce::<9>(acc_e), + ] + }); + + // Adjust claim by eq_r_r scaling + let adjusted_claim = previous_claim * self.eq_r_r.inverse().unwrap(); + let gruen_poly = + D.gruen_poly_deg_3(quadratic_coeffs[0], quadratic_coeffs[1], adjusted_claim); + + gruen_poly * self.eq_r_r + } + + fn ingest_challenge_impl(&mut self, r_j: F::Challenge) { + self.D.bind(r_j); + self.H.bind_in_place(r_j, BindingOrder::LowToHigh); + } +} + +impl SumcheckInstanceProver + for BooleanityCycleSumcheckProver +{ + fn degree(&self) -> usize { + self.params.degree() + } + + fn num_rounds(&self) -> usize { + self.params.log_t + } + + fn input_claim(&self, accumulator: &ProverOpeningAccumulator) -> F { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + ) + .1 + } + + #[tracing::instrument(skip_all, name = "BooleanityCycleSumcheckProver::compute_message")] + fn compute_message(&mut self, _round: usize, previous_claim: F) -> UniPoly { + self.compute_message_impl(previous_claim) + } + + #[tracing::instrument(skip_all, name = "BooleanityCycleSumcheckProver::ingest_challenge")] + fn ingest_challenge(&mut self, r_j: F::Challenge, _round: usize) { + self.ingest_challenge_impl(r_j) + } + + fn cache_openings( + &self, + accumulator: &mut ProverOpeningAccumulator, + transcript: &mut T, + sumcheck_challenges: &[F::Challenge], + ) { + let (r_address_point, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + ); + let mut r_address_le = r_address_point.r; + r_address_le.reverse(); + let mut full_challenges = r_address_le; + full_challenges.extend_from_slice(sumcheck_challenges); + let opening_point = self.params.normalize_opening_point(&full_challenges); + + // H is scaled by rho_i; unscale so cached openings match the committed polynomials. + let claims: Vec = (0..self.H.num_polys()) + .map(|i| self.H.final_sumcheck_claim(i) * self.gamma_powers_inv[i]) + .collect(); + + accumulator.append_sparse( + transcript, + self.params.polynomial_types.clone(), + SumcheckId::Booleanity, + opening_point.r[..self.params.log_k_chunk].to_vec(), + opening_point.r[self.params.log_k_chunk..].to_vec(), + claims, + ); + } + + #[cfg(feature = "allocative")] + fn update_flamegraph(&self, flamegraph: &mut FlameGraphBuilder) { + flamegraph.visit_root(self); + } +} + /// Booleanity Sumcheck Verifier. pub struct BooleanitySumcheckVerifier { params: BooleanitySumcheckParams, @@ -545,3 +943,163 @@ impl SumcheckInstanceVerifier for BooleanityS ); } } + +pub struct BooleanityAddressSumcheckVerifier { + params: BooleanitySumcheckParams, +} + +impl BooleanityAddressSumcheckVerifier { + pub fn new(params: BooleanitySumcheckParams) -> Self { + Self { params } + } + + /// Consume this verifier and return the underlying parameters (for Option B orchestration). + pub fn into_params(self) -> BooleanitySumcheckParams { + self.params + } + + pub fn into_cycle_verifier(self) -> BooleanityCycleSumcheckVerifier { + BooleanityCycleSumcheckVerifier { + params: self.params, + } + } +} + +impl SumcheckInstanceVerifier + for BooleanityAddressSumcheckVerifier +{ + fn degree(&self) -> usize { + self.params.degree() + } + + fn num_rounds(&self) -> usize { + self.params.log_k_chunk + } + + fn input_claim(&self, accumulator: &VerifierOpeningAccumulator) -> F { + self.params.input_claim(accumulator) + } + + fn expected_output_claim( + &self, + accumulator: &VerifierOpeningAccumulator, + _sumcheck_challenges: &[F::Challenge], + ) -> F { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + ) + .1 + } + + fn cache_openings( + &self, + accumulator: &mut VerifierOpeningAccumulator, + transcript: &mut T, + sumcheck_challenges: &[F::Challenge], + ) { + let mut r_address = sumcheck_challenges.to_vec(); + r_address.reverse(); + accumulator.append_virtual( + transcript, + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + OpeningPoint::::new(r_address), + ); + } +} + +pub struct BooleanityCycleSumcheckVerifier { + params: BooleanitySumcheckParams, +} + +impl BooleanityCycleSumcheckVerifier { + pub fn new(params: BooleanitySumcheckParams) -> Self { + Self { params } + } +} + +impl SumcheckInstanceVerifier + for BooleanityCycleSumcheckVerifier +{ + fn degree(&self) -> usize { + self.params.degree() + } + + fn num_rounds(&self) -> usize { + self.params.log_t + } + + fn input_claim(&self, accumulator: &VerifierOpeningAccumulator) -> F { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + ) + .1 + } + + fn expected_output_claim( + &self, + accumulator: &VerifierOpeningAccumulator, + sumcheck_challenges: &[F::Challenge], + ) -> F { + let (r_address_point, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + ); + let mut r_address_le = r_address_point.r; + r_address_le.reverse(); + let mut full_challenges = r_address_le; + full_challenges.extend_from_slice(sumcheck_challenges); + + let ra_claims: Vec = self + .params + .polynomial_types + .iter() + .map(|poly_type| { + accumulator + .get_committed_polynomial_opening(*poly_type, SumcheckId::Booleanity) + .1 + }) + .collect(); + + let combined_r: Vec = self + .params + .r_address + .iter() + .cloned() + .rev() + .chain(self.params.r_cycle.iter().cloned().rev()) + .collect(); + + EqPolynomial::::mle(&full_challenges, &combined_r) + * zip(&self.params.gamma_powers_square, ra_claims) + .map(|(gamma_2i, ra)| (ra.square() - ra) * gamma_2i) + .sum::() + } + + fn cache_openings( + &self, + accumulator: &mut VerifierOpeningAccumulator, + transcript: &mut T, + sumcheck_challenges: &[F::Challenge], + ) { + let (r_address_point, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BooleanityAddrClaim, + SumcheckId::BooleanityAddressPhase, + ); + let mut r_address_le = r_address_point.r; + r_address_le.reverse(); + let mut full_challenges = r_address_le; + full_challenges.extend_from_slice(sumcheck_challenges); + let opening_point = self.params.normalize_opening_point(&full_challenges); + accumulator.append_sparse( + transcript, + self.params.polynomial_types.clone(), + SumcheckId::Booleanity, + opening_point.r, + ); + } +} diff --git a/jolt-core/src/utils/errors.rs b/jolt-core/src/utils/errors.rs index a9e8b12909..b3800e13eb 100644 --- a/jolt-core/src/utils/errors.rs +++ b/jolt-core/src/utils/errors.rs @@ -28,10 +28,14 @@ pub enum ProofVerifyError { InvalidReadWriteConfig(String), #[error("Invalid one-hot configuration: {0}")] InvalidOneHotConfig(String), + #[error("Invalid bytecode commitment configuration: {0}")] + InvalidBytecodeConfig(String), #[error("Dory proof verification failed: {0}")] DoryError(String), #[error("Sumcheck verification failed")] SumcheckVerificationError, #[error("Univariate-skip round verification failed")] UniSkipVerificationError, + #[error("Bytecode type mismatch: {0}")] + BytecodeTypeMismatch(String), } diff --git a/jolt-core/src/zkvm/bytecode/chunks.rs b/jolt-core/src/zkvm/bytecode/chunks.rs new file mode 100644 index 0000000000..f372ef95b6 --- /dev/null +++ b/jolt-core/src/zkvm/bytecode/chunks.rs @@ -0,0 +1,445 @@ +use crate::field::JoltField; +use crate::poly::commitment::dory::{DoryGlobals, DoryLayout}; +use crate::poly::multilinear_polynomial::MultilinearPolynomial; +use crate::utils::thread::unsafe_allocate_zero_vec; +use crate::zkvm::bytecode::BytecodePreprocessing; +use crate::zkvm::instruction::{ + Flags, InstructionLookup, InterleavedBitsMarker, NUM_CIRCUIT_FLAGS, NUM_INSTRUCTION_FLAGS, +}; +use crate::zkvm::lookup_table::LookupTables; +use common::constants::{REGISTER_COUNT, XLEN}; +use rayon::prelude::*; +use tracer::instruction::Instruction; + +/// Total number of "lanes" to commit bytecode fields +pub const fn total_lanes() -> usize { + 3 * (REGISTER_COUNT as usize) // rs1, rs2, rd one-hot lanes + + 2 // unexpanded_pc, imm + + NUM_CIRCUIT_FLAGS + + NUM_INSTRUCTION_FLAGS + + as strum::EnumCount>::COUNT + + 1 // raf flag +} + +/// Canonical lane layout for bytecode chunk polynomials. +/// +/// The global lane order matches [`lane_value`] and the weights in +/// `claim_reductions/bytecode.rs::compute_chunk_lane_weights`. +#[derive(Clone, Copy, Debug)] +pub struct BytecodeLaneLayout { + pub rs1_start: usize, + pub rs2_start: usize, + pub rd_start: usize, + pub unexp_pc_idx: usize, + pub imm_idx: usize, + pub circuit_start: usize, + pub instr_start: usize, + pub lookup_start: usize, + pub raf_flag_idx: usize, +} + +impl BytecodeLaneLayout { + pub const fn new() -> Self { + let reg_count = REGISTER_COUNT as usize; + let rs1_start = 0usize; + let rs2_start = rs1_start + reg_count; + let rd_start = rs2_start + reg_count; + let unexp_pc_idx = rd_start + reg_count; + let imm_idx = unexp_pc_idx + 1; + let circuit_start = imm_idx + 1; + let instr_start = circuit_start + NUM_CIRCUIT_FLAGS; + let lookup_start = instr_start + NUM_INSTRUCTION_FLAGS; + let raf_flag_idx = lookup_start + as strum::EnumCount>::COUNT; + Self { + rs1_start, + rs2_start, + rd_start, + unexp_pc_idx, + imm_idx, + circuit_start, + instr_start, + lookup_start, + raf_flag_idx, + } + } + + #[inline(always)] + #[allow(dead_code)] + pub const fn total_lanes(&self) -> usize { + self.raf_flag_idx + 1 + } + + /// True for all lanes except `unexpanded_pc` and `imm`. + #[inline(always)] + #[allow(dead_code)] + pub const fn is_boolean_lane(&self, global_lane: usize) -> bool { + global_lane != self.unexp_pc_idx && global_lane != self.imm_idx + } +} + +pub const BYTECODE_LANE_LAYOUT: BytecodeLaneLayout = BytecodeLaneLayout::new(); + +/// Active lane values for a single instruction. +/// +/// Most lanes are boolean/one-hot, so we represent them as `One` to avoid +/// unnecessary field multiplications at call sites (e.g. Dory VMV). +#[derive(Clone, Copy, Debug)] +pub enum ActiveLaneValue { + One, + Scalar(F), +} + +/// Evaluate the weighted lane sum for a single instruction: +/// \( \sum_{\ell} weights[\ell] \cdot lane\_value(\ell, instr) \), +/// without scanning all lanes (uses one-hot and boolean sparsity). +#[inline(always)] +pub fn weighted_lane_sum_for_instruction(weights: &[F], instr: &Instruction) -> F { + debug_assert_eq!(weights.len(), total_lanes()); + + let l = BYTECODE_LANE_LAYOUT; + + let normalized = instr.normalize(); + let circuit_flags = ::circuit_flags(instr); + let instr_flags = ::instruction_flags(instr); + let lookup_idx = >::lookup_table(instr) + .map(|t| LookupTables::::enum_index(&t)); + let raf_flag = !InterleavedBitsMarker::is_interleaved_operands(&circuit_flags); + + let unexpanded_pc = F::from_u64(normalized.address as u64); + let imm = F::from_i128(normalized.operands.imm); + let rs1 = normalized.operands.rs1.map(|r| r as usize); + let rs2 = normalized.operands.rs2.map(|r| r as usize); + let rd = normalized.operands.rd.map(|r| r as usize); + + let mut acc = F::zero(); + + // One-hot register lanes: select weight at the active register (or 0 if None). + if let Some(r) = rs1 { + acc += weights[l.rs1_start + r]; + } + if let Some(r) = rs2 { + acc += weights[l.rs2_start + r]; + } + if let Some(r) = rd { + acc += weights[l.rd_start + r]; + } + + // Scalar lanes. + acc += weights[l.unexp_pc_idx] * unexpanded_pc; + acc += weights[l.imm_idx] * imm; + + // Circuit flags (boolean): add weight when flag is true. + for i in 0..NUM_CIRCUIT_FLAGS { + if circuit_flags[i] { + acc += weights[l.circuit_start + i]; + } + } + + // Instruction flags (boolean): add weight when flag is true. + for i in 0..NUM_INSTRUCTION_FLAGS { + if instr_flags[i] { + acc += weights[l.instr_start + i]; + } + } + + // Lookup table selector (one-hot / zero-hot). + if let Some(t) = lookup_idx { + acc += weights[l.lookup_start + t]; + } + + // RAF flag. + if raf_flag { + acc += weights[l.raf_flag_idx]; + } + + acc +} + +/// Enumerate the non-zero lanes for a single instruction in canonical global-lane order. +/// +/// This is the sparse counterpart to [`lane_value`]: instead of scanning all lanes and +/// branching on zeros, we directly visit only lanes that are 1 (for boolean/one-hot lanes) +/// or have a non-zero scalar value (for `unexpanded_pc` and `imm`). +/// +/// This is useful for: +/// - Streaming / VMV computations where the downstream logic needs to map lanes to matrix indices +/// - Any place where per-lane work dominates and the instruction lane vector is sparse +#[inline(always)] +pub fn for_each_active_lane_value( + instr: &Instruction, + mut visit: impl FnMut(usize, ActiveLaneValue), +) { + let l = BYTECODE_LANE_LAYOUT; + + let normalized = instr.normalize(); + let circuit_flags = ::circuit_flags(instr); + let instr_flags = ::instruction_flags(instr); + let lookup_idx = >::lookup_table(instr) + .map(|t| LookupTables::::enum_index(&t)); + let raf_flag = !InterleavedBitsMarker::is_interleaved_operands(&circuit_flags); + + // One-hot register lanes. + if let Some(r) = normalized.operands.rs1 { + visit(l.rs1_start + (r as usize), ActiveLaneValue::One); + } + if let Some(r) = normalized.operands.rs2 { + visit(l.rs2_start + (r as usize), ActiveLaneValue::One); + } + if let Some(r) = normalized.operands.rd { + visit(l.rd_start + (r as usize), ActiveLaneValue::One); + } + + // Scalar lanes (skip if zero). + let unexpanded_pc = F::from_u64(normalized.address as u64); + if !unexpanded_pc.is_zero() { + visit(l.unexp_pc_idx, ActiveLaneValue::Scalar(unexpanded_pc)); + } + let imm = F::from_i128(normalized.operands.imm); + if !imm.is_zero() { + visit(l.imm_idx, ActiveLaneValue::Scalar(imm)); + } + + // Circuit flags. + for i in 0..NUM_CIRCUIT_FLAGS { + if circuit_flags[i] { + visit(l.circuit_start + i, ActiveLaneValue::One); + } + } + + // Instruction flags. + for i in 0..NUM_INSTRUCTION_FLAGS { + if instr_flags[i] { + visit(l.instr_start + i, ActiveLaneValue::One); + } + } + + // Lookup selector. + if let Some(t) = lookup_idx { + visit(l.lookup_start + t, ActiveLaneValue::One); + } + + // RAF flag. + if raf_flag { + visit(l.raf_flag_idx, ActiveLaneValue::One); + } +} + +#[allow(clippy::too_many_arguments)] +#[inline(always)] +pub fn lane_value( + global_lane: usize, + rs1: Option, + rs2: Option, + rd: Option, + unexpanded_pc: F, + imm: F, + circuit_flags: &[bool; NUM_CIRCUIT_FLAGS], + instr_flags: &[bool; NUM_INSTRUCTION_FLAGS], + lookup_idx: Option, + raf_flag: bool, +) -> F { + let reg_count = REGISTER_COUNT as usize; + let rs1_start = 0usize; + let rs2_start = rs1_start + reg_count; + let rd_start = rs2_start + reg_count; + let unexp_pc_idx = rd_start + reg_count; + let imm_idx = unexp_pc_idx + 1; + let circuit_start = imm_idx + 1; + let instr_start = circuit_start + NUM_CIRCUIT_FLAGS; + let lookup_start = instr_start + NUM_INSTRUCTION_FLAGS; + let raf_flag_idx = lookup_start + as strum::EnumCount>::COUNT; + + if global_lane < rs2_start { + // rs1 one-hot + let r = global_lane as u8; + return F::from_bool(rs1 == Some(r)); + } + if global_lane < rd_start { + // rs2 one-hot + let r = (global_lane - rs2_start) as u8; + return F::from_bool(rs2 == Some(r)); + } + if global_lane < unexp_pc_idx { + // rd one-hot + let r = (global_lane - rd_start) as u8; + return F::from_bool(rd == Some(r)); + } + if global_lane == unexp_pc_idx { + return unexpanded_pc; + } + if global_lane == imm_idx { + return imm; + } + if global_lane < instr_start { + let flag_idx = global_lane - circuit_start; + return F::from_bool(circuit_flags[flag_idx]); + } + if global_lane < lookup_start { + let flag_idx = global_lane - instr_start; + return F::from_bool(instr_flags[flag_idx]); + } + if global_lane < raf_flag_idx { + let table_idx = global_lane - lookup_start; + return F::from_bool(lookup_idx == Some(table_idx)); + } + debug_assert_eq!(global_lane, raf_flag_idx); + F::from_bool(raf_flag) +} + +/// Build bytecode chunk polynomials from a preprocessed instruction slice. +/// +/// This avoids constructing a `BytecodePreprocessing` wrapper (and its clones) when callers +/// already have the padded instruction list. +#[tracing::instrument(skip_all, name = "bytecode::build_bytecode_chunks_from_instructions")] +pub fn build_bytecode_chunks_from_instructions( + instructions: &[Instruction], + log_k_chunk: usize, +) -> Vec> { + let k_chunk = 1usize << log_k_chunk; + let bytecode_len = instructions.len(); + let total = total_lanes(); + let num_chunks = total.div_ceil(k_chunk); + + (0..num_chunks) + .into_par_iter() + .map(|chunk_idx| { + let mut coeffs = unsafe_allocate_zero_vec(k_chunk * bytecode_len); + for k in 0..bytecode_len { + let instr = &instructions[k]; + let normalized = instr.normalize(); + let circuit_flags = ::circuit_flags(instr); + let instr_flags = ::instruction_flags(instr); + let lookup_idx = >::lookup_table(instr) + .map(|t| LookupTables::::enum_index(&t)); + let raf_flag = !InterleavedBitsMarker::is_interleaved_operands(&circuit_flags); + + let unexpanded_pc = F::from_u64(normalized.address as u64); + let imm = F::from_i128(normalized.operands.imm); + let rs1 = normalized.operands.rs1; + let rs2 = normalized.operands.rs2; + let rd = normalized.operands.rd; + + for lane in 0..k_chunk { + let global_lane = chunk_idx * k_chunk + lane; + if global_lane >= total { + break; + } + let value = lane_value::( + global_lane, + rs1, + rs2, + rd, + unexpanded_pc, + imm, + &circuit_flags, + &instr_flags, + lookup_idx, + raf_flag, + ); + let idx = DoryGlobals::get_layout().address_cycle_to_index( + lane, + k, + k_chunk, + bytecode_len, + ); + coeffs[idx] = value; + } + } + MultilinearPolynomial::from(coeffs) + }) + .collect() +} + +#[tracing::instrument(skip_all, name = "bytecode::build_bytecode_chunks")] +pub fn build_bytecode_chunks( + bytecode: &BytecodePreprocessing, + log_k_chunk: usize, +) -> Vec> { + build_bytecode_chunks_from_instructions::(&bytecode.bytecode, log_k_chunk) +} + +/// Build bytecode chunk polynomials with main-matrix dimensions for CycleMajor embedding. +/// +/// This creates bytecode chunks with `k_chunk * padded_trace_len` coefficients, using +/// main-matrix indexing (`lane * T + cycle`) instead of bytecode indexing (`lane * bytecode_len + cycle`). +/// +/// **Why this is needed for CycleMajor:** +/// - In CycleMajor, coefficients are ordered as: lane 0's cycles, lane 1's cycles, ... +/// - Bytecode indexing gives: `lane * bytecode_len + cycle` +/// - Main indexing gives: `lane * T + cycle` +/// - When T > bytecode_len, these differ for lane > 0, causing row-commitment hint mismatch +/// +/// **For AddressMajor, this is NOT needed** because both use `cycle * k_chunk + lane`, +/// which gives the same index for cycle < bytecode_len. +/// +/// The bytecode values are placed at positions (lane, cycle) for cycle < bytecode_len, +/// with zeros for cycle >= bytecode_len (matching the "extra cycle vars fixed to 0" embedding). +#[tracing::instrument(skip_all, name = "bytecode::build_bytecode_chunks_for_main_matrix")] +pub fn build_bytecode_chunks_for_main_matrix( + bytecode: &BytecodePreprocessing, + log_k_chunk: usize, + padded_trace_len: usize, + layout: DoryLayout, +) -> Vec> { + debug_assert_eq!( + layout, + DoryLayout::CycleMajor, + "build_bytecode_chunks_for_main_matrix should only be used for CycleMajor layout" + ); + + let k_chunk = 1usize << log_k_chunk; + let bytecode_len = bytecode.bytecode.len(); + let total = total_lanes(); + let num_chunks = total.div_ceil(k_chunk); + + debug_assert!( + padded_trace_len >= bytecode_len, + "padded_trace_len ({padded_trace_len}) must be >= bytecode_len ({bytecode_len})" + ); + + (0..num_chunks) + .into_par_iter() + .map(|chunk_idx| { + // Use padded_trace_len for coefficient array size (main-matrix dimensions) + let mut coeffs = unsafe_allocate_zero_vec(k_chunk * padded_trace_len); + for k in 0..bytecode_len { + let instr = &bytecode.bytecode[k]; + let normalized = instr.normalize(); + let circuit_flags = ::circuit_flags(instr); + let instr_flags = ::instruction_flags(instr); + let lookup_idx = >::lookup_table(instr) + .map(|t| LookupTables::::enum_index(&t)); + let raf_flag = !InterleavedBitsMarker::is_interleaved_operands(&circuit_flags); + + let unexpanded_pc = F::from_u64(normalized.address as u64); + let imm = F::from_i128(normalized.operands.imm); + let rs1 = normalized.operands.rs1; + let rs2 = normalized.operands.rs2; + let rd = normalized.operands.rd; + + for lane in 0..k_chunk { + let global_lane = chunk_idx * k_chunk + lane; + if global_lane >= total { + break; + } + let value = lane_value::( + global_lane, + rs1, + rs2, + rd, + unexpanded_pc, + imm, + &circuit_flags, + &instr_flags, + lookup_idx, + raf_flag, + ); + // Use padded_trace_len (main T) for indexing + let idx = layout.address_cycle_to_index(lane, k, k_chunk, padded_trace_len); + coeffs[idx] = value; + } + } + MultilinearPolynomial::from(coeffs) + }) + .collect() +} diff --git a/jolt-core/src/zkvm/bytecode/mod.rs b/jolt-core/src/zkvm/bytecode/mod.rs index 82f6fb62ab..d626462d30 100644 --- a/jolt-core/src/zkvm/bytecode/mod.rs +++ b/jolt-core/src/zkvm/bytecode/mod.rs @@ -1,12 +1,220 @@ -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use std::io::{Read, Write}; +use std::sync::Arc; + +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate, +}; use common::constants::{ALIGNMENT_FACTOR_BYTECODE, RAM_START_ADDRESS}; use tracer::instruction::{Cycle, Instruction}; +use crate::poly::commitment::commitment_scheme::CommitmentScheme; +use crate::poly::commitment::dory::{DoryContext, DoryGlobals}; +use crate::utils::errors::ProofVerifyError; +use crate::utils::math::Math; +use crate::zkvm::bytecode::chunks::{build_bytecode_chunks, total_lanes}; +use rayon::prelude::*; + +pub(crate) mod chunks; pub mod read_raf_checking; +/// Bytecode commitments that were derived from actual bytecode. +/// +/// This type enforces at the type level that commitments came from honest +/// preprocessing of full bytecode. The canonical constructor is `derive()`, +/// which takes full bytecode and computes commitments. +/// +/// # Trust Model +/// - Create via `derive()` from full bytecode (offline preprocessing) +/// - Or deserialize from a trusted source (assumes honest origin) +/// - Pass to verifier preprocessing for succinct (online) verification +/// +/// # Security Warning +/// If you construct this type with arbitrary commitments (bypassing `derive()`), +/// verification will be unsound. Only use `derive()` or trusted deserialization. +#[derive(Clone, Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)] +pub struct TrustedBytecodeCommitments { + /// The bytecode chunk commitments. + /// Trust is enforced by the type - create via `derive()` or deserialize from trusted source. + pub commitments: Vec, + /// Number of columns used when committing bytecode chunks. + /// + /// This is chosen to match the Main-context sigma used for committed-mode Stage 8 batching. + /// The prover/verifier must use the same `num_columns` in the Main context when building the + /// joint Dory opening proof, or the batched hint/commitment combination will be inconsistent. + pub num_columns: usize, + /// log2(k_chunk) used for lane chunking. + pub log_k_chunk: u8, + /// Bytecode length (power-of-two padded). + pub bytecode_len: usize, +} + +impl TrustedBytecodeCommitments { + /// Derive commitments from full bytecode (the canonical constructor). + /// + /// This is the "offline preprocessing" step that must be done honestly. + /// Returns trusted commitments + hints for opening proofs. + #[tracing::instrument(skip_all, name = "TrustedBytecodeCommitments::derive")] + pub fn derive( + bytecode: &BytecodePreprocessing, + generators: &PCS::ProverSetup, + log_k_chunk: usize, + max_trace_len: usize, + ) -> (Self, Vec) { + let k_chunk = 1usize << log_k_chunk; + let bytecode_len = bytecode.bytecode.len(); + let num_chunks = total_lanes().div_ceil(k_chunk); + + let log_t = max_trace_len.log_2(); + let _guard = DoryGlobals::initialize_bytecode_context_for_main_sigma( + k_chunk, + bytecode_len, + log_k_chunk, + log_t, + ); + let _ctx = DoryGlobals::with_context(DoryContext::Bytecode); + let num_columns = DoryGlobals::get_num_columns(); + + let bytecode_chunks = build_bytecode_chunks::(bytecode, log_k_chunk); + debug_assert_eq!(bytecode_chunks.len(), num_chunks); + + let (commitments, hints): (Vec<_>, Vec<_>) = bytecode_chunks + .par_iter() + .map(|poly| PCS::commit(poly, generators)) + .unzip(); + + ( + Self { + commitments, + num_columns, + log_k_chunk: log_k_chunk as u8, + bytecode_len, + }, + hints, + ) + } +} + +/// Bytecode information available to the verifier. +/// +/// In `Full` mode, the verifier has access to the complete bytecode preprocessing +/// and can materialize bytecode-dependent polynomials (O(K) work). +/// +/// In `Committed` mode, the verifier only sees commitments to the bytecode polynomials, +/// enabling succinct verification via claim reductions. +/// +/// **Note**: The bytecode size K is stored in `JoltSharedPreprocessing.bytecode_size`, +/// NOT in this enum. Use `shared.bytecode_size` to get the size. +#[derive(Debug, Clone)] +pub enum VerifierBytecode { + /// Full bytecode available (Full mode) — verifier can materialize polynomials. + Full(Arc), + /// Only trusted commitments available (Committed mode) — verifier uses claim reductions. + /// Size K is in `JoltSharedPreprocessing.bytecode_size`. + Committed(TrustedBytecodeCommitments), +} + +impl VerifierBytecode { + /// Returns the full bytecode preprocessing, or an error if in Committed mode. + pub fn as_full(&self) -> Result<&Arc, ProofVerifyError> { + match self { + VerifierBytecode::Full(bp) => Ok(bp), + VerifierBytecode::Committed(_) => Err(ProofVerifyError::BytecodeTypeMismatch( + "expected Full, got Committed".to_string(), + )), + } + } + + /// Returns true if this is Full mode. + pub fn is_full(&self) -> bool { + matches!(self, VerifierBytecode::Full(_)) + } + + /// Returns true if this is Committed mode. + pub fn is_committed(&self) -> bool { + matches!(self, VerifierBytecode::Committed(_)) + } + + /// Returns the trusted commitments, or an error if in Full mode. + pub fn as_committed(&self) -> Result<&TrustedBytecodeCommitments, ProofVerifyError> { + match self { + VerifierBytecode::Committed(trusted) => Ok(trusted), + VerifierBytecode::Full(_) => Err(ProofVerifyError::BytecodeTypeMismatch( + "expected Committed, got Full".to_string(), + )), + } + } +} + +// Manual serialization for VerifierBytecode +// Format: tag (u8) followed by variant data +impl CanonicalSerialize for VerifierBytecode { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + match self { + VerifierBytecode::Full(bp) => { + 0u8.serialize_with_mode(&mut writer, compress)?; + bp.as_ref().serialize_with_mode(&mut writer, compress)?; + } + VerifierBytecode::Committed(trusted) => { + 1u8.serialize_with_mode(&mut writer, compress)?; + trusted.serialize_with_mode(&mut writer, compress)?; + } + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + 1 + match self { + VerifierBytecode::Full(bp) => bp.serialized_size(compress), + VerifierBytecode::Committed(trusted) => trusted.serialized_size(compress), + } + } +} + +impl Valid for VerifierBytecode { + fn check(&self) -> Result<(), SerializationError> { + match self { + VerifierBytecode::Full(bp) => bp.check(), + VerifierBytecode::Committed(trusted) => trusted.check(), + } + } +} + +impl CanonicalDeserialize for VerifierBytecode { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let tag = u8::deserialize_with_mode(&mut reader, compress, validate)?; + match tag { + 0 => { + let bp = + BytecodePreprocessing::deserialize_with_mode(&mut reader, compress, validate)?; + Ok(VerifierBytecode::Full(Arc::new(bp))) + } + 1 => { + let trusted = TrustedBytecodeCommitments::::deserialize_with_mode( + &mut reader, + compress, + validate, + )?; + Ok(VerifierBytecode::Committed(trusted)) + } + _ => Err(SerializationError::InvalidData), + } + } +} + +/// Bytecode preprocessing data (O(K)). +/// +/// **Note**: The bytecode size K is stored in `JoltSharedPreprocessing.bytecode_size`, +/// NOT in this struct. Use `shared.bytecode_size` to get the size. #[derive(Default, Debug, Clone, CanonicalSerialize, CanonicalDeserialize)] pub struct BytecodePreprocessing { - pub code_size: usize, pub bytecode: Vec, /// Maps the memory address of each instruction in the bytecode to its "virtual" address. /// See Section 6.1 of the Jolt paper, "Reflecting the program counter". The virtual address @@ -21,20 +229,17 @@ impl BytecodePreprocessing { bytecode.insert(0, Instruction::NoOp); let pc_map = BytecodePCMapper::new(&bytecode); - let code_size = bytecode.len().next_power_of_two().max(2); + let bytecode_size = bytecode.len().next_power_of_two().max(2); // Bytecode: Pad to nearest power of 2 - bytecode.resize(code_size, Instruction::NoOp); + bytecode.resize(bytecode_size, Instruction::NoOp); - Self { - code_size, - bytecode, - pc_map, - } + Self { bytecode, pc_map } } + #[inline(always)] pub fn get_pc(&self, cycle: &Cycle) -> usize { - if matches!(cycle, tracer::instruction::Cycle::NoOp) { + if matches!(cycle, Cycle::NoOp) { return 0; } let instr = cycle.instruction().normalize(); @@ -56,13 +261,17 @@ impl BytecodePCMapper { let mut indices: Vec> = { // For read-raf tests we simulate bytecode being empty #[cfg(test)] - if bytecode.len() == 1 { - vec![None; 1] - } else { - vec![None; Self::get_index(bytecode.last().unwrap().normalize().address) + 1] + { + if bytecode.len() == 1 { + vec![None; 1] + } else { + vec![None; Self::get_index(bytecode.last().unwrap().normalize().address) + 1] + } } #[cfg(not(test))] - vec![None; Self::get_index(bytecode.last().unwrap().normalize().address) + 1] + { + vec![None; Self::get_index(bytecode.last().unwrap().normalize().address) + 1] + } }; let mut last_pc = 0; // Push the initial noop instruction @@ -89,6 +298,7 @@ impl BytecodePCMapper { Self { indices } } + #[inline(always)] pub fn get_pc(&self, address: usize, virtual_sequence_remaining: u16) -> usize { let (base_pc, max_inline_seq) = self .indices @@ -98,6 +308,7 @@ impl BytecodePCMapper { base_pc + (max_inline_seq - virtual_sequence_remaining) as usize } + #[inline(always)] pub const fn get_index(address: usize) -> usize { assert!(address >= RAM_START_ADDRESS as usize); assert!(address.is_multiple_of(ALIGNMENT_FACTOR_BYTECODE)); diff --git a/jolt-core/src/zkvm/bytecode/read_raf_checking.rs b/jolt-core/src/zkvm/bytecode/read_raf_checking.rs index 223a6feaef..3192cfec4e 100644 --- a/jolt-core/src/zkvm/bytecode/read_raf_checking.rs +++ b/jolt-core/src/zkvm/bytecode/read_raf_checking.rs @@ -24,15 +24,18 @@ use crate::{ sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}, }, transcripts::Transcript, - utils::{math::Math, small_scalar::SmallScalar, thread::unsafe_allocate_zero_vec}, + utils::{ + errors::ProofVerifyError, math::Math, small_scalar::SmallScalar, + thread::unsafe_allocate_zero_vec, + }, zkvm::{ - bytecode::BytecodePreprocessing, - config::OneHotParams, + config::{OneHotParams, ProgramMode}, instruction::{ CircuitFlags, Flags, InstructionFlags, InstructionLookup, InterleavedBitsMarker, NUM_CIRCUIT_FLAGS, }, lookup_table::{LookupTables, NUM_LOOKUP_TABLES}, + program::ProgramPreprocessing, witness::{CommittedPolynomial, VirtualPolynomial}, }, }; @@ -131,7 +134,7 @@ pub struct BytecodeReadRafSumcheckProver { trace: Arc>, /// Bytecode preprocessing for computing PCs. #[allocative(skip)] - bytecode_preprocessing: Arc, + program: Arc, pub params: BytecodeReadRafSumcheckParams, } @@ -140,7 +143,7 @@ impl BytecodeReadRafSumcheckProver { pub fn initialize( params: BytecodeReadRafSumcheckParams, trace: Arc>, - bytecode_preprocessing: Arc, + program: Arc, ) -> Self { let claim_per_stage = [ params.rv_claims[0] + params.gamma_powers[5] * params.raf_claim, @@ -224,7 +227,7 @@ impl BytecodeReadRafSumcheckProver { break; } - let pc = bytecode_preprocessing.get_pc(&trace[c]); + let pc = program.get_pc(&trace[c]); // Track touched PCs (avoid duplicates with a simple check) if inner[0][pc].is_zero() { @@ -299,7 +302,7 @@ impl BytecodeReadRafSumcheckProver { prev_round_polys: None, bound_val_evals: None, trace, - bytecode_preprocessing, + program, params, } } @@ -360,7 +363,7 @@ impl BytecodeReadRafSumcheckProver { .trace .par_iter() .map(|cycle| { - let pc = self.bytecode_preprocessing.get_pc(cycle); + let pc = self.program.get_pc(cycle); Some(self.params.one_hot_params.bytecode_pc_chunk(pc, i)) }) .collect(); @@ -371,17 +374,8 @@ impl BytecodeReadRafSumcheckProver { // Drop trace and preprocessing - no longer needed after this self.trace = Arc::new(Vec::new()); } -} -impl SumcheckInstanceProver - for BytecodeReadRafSumcheckProver -{ - fn get_params(&self) -> &dyn SumcheckInstanceParams { - &self.params - } - - #[tracing::instrument(skip_all, name = "BytecodeReadRafSumcheckProver::compute_message")] - fn compute_message(&mut self, round: usize, _previous_claim: F) -> UniPoly { + fn compute_message_internal(&mut self, round: usize, _previous_claim: F) -> UniPoly { if round < self.params.log_K { const DEGREE: usize = 2; @@ -394,7 +388,8 @@ impl SumcheckInstanceProver }); let int_evals = - self.params.int_poly + self.params + .int_poly .sumcheck_evals(i, DEGREE, BindingOrder::LowToHigh); // We have a separate Val polynomial for each stage @@ -408,13 +403,20 @@ impl SumcheckInstanceProver // Which matches with the input claim: // rv_1 + gamma * rv_2 + gamma^2 * rv_3 + gamma^3 * rv_4 + gamma^4 * rv_5 + gamma^5 * raf_1 + gamma^6 * raf_3 let mut val_evals = self - .params.val_polys + .params + .val_polys .iter() // Val polynomials .map(|val| val.sumcheck_evals_array::(i, BindingOrder::LowToHigh)) // Here are the RAF polynomials and their powers .zip([Some(&int_evals), None, Some(&int_evals), None, None]) - .zip([Some(self.params.gamma_powers[5]), None, Some(self.params.gamma_powers[4]), None, None]) + .zip([ + Some(self.params.gamma_powers[5]), + None, + Some(self.params.gamma_powers[4]), + None, + None, + ]) .map(|((val_evals, int_evals), gamma)| { std::array::from_fn::(|j| { val_evals[j] @@ -450,7 +452,7 @@ impl SumcheckInstanceProver agg_round_poly } else { - let degree = >::degree(self); + let degree = self.params.degree(); let out_len = self.gruen_eq_polys[0].E_out_current().len(); let in_len = self.gruen_eq_polys[0].E_in_current().len(); @@ -520,8 +522,7 @@ impl SumcheckInstanceProver } } - #[tracing::instrument(skip_all, name = "BytecodeReadRafSumcheckProver::ingest_challenge")] - fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize) { + fn ingest_challenge_internal(&mut self, r_j: F::Challenge, round: usize) { if let Some(prev_round_polys) = self.prev_round_polys.take() { self.prev_round_claims = prev_round_polys.map(|poly| poly.evaluate(&r_j)); } @@ -550,6 +551,24 @@ impl SumcheckInstanceProver .for_each(|poly| poly.bind(r_j)); } } +} + +impl SumcheckInstanceProver + for BytecodeReadRafSumcheckProver +{ + fn get_params(&self) -> &dyn SumcheckInstanceParams { + &self.params + } + + #[tracing::instrument(skip_all, name = "BytecodeReadRafSumcheckProver::compute_message")] + fn compute_message(&mut self, round: usize, _previous_claim: F) -> UniPoly { + self.compute_message_internal(round, _previous_claim) + } + + #[tracing::instrument(skip_all, name = "BytecodeReadRafSumcheckProver::ingest_challenge")] + fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize) { + self.ingest_challenge_internal(r_j, round) + } fn cache_openings( &self, @@ -584,13 +603,568 @@ impl SumcheckInstanceProver } } +/// Bytecode Read+RAF Address-Phase Sumcheck Prover. +/// +/// This prover handles only the first `log_K` rounds (address variables). +/// The cycle-phase prover is constructed separately from witness + accumulator (Option B). +#[derive(Allocative)] +pub struct BytecodeReadRafAddressSumcheckProver { + /// Per-stage address MLEs F_i(k) built from eq(r_cycle_stage_i, (chunk_index, j)). + F: [MultilinearPolynomial; N_STAGES], + /// Binding challenges for the first log_K variables. + r_address_prime: Vec, + /// Previous-round claims s_i(0)+s_i(1) per stage. + prev_round_claims: [F; N_STAGES], + /// Round polynomials per stage for advancing to the next claim. + prev_round_polys: Option<[UniPoly; N_STAGES]>, + /// Parameters (shared with cycle prover). + pub params: BytecodeReadRafSumcheckParams, +} + +impl BytecodeReadRafAddressSumcheckProver { + /// Initialize a BytecodeReadRafAddressSumcheckProver. + #[tracing::instrument(skip_all, name = "BytecodeReadRafAddressSumcheckProver::initialize")] + pub fn initialize( + params: BytecodeReadRafSumcheckParams, + trace: Arc>, + program: Arc, + ) -> Self { + let claim_per_stage = [ + params.rv_claims[0] + params.gamma_powers[5] * params.raf_claim, + params.rv_claims[1], + params.rv_claims[2] + params.gamma_powers[4] * params.raf_shift_claim, + params.rv_claims[3], + params.rv_claims[4], + ]; + + // Two-table split-eq optimization for computing F[stage][k] = Σ_{c: PC(c)=k} eq(r_cycle, c). + let T = trace.len(); + let K = params.K; + let log_T = params.log_T; + + let lo_bits = log_T / 2; + let hi_bits = log_T - lo_bits; + let in_len: usize = 1 << lo_bits; + let out_len: usize = 1 << hi_bits; + + let (E_hi, E_lo): ([Vec; N_STAGES], [Vec; N_STAGES]) = rayon::join( + || { + params + .r_cycles + .each_ref() + .map(|r_cycle| EqPolynomial::evals(&r_cycle[..hi_bits])) + }, + || { + params + .r_cycles + .each_ref() + .map(|r_cycle| EqPolynomial::evals(&r_cycle[hi_bits..])) + }, + ); + + let num_threads = rayon::current_num_threads(); + let chunk_size = out_len.div_ceil(num_threads); + + let F_polys: [Vec; N_STAGES] = E_hi[0] + .par_chunks(chunk_size) + .enumerate() + .map(|(chunk_idx, chunk)| { + let mut partial: [Vec; N_STAGES] = + array::from_fn(|_| unsafe_allocate_zero_vec(K)); + let mut inner: [Vec; N_STAGES] = array::from_fn(|_| unsafe_allocate_zero_vec(K)); + let mut touched = Vec::with_capacity(in_len); + + let chunk_start = chunk_idx * chunk_size; + for (local_idx, _) in chunk.iter().enumerate() { + let c_hi = chunk_start + local_idx; + let c_hi_base = c_hi * in_len; + + for &k in &touched { + for stage in 0..N_STAGES { + inner[stage][k] = F::zero(); + } + } + touched.clear(); + + for c_lo in 0..in_len { + let c = c_hi_base + c_lo; + if c >= T { + break; + } + + let pc = program.get_pc(&trace[c]); + if inner[0][pc].is_zero() { + touched.push(pc); + } + for stage in 0..N_STAGES { + inner[stage][pc] += E_lo[stage][c_lo]; + } + } + + for &k in &touched { + for stage in 0..N_STAGES { + partial[stage][k] += E_hi[stage][c_hi] * inner[stage][k]; + } + } + } + partial + }) + .reduce( + || array::from_fn(|_| unsafe_allocate_zero_vec(K)), + |mut a, b| { + for stage in 0..N_STAGES { + a[stage] + .par_iter_mut() + .zip(b[stage].par_iter()) + .for_each(|(a, b)| *a += *b); + } + a + }, + ); + + let F = F_polys.map(MultilinearPolynomial::from); + + Self { + F, + r_address_prime: Vec::with_capacity(params.log_K), + prev_round_claims: claim_per_stage, + prev_round_polys: None, + params, + } + } + + fn compute_message_impl(&mut self, _previous_claim: F) -> UniPoly { + const DEGREE: usize = 2; + + let eval_per_stage: [[F; DEGREE]; N_STAGES] = (0..self.params.val_polys[0].len() / 2) + .into_par_iter() + .map(|i| { + let ra_evals = self + .F + .each_ref() + .map(|poly| poly.sumcheck_evals_array::(i, BindingOrder::LowToHigh)); + + let int_evals = + self.params + .int_poly + .sumcheck_evals(i, DEGREE, BindingOrder::LowToHigh); + + let mut val_evals = self + .params + .val_polys + .iter() + .map(|val| val.sumcheck_evals_array::(i, BindingOrder::LowToHigh)) + .zip([Some(&int_evals), None, Some(&int_evals), None, None]) + .zip([ + Some(self.params.gamma_powers[5]), + None, + Some(self.params.gamma_powers[4]), + None, + None, + ]) + .map(|((val_evals, int_evals), gamma)| { + std::array::from_fn::(|j| { + val_evals[j] + + int_evals + .map_or(F::zero(), |int_evals| int_evals[j] * gamma.unwrap()) + }) + }); + + array::from_fn(|stage| { + let [ra_at_0, ra_at_2] = ra_evals[stage]; + let [val_at_0, val_at_2] = val_evals.next().unwrap(); + [ra_at_0 * val_at_0, ra_at_2 * val_at_2] + }) + }) + .reduce( + || [[F::zero(); DEGREE]; N_STAGES], + |a, b| array::from_fn(|i| array::from_fn(|j| a[i][j] + b[i][j])), + ); + + let mut round_polys: [_; N_STAGES] = array::from_fn(|_| UniPoly::zero()); + let mut agg_round_poly = UniPoly::zero(); + + for (stage, evals) in eval_per_stage.into_iter().enumerate() { + let [eval_at_0, eval_at_2] = evals; + let eval_at_1 = self.prev_round_claims[stage] - eval_at_0; + let round_poly = UniPoly::from_evals(&[eval_at_0, eval_at_1, eval_at_2]); + agg_round_poly += &(&round_poly * self.params.gamma_powers[stage]); + round_polys[stage] = round_poly; + } + + self.prev_round_polys = Some(round_polys); + agg_round_poly + } + + fn ingest_challenge_impl(&mut self, r_j: F::Challenge) { + if let Some(prev_round_polys) = self.prev_round_polys.take() { + self.prev_round_claims = prev_round_polys.map(|poly| poly.evaluate(&r_j)); + } + + self.params + .val_polys + .iter_mut() + .for_each(|poly| poly.bind_parallel(r_j, BindingOrder::LowToHigh)); + self.params + .int_poly + .bind_parallel(r_j, BindingOrder::LowToHigh); + self.F + .iter_mut() + .for_each(|poly| poly.bind_parallel(r_j, BindingOrder::LowToHigh)); + self.r_address_prime.push(r_j); + } +} + +impl SumcheckInstanceProver + for BytecodeReadRafAddressSumcheckProver +{ + fn degree(&self) -> usize { + self.params.degree() + } + + fn num_rounds(&self) -> usize { + self.params.log_K + } + + fn input_claim(&self, _accumulator: &ProverOpeningAccumulator) -> F { + self.params.input_claim(_accumulator) + } + + fn compute_message(&mut self, _round: usize, previous_claim: F) -> UniPoly { + self.compute_message_impl(previous_claim) + } + + fn ingest_challenge(&mut self, r_j: F::Challenge, _round: usize) { + self.ingest_challenge_impl(r_j) + } + + fn cache_openings( + &self, + accumulator: &mut ProverOpeningAccumulator, + transcript: &mut T, + sumcheck_challenges: &[F::Challenge], + ) { + let mut r_address = sumcheck_challenges.to_vec(); + r_address.reverse(); + let opening_point = OpeningPoint::::new(r_address); + let address_claim: F = self + .prev_round_claims + .iter() + .zip(self.params.gamma_powers.iter()) + .take(N_STAGES) + .map(|(claim, gamma)| *claim * *gamma) + .sum(); + accumulator.append_virtual( + transcript, + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + opening_point.clone(), + address_claim, + ); + + // Emit Val-only claims at the Stage 6a boundary only when the staged-Val/claim-reduction + // path is enabled. + if self.params.use_staged_val_claims { + for stage in 0..N_STAGES { + let claim = self.params.val_polys[stage].final_sumcheck_claim(); + accumulator.append_virtual( + transcript, + VirtualPolynomial::BytecodeValStage(stage), + SumcheckId::BytecodeReadRafAddressPhase, + opening_point.clone(), + claim, + ); + } + } + } + + #[cfg(feature = "allocative")] + fn update_flamegraph(&self, flamegraph: &mut FlameGraphBuilder) { + flamegraph.visit_root(self); + } +} + +/// Bytecode Read+RAF Cycle-Phase Sumcheck Prover. +/// +/// This prover handles the remaining `log_T` rounds (cycle variables). +/// It is constructed from scratch via [`BytecodeReadRafCycleSumcheckProver::initialize`]. +#[derive(Allocative)] +pub struct BytecodeReadRafCycleSumcheckProver { + /// Chunked RA polynomials over address variables. + ra: Vec>, + /// Per-stage Gruen-split eq polynomials over cycle vars. + gruen_eq_polys: [GruenSplitEqPolynomial; N_STAGES], + /// Final sumcheck claims of stage Val polynomials (with RAF Int folded). + bound_val_evals: [F; N_STAGES], + /// Parameters. + pub params: BytecodeReadRafSumcheckParams, +} + +impl BytecodeReadRafCycleSumcheckProver { + /// Initialize the cycle-phase prover from Stage 6a openings (no replay). + #[tracing::instrument(skip_all, name = "BytecodeReadRafCycleSumcheckProver::initialize")] + pub fn initialize( + params: BytecodeReadRafSumcheckParams, + trace: Arc>, + program: Arc, + accumulator: &ProverOpeningAccumulator, + ) -> Self { + // Recover Stage 6a address challenges from the accumulator. + // Address-phase cache_openings stored them as BIG_ENDIAN (MSB-first). + let (r_address_point, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ); + + // Compute bound_val_evals at r_address (Val + RAF Int folds). + let int_eval = params.int_poly.evaluate(&r_address_point.r); + let int_terms = [ + int_eval * params.gamma_powers[5], // RAF for Stage1 + F::zero(), // No RAF for Stage2 + int_eval * params.gamma_powers[4], // RAF for Stage3 + F::zero(), // No RAF for Stage4 + F::zero(), // No RAF for Stage5 + ]; + let bound_val_evals: [F; N_STAGES] = if params.use_staged_val_claims { + (0..N_STAGES) + .map(|stage| { + let val_claim = accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeValStage(stage), + SumcheckId::BytecodeReadRafAddressPhase, + ) + .1; + val_claim + int_terms[stage] + }) + .collect::>() + .try_into() + .unwrap() + } else { + // Full mode: evaluate Val polynomials directly at r_address. + params + .val_polys + .iter() + .enumerate() + .map(|(stage, poly)| poly.evaluate(&r_address_point.r) + int_terms[stage]) + .collect::>() + .try_into() + .unwrap() + }; + + // Build RA polynomials from witness using MSB-first address challenges. + let r_address_chunks = params + .one_hot_params + .compute_r_address_chunks::(&r_address_point.r); + let ra: Vec> = r_address_chunks + .iter() + .enumerate() + .map(|(i, r_address_chunk)| { + let ra_i: Vec> = trace + .par_iter() + .map(|cycle| { + let pc = program.get_pc(cycle); + Some(params.one_hot_params.bytecode_pc_chunk(pc, i)) + }) + .collect(); + RaPolynomial::new(Arc::new(ra_i), EqPolynomial::evals(r_address_chunk)) + }) + .collect(); + + let gruen_eq_polys = params + .r_cycles + .each_ref() + .map(|r_cycle| GruenSplitEqPolynomial::new(r_cycle, BindingOrder::LowToHigh)); + + Self { + ra, + gruen_eq_polys, + bound_val_evals, + params, + } + } + + fn compute_message_impl(&mut self, _previous_claim: F) -> UniPoly { + let degree = self.params.degree(); + + let out_len = self.gruen_eq_polys[0].E_out_current().len(); + let in_len = self.gruen_eq_polys[0].E_in_current().len(); + let in_n_vars = in_len.log_2(); + + let (mut q0_per_stage, mut q_evals_per_stage): ([F; N_STAGES], [Vec; N_STAGES]) = (0 + ..out_len) + .into_par_iter() + .map(|j_hi| { + let mut ra_eval_pairs = vec![(F::zero(), F::zero()); self.ra.len()]; + let mut ra_prod_evals = vec![F::zero(); degree - 1]; + let mut q0_unreduced: [_; N_STAGES] = array::from_fn(|_| F::Unreduced::zero()); + let mut q_unreduced: [_; N_STAGES] = + array::from_fn(|_| vec![F::Unreduced::zero(); degree - 1]); + + for j_lo in 0..in_len { + let j = j_lo + (j_hi << in_n_vars); + + for (i, ra_i) in self.ra.iter().enumerate() { + let ra_i_eval_at_j_0 = ra_i.get_bound_coeff(j * 2); + let ra_i_eval_at_j_1 = ra_i.get_bound_coeff(j * 2 + 1); + ra_eval_pairs[i] = (ra_i_eval_at_j_0, ra_i_eval_at_j_1); + } + + // Product polynomial evaluations on U_d = [1, 2, ..., d-1, ∞]. + eval_linear_prod_assign(&ra_eval_pairs, &mut ra_prod_evals); + // Also compute P(0) = ∏_i ra_i(0) (needed to build q(0) directly). + let prod_at_0 = ra_eval_pairs + .iter() + .fold(F::one(), |acc, (p0, _p1)| acc * *p0); + + for stage in 0..N_STAGES { + let eq_in_eval = self.gruen_eq_polys[stage].E_in_current()[j_lo]; + q0_unreduced[stage] += eq_in_eval.mul_unreduced::<9>(prod_at_0); + for i in 0..degree - 1 { + q_unreduced[stage][i] += + eq_in_eval.mul_unreduced::<9>(ra_prod_evals[i]); + } + } + } + + let q0: [F; N_STAGES] = array::from_fn(|stage| { + let eq_out_eval = self.gruen_eq_polys[stage].E_out_current()[j_hi]; + eq_out_eval * F::from_montgomery_reduce(q0_unreduced[stage]) + }); + let q_evals: [Vec; N_STAGES] = array::from_fn(|stage| { + let eq_out_eval = self.gruen_eq_polys[stage].E_out_current()[j_hi]; + q_unreduced[stage] + .iter() + .map(|v| eq_out_eval * F::from_montgomery_reduce(*v)) + .collect() + }); + (q0, q_evals) + }) + .reduce( + || { + ( + array::from_fn(|_| F::zero()), + array::from_fn(|_| vec![F::zero(); degree - 1]), + ) + }, + |mut a, b| { + for stage in 0..N_STAGES { + a.0[stage] += b.0[stage]; + a.1[stage] + .iter_mut() + .zip(b.1[stage].iter()) + .for_each(|(x, y)| *x += *y); + } + a + }, + ); + + // Multiply by bound values (push into q). + for stage in 0..N_STAGES { + q0_per_stage[stage] *= self.bound_val_evals[stage]; + q_evals_per_stage[stage] + .iter_mut() + .for_each(|v| *v *= self.bound_val_evals[stage]); + } + + let mut agg_round_poly = UniPoly::zero(); + for stage in 0..N_STAGES { + let round_poly = self.gruen_eq_polys[stage] + .gruen_poly_from_evals_with_q0(&q_evals_per_stage[stage], q0_per_stage[stage]); + agg_round_poly += &(&round_poly * self.params.gamma_powers[stage]); + } + agg_round_poly + } + + fn ingest_challenge_impl(&mut self, r_j: F::Challenge) { + self.ra + .iter_mut() + .for_each(|ra| ra.bind_parallel(r_j, BindingOrder::LowToHigh)); + self.gruen_eq_polys + .iter_mut() + .for_each(|poly| poly.bind(r_j)); + } +} + +impl SumcheckInstanceProver + for BytecodeReadRafCycleSumcheckProver +{ + fn degree(&self) -> usize { + self.params.degree() + } + + fn num_rounds(&self) -> usize { + self.params.log_T + } + + fn input_claim(&self, accumulator: &ProverOpeningAccumulator) -> F { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ) + .1 + } + + #[tracing::instrument(skip_all, name = "BytecodeReadRafCycleSumcheckProver::compute_message")] + fn compute_message(&mut self, _round: usize, previous_claim: F) -> UniPoly { + self.compute_message_impl(previous_claim) + } + + #[tracing::instrument( + skip_all, + name = "BytecodeReadRafCycleSumcheckProver::ingest_challenge" + )] + fn ingest_challenge(&mut self, r_j: F::Challenge, _round: usize) { + self.ingest_challenge_impl(r_j) + } + + fn cache_openings( + &self, + accumulator: &mut ProverOpeningAccumulator, + transcript: &mut T, + sumcheck_challenges: &[F::Challenge], + ) { + let (r_address_point, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ); + let mut r_address_le = r_address_point.r; + r_address_le.reverse(); + let mut full_challenges = r_address_le; + full_challenges.extend_from_slice(sumcheck_challenges); + let opening_point = self.params.normalize_opening_point(&full_challenges); + let (r_address, r_cycle) = opening_point.split_at(self.params.log_K); + + let r_address_chunks = self + .params + .one_hot_params + .compute_r_address_chunks::(&r_address.r); + + for i in 0..self.params.d { + accumulator.append_sparse( + transcript, + vec![CommittedPolynomial::BytecodeRa(i)], + SumcheckId::BytecodeReadRaf, + r_address_chunks[i].clone(), + r_cycle.clone().into(), + vec![self.ra[i].final_sumcheck_claim()], + ); + } + } + + #[cfg(feature = "allocative")] + fn update_flamegraph(&self, flamegraph: &mut FlameGraphBuilder) { + flamegraph.visit_root(self); + } +} + pub struct BytecodeReadRafSumcheckVerifier { params: BytecodeReadRafSumcheckParams, } impl BytecodeReadRafSumcheckVerifier { pub fn gen( - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, n_cycle_vars: usize, one_hot_params: &OneHotParams, opening_accumulator: &VerifierOpeningAccumulator, @@ -598,7 +1172,7 @@ impl BytecodeReadRafSumcheckVerifier { ) -> Self { Self { params: BytecodeReadRafSumcheckParams::gen( - bytecode_preprocessing, + program, n_cycle_vars, one_hot_params, opening_accumulator, @@ -695,6 +1269,252 @@ impl SumcheckInstanceVerifier } } +pub struct BytecodeReadRafAddressSumcheckVerifier { + params: BytecodeReadRafSumcheckParams, +} + +impl BytecodeReadRafAddressSumcheckVerifier { + pub fn new( + program: Option<&ProgramPreprocessing>, + n_cycle_vars: usize, + one_hot_params: &OneHotParams, + opening_accumulator: &VerifierOpeningAccumulator, + transcript: &mut impl Transcript, + program_mode: ProgramMode, + ) -> Result { + let mut params = match program_mode { + // Commitment mode: verifier MUST avoid O(K_bytecode) work here, and later stages will + // relate staged Val claims to committed bytecode. + ProgramMode::Committed => BytecodeReadRafSumcheckParams::gen_verifier( + n_cycle_vars, + one_hot_params, + opening_accumulator, + transcript, + ), + // Full mode: verifier materializes/evaluates bytecode-dependent polynomials (O(K_bytecode)). + ProgramMode::Full => BytecodeReadRafSumcheckParams::gen( + program.ok_or_else(|| { + ProofVerifyError::BytecodeTypeMismatch( + "expected Full bytecode preprocessing, got Committed".to_string(), + ) + })?, + n_cycle_vars, + one_hot_params, + opening_accumulator, + transcript, + ), + }; + params.use_staged_val_claims = program_mode == ProgramMode::Committed; + Ok(Self { params }) + } + + /// Consume this verifier and return the underlying parameters (for Option B orchestration). + pub fn into_params(self) -> BytecodeReadRafSumcheckParams { + self.params + } + + pub fn into_cycle_verifier(self) -> BytecodeReadRafCycleSumcheckVerifier { + BytecodeReadRafCycleSumcheckVerifier { + params: self.params, + } + } +} + +impl SumcheckInstanceVerifier + for BytecodeReadRafAddressSumcheckVerifier +{ + fn degree(&self) -> usize { + self.params.degree() + } + + fn num_rounds(&self) -> usize { + self.params.log_K + } + + fn input_claim(&self, accumulator: &VerifierOpeningAccumulator) -> F { + self.params.input_claim(accumulator) + } + + fn expected_output_claim( + &self, + accumulator: &VerifierOpeningAccumulator, + _sumcheck_challenges: &[F::Challenge], + ) -> F { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ) + .1 + } + + fn cache_openings( + &self, + accumulator: &mut VerifierOpeningAccumulator, + transcript: &mut T, + sumcheck_challenges: &[F::Challenge], + ) { + let mut r_address = sumcheck_challenges.to_vec(); + r_address.reverse(); + let opening_point = OpeningPoint::::new(r_address); + accumulator.append_virtual( + transcript, + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + opening_point.clone(), + ); + + // Populate opening points for the Val-only bytecode stage claims emitted in Stage 6a, + // but only when the staged-Val/claim-reduction path is enabled. + if self.params.use_staged_val_claims { + for stage in 0..N_STAGES { + accumulator.append_virtual( + transcript, + VirtualPolynomial::BytecodeValStage(stage), + SumcheckId::BytecodeReadRafAddressPhase, + opening_point.clone(), + ); + } + } + } +} + +pub struct BytecodeReadRafCycleSumcheckVerifier { + params: BytecodeReadRafSumcheckParams, +} + +impl BytecodeReadRafCycleSumcheckVerifier { + pub fn new(params: BytecodeReadRafSumcheckParams) -> Self { + Self { params } + } +} + +impl SumcheckInstanceVerifier + for BytecodeReadRafCycleSumcheckVerifier +{ + fn degree(&self) -> usize { + self.params.degree() + } + + fn num_rounds(&self) -> usize { + self.params.log_T + } + + fn input_claim(&self, accumulator: &VerifierOpeningAccumulator) -> F { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ) + .1 + } + + fn expected_output_claim( + &self, + accumulator: &VerifierOpeningAccumulator, + sumcheck_challenges: &[F::Challenge], + ) -> F { + let (r_address_point, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ); + let mut r_address_le = r_address_point.r; + r_address_le.reverse(); + let mut full_challenges = r_address_le; + full_challenges.extend_from_slice(sumcheck_challenges); + let opening_point = self.params.normalize_opening_point(&full_challenges); + let (r_address_prime, r_cycle_prime) = opening_point.split_at(self.params.log_K); + + let int_poly = self.params.int_poly.evaluate(&r_address_prime.r); + + let ra_claims = (0..self.params.d).map(|i| { + accumulator + .get_committed_polynomial_opening( + CommittedPolynomial::BytecodeRa(i), + SumcheckId::BytecodeReadRaf, + ) + .1 + }); + + let int_terms = [ + int_poly * self.params.gamma_powers[5], // RAF for Stage1 + F::zero(), // There's no raf for Stage2 + int_poly * self.params.gamma_powers[4], // RAF for Stage3 + F::zero(), // There's no raf for Stage4 + F::zero(), // There's no raf for Stage5 + ]; + let val = if self.params.use_staged_val_claims { + // Fast verifier path: consume Val_s(r_bc) claims emitted at the Stage 6a boundary, + // rather than re-evaluating `val_polys` (O(K_bytecode)). + (0..N_STAGES) + .zip(self.params.r_cycles.iter()) + .zip(self.params.gamma_powers.iter()) + .zip(int_terms) + .map(|(((stage, r_cycle), gamma), int_term)| { + let val_claim = accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeValStage(stage), + SumcheckId::BytecodeReadRafAddressPhase, + ) + .1; + (val_claim + int_term) + * EqPolynomial::::mle(r_cycle, &r_cycle_prime.r) + * *gamma + }) + .sum::() + } else { + // Legacy verifier path: directly evaluate Val polynomials at r_bc (O(K_bytecode)). + self.params + .val_polys + .iter() + .zip(&self.params.r_cycles) + .zip(&self.params.gamma_powers) + .zip(int_terms) + .map(|(((val, r_cycle), gamma), int_term)| { + (val.evaluate(&r_address_prime.r) + int_term) + * EqPolynomial::::mle(r_cycle, &r_cycle_prime.r) + * *gamma + }) + .sum::() + }; + + ra_claims.fold(val, |running, ra_claim| running * ra_claim) + } + + fn cache_openings( + &self, + accumulator: &mut VerifierOpeningAccumulator, + transcript: &mut T, + sumcheck_challenges: &[F::Challenge], + ) { + let (r_address_point, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ); + let mut r_address_le = r_address_point.r; + r_address_le.reverse(); + let mut full_challenges = r_address_le; + full_challenges.extend_from_slice(sumcheck_challenges); + let opening_point = self.params.normalize_opening_point(&full_challenges); + let (r_address, r_cycle) = opening_point.split_at(self.params.log_K); + + let r_address_chunks = self + .params + .one_hot_params + .compute_r_address_chunks::(&r_address.r); + + (0..self.params.d).for_each(|i| { + let opening_point = [&r_address_chunks[i][..], &r_cycle.r].concat(); + accumulator.append_sparse( + transcript, + vec![CommittedPolynomial::BytecodeRa(i)], + SumcheckId::BytecodeReadRaf, + opening_point, + ); + }); + } +} + #[derive(Allocative, Clone)] pub struct BytecodeReadRafSumcheckParams { /// Index `i` stores `gamma^i`. @@ -708,6 +1528,9 @@ pub struct BytecodeReadRafSumcheckParams { /// log2(K) and log2(T) used to determine round counts. pub log_K: usize, pub log_T: usize, + /// If true, Stage 6a emits `Val_s(r_bc)` as virtual openings and Stage 6b consumes them + /// (instead of verifier re-materializing/evaluating `val_polys`). + pub use_staged_val_claims: bool, /// Number of address chunks (and RA polynomials in the product). pub d: usize, /// Stage Val polynomials evaluated over address vars. @@ -719,20 +1542,62 @@ pub struct BytecodeReadRafSumcheckParams { /// Identity polynomial over address vars used to inject RAF contributions. pub int_poly: IdentityPolynomial, pub r_cycles: [Vec; N_STAGES], + /// Stage-specific batching gammas used to define Val(k) polynomials. + /// Stored so later claim reductions can reconstruct lane weights without resampling the transcript. + pub stage1_gammas: Vec, + pub stage2_gammas: Vec, + pub stage3_gammas: Vec, + pub stage4_gammas: Vec, + pub stage5_gammas: Vec, } impl BytecodeReadRafSumcheckParams { #[tracing::instrument(skip_all, name = "BytecodeReadRafSumcheckParams::gen")] pub fn gen( - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, n_cycle_vars: usize, one_hot_params: &OneHotParams, opening_accumulator: &dyn OpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { - let gamma_powers = transcript.challenge_scalar_powers(7); + Self::gen_impl( + Some(program), + n_cycle_vars, + one_hot_params, + opening_accumulator, + transcript, + true, + ) + } - let bytecode = &bytecode_preprocessing.bytecode; + /// Verifier-side generator: avoids materializing Val(k) polynomials (O(K_bytecode)). + #[tracing::instrument(skip_all, name = "BytecodeReadRafSumcheckParams::gen_verifier")] + pub fn gen_verifier( + n_cycle_vars: usize, + one_hot_params: &OneHotParams, + opening_accumulator: &dyn OpeningAccumulator, + transcript: &mut impl Transcript, + ) -> Self { + Self::gen_impl( + None, + n_cycle_vars, + one_hot_params, + opening_accumulator, + transcript, + false, + ) + } + + #[allow(clippy::too_many_arguments)] + fn gen_impl( + program: Option<&ProgramPreprocessing>, + n_cycle_vars: usize, + one_hot_params: &OneHotParams, + opening_accumulator: &dyn OpeningAccumulator, + transcript: &mut impl Transcript, + compute_val_polys: bool, + ) -> Self { + let gamma_powers = transcript.challenge_scalar_powers(7); // Generate all stage-specific gamma powers upfront (order must match verifier) let stage1_gammas: Vec = transcript.challenge_scalar_powers(2 + NUM_CIRCUIT_FLAGS); @@ -749,38 +1614,46 @@ impl BytecodeReadRafSumcheckParams { let rv_claim_5 = Self::compute_rv_claim_5(opening_accumulator, &stage5_gammas); let rv_claims = [rv_claim_1, rv_claim_2, rv_claim_3, rv_claim_4, rv_claim_5]; - // Pre-compute eq_r_register for stages 4 and 5 (they use different r_register points) - let r_register_4 = opening_accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::RdWa, - SumcheckId::RegistersReadWriteChecking, - ) - .0 - .r; - let eq_r_register_4 = - EqPolynomial::::evals(&r_register_4[..(REGISTER_COUNT as usize).log_2()]); - - let r_register_5 = opening_accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::RdWa, - SumcheckId::RegistersValEvaluation, + let val_polys = if compute_val_polys { + let instructions = &program + .expect("compute_val_polys requires program preprocessing") + .instructions; + // Pre-compute eq_r_register for stages 4 and 5 (they use different r_register points) + let r_register_4 = opening_accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::RdWa, + SumcheckId::RegistersReadWriteChecking, + ) + .0 + .r; + let eq_r_register_4 = + EqPolynomial::::evals(&r_register_4[..(REGISTER_COUNT as usize).log_2()]); + + let r_register_5 = opening_accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::RdWa, + SumcheckId::RegistersValEvaluation, + ) + .0 + .r; + let eq_r_register_5 = + EqPolynomial::::evals(&r_register_5[..(REGISTER_COUNT as usize).log_2()]); + + // Fused pass: compute all val polynomials in a single parallel iteration + Self::compute_val_polys( + instructions, + &eq_r_register_4, + &eq_r_register_5, + &stage1_gammas, + &stage2_gammas, + &stage3_gammas, + &stage4_gammas, + &stage5_gammas, ) - .0 - .r; - let eq_r_register_5 = - EqPolynomial::::evals(&r_register_5[..(REGISTER_COUNT as usize).log_2()]); - - // Fused pass: compute all val polynomials in a single parallel iteration - let val_polys = Self::compute_val_polys( - bytecode, - &eq_r_register_4, - &eq_r_register_5, - &stage1_gammas, - &stage2_gammas, - &stage3_gammas, - &stage4_gammas, - &stage5_gammas, - ); + } else { + // Verifier doesn't need these (and must not iterate over bytecode). + array::from_fn(|_| MultilinearPolynomial::default()) + }; let int_poly = IdentityPolynomial::new(one_hot_params.bytecode_k.log_2()); @@ -840,12 +1713,18 @@ impl BytecodeReadRafSumcheckParams { log_K: one_hot_params.bytecode_k.log_2(), d: one_hot_params.bytecode_d, log_T: n_cycle_vars, + use_staged_val_claims: false, val_polys, rv_claims, raf_claim, raf_shift_claim, int_poly, r_cycles, + stage1_gammas, + stage2_gammas, + stage3_gammas, + stage4_gammas, + stage5_gammas, } } diff --git a/jolt-core/src/zkvm/claim_reductions/advice.rs b/jolt-core/src/zkvm/claim_reductions/advice.rs index 6ec3ddd049..8bbc7de8fa 100644 --- a/jolt-core/src/zkvm/claim_reductions/advice.rs +++ b/jolt-core/src/zkvm/claim_reductions/advice.rs @@ -138,7 +138,8 @@ impl AdviceClaimReductionParams { let log_t = trace_len.log_2(); let log_k_chunk = OneHotConfig::new(log_t).log_k_chunk as usize; - let (main_col_vars, main_row_vars) = DoryGlobals::main_sigma_nu(log_k_chunk, log_t); + let (main_col_vars, main_row_vars) = DoryGlobals::try_get_main_sigma_nu() + .unwrap_or_else(|| DoryGlobals::main_sigma_nu(log_k_chunk, log_t)); let r_val_eval = accumulator .get_advice_opening(kind, SumcheckId::RamValEvaluation) @@ -510,11 +511,8 @@ impl SumcheckInstanceProver for AdviceClaimRe fn round_offset(&self, max_num_rounds: usize) -> usize { match self.params.phase { ReductionPhase::CycleVariables => { - // Align to the *start* of Booleanity's cycle segment, so local rounds correspond - // to low Dory column bits in the unified point ordering. - let booleanity_rounds = self.params.log_k_chunk + self.params.log_t; - let booleanity_offset = max_num_rounds - booleanity_rounds; - booleanity_offset + self.params.log_k_chunk + // Stage 6b only spans cycle variables; align to the start of the cycle segment. + max_num_rounds.saturating_sub(self.params.log_t) } ReductionPhase::AddressVariables => 0, } @@ -656,11 +654,7 @@ impl SumcheckInstanceVerifier fn round_offset(&self, max_num_rounds: usize) -> usize { let params = self.params.borrow(); match params.phase { - ReductionPhase::CycleVariables => { - let booleanity_rounds = params.log_k_chunk + params.log_t; - let booleanity_offset = max_num_rounds - booleanity_rounds; - booleanity_offset + params.log_k_chunk - } + ReductionPhase::CycleVariables => max_num_rounds.saturating_sub(params.log_t), ReductionPhase::AddressVariables => 0, } } diff --git a/jolt-core/src/zkvm/claim_reductions/bytecode.rs b/jolt-core/src/zkvm/claim_reductions/bytecode.rs new file mode 100644 index 0000000000..792cc3354a --- /dev/null +++ b/jolt-core/src/zkvm/claim_reductions/bytecode.rs @@ -0,0 +1,742 @@ +//! Two-phase Bytecode claim reduction (Stage 6b cycle → Stage 7 lane/address). +//! +//! This reduction batches the 5 bytecode Val-stage claims emitted at the Stage 6a boundary: +//! `Val_s(r_bc)` for `s = 0..5` (val-only; RAF terms excluded). +//! +//! High level: +//! - Sample `η` and form `C_in = Σ_s η^s · Val_s(r_bc)`. +//! - Define a canonical set of bytecode "lanes" (448 total) and a lane weight function +//! `W_η(lane) = Σ_s η^s · w_s(lane)` derived from the same stage-specific gammas used to +//! define `Val_s`. +//! - Prove, via a two-phase sumcheck, that `C_in` equals a single linear functional of the +//! (eventual) committed bytecode chunk polynomials. +//! +//! NOTE: This module wires the reduction logic and emits openings for bytecode chunk polynomials. +//! Commitment + Stage 8 batching integration is handled separately (see `bytecode-commitment-progress.md`). + +use std::cell::RefCell; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use allocative::Allocative; +use itertools::Itertools; +use rayon::prelude::*; + +use crate::field::JoltField; +use crate::poly::eq_poly::EqPolynomial; +use crate::poly::multilinear_polynomial::{BindingOrder, MultilinearPolynomial, PolynomialBinding}; +use crate::poly::opening_proof::{ + OpeningAccumulator, OpeningPoint, ProverOpeningAccumulator, SumcheckId, + VerifierOpeningAccumulator, BIG_ENDIAN, LITTLE_ENDIAN, +}; +use crate::poly::unipoly::UniPoly; +use crate::subprotocols::sumcheck_prover::SumcheckInstanceProver; +use crate::subprotocols::sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}; +use crate::transcripts::Transcript; +use crate::utils::math::Math; +use crate::zkvm::bytecode::chunks::{ + for_each_active_lane_value, total_lanes, weighted_lane_sum_for_instruction, ActiveLaneValue, +}; +use crate::zkvm::bytecode::read_raf_checking::BytecodeReadRafSumcheckParams; +use crate::zkvm::instruction::{ + CircuitFlags, InstructionFlags, NUM_CIRCUIT_FLAGS, NUM_INSTRUCTION_FLAGS, +}; +use crate::zkvm::lookup_table::LookupTables; +use crate::zkvm::program::ProgramPreprocessing; +use crate::zkvm::witness::{CommittedPolynomial, VirtualPolynomial}; +use common::constants::{REGISTER_COUNT, XLEN}; +use strum::EnumCount; + +const DEGREE_BOUND: usize = 2; +const NUM_VAL_STAGES: usize = 5; + +/// For `DoryLayout::AddressMajor`, committed bytecode chunks are stored in "cycle-major" index order +/// (cycle*K + address), which makes `BindingOrder::LowToHigh` bind **lane** bits first. +/// +/// The claim reduction sumcheck needs to bind **cycle** bits first in Stage 6b, so we permute +/// dense coefficient vectors into the `DoryLayout::CycleMajor` order (address*T + cycle) when +/// running the reduction. This is a pure index permutation, i.e. a variable renaming, and the +/// resulting evaluations match the committed polynomial when the opening point is interpreted in +/// the unified `[lane || cycle]` order. +// NOTE: With the fused-lane cycle-phase refactor, we no longer materialize the full per-lane +// bytecode chunk polynomials inside this reduction prover. This means we also no longer need +// to permute AddressMajor <-> CycleMajor coefficient vectors here. + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Allocative)] +pub enum BytecodeReductionPhase { + CycleVariables, + LaneVariables, +} + +#[derive(Clone, Allocative)] +pub struct BytecodeClaimReductionParams { + pub phase: BytecodeReductionPhase, + pub eta: F, + pub eta_powers: [F; NUM_VAL_STAGES], + pub log_k: usize, + pub log_k_chunk: usize, + pub num_chunks: usize, + /// Bytecode address point (log_K bits, big-endian). + pub r_bc: OpeningPoint, + /// Per-chunk lane weight tables (length = k_chunk) for `W_eta`. + pub chunk_lane_weights: Vec>, + /// (little-endian) challenges used in the cycle phase. + pub cycle_var_challenges: Vec, +} + +impl BytecodeClaimReductionParams { + pub fn new( + bytecode_read_raf_params: &BytecodeReadRafSumcheckParams, + accumulator: &dyn OpeningAccumulator, + transcript: &mut impl Transcript, + ) -> Self { + let log_k = bytecode_read_raf_params.log_K; + + let eta: F = transcript.challenge_scalar(); + let mut eta_powers = [F::one(); NUM_VAL_STAGES]; + for i in 1..NUM_VAL_STAGES { + eta_powers[i] = eta_powers[i - 1] * eta; + } + + // r_bc comes from the Stage 6a BytecodeReadRaf address phase. + let (r_bc, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeReadRafAddrClaim, + SumcheckId::BytecodeReadRafAddressPhase, + ); + + let log_k_chunk = bytecode_read_raf_params.one_hot_params.log_k_chunk; + let k_chunk = 1 << log_k_chunk; + let num_chunks = total_lanes().div_ceil(k_chunk); + + let chunk_lane_weights = compute_chunk_lane_weights( + bytecode_read_raf_params, + accumulator, + &eta_powers, + num_chunks, + k_chunk, + ); + + Self { + phase: BytecodeReductionPhase::CycleVariables, + eta, + eta_powers, + log_k, + log_k_chunk, + num_chunks, + r_bc, + chunk_lane_weights, + cycle_var_challenges: vec![], + } + } +} + +impl SumcheckInstanceParams for BytecodeClaimReductionParams { + fn input_claim(&self, accumulator: &dyn OpeningAccumulator) -> F { + match self.phase { + BytecodeReductionPhase::CycleVariables => (0..NUM_VAL_STAGES) + .map(|stage| { + let (_, val_claim) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeValStage(stage), + SumcheckId::BytecodeReadRafAddressPhase, + ); + self.eta_powers[stage] * val_claim + }) + .sum(), + BytecodeReductionPhase::LaneVariables => { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeClaimReductionIntermediate, + SumcheckId::BytecodeClaimReductionCyclePhase, + ) + .1 + } + } + } + + fn degree(&self) -> usize { + DEGREE_BOUND + } + + fn num_rounds(&self) -> usize { + match self.phase { + BytecodeReductionPhase::CycleVariables => self.log_k, + BytecodeReductionPhase::LaneVariables => self.log_k_chunk, + } + } + + fn normalize_opening_point( + &self, + challenges: &[::Challenge], + ) -> OpeningPoint { + match self.phase { + BytecodeReductionPhase::CycleVariables => { + OpeningPoint::::new(challenges.to_vec()).match_endianness() + } + BytecodeReductionPhase::LaneVariables => { + // Full point: [lane || cycle] in big-endian. + let full_le: Vec = + [self.cycle_var_challenges.as_slice(), challenges].concat(); + OpeningPoint::::new(full_le).match_endianness() + } + } + } +} + +#[derive(Allocative)] +pub struct BytecodeClaimReductionProver { + pub params: BytecodeClaimReductionParams, + /// Program instructions (padded to power-of-2). Used for a fast first round. + #[allocative(skip)] + program: Arc, + /// Cycle-only polynomial: + /// \( S(k) = \sum_{\ell} W_{\eta}(\ell) \cdot lane\_value(\ell, instr[k]) \). + /// + /// This matches the GPU implementation's "main polynomial" strategy: during the cycle-phase + /// sumcheck we only need the **lane-summed** polynomial over the cycle domain (size K), + /// rather than all 448 lane polynomials. + cycle_weighted_sum: MultilinearPolynomial, + /// Lane-only chunk polynomials after evaluating cycle vars at `r_cycle`: + /// \( B_i(\cdot, r\_cycle) \) for each chunk i. + /// + /// This is computed once at the Stage 6b → Stage 7 transition and is only + /// `num_chunks * k_chunk` field elements (≤ 448 total, padded). + lane_chunks_at_r_cycle: Vec>, + /// Eq table/polynomial over the bytecode address point `r_bc` (cycle variables only). + eq_r_bc: MultilinearPolynomial, + /// Lane-weight polynomials over the lane variables only (one per chunk). + lane_weight_polys: Vec>, + /// Batched-sumcheck scaling for trailing dummy rounds (see `round_offset`). + #[allocative(skip)] + batch_dummy_rounds: AtomicUsize, +} + +impl BytecodeClaimReductionProver { + #[tracing::instrument(skip_all, name = "BytecodeClaimReductionProver::initialize")] + pub fn initialize( + params: BytecodeClaimReductionParams, + program: Arc, + ) -> Self { + let log_k = params.log_k; + let t_size = 1 << log_k; + let k_chunk = 1 << params.log_k_chunk; + + // Eq table over the bytecode address point. + let eq_r_bc = EqPolynomial::::evals(¶ms.r_bc.r); + debug_assert_eq!(eq_r_bc.len(), t_size); + + // Keep eq table as a polynomial so we can bind it during the cycle phase. + let eq_r_bc = MultilinearPolynomial::from(eq_r_bc); + + // Lane-weight polynomials (lane vars only) used in the lane phase. + let lane_weight_polys: Vec> = params + .chunk_lane_weights + .iter() + .map(|w| MultilinearPolynomial::from(w.clone())) + .collect(); + + // Build the fused-lane cycle polynomial S(k) over the cycle domain only. + let bytecode_len = program.bytecode_len(); + debug_assert_eq!(bytecode_len, t_size); + let total = total_lanes(); + let mut lane_weights_global = vec![F::zero(); total]; + for global_lane in 0..total { + let chunk_idx = global_lane / k_chunk; + let lane = global_lane % k_chunk; + lane_weights_global[global_lane] = params.chunk_lane_weights[chunk_idx][lane]; + } + let cycle_weighted_evals: Vec = program + .instructions + .par_iter() + .map(|instr| weighted_lane_sum_for_instruction(&lane_weights_global, instr)) + .collect(); + debug_assert_eq!(cycle_weighted_evals.len(), t_size); + let cycle_weighted_sum = MultilinearPolynomial::from(cycle_weighted_evals); + + Self { + params, + program, + cycle_weighted_sum, + lane_chunks_at_r_cycle: vec![], + eq_r_bc, + lane_weight_polys, + batch_dummy_rounds: AtomicUsize::new(0), + } + } + + /// Prepare the lane-phase witness polynomials \(B_i(\cdot, r_{cycle})\). + /// + /// This is intended to be called once after the cycle-phase sumcheck has finished + /// (i.e. after all `log_K` cycle challenges are known) and before we transition + /// `params.phase` to [`BytecodeReductionPhase::LaneVariables`]. + #[tracing::instrument(skip_all, name = "BytecodeClaimReductionProver::prepare_lane_phase")] + pub fn prepare_lane_phase(&mut self) { + if !self.lane_chunks_at_r_cycle.is_empty() { + return; + } + + let log_k = self.params.log_k; + let k_chunk = 1usize << self.params.log_k_chunk; + let num_chunks = self.params.num_chunks; + let total = total_lanes(); + + assert_eq!( + self.params.cycle_var_challenges.len(), + log_k, + "prepare_lane_phase called before cycle challenges are complete (have {}, expected {})", + self.params.cycle_var_challenges.len(), + log_k + ); + + // Convert the stored LE (LSB-first) cycle challenges into BE (MSB-first) order + // for EqPolynomial::evals, which uses big-endian indexing. + let r_cycle_be: OpeningPoint = + OpeningPoint::::new(self.params.cycle_var_challenges.clone()) + .match_endianness(); + + let eq_cycle = EqPolynomial::::evals(&r_cycle_be.r); + debug_assert_eq!(eq_cycle.len(), self.program.instructions.len()); + + // b_vals[global_lane] = Σ_k eq(r_cycle, k) * lane_value(global_lane, instr[k]) + let b_vals: Vec = self + .program + .instructions + .par_iter() + .zip(eq_cycle.par_iter()) + .fold( + || vec![F::zero(); total], + |mut acc, (instr, eq_k)| { + for_each_active_lane_value::(instr, |lane, v| match v { + ActiveLaneValue::One => { + acc[lane] += *eq_k; + } + ActiveLaneValue::Scalar(s) => { + acc[lane] += *eq_k * s; + } + }); + acc + }, + ) + .reduce( + || vec![F::zero(); total], + |mut a, b| { + a.iter_mut().zip(b.iter()).for_each(|(x, y)| *x += *y); + a + }, + ); + + // Chunk b_vals into `num_chunks` lane polynomials of length k_chunk. + self.lane_chunks_at_r_cycle = (0..num_chunks) + .map(|chunk_idx| { + let mut coeffs = vec![F::zero(); k_chunk]; + for lane in 0..k_chunk { + let global_lane = chunk_idx * k_chunk + lane; + if global_lane < total { + coeffs[lane] = b_vals[global_lane]; + } + } + MultilinearPolynomial::from(coeffs) + }) + .collect(); + } + + fn compute_message_impl(&self, _round: usize, previous_claim: F) -> UniPoly { + let mut evals: [F; DEGREE_BOUND] = match self.params.phase { + BytecodeReductionPhase::CycleVariables => { + let t_size = self.eq_r_bc.len(); + debug_assert_eq!(t_size, self.cycle_weighted_sum.len()); + debug_assert!(t_size.is_power_of_two()); + + let eq_evals: &[F] = match &self.eq_r_bc { + MultilinearPolynomial::LargeScalars(p) => &p.Z, + _ => unreachable!("EqPolynomial::evals produces a dense field polynomial"), + }; + let s_evals: &[F] = match &self.cycle_weighted_sum { + MultilinearPolynomial::LargeScalars(p) => &p.Z, + _ => unreachable!("cycle_weighted_sum is a dense field polynomial"), + }; + + // Round univariate is over the current LSB of the (remaining) cycle domain. + let num_pairs = t_size / 2; + let (h0_sum, h2_sum) = (0..num_pairs) + .into_par_iter() + .map(|j| { + let k0 = 2 * j; + let k1 = k0 + 1; + let s0 = s_evals[k0]; + let s1 = s_evals[k1]; + let e0 = eq_evals[k0]; + let e1 = eq_evals[k1]; + + let h0 = s0 * e0; + let s2 = (s1 + s1) - s0; + let e2 = (e1 + e1) - e0; + let h2 = s2 * e2; + (h0, h2) + }) + .reduce( + || (F::zero(), F::zero()), + |(a0, a1), (b0, b1)| (a0 + b0, a1 + b1), + ); + + [h0_sum, h2_sum] + } + BytecodeReductionPhase::LaneVariables => { + let eq_eval = self.eq_r_bc.get_bound_coeff(0); + assert!( + !self.lane_chunks_at_r_cycle.is_empty(), + "lane-phase invoked before prepare_lane_phase()" + ); + let half = self.lane_chunks_at_r_cycle[0].len() / 2; + (0..half) + .into_par_iter() + .map(|j| { + let mut out = [F::zero(); DEGREE_BOUND]; + for (chunk_idx, b) in self.lane_chunks_at_r_cycle.iter().enumerate() { + let b_evals = + b.sumcheck_evals_array::(j, BindingOrder::LowToHigh); + let lw_evals = self.lane_weight_polys[chunk_idx] + .sumcheck_evals_array::(j, BindingOrder::LowToHigh); + out[0] += b_evals[0] * (lw_evals[0] * eq_eval); + out[1] += b_evals[1] * (lw_evals[1] * eq_eval); + } + out + }) + .reduce( + || [F::zero(); DEGREE_BOUND], + |mut acc, arr| { + acc.iter_mut().zip(arr.iter()).for_each(|(a, b)| *a += *b); + acc + }, + ) + } + }; + + // If this instance is back-loaded in a batched sumcheck (i.e., it has trailing dummy + // rounds), then `previous_claim` is scaled by 2^{dummy_rounds}. The per-round univariate + // evaluations must be scaled by the same factor to satisfy the sumcheck consistency check. + let dummy_rounds = self.batch_dummy_rounds.load(Ordering::Relaxed); + if dummy_rounds != 0 { + let scale = F::one().mul_pow_2(dummy_rounds); + for e in evals.iter_mut() { + *e *= scale; + } + } + UniPoly::from_evals_and_hint(previous_claim, &evals) + } +} + +impl SumcheckInstanceProver for BytecodeClaimReductionProver { + fn get_params(&self) -> &dyn SumcheckInstanceParams { + &self.params + } + + fn round_offset(&self, max_num_rounds: usize) -> usize { + // Bytecode claim reduction's cycle-phase rounds must align to the *start* of the + // batched cycle challenge vector so that its (log_K) point is the suffix (LSB side) + // of the full (log_T) cycle point used by other Stage 6b instances. This is required + // for Stage 8's committed-bytecode embedding when log_T > log_K. + // + // This deviates from the default "front-loaded" batching offset, so we record the number + // of trailing dummy rounds and scale univariate evaluations accordingly. + let dummy_rounds = max_num_rounds.saturating_sub(self.params.num_rounds()); + self.batch_dummy_rounds + .store(dummy_rounds, Ordering::Relaxed); + 0 + } + + #[tracing::instrument(skip_all, name = "BytecodeClaimReductionProver::compute_message")] + fn compute_message(&mut self, _round: usize, previous_claim: F) -> UniPoly { + self.compute_message_impl(_round, previous_claim) + } + + #[tracing::instrument(skip_all, name = "BytecodeClaimReductionProver::ingest_challenge")] + fn ingest_challenge(&mut self, r_j: F::Challenge, _round: usize) { + if self.params.phase == BytecodeReductionPhase::CycleVariables { + self.params.cycle_var_challenges.push(r_j); + self.eq_r_bc.bind_parallel(r_j, BindingOrder::LowToHigh); + self.cycle_weighted_sum + .bind_parallel(r_j, BindingOrder::LowToHigh); + } + if self.params.phase == BytecodeReductionPhase::LaneVariables { + self.lane_weight_polys + .iter_mut() + .for_each(|p| p.bind_parallel(r_j, BindingOrder::LowToHigh)); + self.lane_chunks_at_r_cycle + .iter_mut() + .for_each(|p| p.bind_parallel(r_j, BindingOrder::LowToHigh)); + } + } + + fn cache_openings( + &self, + accumulator: &mut ProverOpeningAccumulator, + transcript: &mut T, + sumcheck_challenges: &[F::Challenge], + ) { + match self.params.phase { + BytecodeReductionPhase::CycleVariables => { + // Cache intermediate claim for Stage 7. + let opening_point = self.params.normalize_opening_point(sumcheck_challenges); + + let eq_eval = self.eq_r_bc.get_bound_coeff(0); + let s_eval = self.cycle_weighted_sum.get_bound_coeff(0); + let sum = s_eval * eq_eval; + + accumulator.append_virtual( + transcript, + VirtualPolynomial::BytecodeClaimReductionIntermediate, + SumcheckId::BytecodeClaimReductionCyclePhase, + opening_point, + sum, + ); + } + BytecodeReductionPhase::LaneVariables => { + // Cache final openings of the bytecode chunk polynomials at the full point. + let opening_point = self.params.normalize_opening_point(sumcheck_challenges); + let (r_lane, r_cycle) = opening_point.split_at(self.params.log_k_chunk); + + let polynomial_types: Vec = (0..self.params.num_chunks) + .map(CommittedPolynomial::BytecodeChunk) + .collect(); + let claims: Vec = self + .lane_chunks_at_r_cycle + .iter() + .map(|p| p.final_sumcheck_claim()) + .collect(); + + accumulator.append_sparse( + transcript, + polynomial_types, + SumcheckId::BytecodeClaimReduction, + r_lane.r, + r_cycle.r, + claims, + ); + } + } + } + + #[cfg(feature = "allocative")] + fn update_flamegraph(&self, flamegraph: &mut allocative::FlameGraphBuilder) { + flamegraph.visit_root(self); + } +} + +pub struct BytecodeClaimReductionVerifier { + pub params: RefCell>, +} + +impl BytecodeClaimReductionVerifier { + pub fn new(params: BytecodeClaimReductionParams) -> Self { + Self { + params: RefCell::new(params), + } + } +} + +impl SumcheckInstanceVerifier + for BytecodeClaimReductionVerifier +{ + fn get_params(&self) -> &dyn SumcheckInstanceParams { + unsafe { &*self.params.as_ptr() } + } + + fn round_offset(&self, _max_num_rounds: usize) -> usize { + // Must mirror the prover: align this instance to the start of the batched challenge vector. + 0 + } + + fn expected_output_claim( + &self, + accumulator: &VerifierOpeningAccumulator, + sumcheck_challenges: &[F::Challenge], + ) -> F { + let params = self.params.borrow(); + match params.phase { + BytecodeReductionPhase::CycleVariables => { + accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::BytecodeClaimReductionIntermediate, + SumcheckId::BytecodeClaimReductionCyclePhase, + ) + .1 + } + BytecodeReductionPhase::LaneVariables => { + let opening_point = params.normalize_opening_point(sumcheck_challenges); + let (r_lane, r_cycle) = opening_point.split_at(params.log_k_chunk); + + let eq_eval = EqPolynomial::::mle(&r_cycle.r, ¶ms.r_bc.r); + + // Evaluate each chunk's lane-weight polynomial at r_lane and combine with chunk openings. + let eq_lane = EqPolynomial::::evals(&r_lane.r); + let mut sum = F::zero(); + for chunk_idx in 0..params.num_chunks { + let (_, chunk_opening) = accumulator.get_committed_polynomial_opening( + CommittedPolynomial::BytecodeChunk(chunk_idx), + SumcheckId::BytecodeClaimReduction, + ); + let w_eval: F = params.chunk_lane_weights[chunk_idx] + .iter() + .zip(eq_lane.iter()) + .map(|(w, e)| *w * *e) + .sum(); + sum += chunk_opening * w_eval; + } + + sum * eq_eval + } + } + } + + fn cache_openings( + &self, + accumulator: &mut VerifierOpeningAccumulator, + transcript: &mut T, + sumcheck_challenges: &[F::Challenge], + ) { + let mut params = self.params.borrow_mut(); + match params.phase { + BytecodeReductionPhase::CycleVariables => { + let opening_point = params.normalize_opening_point(sumcheck_challenges); + accumulator.append_virtual( + transcript, + VirtualPolynomial::BytecodeClaimReductionIntermediate, + SumcheckId::BytecodeClaimReductionCyclePhase, + opening_point, + ); + // Record LE challenges for phase 2 normalization. + params.cycle_var_challenges = sumcheck_challenges.to_vec(); + } + BytecodeReductionPhase::LaneVariables => { + let opening_point = params.normalize_opening_point(sumcheck_challenges); + let polynomial_types: Vec = (0..params.num_chunks) + .map(CommittedPolynomial::BytecodeChunk) + .collect(); + accumulator.append_sparse( + transcript, + polynomial_types, + SumcheckId::BytecodeClaimReduction, + opening_point.r, + ); + } + } + } +} + +fn compute_chunk_lane_weights( + bytecode_read_raf_params: &BytecodeReadRafSumcheckParams, + accumulator: &dyn OpeningAccumulator, + eta_powers: &[F; NUM_VAL_STAGES], + num_chunks: usize, + k_chunk: usize, +) -> Vec> { + let reg_count = REGISTER_COUNT as usize; + let total = total_lanes(); + + // Offsets (canonical lane ordering) + let rs1_start = 0usize; + let rs2_start = rs1_start + reg_count; + let rd_start = rs2_start + reg_count; + let unexp_pc_idx = rd_start + reg_count; + let imm_idx = unexp_pc_idx + 1; + let circuit_start = imm_idx + 1; + let instr_start = circuit_start + NUM_CIRCUIT_FLAGS; + let lookup_start = instr_start + NUM_INSTRUCTION_FLAGS; + let raf_flag_idx = lookup_start + LookupTables::::COUNT; + debug_assert_eq!(raf_flag_idx + 1, total); + + // Eq tables for stage4/stage5 register selection weights. + let log_reg = reg_count.log_2(); + let r_register_4 = accumulator + .get_virtual_polynomial_opening( + VirtualPolynomial::RdWa, + SumcheckId::RegistersReadWriteChecking, + ) + .0 + .r; + let eq_r_register_4 = EqPolynomial::::evals(&r_register_4[..log_reg]); + + let r_register_5 = accumulator + .get_virtual_polynomial_opening(VirtualPolynomial::RdWa, SumcheckId::RegistersValEvaluation) + .0 + .r; + let eq_r_register_5 = EqPolynomial::::evals(&r_register_5[..log_reg]); + + let mut weights = vec![F::zero(); total]; + + // Stage 1 + { + let coeff = eta_powers[0]; + let g = &bytecode_read_raf_params.stage1_gammas; + weights[unexp_pc_idx] += coeff * g[0]; + weights[imm_idx] += coeff * g[1]; + for i in 0..NUM_CIRCUIT_FLAGS { + weights[circuit_start + i] += coeff * g[2 + i]; + } + } + + // Stage 2 + { + let coeff = eta_powers[1]; + let g = &bytecode_read_raf_params.stage2_gammas; + weights[circuit_start + (CircuitFlags::Jump as usize)] += coeff * g[0]; + weights[instr_start + (InstructionFlags::Branch as usize)] += coeff * g[1]; + weights[instr_start + (InstructionFlags::IsRdNotZero as usize)] += coeff * g[2]; + weights[circuit_start + (CircuitFlags::WriteLookupOutputToRD as usize)] += coeff * g[3]; + } + + // Stage 3 + { + let coeff = eta_powers[2]; + let g = &bytecode_read_raf_params.stage3_gammas; + weights[imm_idx] += coeff * g[0]; + weights[unexp_pc_idx] += coeff * g[1]; + weights[instr_start + (InstructionFlags::LeftOperandIsRs1Value as usize)] += coeff * g[2]; + weights[instr_start + (InstructionFlags::LeftOperandIsPC as usize)] += coeff * g[3]; + weights[instr_start + (InstructionFlags::RightOperandIsRs2Value as usize)] += coeff * g[4]; + weights[instr_start + (InstructionFlags::RightOperandIsImm as usize)] += coeff * g[5]; + weights[instr_start + (InstructionFlags::IsNoop as usize)] += coeff * g[6]; + weights[circuit_start + (CircuitFlags::VirtualInstruction as usize)] += coeff * g[7]; + weights[circuit_start + (CircuitFlags::IsFirstInSequence as usize)] += coeff * g[8]; + } + + // Stage 4 + { + let coeff = eta_powers[3]; + let g = &bytecode_read_raf_params.stage4_gammas; + for r in 0..reg_count { + weights[rd_start + r] += coeff * g[0] * eq_r_register_4[r]; + weights[rs1_start + r] += coeff * g[1] * eq_r_register_4[r]; + weights[rs2_start + r] += coeff * g[2] * eq_r_register_4[r]; + } + } + + // Stage 5 + { + let coeff = eta_powers[4]; + let g = &bytecode_read_raf_params.stage5_gammas; + for r in 0..reg_count { + weights[rd_start + r] += coeff * g[0] * eq_r_register_5[r]; + } + weights[raf_flag_idx] += coeff * g[1]; + for i in 0..LookupTables::::COUNT { + weights[lookup_start + i] += coeff * g[2 + i]; + } + } + + // Chunk into k_chunk-sized blocks. + (0..num_chunks) + .map(|chunk_idx| { + (0..k_chunk) + .map(|lane| { + let global = chunk_idx * k_chunk + lane; + if global < total { + weights[global] + } else { + F::zero() + } + }) + .collect_vec() + }) + .collect_vec() +} diff --git a/jolt-core/src/zkvm/claim_reductions/hamming_weight.rs b/jolt-core/src/zkvm/claim_reductions/hamming_weight.rs index d40860f35a..8692c03b3a 100644 --- a/jolt-core/src/zkvm/claim_reductions/hamming_weight.rs +++ b/jolt-core/src/zkvm/claim_reductions/hamming_weight.rs @@ -99,6 +99,7 @@ use crate::subprotocols::{ use crate::transcripts::Transcript; use crate::zkvm::{ config::OneHotParams, + program::ProgramPreprocessing, verifier::JoltSharedPreprocessing, witness::{CommittedPolynomial, VirtualPolynomial}, }; @@ -309,13 +310,14 @@ impl HammingWeightClaimReductionProver { params: HammingWeightClaimReductionParams, trace: &[Cycle], preprocessing: &JoltSharedPreprocessing, + program: &ProgramPreprocessing, one_hot_params: &OneHotParams, ) -> Self { // Compute all G_i polynomials via streaming. // `params.r_cycle` is in BIG_ENDIAN (OpeningPoint) convention. let G_vecs = compute_all_G::( trace, - &preprocessing.bytecode, + program, &preprocessing.memory_layout, one_hot_params, ¶ms.r_cycle, diff --git a/jolt-core/src/zkvm/claim_reductions/mod.rs b/jolt-core/src/zkvm/claim_reductions/mod.rs index 5d19f993a1..697342f5d1 100644 --- a/jolt-core/src/zkvm/claim_reductions/mod.rs +++ b/jolt-core/src/zkvm/claim_reductions/mod.rs @@ -1,7 +1,9 @@ pub mod advice; +pub mod bytecode; pub mod hamming_weight; pub mod increments; pub mod instruction_lookups; +pub mod program_image; pub mod ram_ra; pub mod registers; @@ -9,6 +11,10 @@ pub use advice::{ AdviceClaimReductionParams, AdviceClaimReductionProver, AdviceClaimReductionVerifier, AdviceKind, }; +pub use bytecode::{ + BytecodeClaimReductionParams, BytecodeClaimReductionProver, BytecodeClaimReductionVerifier, + BytecodeReductionPhase, +}; pub use hamming_weight::{ HammingWeightClaimReductionParams, HammingWeightClaimReductionProver, HammingWeightClaimReductionVerifier, @@ -21,6 +27,10 @@ pub use instruction_lookups::{ InstructionLookupsClaimReductionSumcheckParams, InstructionLookupsClaimReductionSumcheckProver, InstructionLookupsClaimReductionSumcheckVerifier, }; +pub use program_image::{ + ProgramImageClaimReductionParams, ProgramImageClaimReductionProver, + ProgramImageClaimReductionVerifier, +}; pub use ram_ra::{ RaReductionParams, RamRaClaimReductionSumcheckProver, RamRaClaimReductionSumcheckVerifier, }; diff --git a/jolt-core/src/zkvm/claim_reductions/program_image.rs b/jolt-core/src/zkvm/claim_reductions/program_image.rs new file mode 100644 index 0000000000..16c232231d --- /dev/null +++ b/jolt-core/src/zkvm/claim_reductions/program_image.rs @@ -0,0 +1,468 @@ +//! Program-image (initial RAM) claim reduction. +//! +//! In committed bytecode mode, Stage 4 consumes prover-supplied scalar claims for the +//! program-image contribution to `Val_init(r_address)` without materializing the initial RAM. +//! This sumcheck binds those scalars to a trusted commitment to the program-image words polynomial. + +use allocative::Allocative; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use rayon::prelude::*; + +use crate::field::JoltField; +use crate::poly::eq_poly::EqPolynomial; +use crate::poly::multilinear_polynomial::{BindingOrder, MultilinearPolynomial, PolynomialBinding}; +use crate::poly::opening_proof::{ + OpeningAccumulator, OpeningPoint, ProverOpeningAccumulator, SumcheckId, + VerifierOpeningAccumulator, BIG_ENDIAN, LITTLE_ENDIAN, +}; +use crate::poly::unipoly::UniPoly; +use crate::subprotocols::sumcheck_prover::SumcheckInstanceProver; +use crate::subprotocols::sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}; +use crate::transcripts::Transcript; +use crate::utils::math::Math; +use crate::zkvm::config::ReadWriteConfig; +use crate::zkvm::ram::remap_address; +use crate::zkvm::witness::{CommittedPolynomial, VirtualPolynomial}; +use tracer::JoltDevice; + +const DEGREE_BOUND: usize = 2; + +#[derive(Clone, Allocative)] +pub struct ProgramImageClaimReductionParams { + pub gamma: F, + pub single_opening: bool, + pub ram_num_vars: usize, + pub start_index: usize, + pub padded_len_words: usize, + pub m: usize, + pub r_addr_rw: Vec, + pub r_addr_raf: Option>, +} + +impl ProgramImageClaimReductionParams { + #[allow(clippy::too_many_arguments)] + pub fn new( + program_io: &JoltDevice, + ram_min_bytecode_address: u64, + padded_len_words: usize, + ram_K: usize, + trace_len: usize, + rw_config: &ReadWriteConfig, + accumulator: &dyn OpeningAccumulator, + transcript: &mut impl Transcript, + ) -> Self { + let ram_num_vars = ram_K.log_2(); + let start_index = + remap_address(ram_min_bytecode_address, &program_io.memory_layout).unwrap() as usize; + let m = padded_len_words.log_2(); + debug_assert!(padded_len_words.is_power_of_two()); + debug_assert!(padded_len_words > 0); + + // r_address_rw comes from RamVal/RamReadWriteChecking (Stage 2). + let (r_rw, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::RamVal, + SumcheckId::RamReadWriteChecking, + ); + let (r_addr_rw, _) = r_rw.split_at(ram_num_vars); + + // r_address_raf comes from RamValFinal/RamOutputCheck (Stage 2), but may equal r_address_rw. + let log_t = trace_len.log_2(); + let single_opening = rw_config.needs_single_advice_opening(log_t); + let r_addr_raf = if single_opening { + None + } else { + let (r_raf, _) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::RamValFinal, + SumcheckId::RamOutputCheck, + ); + let (r_addr_raf, _) = r_raf.split_at(ram_num_vars); + Some(r_addr_raf.r) + }; + + // Sample gamma for combining rw + raf. + let gamma: F = transcript.challenge_scalar(); + + Self { + gamma, + single_opening, + ram_num_vars, + start_index, + padded_len_words, + m, + r_addr_rw: r_addr_rw.r, + r_addr_raf, + } + } +} + +impl SumcheckInstanceParams for ProgramImageClaimReductionParams { + fn input_claim(&self, accumulator: &dyn OpeningAccumulator) -> F { + // Scalar claims were staged in Stage 4 as virtual openings. + let (_, c_rw) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::ProgramImageInitContributionRw, + SumcheckId::RamValEvaluation, + ); + if self.single_opening { + c_rw + } else { + let (_, c_raf) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::ProgramImageInitContributionRaf, + SumcheckId::RamValFinalEvaluation, + ); + c_rw + self.gamma * c_raf + } + } + + fn degree(&self) -> usize { + DEGREE_BOUND + } + + fn num_rounds(&self) -> usize { + self.m + } + + fn normalize_opening_point( + &self, + challenges: &[::Challenge], + ) -> OpeningPoint { + // Challenges are in little-endian round order (LSB first) when binding LowToHigh. + OpeningPoint::::new(challenges.to_vec()).match_endianness() + } +} + +#[derive(Allocative)] +pub struct ProgramImageClaimReductionProver { + pub params: ProgramImageClaimReductionParams, + program_word: MultilinearPolynomial, + eq_slice: MultilinearPolynomial, + /// Number of trailing dummy rounds in a batched Stage 6b sumcheck. + batch_dummy_rounds: AtomicUsize, +} + +fn build_eq_slice_table( + r_addr: &[F::Challenge], + start_index: usize, + len: usize, +) -> Vec { + debug_assert!(len.is_power_of_two()); + let mut out = Vec::with_capacity(len); + let mut idx = start_index; + let mut off = 0usize; + while off < len { + let remaining = len - off; + let (block_size, block_evals) = + EqPolynomial::::evals_for_max_aligned_block(r_addr, idx, remaining); + out.extend_from_slice(&block_evals); + idx += block_size; + off += block_size; + } + debug_assert_eq!(out.len(), len); + out +} + +impl ProgramImageClaimReductionProver { + #[tracing::instrument(skip_all, name = "ProgramImageClaimReductionProver::initialize")] + pub fn initialize( + params: ProgramImageClaimReductionParams, + program_image_words_padded: Vec, + ) -> Self { + debug_assert_eq!(program_image_words_padded.len(), params.padded_len_words); + debug_assert_eq!(params.padded_len_words, 1usize << params.m); + + let program_word: MultilinearPolynomial = + MultilinearPolynomial::from(program_image_words_padded); + + let eq_rw = build_eq_slice_table::( + ¶ms.r_addr_rw, + params.start_index, + params.padded_len_words, + ); + let mut eq_comb = eq_rw; + if !params.single_opening { + let r_raf = params.r_addr_raf.as_ref().expect("missing raf address"); + let eq_raf = + build_eq_slice_table::(r_raf, params.start_index, params.padded_len_words); + for (c, e) in eq_comb.iter_mut().zip(eq_raf.iter()) { + *c += params.gamma * *e; + } + } + let eq_slice: MultilinearPolynomial = MultilinearPolynomial::from(eq_comb); + + Self { + params, + program_word, + eq_slice, + batch_dummy_rounds: AtomicUsize::new(0), + } + } +} + +impl SumcheckInstanceProver + for ProgramImageClaimReductionProver +{ + fn get_params(&self) -> &dyn SumcheckInstanceParams { + &self.params + } + + fn round_offset(&self, max_num_rounds: usize) -> usize { + // Align to the *start* of the Stage 6b challenge vector so that the resulting + // big-endian opening point is the suffix (LSB side) of the full log_T cycle point. + // This is required for Stage 8 embedding when log_T > m. + let dummy_rounds = max_num_rounds.saturating_sub(self.params.num_rounds()); + self.batch_dummy_rounds + .store(dummy_rounds, Ordering::Relaxed); + 0 + } + + #[tracing::instrument(skip_all, name = "ProgramImageClaimReductionProver::compute_message")] + fn compute_message(&mut self, _round: usize, previous_claim: F) -> UniPoly { + let half = self.program_word.len() / 2; + let program_word = &self.program_word; + let eq_slice = &self.eq_slice; + let mut evals: [F; DEGREE_BOUND] = (0..half) + .into_par_iter() + .map(|j| { + let pw = + program_word.sumcheck_evals_array::(j, BindingOrder::LowToHigh); + let eq = eq_slice.sumcheck_evals_array::(j, BindingOrder::LowToHigh); + let mut out = [F::zero(); DEGREE_BOUND]; + for i in 0..DEGREE_BOUND { + out[i] = pw[i] * eq[i]; + } + out + }) + .reduce( + || [F::zero(); DEGREE_BOUND], + |mut acc, arr| { + acc.iter_mut().zip(arr.iter()).for_each(|(a, b)| *a += *b); + acc + }, + ); + // If this instance has trailing dummy rounds, `previous_claim` is scaled by 2^{dummy_rounds} + // in the batched sumcheck. Scale the per-round univariate evaluations accordingly so the + // sumcheck consistency checks pass (mirrors BytecodeClaimReduction). + let dummy_rounds = self.batch_dummy_rounds.load(Ordering::Relaxed); + if dummy_rounds != 0 { + let scale = F::one().mul_pow_2(dummy_rounds); + for e in evals.iter_mut() { + *e *= scale; + } + } + UniPoly::from_evals_and_hint(previous_claim, &evals) + } + + #[tracing::instrument(skip_all, name = "ProgramImageClaimReductionProver::ingest_challenge")] + fn ingest_challenge(&mut self, r_j: F::Challenge, _round: usize) { + self.program_word + .bind_parallel(r_j, BindingOrder::LowToHigh); + self.eq_slice.bind_parallel(r_j, BindingOrder::LowToHigh); + } + + fn cache_openings( + &self, + accumulator: &mut ProverOpeningAccumulator, + transcript: &mut T, + sumcheck_challenges: &[F::Challenge], + ) { + let opening_point = self.params.normalize_opening_point(sumcheck_challenges); + let claim = self.program_word.final_sumcheck_claim(); + accumulator.append_dense( + transcript, + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + opening_point.r, + claim, + ); + } + + #[cfg(feature = "allocative")] + fn update_flamegraph(&self, flamegraph: &mut allocative::FlameGraphBuilder) { + flamegraph.visit_root(self); + } +} + +pub struct ProgramImageClaimReductionVerifier { + pub params: ProgramImageClaimReductionParams, +} + +fn eval_eq_slice_at_r_star_lsb_dp( + r_addr_be: &[F::Challenge], + start_index: usize, + m: usize, + r_star_lsb: &[F::Challenge], +) -> F { + let ell = r_addr_be.len(); + debug_assert_eq!(r_star_lsb.len(), m); + debug_assert!(m <= ell); + + // DP over carry bit, iterating LSB -> MSB across the RAM address bits. + let mut dp0 = F::one(); // carry=0 + let mut dp1 = F::zero(); // carry=1 + + for i in 0..ell { + let start_bit = ((start_index >> i) & 1) as u8; + let y_var = i < m; + let r_y: F = if y_var { + r_star_lsb[i].into() + } else { + F::zero() + }; + + let r_addr_bit: F = r_addr_be[ell - 1 - i].into(); // LSB-first mapping + let k0 = F::one() - r_addr_bit; + let k1 = r_addr_bit; + + let mut ndp0 = F::zero(); + let mut ndp1 = F::zero(); + + // Transition from carry=0 + if !dp0.is_zero() { + if y_var { + // y=0 + let sum0 = start_bit; + let k_bit0 = sum0 & 1; + let carry0 = (sum0 >> 1) & 1; + let addr_factor0 = if k_bit0 == 1 { k1 } else { k0 }; + let y_factor0 = F::one() - r_y; + if carry0 == 0 { + ndp0 += dp0 * addr_factor0 * y_factor0; + } else { + ndp1 += dp0 * addr_factor0 * y_factor0; + } + // y=1 + let sum1 = start_bit + 1; + let k_bit1 = sum1 & 1; + let carry1 = (sum1 >> 1) & 1; + let addr_factor1 = if k_bit1 == 1 { k1 } else { k0 }; + let y_factor1 = r_y; + if carry1 == 0 { + ndp0 += dp0 * addr_factor1 * y_factor1; + } else { + ndp1 += dp0 * addr_factor1 * y_factor1; + } + } else { + // y is fixed 0 + let sum0 = start_bit; + let k_bit0 = sum0 & 1; + let carry0 = (sum0 >> 1) & 1; + let addr_factor0 = if k_bit0 == 1 { k1 } else { k0 }; + if carry0 == 0 { + ndp0 += dp0 * addr_factor0; + } else { + ndp1 += dp0 * addr_factor0; + } + } + } + + // Transition from carry=1 + if !dp1.is_zero() { + if y_var { + // y=0 + let sum0 = start_bit + 1; + let k_bit0 = sum0 & 1; + let carry0 = (sum0 >> 1) & 1; + let addr_factor0 = if k_bit0 == 1 { k1 } else { k0 }; + let y_factor0 = F::one() - r_y; + if carry0 == 0 { + ndp0 += dp1 * addr_factor0 * y_factor0; + } else { + ndp1 += dp1 * addr_factor0 * y_factor0; + } + // y=1 + let sum1 = start_bit + 1 + 1; + let k_bit1 = sum1 & 1; + let carry1 = (sum1 >> 1) & 1; + let addr_factor1 = if k_bit1 == 1 { k1 } else { k0 }; + let y_factor1 = r_y; + if carry1 == 0 { + ndp0 += dp1 * addr_factor1 * y_factor1; + } else { + ndp1 += dp1 * addr_factor1 * y_factor1; + } + } else { + // y is fixed 0 + let sum0 = start_bit + 1; + let k_bit0 = sum0 & 1; + let carry0 = (sum0 >> 1) & 1; + let addr_factor0 = if k_bit0 == 1 { k1 } else { k0 }; + if carry0 == 0 { + ndp0 += dp1 * addr_factor0; + } else { + ndp1 += dp1 * addr_factor0; + } + } + } + + dp0 = ndp0; + dp1 = ndp1; + } + + // Discard carry-out paths: indices >= 2^ell are out-of-range and contribute 0. + dp0 +} + +impl SumcheckInstanceVerifier + for ProgramImageClaimReductionVerifier +{ + fn get_params(&self) -> &dyn SumcheckInstanceParams { + &self.params + } + + fn round_offset(&self, _max_num_rounds: usize) -> usize { + // Must mirror prover: align to the start of Stage 6b challenge vector. + 0 + } + + fn expected_output_claim( + &self, + accumulator: &VerifierOpeningAccumulator, + sumcheck_challenges: &[F::Challenge], + ) -> F { + let (_, pw_eval) = accumulator.get_committed_polynomial_opening( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + ); + + // sumcheck_challenges are LSB-first (binding LowToHigh), which is exactly what the DP uses. + let eq_rw = eval_eq_slice_at_r_star_lsb_dp::( + &self.params.r_addr_rw, + self.params.start_index, + self.params.m, + sumcheck_challenges, + ); + let eq_comb = if self.params.single_opening { + eq_rw + } else { + let r_raf = self + .params + .r_addr_raf + .as_ref() + .expect("missing raf address"); + let eq_raf = eval_eq_slice_at_r_star_lsb_dp::( + r_raf, + self.params.start_index, + self.params.m, + sumcheck_challenges, + ); + eq_rw + self.params.gamma * eq_raf + }; + + pw_eval * eq_comb + } + + fn cache_openings( + &self, + accumulator: &mut VerifierOpeningAccumulator, + transcript: &mut T, + sumcheck_challenges: &[F::Challenge], + ) { + let opening_point = self.params.normalize_opening_point(sumcheck_challenges); + accumulator.append_dense( + transcript, + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + opening_point.r, + ); + } +} diff --git a/jolt-core/src/zkvm/config.rs b/jolt-core/src/zkvm/config.rs index c7846b1347..0121261ca8 100644 --- a/jolt-core/src/zkvm/config.rs +++ b/jolt-core/src/zkvm/config.rs @@ -1,5 +1,8 @@ use allocative::Allocative; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate, +}; +use std::io::{Read, Write}; use crate::field::JoltField; use crate::utils::math::Math; @@ -20,6 +23,62 @@ pub fn get_instruction_sumcheck_phases(log_t: usize) -> usize { } } +/// Controls whether the prover/verifier use the **full** program path (verifier may do O(K)) +/// or the **committed** program path (staged Val claims + claim reduction + folded Stage 8 +/// opening for bytecode chunk + program image commitments). +/// +/// "Program" encompasses both bytecode (instructions) and program image (initial RAM). +#[repr(u8)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Allocative)] +pub enum ProgramMode { + /// Full mode: verifier has full access to bytecode and program image. + Full = 0, + /// Committed mode: verifier only has commitments to bytecode chunks and program image. + /// Uses staged Val claims + claim reductions + folded Stage 8 joint opening. + Committed = 1, +} + +impl Default for ProgramMode { + fn default() -> Self { + Self::Full + } +} + +impl CanonicalSerialize for ProgramMode { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (*self as u8).serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + (*self as u8).serialized_size(compress) + } +} + +impl Valid for ProgramMode { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl CanonicalDeserialize for ProgramMode { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let value = u8::deserialize_with_mode(reader, compress, validate)?; + match value { + 0 => Ok(Self::Full), + 1 => Ok(Self::Committed), + _ => Err(SerializationError::InvalidData), + } + } +} + /// Configuration for read-write checking sumchecks. /// /// Contains parameters that control phase structure for RAM and register @@ -150,6 +209,22 @@ impl OneHotConfig { } } + /// Create a OneHotConfig with an explicit log_k_chunk. + pub fn from_log_k_chunk(log_k_chunk: usize) -> Self { + debug_assert!(log_k_chunk == 4 || log_k_chunk == 8); + let log_k_chunk = log_k_chunk as u8; + let lookups_ra_virtual_log_k_chunk = if log_k_chunk == 4 { + LOG_K / 8 + } else { + LOG_K / 4 + }; + + Self { + log_k_chunk, + lookups_ra_virtual_log_k_chunk: lookups_ra_virtual_log_k_chunk as u8, + } + } + /// Validates that the one-hot configuration is valid. /// /// This is called by the verifier to ensure the prover hasn't provided diff --git a/jolt-core/src/zkvm/mod.rs b/jolt-core/src/zkvm/mod.rs index 82117f6b76..871df62084 100644 --- a/jolt-core/src/zkvm/mod.rs +++ b/jolt-core/src/zkvm/mod.rs @@ -12,6 +12,8 @@ use crate::{ use ark_bn254::Fr; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use eyre::Result; +#[cfg(feature = "pprof")] +use pprof::protos::Message; use proof_serialization::JoltProof; #[cfg(feature = "prover")] use prover::JoltCpuProver; @@ -26,6 +28,7 @@ pub mod config; pub mod instruction; pub mod instruction_lookups; pub mod lookup_table; +pub mod program; pub mod proof_serialization; #[cfg(feature = "prover")] pub mod prover; @@ -36,6 +39,9 @@ pub mod spartan; pub mod verifier; pub mod witness; +#[cfg(test)] +mod tests; + // Scoped CPU profiler for performance analysis. Feature-gated by "pprof". // Usage: let _guard = pprof_scope!("label"); // @@ -64,7 +70,6 @@ impl Drop for PprofGuard { let _ = std::fs::create_dir_all(dir); } if let Ok(mut f) = std::fs::File::create(&filename) { - use pprof::protos::Message; if let Ok(p) = report.pprof() { let mut buf = Vec::new(); if p.encode(&mut buf).is_ok() { diff --git a/jolt-core/src/zkvm/program.rs b/jolt-core/src/zkvm/program.rs new file mode 100644 index 0000000000..6b1c7a75c7 --- /dev/null +++ b/jolt-core/src/zkvm/program.rs @@ -0,0 +1,775 @@ +//! Unified program preprocessing module. +//! +//! This module contains all static program data derived from the ELF: +//! - **Instructions** (`instructions`, `pc_map`): Decoded RISC-V instructions for bytecode lookup tables +//! - **Program image** (`min_bytecode_address`, `program_image_words`): Initial RAM state +//! +//! Both come from the same ELF file and are conceptually "the program". + +use std::any::TypeId; +use std::collections::HashMap; +use std::io::{Read, Write}; +use std::sync::Arc; + +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate, +}; +use common::constants::BYTES_PER_INSTRUCTION; +use rayon::prelude::*; +use tracer::instruction::{Cycle, Instruction}; + +use crate::poly::commitment::commitment_scheme::CommitmentScheme; +use crate::poly::commitment::dory::{ + ArkG1, ArkGT, ArkworksProverSetup, DoryCommitmentScheme, DoryContext, DoryGlobals, DoryLayout, + BN254, +}; +use crate::poly::multilinear_polynomial::MultilinearPolynomial; +use crate::utils::errors::ProofVerifyError; +use crate::utils::math::Math; +use crate::zkvm::bytecode::chunks::{ + build_bytecode_chunks, build_bytecode_chunks_for_main_matrix, for_each_active_lane_value, + total_lanes, ActiveLaneValue, +}; +pub use crate::zkvm::bytecode::BytecodePCMapper; +use crate::zkvm::bytecode::BytecodePreprocessing; +use ark_bn254::{Fr, G1Projective}; +use ark_ff::{One, Zero}; +use dory::primitives::arithmetic::PairingCurve; + +// ───────────────────────────────────────────────────────────────────────────── +// ProgramPreprocessing - Full program data (prover + full-mode verifier) +// ───────────────────────────────────────────────────────────────────────────── + +/// Full program preprocessing - includes both bytecode instructions and RAM image. +/// +/// Both come from the same ELF file: +/// - `instructions` + `pc_map`: for bytecode lookup tables +/// - `program_image_words`: for initial RAM state +/// +/// # Usage +/// - Prover always has full access to this data +/// - In Full mode, verifier also has full access +/// - In Committed mode, verifier only has `TrustedProgramCommitments` +#[derive(Debug, Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct ProgramPreprocessing { + // ─── Bytecode (instructions) ─── + /// Decoded RISC-V instructions (padded to power-of-2). + pub instructions: Vec, + /// PC mapping for instruction lookup. + pub pc_map: BytecodePCMapper, + + // ─── Program image (RAM init) ─── + /// Minimum bytecode address (word-aligned). + pub min_bytecode_address: u64, + /// Program-image words (little-endian packed u64 values). + pub program_image_words: Vec, +} + +impl Default for ProgramPreprocessing { + fn default() -> Self { + Self { + instructions: vec![Instruction::NoOp, Instruction::NoOp], + pc_map: BytecodePCMapper::default(), + min_bytecode_address: 0, + program_image_words: Vec::new(), + } + } +} + +impl ProgramPreprocessing { + /// Preprocess program from decoded ELF outputs. + /// + /// # Arguments + /// - `instructions`: Decoded RISC-V instructions from ELF + /// - `memory_init`: Raw bytes from ELF that form initial RAM + #[tracing::instrument(skip_all, name = "ProgramPreprocessing::preprocess")] + pub fn preprocess(instructions: Vec, memory_init: Vec<(u64, u8)>) -> Self { + // ─── Process instructions (from BytecodePreprocessing::preprocess) ─── + let mut bytecode = instructions; + // Prepend a single no-op instruction + bytecode.insert(0, Instruction::NoOp); + let pc_map = BytecodePCMapper::new(&bytecode); + + let bytecode_size = bytecode.len().next_power_of_two().max(2); + // Pad to nearest power of 2 + bytecode.resize(bytecode_size, Instruction::NoOp); + + // ─── Process program image (from ProgramImagePreprocessing::preprocess) ─── + let min_bytecode_address = memory_init + .iter() + .map(|(address, _)| *address) + .min() + .unwrap_or(0); + + let max_bytecode_address = memory_init + .iter() + .map(|(address, _)| *address) + .max() + .unwrap_or(0) + + (BYTES_PER_INSTRUCTION as u64 - 1); + + let num_words = max_bytecode_address.next_multiple_of(8) / 8 - min_bytecode_address / 8 + 1; + let mut program_image_words = vec![0u64; num_words as usize]; + // Convert bytes into words and populate `program_image_words` + for chunk in + memory_init.chunk_by(|(address_a, _), (address_b, _)| address_a / 8 == address_b / 8) + { + let mut word = [0u8; 8]; + for (address, byte) in chunk { + word[(address % 8) as usize] = *byte; + } + let word = u64::from_le_bytes(word); + let remapped_index = (chunk[0].0 / 8 - min_bytecode_address / 8) as usize; + program_image_words[remapped_index] = word; + } + + Self { + instructions: bytecode, + pc_map, + min_bytecode_address, + program_image_words, + } + } + + /// Bytecode length (power-of-2 padded). + pub fn bytecode_len(&self) -> usize { + self.instructions.len() + } + + /// Program image word count (unpadded). + pub fn program_image_len_words(&self) -> usize { + self.program_image_words.len() + } + + /// Program image word count (power-of-2 padded). + pub fn program_image_len_words_padded(&self) -> usize { + self.program_image_words.len().next_power_of_two().max(2) + } + + /// Extract metadata-only for shared preprocessing. + pub fn meta(&self) -> ProgramMetadata { + ProgramMetadata { + min_bytecode_address: self.min_bytecode_address, + program_image_len_words: self.program_image_words.len(), + bytecode_len: self.instructions.len(), + } + } + + /// Get PC for a given cycle (instruction lookup). + #[inline(always)] + pub fn get_pc(&self, cycle: &Cycle) -> usize { + if matches!(cycle, Cycle::NoOp) { + return 0; + } + let instr = cycle.instruction().normalize(); + self.pc_map + .get_pc(instr.address, instr.virtual_sequence_remaining.unwrap_or(0)) + } + + /// Get a BytecodePreprocessing-compatible view. + /// + /// This is for backward compatibility with code that expects BytecodePreprocessing. + pub fn as_bytecode(&self) -> crate::zkvm::bytecode::BytecodePreprocessing { + crate::zkvm::bytecode::BytecodePreprocessing { + bytecode: self.instructions.clone(), + pc_map: self.pc_map.clone(), + } + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// ProgramMetadata - O(1) metadata (shared between prover and verifier) +// ───────────────────────────────────────────────────────────────────────────── + +/// Metadata-only program info (shared between prover and verifier). +/// +/// O(1) data, safe for committed mode verifier. Does NOT contain +/// the actual instructions or program image words. +#[derive(Debug, Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct ProgramMetadata { + /// Minimum bytecode address (word-aligned). + pub min_bytecode_address: u64, + /// Number of program-image words (unpadded). + pub program_image_len_words: usize, + /// Bytecode length (power-of-2 padded). + pub bytecode_len: usize, +} + +impl ProgramMetadata { + /// Create metadata from full preprocessing. + pub fn from_program(program: &ProgramPreprocessing) -> Self { + program.meta() + } + + /// Program image word count (power-of-2 padded). + pub fn program_image_len_words_padded(&self) -> usize { + self.program_image_len_words.next_power_of_two().max(2) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// TrustedProgramCommitments - Unified commitments for committed mode +// ───────────────────────────────────────────────────────────────────────────── + +/// Trusted commitments for the entire program (bytecode chunks + program image). +/// +/// Derived from full `ProgramPreprocessing` during offline preprocessing. +/// This is what the verifier receives in Committed mode. +/// +/// # Trust Model +/// - Create via `derive()` from full program (offline preprocessing) +/// - Or deserialize from a trusted source (assumes honest origin) +/// - Pass to verifier preprocessing for succinct (online) verification +/// +/// # Security Warning +/// If you construct this type with arbitrary commitments (bypassing `derive()`), +/// verification will be unsound. Only use `derive()` or trusted deserialization. +#[derive(Clone, Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)] +pub struct TrustedProgramCommitments { + // ─── Bytecode chunk commitments ─── + /// Commitments to bytecode chunk polynomials. + pub bytecode_commitments: Vec, + /// Number of columns used when committing bytecode chunks. + pub bytecode_num_columns: usize, + /// log2(k_chunk) used for lane chunking. + pub log_k_chunk: u8, + /// Bytecode length (power-of-two padded). + pub bytecode_len: usize, + /// The T value used for bytecode coefficient indexing. + /// For CycleMajor: max_trace_len (main-matrix dimensions). + /// For AddressMajor: bytecode_len (bytecode dimensions). + /// Used in Stage 8 VMP to ensure correct index mapping. + pub bytecode_T: usize, + + // ─── Program image commitment ─── + /// Commitment to the program-image polynomial. + pub program_image_commitment: PCS::Commitment, + /// Number of columns used when committing program image. + pub program_image_num_columns: usize, + /// Number of program-image words (power-of-two padded). + pub program_image_num_words: usize, +} + +/// Opening hints for `TrustedProgramCommitments`. +/// +/// These are the Dory tier-1 data needed to build opening proofs. +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct TrustedProgramHints { + /// Hints for bytecode chunk commitments (one per chunk). + pub bytecode_hints: Vec, + /// Hint for program image commitment. + pub program_image_hint: PCS::OpeningProofHint, +} + +impl TrustedProgramCommitments { + /// Derive all program commitments from full preprocessing. + /// + /// This is the "offline preprocessing" step that must be done honestly. + /// Returns trusted commitments + hints for opening proofs. + #[tracing::instrument(skip_all, name = "TrustedProgramCommitments::derive")] + pub fn derive( + program: &ProgramPreprocessing, + generators: &PCS::ProverSetup, + log_k_chunk: usize, + max_trace_len: usize, + ) -> (Self, TrustedProgramHints) { + // ─── Derive bytecode commitments ─── + let k_chunk = 1usize << log_k_chunk; + let bytecode_len = program.bytecode_len(); + let num_chunks = total_lanes().div_ceil(k_chunk); + let log_t = max_trace_len.log_2(); + + // Get layout before context initialization. Layout affects coefficient indexing. + let layout = DoryGlobals::get_layout(); + + // Bytecode commitments: prefer a streaming/sparse Tier-1 commitment path for Dory. + // + // This avoids materializing dense coefficient vectors of length (k_chunk * T) per chunk. + // For non-Dory PCS implementations, we fall back to the dense polynomial commit path. + let (bytecode_commitments, bytecode_hints, bytecode_num_columns, bytecode_T) = if TypeId::of::< + PCS, + >( + ) + == TypeId::of::() + { + // SAFETY: guarded by the TypeId check above. In this monomorphization, PCS is + // DoryCommitmentScheme, so ProverSetup/Commitment/Hint types match exactly. + let dory_setup: &ArkworksProverSetup = + unsafe { &*(generators as *const PCS::ProverSetup as *const ArkworksProverSetup) }; + let (commitments, hints, num_columns, bytecode_t) = + derive_bytecode_commitments_sparse_dory( + program, + dory_setup, + log_k_chunk, + max_trace_len, + layout, + ); + let commitments: Vec = unsafe { std::mem::transmute(commitments) }; + let hints: Vec = unsafe { std::mem::transmute(hints) }; + (commitments, hints, num_columns, bytecode_t) + } else { + // Layout-conditional bytecode commitment generation (dense fallback): + // - CycleMajor: Use main-matrix dimensions (k_chunk * T) for correct Stage 8 embedding + // - AddressMajor: Use bytecode dimensions (k_chunk * bytecode_len), which works correctly + // + // Note: The context guard must remain alive through the commit operation, so we + // initialize and build/commit together for each layout branch. + // + // bytecode_T: The T value used for bytecode coefficient indexing (needed for Stage 8 VMP). + match layout { + DoryLayout::CycleMajor => { + let _guard = DoryGlobals::initialize_bytecode_context_with_main_dimensions( + k_chunk, + max_trace_len, + log_k_chunk, + ); + let _ctx = DoryGlobals::with_context(DoryContext::Bytecode); + let num_columns = DoryGlobals::get_num_columns(); + + let chunks = build_bytecode_chunks_for_main_matrix_from_program::( + program, + log_k_chunk, + max_trace_len, + layout, + ); + debug_assert_eq!(chunks.len(), num_chunks); + + let (commitments, hints): (Vec<_>, Vec<_>) = chunks + .par_iter() + .map(|poly| PCS::commit(poly, generators)) + .unzip(); + (commitments, hints, num_columns, max_trace_len) + } + DoryLayout::AddressMajor => { + let _guard = DoryGlobals::initialize_bytecode_context_for_main_sigma( + k_chunk, + bytecode_len, + log_k_chunk, + log_t, + ); + let _ctx = DoryGlobals::with_context(DoryContext::Bytecode); + let num_columns = DoryGlobals::get_num_columns(); + + let chunks = + build_bytecode_chunks_from_program::(program, log_k_chunk); + debug_assert_eq!(chunks.len(), num_chunks); + + let (commitments, hints): (Vec<_>, Vec<_>) = chunks + .par_iter() + .map(|poly| PCS::commit(poly, generators)) + .unzip(); + (commitments, hints, num_columns, bytecode_len) + } + } + }; + + // ─── Derive program image commitment ─── + // Compute Main's column width (sigma_main) for Stage 8 hint compatibility. + let (sigma_main, _nu_main) = DoryGlobals::main_sigma_nu(log_k_chunk, log_t); + let main_num_columns = 1usize << sigma_main; + + // Pad to power-of-two, but ensure at least `main_num_columns` so we have ≥1 row. + // This is required for the ProgramImage matrix to be non-degenerate when using + // Main's column width. + let program_image_num_words = program + .program_image_len_words() + .next_power_of_two() + .max(1) + .max(main_num_columns); + + // Initialize ProgramImage context with Main's column width for hint compatibility. + DoryGlobals::initialize_program_image_context_with_num_columns( + k_chunk, + program_image_num_words, + main_num_columns, + ); + let _ctx2 = DoryGlobals::with_context(DoryContext::ProgramImage); + let program_image_num_columns = DoryGlobals::get_num_columns(); + + // Build program image polynomial with padded size + let program_image_mle: MultilinearPolynomial = + build_program_image_polynomial_padded(program, program_image_num_words); + let (program_image_commitment, program_image_hint) = + PCS::commit(&program_image_mle, generators); + + ( + Self { + bytecode_commitments, + bytecode_num_columns, + log_k_chunk: log_k_chunk as u8, + bytecode_len, + bytecode_T, + program_image_commitment, + program_image_num_columns, + program_image_num_words, + }, + TrustedProgramHints { + bytecode_hints, + program_image_hint, + }, + ) + } + + /// Build the program-image polynomial from full preprocessing. + /// + /// Needed for Stage 8 opening proof generation. + pub fn build_program_image_polynomial( + program: &ProgramPreprocessing, + ) -> MultilinearPolynomial { + build_program_image_polynomial::(program) + } + + /// Build the program-image polynomial with explicit padded size. + /// + /// Used in committed mode where the padded size may be larger than the program's + /// own padded size (to match Main context dimensions). + pub fn build_program_image_polynomial_padded( + program: &ProgramPreprocessing, + padded_len: usize, + ) -> MultilinearPolynomial { + build_program_image_polynomial_padded::(program, padded_len) + } +} + +/// Build program-image polynomial from ProgramPreprocessing. +fn build_program_image_polynomial( + program: &ProgramPreprocessing, +) -> MultilinearPolynomial { + let padded_len = program.program_image_len_words_padded(); + build_program_image_polynomial_padded::(program, padded_len) +} + +/// Build program-image polynomial from ProgramPreprocessing with explicit padded size. +/// +/// Implementation note: we store program-image coefficients as `u64` small scalars (U64Scalars) +/// to avoid eagerly converting the entire image to field elements. +fn build_program_image_polynomial_padded( + program: &ProgramPreprocessing, + padded_len: usize, +) -> MultilinearPolynomial { + debug_assert!(padded_len.is_power_of_two()); + debug_assert!(padded_len >= program.program_image_words.len()); + let mut coeffs = vec![0u64; padded_len]; + for (i, &word) in program.program_image_words.iter().enumerate() { + coeffs[i] = word; + } + MultilinearPolynomial::from(coeffs) +} + +/// Streaming/sparse bytecode commitments for Dory. +/// +/// Computes tier-1 row commitments directly from the instruction stream by only touching +/// nonzero lane values (via `for_each_active_lane_value`). This avoids materializing the +/// dense coefficient vectors for each bytecode chunk polynomial. +/// +/// Returns: +/// - commitments: one per bytecode chunk +/// - hints: tier-1 row commitments per chunk (Dory opening proof hint) +/// - num_columns: bytecode context matrix width +/// - bytecode_T: the T used for coefficient indexing (needed later in Stage 8 VMP) +fn derive_bytecode_commitments_sparse_dory( + program: &ProgramPreprocessing, + setup: &ArkworksProverSetup, + log_k_chunk: usize, + max_trace_len: usize, + layout: DoryLayout, +) -> (Vec, Vec>, usize, usize) { + let k_chunk = 1usize << log_k_chunk; + let bytecode_len = program.bytecode_len(); + let num_chunks = total_lanes().div_ceil(k_chunk); + let log_t = max_trace_len.log_2(); + + // Initialize Bytecode context with dimensions matching the committed-bytecode strategy. + let (num_columns, bytecode_T) = match layout { + DoryLayout::CycleMajor => { + let _guard = DoryGlobals::initialize_bytecode_context_with_main_dimensions( + k_chunk, + max_trace_len, + log_k_chunk, + ); + let _ctx = DoryGlobals::with_context(DoryContext::Bytecode); + (DoryGlobals::get_num_columns(), max_trace_len) + } + DoryLayout::AddressMajor => { + let _guard = DoryGlobals::initialize_bytecode_context_for_main_sigma( + k_chunk, + bytecode_len, + log_k_chunk, + log_t, + ); + let _ctx = DoryGlobals::with_context(DoryContext::Bytecode); + (DoryGlobals::get_num_columns(), bytecode_len) + } + }; + + let total_size = k_chunk * bytecode_T; + debug_assert!( + total_size % num_columns == 0, + "expected (k_chunk*bytecode_T) divisible by num_columns" + ); + let num_rows = total_size / num_columns; + + // Build tier-1 row commitments by streaming once over the program instructions. + // + // Parallelization strategy: + // - Parallelize over `cycle` (Rayon over instructions). + // - Each thread accumulates into per-chunk sparse maps: chunk -> (row_idx -> row_commitment). + // - Reduce by pointwise addition of the sparse maps. + // + // This avoids the previous O(num_chunks * bytecode_len) rescans of the instruction stream. + let sparse_rows_by_chunk: Vec> = program.instructions + [..bytecode_len] + .par_iter() + .enumerate() + .fold( + || vec![HashMap::::new(); num_chunks], + |mut acc, (cycle, instr)| { + for_each_active_lane_value::(instr, |global_lane, lane_val| { + let chunk_idx = global_lane / k_chunk; + if chunk_idx >= num_chunks { + return; + } + let lane = global_lane % k_chunk; + + let global_index = + layout.address_cycle_to_index(lane, cycle, k_chunk, bytecode_T); + let row_idx = global_index / num_columns; + let col_idx = global_index % num_columns; + debug_assert!(row_idx < num_rows); + + let scalar = match lane_val { + ActiveLaneValue::One => Fr::one(), + ActiveLaneValue::Scalar(v) => v, + }; + if scalar.is_zero() { + return; + } + + let base = setup.g1_vec[col_idx].0; + let entry = acc[chunk_idx] + .entry(row_idx) + .or_insert_with(G1Projective::zero); + if scalar.is_one() { + *entry += base; + } else { + *entry += base * scalar; + } + }); + acc + }, + ) + .reduce( + || vec![HashMap::::new(); num_chunks], + |mut a, b| { + for (a_map, b_map) in a.iter_mut().zip(b.into_iter()) { + for (row_idx, row_commitment) in b_map.into_iter() { + let entry = a_map.entry(row_idx).or_insert_with(G1Projective::zero); + *entry += row_commitment; + } + } + a + }, + ); + + // Materialize full row-commitment vectors (hints) and compute tier-2 commitments. + let (commitments, hints): (Vec, Vec>) = sparse_rows_by_chunk + .into_iter() + .map(|row_map| { + // Full hint vector required by Dory opening proof. + let mut row_commitments: Vec = vec![ArkG1(G1Projective::zero()); num_rows]; + + // For tier-2 commitment, we can skip identity rows (pairing with identity is neutral). + let mut nonzero_rows: Vec = Vec::with_capacity(row_map.len()); + let mut nonzero_g2: Vec<_> = Vec::with_capacity(row_map.len()); + + for (row_idx, row_commitment) in row_map.into_iter() { + let rc = ArkG1(row_commitment); + row_commitments[row_idx] = rc; + nonzero_rows.push(rc); + nonzero_g2.push(setup.g2_vec[row_idx].clone()); + } + + let tier2 = ::multi_pair_g2_setup(&nonzero_rows, &nonzero_g2); + (tier2, row_commitments) + }) + .unzip(); + + (commitments, hints, num_columns, bytecode_T) +} + +/// Build bytecode chunks from ProgramPreprocessing. +/// +/// This is a wrapper that provides the legacy `BytecodePreprocessing`-like interface. +fn build_bytecode_chunks_from_program( + program: &ProgramPreprocessing, + log_k_chunk: usize, +) -> Vec> { + // Use the existing chunk-building logic via a shim + let legacy = BytecodePreprocessing { + bytecode: program.instructions.clone(), + pc_map: program.pc_map.clone(), + }; + build_bytecode_chunks::(&legacy, log_k_chunk) +} + +/// Build bytecode chunks with main-matrix dimensions for CycleMajor Stage 8 embedding. +/// +/// Uses `padded_trace_len` for coefficient indexing so that bytecode polynomials +/// are correctly embedded in the main matrix when T > bytecode_len. +fn build_bytecode_chunks_for_main_matrix_from_program( + program: &ProgramPreprocessing, + log_k_chunk: usize, + padded_trace_len: usize, + layout: DoryLayout, +) -> Vec> { + let legacy = BytecodePreprocessing { + bytecode: program.instructions.clone(), + pc_map: program.pc_map.clone(), + }; + build_bytecode_chunks_for_main_matrix::(&legacy, log_k_chunk, padded_trace_len, layout) +} + +// ───────────────────────────────────────────────────────────────────────────── +// VerifierProgram - Verifier's view of program data +// ───────────────────────────────────────────────────────────────────────────── + +/// Verifier's view of program data. +/// +/// - `Full`: Verifier has full access to the program data (O(program_size) data). +/// - `Committed`: Verifier only has trusted commitments (O(1) data). +#[derive(Debug, Clone)] +pub enum VerifierProgram { + /// Full program data available (Full mode). + Full(Arc), + /// Only trusted commitments available (Committed mode). + Committed(TrustedProgramCommitments), +} + +impl VerifierProgram { + /// Returns the full program preprocessing, or an error if in Committed mode. + pub fn as_full(&self) -> Result<&Arc, ProofVerifyError> { + match self { + VerifierProgram::Full(p) => Ok(p), + VerifierProgram::Committed(_) => Err(ProofVerifyError::BytecodeTypeMismatch( + "expected Full, got Committed".to_string(), + )), + } + } + + /// Returns true if this is Full mode. + pub fn is_full(&self) -> bool { + matches!(self, VerifierProgram::Full(_)) + } + + /// Returns true if this is Committed mode. + pub fn is_committed(&self) -> bool { + matches!(self, VerifierProgram::Committed(_)) + } + + /// Returns the trusted commitments, or an error if in Full mode. + pub fn as_committed(&self) -> Result<&TrustedProgramCommitments, ProofVerifyError> { + match self { + VerifierProgram::Committed(trusted) => Ok(trusted), + VerifierProgram::Full(_) => Err(ProofVerifyError::BytecodeTypeMismatch( + "expected Committed, got Full".to_string(), + )), + } + } + + /// Get the program-image words (only in Full mode). + pub fn program_image_words(&self) -> Option<&[u64]> { + match self { + VerifierProgram::Full(p) => Some(&p.program_image_words), + VerifierProgram::Committed(_) => None, + } + } + + /// Get the instructions (only in Full mode). + pub fn instructions(&self) -> Option<&[Instruction]> { + match self { + VerifierProgram::Full(p) => Some(&p.instructions), + VerifierProgram::Committed(_) => None, + } + } + + /// Get the full program preprocessing (only in Full mode). + pub fn full(&self) -> Option<&Arc> { + match self { + VerifierProgram::Full(p) => Some(p), + VerifierProgram::Committed(_) => None, + } + } + + /// Get a BytecodePreprocessing-compatible view (only in Full mode). + /// + /// Returns a new BytecodePreprocessing struct for backward compatibility. + pub fn as_bytecode(&self) -> Option { + match self { + VerifierProgram::Full(p) => Some(p.as_bytecode()), + VerifierProgram::Committed(_) => None, + } + } +} + +// Manual serialization for VerifierProgram +impl CanonicalSerialize for VerifierProgram { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + match self { + VerifierProgram::Full(p) => { + 0u8.serialize_with_mode(&mut writer, compress)?; + p.as_ref().serialize_with_mode(&mut writer, compress)?; + } + VerifierProgram::Committed(trusted) => { + 1u8.serialize_with_mode(&mut writer, compress)?; + trusted.serialize_with_mode(&mut writer, compress)?; + } + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + 1 + match self { + VerifierProgram::Full(p) => p.serialized_size(compress), + VerifierProgram::Committed(trusted) => trusted.serialized_size(compress), + } + } +} + +impl Valid for VerifierProgram { + fn check(&self) -> Result<(), SerializationError> { + match self { + VerifierProgram::Full(p) => p.check(), + VerifierProgram::Committed(trusted) => trusted.check(), + } + } +} + +impl CanonicalDeserialize for VerifierProgram { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let tag = u8::deserialize_with_mode(&mut reader, compress, validate)?; + match tag { + 0 => { + let p = + ProgramPreprocessing::deserialize_with_mode(&mut reader, compress, validate)?; + Ok(VerifierProgram::Full(Arc::new(p))) + } + 1 => { + let trusted = TrustedProgramCommitments::::deserialize_with_mode( + &mut reader, + compress, + validate, + )?; + Ok(VerifierProgram::Committed(trusted)) + } + _ => Err(SerializationError::InvalidData), + } + } +} diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index c6de7a1c60..354da4e4e2 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -1,5 +1,6 @@ use std::{ collections::BTreeMap, + fs::File, io::{Read, Write}, }; @@ -18,7 +19,7 @@ use crate::{ subprotocols::sumcheck::SumcheckInstanceProof, transcripts::Transcript, zkvm::{ - config::{OneHotConfig, ReadWriteConfig}, + config::{OneHotConfig, ProgramMode, ReadWriteConfig}, instruction::{CircuitFlags, InstructionFlags}, witness::{CommittedPolynomial, VirtualPolynomial}, }, @@ -38,13 +39,15 @@ pub struct JoltProof, FS: Transcr pub stage3_sumcheck_proof: SumcheckInstanceProof, pub stage4_sumcheck_proof: SumcheckInstanceProof, pub stage5_sumcheck_proof: SumcheckInstanceProof, - pub stage6_sumcheck_proof: SumcheckInstanceProof, + pub stage6a_sumcheck_proof: SumcheckInstanceProof, + pub stage6b_sumcheck_proof: SumcheckInstanceProof, pub stage7_sumcheck_proof: SumcheckInstanceProof, pub joint_opening_proof: PCS::Proof, pub untrusted_advice_commitment: Option, pub trace_length: usize, pub ram_K: usize, pub bytecode_K: usize, + pub program_mode: ProgramMode, pub rw_config: ReadWriteConfig, pub one_hot_config: OneHotConfig, pub dory_layout: DoryLayout, @@ -254,19 +257,31 @@ impl CanonicalSerialize for CommittedPolynomial { 3u8.serialize_with_mode(&mut writer, compress)?; (u8::try_from(*i).unwrap()).serialize_with_mode(writer, compress) } + Self::BytecodeChunk(i) => { + 7u8.serialize_with_mode(&mut writer, compress)?; + (u8::try_from(*i).unwrap()).serialize_with_mode(writer, compress) + } Self::RamRa(i) => { 4u8.serialize_with_mode(&mut writer, compress)?; (u8::try_from(*i).unwrap()).serialize_with_mode(writer, compress) } Self::TrustedAdvice => 5u8.serialize_with_mode(writer, compress), Self::UntrustedAdvice => 6u8.serialize_with_mode(writer, compress), + Self::ProgramImageInit => 8u8.serialize_with_mode(writer, compress), } } fn serialized_size(&self, _compress: Compress) -> usize { match self { - Self::RdInc | Self::RamInc | Self::TrustedAdvice | Self::UntrustedAdvice => 1, - Self::InstructionRa(_) | Self::BytecodeRa(_) | Self::RamRa(_) => 2, + Self::RdInc + | Self::RamInc + | Self::TrustedAdvice + | Self::UntrustedAdvice + | Self::ProgramImageInit => 1, + Self::InstructionRa(_) + | Self::BytecodeRa(_) + | Self::BytecodeChunk(_) + | Self::RamRa(_) => 2, } } } @@ -301,6 +316,11 @@ impl CanonicalDeserialize for CommittedPolynomial { } 5 => Self::TrustedAdvice, 6 => Self::UntrustedAdvice, + 7 => { + let i = u8::deserialize_with_mode(reader, compress, validate)?; + Self::BytecodeChunk(i as usize) + } + 8 => Self::ProgramImageInit, _ => return Err(SerializationError::InvalidData), }, ) @@ -367,6 +387,19 @@ impl CanonicalSerialize for VirtualPolynomial { 40u8.serialize_with_mode(&mut writer, compress)?; (u8::try_from(*flag).unwrap()).serialize_with_mode(&mut writer, compress) } + Self::BytecodeValStage(stage) => { + 41u8.serialize_with_mode(&mut writer, compress)?; + (u8::try_from(*stage).unwrap()).serialize_with_mode(&mut writer, compress) + } + Self::BytecodeReadRafAddrClaim => 42u8.serialize_with_mode(&mut writer, compress), + Self::BooleanityAddrClaim => 43u8.serialize_with_mode(&mut writer, compress), + Self::BytecodeClaimReductionIntermediate => { + 44u8.serialize_with_mode(&mut writer, compress) + } + Self::ProgramImageInitContributionRw => 45u8.serialize_with_mode(&mut writer, compress), + Self::ProgramImageInitContributionRaf => { + 46u8.serialize_with_mode(&mut writer, compress) + } } } @@ -408,11 +441,17 @@ impl CanonicalSerialize for VirtualPolynomial { | Self::RamValInit | Self::RamValFinal | Self::RamHammingWeight - | Self::UnivariateSkip => 1, + | Self::UnivariateSkip + | Self::BytecodeReadRafAddrClaim + | Self::BooleanityAddrClaim + | Self::BytecodeClaimReductionIntermediate + | Self::ProgramImageInitContributionRw + | Self::ProgramImageInitContributionRaf => 1, Self::InstructionRa(_) | Self::OpFlags(_) | Self::InstructionFlags(_) - | Self::LookupTableFlag(_) => 2, + | Self::LookupTableFlag(_) + | Self::BytecodeValStage(_) => 2, } } } @@ -488,6 +527,15 @@ impl CanonicalDeserialize for VirtualPolynomial { let flag = u8::deserialize_with_mode(&mut reader, compress, validate)?; Self::LookupTableFlag(flag as usize) } + 41 => { + let stage = u8::deserialize_with_mode(&mut reader, compress, validate)?; + Self::BytecodeValStage(stage as usize) + } + 42 => Self::BytecodeReadRafAddrClaim, + 43 => Self::BooleanityAddrClaim, + 44 => Self::BytecodeClaimReductionIntermediate, + 45 => Self::ProgramImageInitContributionRw, + 46 => Self::ProgramImageInitContributionRaf, _ => return Err(SerializationError::InvalidData), }, ) @@ -499,7 +547,6 @@ pub fn serialize_and_print_size( file_name: &str, item: &impl CanonicalSerialize, ) -> Result<(), SerializationError> { - use std::fs::File; let mut file = File::create(file_name)?; item.serialize_compressed(&mut file)?; let file_size_bytes = file.metadata()?.len(); diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index 4c87fcc55b..6eab2f2638 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -13,13 +13,8 @@ use std::{ time::Instant, }; -use crate::poly::commitment::dory::DoryContext; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use crate::zkvm::config::ReadWriteConfig; -use crate::zkvm::verifier::JoltSharedPreprocessing; -use crate::zkvm::Serializable; - #[cfg(not(target_arch = "wasm32"))] use crate::utils::profiling::print_current_memory_usage; #[cfg(feature = "allocative")] @@ -30,7 +25,7 @@ use crate::{ poly::{ commitment::{ commitment_scheme::StreamingCommitmentScheme, - dory::{DoryGlobals, DoryLayout}, + dory::{DoryContext, DoryGlobals, DoryLayout}, }, multilinear_polynomial::MultilinearPolynomial, opening_proof::{ @@ -41,7 +36,10 @@ use crate::{ }, pprof_scope, subprotocols::{ - booleanity::{BooleanitySumcheckParams, BooleanitySumcheckProver}, + booleanity::{ + BooleanityAddressSumcheckProver, BooleanityCycleSumcheckProver, + BooleanitySumcheckParams, + }, sumcheck::{BatchedSumcheck, SumcheckInstanceProof}, sumcheck_prover::SumcheckInstanceProver, univariate_skip::{prove_uniskip_round, UniSkipFirstRoundProof}, @@ -49,28 +47,31 @@ use crate::{ transcripts::Transcript, utils::{math::Math, thread::drop_in_background_thread}, zkvm::{ - bytecode::read_raf_checking::BytecodeReadRafSumcheckParams, + bytecode::{chunks::total_lanes, read_raf_checking::BytecodeReadRafSumcheckParams}, claim_reductions::{ AdviceClaimReductionParams, AdviceClaimReductionProver, AdviceKind, + BytecodeClaimReductionParams, BytecodeClaimReductionProver, BytecodeReductionPhase, HammingWeightClaimReductionParams, HammingWeightClaimReductionProver, IncClaimReductionSumcheckParams, IncClaimReductionSumcheckProver, InstructionLookupsClaimReductionSumcheckParams, - InstructionLookupsClaimReductionSumcheckProver, RaReductionParams, - RamRaClaimReductionSumcheckProver, RegistersClaimReductionSumcheckParams, - RegistersClaimReductionSumcheckProver, + InstructionLookupsClaimReductionSumcheckProver, ProgramImageClaimReductionParams, + ProgramImageClaimReductionProver, RaReductionParams, RamRaClaimReductionSumcheckProver, + RegistersClaimReductionSumcheckParams, RegistersClaimReductionSumcheckProver, }, - config::OneHotParams, + config::{OneHotParams, ProgramMode, ReadWriteConfig}, instruction_lookups::{ ra_virtual::InstructionRaSumcheckParams, read_raf_checking::InstructionReadRafSumcheckParams, }, + program::{ProgramPreprocessing, TrustedProgramCommitments, TrustedProgramHints}, ram::{ hamming_booleanity::HammingBooleanitySumcheckParams, output_check::OutputSumcheckParams, - populate_memory_states, + populate_memory_states, prover_accumulate_program_image, ra_virtual::RamRaVirtualParams, raf_evaluation::RafEvaluationSumcheckParams, read_write_checking::RamReadWriteCheckingParams, + remap_address, val_evaluation::{ ValEvaluationSumcheckParams, ValEvaluationSumcheckProver as RamValEvaluationSumcheckProver, @@ -90,13 +91,17 @@ use crate::{ }, shift::ShiftSumcheckParams, }, + verifier::JoltSharedPreprocessing, witness::all_committed_polynomials, + Serializable, }, }; use crate::{ poly::commitment::commitment_scheme::CommitmentScheme, zkvm::{ - bytecode::read_raf_checking::BytecodeReadRafSumcheckProver, + bytecode::read_raf_checking::{ + BytecodeReadRafAddressSumcheckProver, BytecodeReadRafCycleSumcheckProver, + }, fiat_shamir_preamble, instruction_lookups::{ ra_virtual::InstructionRaSumcheckProver as LookupsRaSumcheckProver, @@ -128,6 +133,7 @@ use crate::{ #[cfg(feature = "allocative")] use allocative::FlameGraphBuilder; +use common::constants::ONEHOT_CHUNK_THRESHOLD_LOG_T; use common::jolt_device::MemoryConfig; use itertools::{zip_eq, Itertools}; use rayon::prelude::*; @@ -153,6 +159,13 @@ pub struct JoltCpuProver< /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). /// Cache the prover state here between stages. advice_reduction_prover_untrusted: Option>, + /// The bytecode claim reduction sumcheck effectively spans two stages (6b and 7). + /// Cache the prover state here between stages. + bytecode_reduction_prover: Option>, + /// Bytecode read RAF params, cached between Stage 6a and 6b. + bytecode_read_raf_params: Option>, + /// Booleanity params, cached between Stage 6a and 6b. + booleanity_params: Option>, pub unpadded_trace_len: usize, pub padded_trace_len: usize, pub transcript: ProofTranscript, @@ -162,6 +175,8 @@ pub struct JoltCpuProver< pub final_ram_state: Vec, pub one_hot_params: OneHotParams, pub rw_config: ReadWriteConfig, + /// First-class selection of full vs committed bytecode mode. + pub program_mode: ProgramMode, } impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscript: Transcript> JoltCpuProver<'a, F, PCS, ProofTranscript> @@ -174,6 +189,29 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip trusted_advice: &[u8], trusted_advice_commitment: Option, trusted_advice_hint: Option, + ) -> Self { + Self::gen_from_elf_with_program_mode( + preprocessing, + elf_contents, + inputs, + untrusted_advice, + trusted_advice, + trusted_advice_commitment, + trusted_advice_hint, + ProgramMode::Full, + ) + } + + #[allow(clippy::too_many_arguments)] + pub fn gen_from_elf_with_program_mode( + preprocessing: &'a JoltProverPreprocessing, + elf_contents: &[u8], + inputs: &[u8], + untrusted_advice: &[u8], + trusted_advice: &[u8], + trusted_advice_commitment: Option, + trusted_advice_hint: Option, + program_mode: ProgramMode, ) -> Self { let memory_config = MemoryConfig { max_untrusted_advice_size: preprocessing.shared.memory_layout.max_untrusted_advice_size, @@ -219,7 +257,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip trace.len(), ); - Self::gen_from_trace( + Self::gen_from_trace_with_program_mode( preprocessing, lazy_trace, trace, @@ -227,17 +265,19 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip trusted_advice_commitment, trusted_advice_hint, final_memory_state, + program_mode, ) } /// Adjusts the padded trace length to ensure the main Dory matrix is large enough - /// to embed advice polynomials as the top-left block. + /// to embed "extra" (non-trace-streamed) polynomials as the top-left block. /// /// Returns the adjusted padded_trace_len that satisfies: /// - `sigma_main >= max_sigma_a` /// - `nu_main >= max_nu_a` /// - /// Panics if `max_padded_trace_length` is too small for the configured advice sizes. + /// Panics if `max_padded_trace_length` is too small for the configured sizes. + #[allow(clippy::too_many_arguments)] fn adjust_trace_length_for_advice( mut padded_trace_len: usize, max_padded_trace_length: usize, @@ -245,6 +285,8 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip max_untrusted_advice_size: u64, has_trusted_advice: bool, has_untrusted_advice: bool, + has_program_image: bool, + program_image_len_words_padded: usize, ) -> usize { // Canonical advice shape policy (balanced): // - advice_vars = log2(advice_len) @@ -266,6 +308,13 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip max_nu_a = max_nu_a.max(nu_a); } + if has_program_image { + let prog_vars = program_image_len_words_padded.log_2(); + let (sigma_p, nu_p) = DoryGlobals::balanced_sigma_nu(prog_vars); + max_sigma_a = max_sigma_a.max(sigma_p); + max_nu_a = max_nu_a.max(nu_p); + } + if max_sigma_a == 0 && max_nu_a == 0 { return padded_trace_len; } @@ -308,6 +357,28 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip } pub fn gen_from_trace( + preprocessing: &'a JoltProverPreprocessing, + lazy_trace: LazyTraceIterator, + trace: Vec, + program_io: JoltDevice, + trusted_advice_commitment: Option, + trusted_advice_hint: Option, + final_memory_state: Memory, + ) -> Self { + Self::gen_from_trace_with_program_mode( + preprocessing, + lazy_trace, + trace, + program_io, + trusted_advice_commitment, + trusted_advice_hint, + final_memory_state, + ProgramMode::Full, + ) + } + + #[allow(clippy::too_many_arguments)] + pub fn gen_from_trace_with_program_mode( preprocessing: &'a JoltProverPreprocessing, lazy_trace: LazyTraceIterator, mut trace: Vec, @@ -315,6 +386,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip trusted_advice_commitment: Option, trusted_advice_hint: Option, final_memory_state: Memory, + program_mode: ProgramMode, ) -> Self { // truncate trailing zeros on device outputs program_io.outputs.truncate( @@ -332,6 +404,42 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip } else { (trace.len() + 1).next_power_of_two() }; + + // In Committed mode, Stage 8 folds bytecode chunk openings into the *joint* opening. + // That folding currently requires log_T >= log_K_bytecode, so we ensure the padded trace + // length is at least the (power-of-two padded) bytecode size. + // + // For CycleMajor layout, bytecode chunks are committed with bytecode_T for coefficient + // indexing. The main context's T must be >= bytecode_T for row indices to align correctly + // during Stage 8 VMP computation. + let padded_trace_len = if program_mode == ProgramMode::Committed { + let trusted = preprocessing + .program_commitments + .as_ref() + .expect("program commitments missing in committed preprocessing"); + padded_trace_len + .max(preprocessing.shared.bytecode_size()) + .max(trusted.bytecode_T) // Ensure T >= bytecode_T for CycleMajor row alignment + } else { + padded_trace_len + }; + // In Committed mode, ProgramImageClaimReduction uses `m = log2(padded_len_words)` rounds and is + // back-loaded into Stage 6b, so we require log_T >= m. A sufficient condition is T >= padded_len_words. + let (has_program_image, program_image_len_words_padded) = + if program_mode == ProgramMode::Committed { + let trusted = preprocessing + .program_commitments + .as_ref() + .expect("program commitments missing in committed preprocessing"); + (true, trusted.program_image_num_words) + } else { + (false, 0usize) + }; + let padded_trace_len = if has_program_image { + padded_trace_len.max(program_image_len_words_padded) + } else { + padded_trace_len + }; // We may need extra padding so the main Dory matrix has enough (row, col) variables // to embed advice commitments committed in their own preprocessing-only contexts. let has_trusted_advice = !program_io.trusted_advice.is_empty(); @@ -344,6 +452,8 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip preprocessing.shared.memory_layout.max_untrusted_advice_size, has_trusted_advice, has_untrusted_advice, + has_program_image, + program_image_len_words_padded, ); trace.resize(padded_trace_len, Cycle::NoOp); @@ -352,7 +462,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip let ram_K = trace .par_iter() .filter_map(|cycle| { - crate::zkvm::ram::remap_address( + remap_address( cycle.ram_access().address() as u64, &preprocessing.shared.memory_layout, ) @@ -360,12 +470,19 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip .max() .unwrap_or(0) .max( - crate::zkvm::ram::remap_address( - preprocessing.shared.ram.min_bytecode_address, + remap_address( + preprocessing.program.min_bytecode_address, &preprocessing.shared.memory_layout, ) .unwrap_or(0) - + preprocessing.shared.ram.bytecode_words.len() as u64 + + { + let base = preprocessing.program.program_image_words.len() as u64; + if has_program_image { + (program_image_len_words_padded as u64).max(base) + } else { + base + } + } + 1, ) .next_power_of_two() as usize; @@ -377,7 +494,8 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip let (initial_ram_state, final_ram_state) = gen_ram_memory_states::( ram_K, - &preprocessing.shared.ram, + preprocessing.program.min_bytecode_address, + &preprocessing.program.program_image_words, &program_io, &final_memory_state, ); @@ -385,8 +503,16 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip let log_T = trace.len().log_2(); let ram_log_K = ram_K.log_2(); let rw_config = ReadWriteConfig::new(log_T, ram_log_K); - let one_hot_params = - OneHotParams::new(log_T, preprocessing.shared.bytecode.code_size, ram_K); + let one_hot_params = if program_mode == ProgramMode::Committed { + let committed = preprocessing + .program_commitments + .as_ref() + .expect("program commitments missing in committed mode"); + let config = OneHotConfig::from_log_k_chunk(committed.log_k_chunk as usize); + OneHotParams::from_config(&config, preprocessing.shared.bytecode_size(), ram_K) + } else { + OneHotParams::new(log_T, preprocessing.shared.bytecode_size(), ram_K) + }; Self { preprocessing, @@ -402,6 +528,9 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip }, advice_reduction_prover_trusted: None, advice_reduction_prover_untrusted: None, + bytecode_reduction_prover: None, + bytecode_read_raf_params: None, + booleanity_params: None, unpadded_trace_len, padded_trace_len, transcript, @@ -411,6 +540,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip final_ram_state, one_hot_params, rw_config, + program_mode, } } @@ -434,13 +564,63 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip tracing::info!( "bytecode size: {}", - self.preprocessing.shared.bytecode.code_size + self.preprocessing.shared.bytecode_size() ); let (commitments, mut opening_proof_hints) = self.generate_and_commit_witness_polynomials(); let untrusted_advice_commitment = self.generate_and_commit_untrusted_advice(); self.generate_and_commit_trusted_advice(); + if self.program_mode == ProgramMode::Committed { + if let Some(trusted) = &self.preprocessing.program_commitments { + // Append bytecode chunk commitments + for commitment in &trusted.bytecode_commitments { + self.transcript.append_serializable(commitment); + } + // Append program image commitment + self.transcript + .append_serializable(&trusted.program_image_commitment); + #[cfg(test)] + { + // Sanity: re-commit the program image polynomial and ensure it matches the trusted commitment. + // Must use the same padded size and context as TrustedProgramCommitments::derive(). + let mle = + TrustedProgramCommitments::::build_program_image_polynomial_padded::( + &self.preprocessing.program, + trusted.program_image_num_words, + ); + // Recompute log_k_chunk and max_log_t to get Main's sigma. + let max_t_any: usize = self + .preprocessing + .shared + .max_padded_trace_length + .max(self.preprocessing.shared.bytecode_size()) + .next_power_of_two(); + let max_log_t = max_t_any.log_2(); + let log_k_chunk = if max_log_t < common::constants::ONEHOT_CHUNK_THRESHOLD_LOG_T + { + 4 + } else { + 8 + }; + // Use the explicit context initialization to match TrustedProgramCommitments::derive() + let (sigma_main, _) = DoryGlobals::main_sigma_nu(log_k_chunk, max_log_t); + let main_num_columns = 1usize << sigma_main; + DoryGlobals::initialize_program_image_context_with_num_columns( + 1usize << log_k_chunk, + trusted.program_image_num_words, + main_num_columns, + ); + let _ctx = DoryGlobals::with_context(DoryContext::ProgramImage); + let (recommit, _hint) = PCS::commit(&mle, &self.preprocessing.generators); + assert_eq!( + recommit, trusted.program_image_commitment, + "ProgramImageInit commitment mismatch vs polynomial used in proving" + ); + } + } + } + // Add advice hints for batched Stage 8 opening if let Some(hint) = self.advice.trusted_advice_hint.take() { opening_proof_hints.insert(CommittedPolynomial::TrustedAdvice, hint); @@ -448,13 +628,28 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip if let Some(hint) = self.advice.untrusted_advice_hint.take() { opening_proof_hints.insert(CommittedPolynomial::UntrustedAdvice, hint); } + if self.program_mode == ProgramMode::Committed { + if let Some(hints) = self.preprocessing.program_hints.as_ref() { + for (idx, hint) in hints.bytecode_hints.iter().enumerate() { + opening_proof_hints + .insert(CommittedPolynomial::BytecodeChunk(idx), hint.clone()); + } + } + if let Some(hints) = self.preprocessing.program_hints.as_ref() { + opening_proof_hints.insert( + CommittedPolynomial::ProgramImageInit, + hints.program_image_hint.clone(), + ); + } + } let (stage1_uni_skip_first_round_proof, stage1_sumcheck_proof) = self.prove_stage1(); let (stage2_uni_skip_first_round_proof, stage2_sumcheck_proof) = self.prove_stage2(); let stage3_sumcheck_proof = self.prove_stage3(); let stage4_sumcheck_proof = self.prove_stage4(); let stage5_sumcheck_proof = self.prove_stage5(); - let stage6_sumcheck_proof = self.prove_stage6(); + let stage6a_sumcheck_proof = self.prove_stage6a(); + let stage6b_sumcheck_proof = self.prove_stage6b(); let stage7_sumcheck_proof = self.prove_stage7(); let joint_opening_proof = self.prove_stage8(opening_proof_hints); @@ -489,12 +684,14 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip stage3_sumcheck_proof, stage4_sumcheck_proof, stage5_sumcheck_proof, - stage6_sumcheck_proof, + stage6a_sumcheck_proof, + stage6b_sumcheck_proof, stage7_sumcheck_proof, joint_opening_proof, trace_length: self.trace.len(), ram_K: self.one_hot_params.ram_k, bytecode_K: self.one_hot_params.bytecode_k, + program_mode: self.program_mode, rw_config: self.rw_config.clone(), one_hot_config: self.one_hot_params.to_config(), dory_layout: DoryGlobals::get_layout(), @@ -519,12 +716,26 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip Vec, HashMap, ) { - let _guard = DoryGlobals::initialize_context( - 1 << self.one_hot_params.log_k_chunk, - self.padded_trace_len, - DoryContext::Main, - Some(DoryGlobals::get_layout()), - ); + let _guard = if self.program_mode == ProgramMode::Committed { + let committed = self + .preprocessing + .program_commitments + .as_ref() + .expect("program commitments missing in committed mode"); + DoryGlobals::initialize_main_context_with_num_columns( + 1 << self.one_hot_params.log_k_chunk, + self.padded_trace_len, + committed.bytecode_num_columns, + Some(DoryGlobals::get_layout()), + ) + } else { + DoryGlobals::initialize_context( + 1 << self.one_hot_params.log_k_chunk, + self.padded_trace_len, + DoryContext::Main, + Some(DoryGlobals::get_layout()), + ) + }; let polys = all_committed_polynomials(&self.one_hot_params); let T = DoryGlobals::get_T(); @@ -548,7 +759,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip .par_iter() .map(|poly_id| { let witness: MultilinearPolynomial = poly_id.generate_witness( - &self.preprocessing.shared.bytecode, + &self.preprocessing.program, &self.preprocessing.shared.memory_layout, &trace, Some(&self.one_hot_params), @@ -588,6 +799,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip poly.stream_witness_and_commit_rows::<_, PCS>( &self.preprocessing.generators, &self.preprocessing.shared, + &self.preprocessing.program, &chunk, &self.one_hot_params, ) @@ -702,7 +914,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip let mut uni_skip = OuterUniSkipProver::initialize( uni_skip_params.clone(), &self.trace, - &self.preprocessing.shared.bytecode, + &self.preprocessing.program, ); let first_round_proof = prove_uniskip_round( &mut uni_skip, @@ -718,7 +930,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip let schedule = LinearOnlySchedule::new(uni_skip_params.tau.len() - 1); let shared = OuterSharedState::new( Arc::clone(&self.trace), - &self.preprocessing.shared.bytecode, + &self.preprocessing.program, &uni_skip_params, &self.opening_accumulator, ); @@ -798,7 +1010,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip let ram_read_write_checking = RamReadWriteCheckingProver::initialize( ram_read_write_checking_params, &self.trace, - &self.preprocessing.shared.bytecode, + &self.preprocessing.program, &self.program_io.memory_layout, &self.initial_ram_state, ); @@ -875,7 +1087,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip let spartan_shift = ShiftSumcheckProver::initialize( spartan_shift_params, Arc::clone(&self.trace), - &self.preprocessing.shared.bytecode, + &self.preprocessing.program, ); let spartan_instruction_input = InstructionInputSumcheckProver::initialize( spartan_instruction_input_params, @@ -936,6 +1148,24 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip self.rw_config .needs_single_advice_opening(self.trace.len().log_2()), ); + if self.program_mode == ProgramMode::Committed { + let trusted = self + .preprocessing + .program_commitments + .as_ref() + .expect("program commitments missing in committed mode"); + prover_accumulate_program_image::( + self.one_hot_params.ram_k, + self.preprocessing.program.min_bytecode_address, + &self.preprocessing.program.program_image_words, + &self.program_io, + trusted.program_image_num_words, + &mut self.opening_accumulator, + &mut self.transcript, + self.rw_config + .needs_single_advice_opening(self.trace.len().log_2()), + ); + } let registers_read_write_checking_params = RegistersReadWriteCheckingParams::new( self.trace.len(), @@ -955,19 +1185,19 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip let registers_read_write_checking = RegistersReadWriteCheckingProver::initialize( registers_read_write_checking_params, self.trace.clone(), - &self.preprocessing.shared.bytecode, + &self.preprocessing.program, &self.program_io.memory_layout, ); let ram_val_evaluation = RamValEvaluationSumcheckProver::initialize( ram_val_evaluation_params, &self.trace, - &self.preprocessing.shared.bytecode, + &self.preprocessing.program, &self.program_io.memory_layout, ); let ram_val_final = ValFinalSumcheckProver::initialize( ram_val_final_params, &self.trace, - &self.preprocessing.shared.bytecode, + &self.preprocessing.program, &self.program_io.memory_layout, ); @@ -1024,7 +1254,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip let registers_val_evaluation = RegistersValEvaluationSumcheckProver::initialize( registers_val_evaluation_params, &self.trace, - &self.preprocessing.shared.bytecode, + &self.preprocessing.program, &self.program_io.memory_layout, ); let ram_ra_reduction = RamRaClaimReductionSumcheckProver::initialize( @@ -1070,20 +1300,19 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip } #[tracing::instrument(skip_all)] - fn prove_stage6(&mut self) -> SumcheckInstanceProof { + fn prove_stage6a(&mut self) -> SumcheckInstanceProof { #[cfg(not(target_arch = "wasm32"))] - print_current_memory_usage("Stage 6 baseline"); + print_current_memory_usage("Stage 6a baseline"); - let bytecode_read_raf_params = BytecodeReadRafSumcheckParams::gen( - &self.preprocessing.shared.bytecode, + let mut bytecode_read_raf_params = BytecodeReadRafSumcheckParams::gen( + &self.preprocessing.program, self.trace.len().log_2(), &self.one_hot_params, &self.opening_accumulator, &mut self.transcript, ); - - let ram_hamming_booleanity_params = - HammingBooleanitySumcheckParams::new(&self.opening_accumulator); + bytecode_read_raf_params.use_staged_val_claims = + self.program_mode == ProgramMode::Committed; let booleanity_params = BooleanitySumcheckParams::new( self.trace.len().log_2(), @@ -1092,6 +1321,65 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip &mut self.transcript, ); + let mut bytecode_read_raf = BytecodeReadRafAddressSumcheckProver::initialize( + bytecode_read_raf_params.clone(), + Arc::clone(&self.trace), + Arc::clone(&self.preprocessing.program), + ); + let mut booleanity = BooleanityAddressSumcheckProver::initialize( + booleanity_params.clone(), + &self.trace, + &self.preprocessing.program, + &self.program_io.memory_layout, + ); + + #[cfg(feature = "allocative")] + { + print_data_structure_heap_usage( + "BytecodeReadRafAddressSumcheckProver", + &bytecode_read_raf, + ); + print_data_structure_heap_usage("BooleanityAddressSumcheckProver", &booleanity); + } + + let mut instances: Vec<&mut dyn SumcheckInstanceProver<_, _>> = + vec![&mut bytecode_read_raf, &mut booleanity]; + + #[cfg(feature = "allocative")] + write_instance_flamegraph_svg(&instances, "stage6a_start_flamechart.svg"); + tracing::info!("Stage 6a proving"); + let (sumcheck_proof, _r_stage6a) = BatchedSumcheck::prove( + instances.iter_mut().map(|v| &mut **v as _).collect(), + &mut self.opening_accumulator, + &mut self.transcript, + ); + #[cfg(feature = "allocative")] + write_instance_flamegraph_svg(&instances, "stage6a_end_flamechart.svg"); + + // Cache params for Stage 6b + self.bytecode_read_raf_params = Some(bytecode_read_raf_params); + self.booleanity_params = Some(booleanity_params); + + sumcheck_proof + } + + #[tracing::instrument(skip_all)] + fn prove_stage6b(&mut self) -> SumcheckInstanceProof { + #[cfg(not(target_arch = "wasm32"))] + print_current_memory_usage("Stage 6b baseline"); + + let bytecode_read_raf_params = self + .bytecode_read_raf_params + .take() + .expect("bytecode_read_raf_params must be set by prove_stage6a"); + let booleanity_params = self + .booleanity_params + .take() + .expect("booleanity_params must be set by prove_stage6a"); + + let ram_hamming_booleanity_params = + HammingBooleanitySumcheckParams::new(&self.opening_accumulator); + let ram_ra_virtual_params = RamRaVirtualParams::new( self.trace.len(), &self.one_hot_params, @@ -1108,7 +1396,24 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip &mut self.transcript, ); - // Advice claim reduction (Phase 1 in Stage 6): trusted and untrusted are separate instances. + // Bytecode claim reduction (Phase 1 in Stage 6b): consumes Val_s(r_bc) from Stage 6a and + // caches an intermediate claim for Stage 7. + if self.program_mode == ProgramMode::Committed { + let bytecode_reduction_params = BytecodeClaimReductionParams::new( + &bytecode_read_raf_params, + &self.opening_accumulator, + &mut self.transcript, + ); + self.bytecode_reduction_prover = Some(BytecodeClaimReductionProver::initialize( + bytecode_reduction_params, + Arc::clone(&self.preprocessing.program), + )); + } else { + // Legacy mode: do not run the bytecode claim reduction. + self.bytecode_reduction_prover = None; + } + + // Advice claim reduction (Phase 1 in Stage 6b): trusted and untrusted are separate instances. if self.advice.trusted_advice_polynomial.is_some() { let trusted_advice_params = AdviceClaimReductionParams::new( AdviceKind::Trusted, @@ -1159,20 +1464,22 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip }; } - let mut bytecode_read_raf = BytecodeReadRafSumcheckProver::initialize( + // Initialize Stage 6b cycle provers from Stage 6a openings + let mut bytecode_read_raf = BytecodeReadRafCycleSumcheckProver::initialize( bytecode_read_raf_params, Arc::clone(&self.trace), - Arc::clone(&self.preprocessing.shared.bytecode), + Arc::clone(&self.preprocessing.program), + &self.opening_accumulator, ); - let mut ram_hamming_booleanity = - HammingBooleanitySumcheckProver::initialize(ram_hamming_booleanity_params, &self.trace); - - let mut booleanity = BooleanitySumcheckProver::initialize( + let mut booleanity = BooleanityCycleSumcheckProver::initialize( booleanity_params, &self.trace, - &self.preprocessing.shared.bytecode, + &self.preprocessing.program, &self.program_io.memory_layout, + &self.opening_accumulator, ); + let mut ram_hamming_booleanity = + HammingBooleanitySumcheckProver::initialize(ram_hamming_booleanity_params, &self.trace); let mut ram_ra_virtual = RamRaVirtualSumcheckProver::initialize( ram_ra_virtual_params, @@ -1187,12 +1494,15 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip #[cfg(feature = "allocative")] { - print_data_structure_heap_usage("BytecodeReadRafSumcheckProver", &bytecode_read_raf); - print_data_structure_heap_usage("BooleanitySumcheckProver", &booleanity); + print_data_structure_heap_usage( + "BytecodeReadRafCycleSumcheckProver", + &bytecode_read_raf, + ); print_data_structure_heap_usage( "ram HammingBooleanitySumcheckProver", &ram_hamming_booleanity, ); + print_data_structure_heap_usage("BooleanityCycleSumcheckProver", &booleanity); print_data_structure_heap_usage("RamRaSumcheckProver", &ram_ra_virtual); print_data_structure_heap_usage("LookupsRaSumcheckProver", &lookups_ra_virtual); print_data_structure_heap_usage("IncClaimReductionSumcheckProver", &inc_reduction); @@ -1212,23 +1522,62 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip &mut lookups_ra_virtual, &mut inc_reduction, ]; + if let Some(bytecode) = self.bytecode_reduction_prover.as_mut() { + instances.push(bytecode); + } if let Some(advice) = self.advice_reduction_prover_trusted.as_mut() { instances.push(advice); } if let Some(advice) = self.advice_reduction_prover_untrusted.as_mut() { instances.push(advice); } + // Program-image claim reduction (Stage 6b): binds staged Stage 4 program-image scalar claims + // to the trusted commitment via a degree-2 sumcheck, caching an opening of ProgramImageInit. + let mut program_image_reduction: Option> = None; + if self.program_mode == ProgramMode::Committed { + let trusted = self + .preprocessing + .program_commitments + .as_ref() + .expect("program commitments missing in committed mode"); + let padded_len_words = trusted.program_image_num_words; + let log_t = self.trace.len().log_2(); + let m = padded_len_words.log_2(); + assert!( + m <= log_t, + "program-image claim reduction requires m=log2(padded_len_words) <= log_T (got m={m}, log_T={log_t})" + ); + let params = ProgramImageClaimReductionParams::new( + &self.program_io, + self.preprocessing.program.min_bytecode_address, + padded_len_words, + self.one_hot_params.ram_k, + self.trace.len(), + &self.rw_config, + &self.opening_accumulator, + &mut self.transcript, + ); + // Build padded coefficients for ProgramWord polynomial. + let mut coeffs = self.preprocessing.program.program_image_words.clone(); + coeffs.resize(padded_len_words, 0u64); + program_image_reduction = + Some(ProgramImageClaimReductionProver::initialize(params, coeffs)); + } + if let Some(ref mut prog) = program_image_reduction { + instances.push(prog); + } #[cfg(feature = "allocative")] - write_instance_flamegraph_svg(&instances, "stage6_start_flamechart.svg"); - tracing::info!("Stage 6 proving"); - let (sumcheck_proof, _r_stage6) = BatchedSumcheck::prove( + write_instance_flamegraph_svg(&instances, "stage6b_start_flamechart.svg"); + tracing::info!("Stage 6b proving"); + + let (sumcheck_proof, _r_stage6b) = BatchedSumcheck::prove( instances.iter_mut().map(|v| &mut **v as _).collect(), &mut self.opening_accumulator, &mut self.transcript, ); #[cfg(feature = "allocative")] - write_instance_flamegraph_svg(&instances, "stage6_end_flamechart.svg"); + write_instance_flamegraph_svg(&instances, "stage6b_end_flamechart.svg"); drop_in_background_thread(bytecode_read_raf); drop_in_background_thread(booleanity); drop_in_background_thread(ram_hamming_booleanity); @@ -1236,6 +1585,10 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip drop_in_background_thread(lookups_ra_virtual); drop_in_background_thread(inc_reduction); + if let Some(prog) = program_image_reduction { + drop_in_background_thread(prog); + } + sumcheck_proof } @@ -1253,6 +1606,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip hw_params, &self.trace, &self.preprocessing.shared, + &self.preprocessing.program, &self.one_hot_params, ); @@ -1260,10 +1614,19 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip print_data_structure_heap_usage("HammingWeightClaimReductionProver", &hw_prover); // Run Stage 7 batched sumcheck (address rounds only). - // Includes HammingWeightClaimReduction plus address phase of advice reduction instances (if needed). + // Includes HammingWeightClaimReduction plus lane/address-phase reductions (if needed). let mut instances: Vec>> = vec![Box::new(hw_prover)]; + if let Some(mut bytecode_reduction_prover) = self.bytecode_reduction_prover.take() { + // Stage 6b → Stage 7 transition for bytecode claim reduction: + // - Cycle-phase sumcheck is complete, so we can materialize the lane-phase witness + // polynomials B_i(·, r_cycle) (GPU-style "export b_vals"). + bytecode_reduction_prover.prepare_lane_phase(); + bytecode_reduction_prover.params.phase = BytecodeReductionPhase::LaneVariables; + instances.push(Box::new(bytecode_reduction_prover)); + } + if let Some(mut advice_reduction_prover_trusted) = self.advice_reduction_prover_trusted.take() { @@ -1315,12 +1678,26 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip ) -> PCS::Proof { tracing::info!("Stage 8 proving (Dory batch opening)"); - let _guard = DoryGlobals::initialize_context( - self.one_hot_params.k_chunk, - self.padded_trace_len, - DoryContext::Main, - Some(DoryGlobals::get_layout()), - ); + let _guard = if self.program_mode == ProgramMode::Committed { + let committed = self + .preprocessing + .program_commitments + .as_ref() + .expect("program commitments missing in committed mode"); + DoryGlobals::initialize_main_context_with_num_columns( + self.one_hot_params.k_chunk, + self.padded_trace_len, + committed.bytecode_num_columns, + Some(DoryGlobals::get_layout()), + ) + } else { + DoryGlobals::initialize_context( + self.one_hot_params.k_chunk, + self.padded_trace_len, + DoryContext::Main, + Some(DoryGlobals::get_layout()), + ) + }; // Get the unified opening point from HammingWeightClaimReduction // This contains (r_address_stage7 || r_cycle_stage6) in big-endian @@ -1435,6 +1812,65 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip )); } + // Bytecode chunk polynomials: committed in Bytecode context and embedded into the + // main opening point by fixing the extra cycle variables to 0. + if self.program_mode == ProgramMode::Committed { + let (bytecode_point, _) = self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::BytecodeChunk(0), + SumcheckId::BytecodeClaimReduction, + ); + let log_t = opening_point.r.len() - log_k_chunk; + let log_k = bytecode_point.r.len() - log_k_chunk; + assert!( + log_k <= log_t, + "bytecode folding requires log_T >= log_K (got log_T={log_t}, log_K={log_k})" + ); + #[cfg(test)] + { + if log_k == log_t { + assert_eq!( + bytecode_point.r, opening_point.r, + "BytecodeChunk opening point must equal unified opening point when log_K == log_T" + ); + } else { + let (r_lane_main, r_cycle_main) = opening_point.split_at(log_k_chunk); + let (r_lane_bc, r_cycle_bc) = bytecode_point.split_at(log_k_chunk); + debug_assert_eq!(r_lane_main.r, r_lane_bc.r); + debug_assert_eq!(&r_cycle_main.r[(log_t - log_k)..], r_cycle_bc.r.as_slice()); + } + } + let lagrange_factor = + compute_advice_lagrange_factor::(&opening_point.r, &bytecode_point.r); + + let num_chunks = total_lanes().div_ceil(self.one_hot_params.k_chunk); + for i in 0..num_chunks { + let (_, claim) = self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::BytecodeChunk(i), + SumcheckId::BytecodeClaimReduction, + ); + polynomial_claims.push(( + CommittedPolynomial::BytecodeChunk(i), + claim * lagrange_factor, + )); + } + } + + // Program-image polynomial: opened by ProgramImageClaimReduction in Stage 6b. + // Embed into the top-left block of the main matrix (same trick as advice). + if self.program_mode == ProgramMode::Committed { + let (prog_point, prog_claim) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + ); + let lagrange_factor = + compute_advice_lagrange_factor::(&opening_point.r, &prog_point.r); + polynomial_claims.push(( + CommittedPolynomial::ProgramImageInit, + prog_claim * lagrange_factor, + )); + } + // 2. Sample gamma and compute powers for RLC let claims: Vec = polynomial_claims.iter().map(|(_, c)| *c).collect(); self.transcript.append_scalars(&claims); @@ -1448,7 +1884,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip }; let streaming_data = Arc::new(RLCStreamingData { - bytecode: Arc::clone(&self.preprocessing.shared.bytecode), + program: Arc::clone(&self.preprocessing.program), memory_layout: self.preprocessing.shared.memory_layout.clone(), }); @@ -1460,15 +1896,45 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip if let Some(poly) = self.advice.untrusted_advice_polynomial.take() { advice_polys.insert(CommittedPolynomial::UntrustedAdvice, poly); } + if self.program_mode == ProgramMode::Committed { + let trusted = self + .preprocessing + .program_commitments + .as_ref() + .expect("program commitments missing in committed mode"); + // Use the padded size from the trusted commitments (may be larger than program's own padded size) + advice_polys.insert( + CommittedPolynomial::ProgramImageInit, + TrustedProgramCommitments::::build_program_image_polynomial_padded::( + &self.preprocessing.program, + trusted.program_image_num_words, + ), + ); + } // Build streaming RLC polynomial directly (no witness poly regeneration!) // Use materialized trace (default, single pass) instead of lazy trace + // + // bytecode_T: The T value used for bytecode coefficient indexing. + // In Committed mode, use the value stored in trusted commitments. + // In Full mode, use bytecode_len (original behavior). + let bytecode_T = if self.program_mode == ProgramMode::Committed { + let trusted = self + .preprocessing + .program_commitments + .as_ref() + .expect("program commitments missing in committed mode"); + trusted.bytecode_T + } else { + self.preprocessing.program.bytecode_len() + }; let (joint_poly, hint) = state.build_streaming_rlc::( self.one_hot_params.clone(), TraceSource::Materialized(Arc::clone(&self.trace)), streaming_data, opening_proof_hints, advice_polys, + bytecode_T, ); PCS::prove( @@ -1519,6 +1985,15 @@ fn write_instance_flamegraph_svg( pub struct JoltProverPreprocessing> { pub generators: PCS::ProverSetup, pub shared: JoltSharedPreprocessing, + /// Full program preprocessing (prover always has full access for witness computation). + pub program: Arc, + /// Trusted program commitments (only in Committed mode). + /// + /// In Full mode: None (verifier has full program). + /// In Committed mode: Some(trusted) for bytecode + program-image commitments. + pub program_commitments: Option>, + /// Opening proof hints for program commitments (only in Committed mode). + pub program_hints: Option>, } impl JoltProverPreprocessing @@ -1526,12 +2001,8 @@ where F: JoltField, PCS: CommitmentScheme, { - #[tracing::instrument(skip_all, name = "JoltProverPreprocessing::gen")] - pub fn new( - shared: JoltSharedPreprocessing, - // max_trace_length: usize, - ) -> JoltProverPreprocessing { - use common::constants::ONEHOT_CHUNK_THRESHOLD_LOG_T; + /// Setup generators based on trace length (Main context). + fn setup_generators(shared: &JoltSharedPreprocessing) -> PCS::ProverSetup { let max_T: usize = shared.max_padded_trace_length.next_power_of_two(); let max_log_T = max_T.log_2(); // Use the maximum possible log_k_chunk for generator setup @@ -1540,8 +2011,84 @@ where } else { 8 }; - let generators = PCS::setup_prover(max_log_k_chunk + max_log_T); - JoltProverPreprocessing { generators, shared } + PCS::setup_prover(max_log_k_chunk + max_log_T) + } + + /// Setup generators for Committed mode, ensuring capacity for both: + /// - Main context up to `max_padded_trace_length` + /// - Bytecode context up to `bytecode_size` + /// - ProgramImage context up to the padded program-image word length + fn setup_generators_committed( + shared: &JoltSharedPreprocessing, + program: &ProgramPreprocessing, + ) -> PCS::ProverSetup { + let prog_len_words_padded = program.program_image_len_words_padded(); + let max_t_any: usize = shared + .max_padded_trace_length + .max(shared.bytecode_size()) + .max(prog_len_words_padded) + .next_power_of_two(); + let max_log_t_any = max_t_any.log_2(); + let max_log_k_chunk = if max_log_t_any < ONEHOT_CHUNK_THRESHOLD_LOG_T { + 4 + } else { + 8 + }; + PCS::setup_prover(max_log_k_chunk + max_log_t_any) + } + + /// Create prover preprocessing in Full mode (no commitments). + /// + /// Use this when the verifier will have access to full program. + #[tracing::instrument(skip_all, name = "JoltProverPreprocessing::new")] + pub fn new( + shared: JoltSharedPreprocessing, + program: Arc, + ) -> JoltProverPreprocessing { + let generators = Self::setup_generators(&shared); + JoltProverPreprocessing { + generators, + shared, + program, + program_commitments: None, + program_hints: None, + } + } + + /// Create prover preprocessing in Committed mode (with program commitments). + /// + /// Use this when the verifier should only receive commitments (succinct verification). + /// Computes commitments + hints for all bytecode chunk polynomials and program image during preprocessing. + #[tracing::instrument(skip_all, name = "JoltProverPreprocessing::new_committed")] + pub fn new_committed( + shared: JoltSharedPreprocessing, + program: Arc, + ) -> JoltProverPreprocessing { + let generators = Self::setup_generators_committed(&shared, &program); + let max_t_any: usize = shared + .max_padded_trace_length + .max(shared.bytecode_size()) + .next_power_of_two(); + let max_log_t = max_t_any.log_2(); + let log_k_chunk = if max_log_t < common::constants::ONEHOT_CHUNK_THRESHOLD_LOG_T { + 4 + } else { + 8 + }; + let (program_commitments, program_hints) = + TrustedProgramCommitments::derive(&program, &generators, log_k_chunk, max_t_any); + JoltProverPreprocessing { + generators, + shared, + program, + program_commitments: Some(program_commitments), + program_hints: Some(program_hints), + } + } + + /// Check if this preprocessing is in Committed mode. + pub fn is_committed_mode(&self) -> bool { + self.program_commitments.is_some() } pub fn save_to_target_dir(&self, target_dir: &str) -> std::io::Result<()> { @@ -1566,891 +2113,3 @@ impl> Serializable for JoltProverPreprocessing { } - -#[cfg(test)] -mod tests { - use ark_bn254::Fr; - use serial_test::serial; - - use crate::host; - use crate::poly::commitment::dory::{DoryGlobals, DoryLayout}; - use crate::poly::{ - commitment::{ - commitment_scheme::CommitmentScheme, - dory::{DoryCommitmentScheme, DoryContext}, - }, - multilinear_polynomial::MultilinearPolynomial, - opening_proof::{OpeningAccumulator, SumcheckId}, - }; - use crate::zkvm::claim_reductions::AdviceKind; - use crate::zkvm::verifier::JoltSharedPreprocessing; - use crate::zkvm::witness::CommittedPolynomial; - use crate::zkvm::{ - prover::JoltProverPreprocessing, - ram::populate_memory_states, - verifier::{JoltVerifier, JoltVerifierPreprocessing}, - RV64IMACProver, RV64IMACVerifier, - }; - - fn commit_trusted_advice_preprocessing_only( - preprocessing: &JoltProverPreprocessing, - trusted_advice_bytes: &[u8], - ) -> ( - ::Commitment, - ::OpeningProofHint, - ) { - let max_trusted_advice_size = preprocessing.shared.memory_layout.max_trusted_advice_size; - let mut trusted_advice_words = vec![0u64; (max_trusted_advice_size as usize) / 8]; - populate_memory_states( - 0, - trusted_advice_bytes, - Some(&mut trusted_advice_words), - None, - ); - - let poly = MultilinearPolynomial::::from(trusted_advice_words); - let advice_len = poly.len().next_power_of_two().max(1); - - let _guard = - DoryGlobals::initialize_context(1, advice_len, DoryContext::TrustedAdvice, None); - let (commitment, hint) = { - let _ctx = DoryGlobals::with_context(DoryContext::TrustedAdvice); - DoryCommitmentScheme::commit(&poly, &preprocessing.generators) - }; - (commitment, hint) - } - - #[test] - #[serial] - fn fib_e2e_dory() { - DoryGlobals::reset(); - let mut program = host::Program::new("fibonacci-guest"); - let inputs = postcard::to_stdvec(&100u32).unwrap(); - let (bytecode, init_memory_state, _) = program.decode(); - let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), - init_memory_state, - 1 << 16, - ); - - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - let elf_contents_opt = program.get_elf_contents(); - let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); - let prover = RV64IMACProver::gen_from_elf( - &prover_preprocessing, - elf_contents, - &inputs, - &[], - &[], - None, - None, - ); - let io_device = prover.program_io.clone(); - let (jolt_proof, debug_info) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::new( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); - let verifier = RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device, - None, - debug_info, - ) - .expect("Failed to create verifier"); - verifier.verify().expect("Failed to verify proof"); - } - - #[test] - #[serial] - fn small_trace_e2e_dory() { - DoryGlobals::reset(); - let mut program = host::Program::new("fibonacci-guest"); - let inputs = postcard::to_stdvec(&5u32).unwrap(); - let (bytecode, init_memory_state, _) = program.decode(); - let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), - init_memory_state, - 256, - ); - - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - let elf_contents_opt = program.get_elf_contents(); - let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); - let log_chunk = 8; // Use default log_chunk for tests - let prover = RV64IMACProver::gen_from_elf( - &prover_preprocessing, - elf_contents, - &inputs, - &[], - &[], - None, - None, - ); - - assert!( - prover.padded_trace_len <= (1 << log_chunk), - "Test requires T <= chunk_size ({}), got T = {}", - 1 << log_chunk, - prover.padded_trace_len - ); - - let io_device = prover.program_io.clone(); - let (jolt_proof, debug_info) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::new( - prover_preprocessing.shared.clone(), - prover_preprocessing.generators.to_verifier_setup(), - ); - let verifier = RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device, - None, - debug_info, - ) - .expect("Failed to create verifier"); - verifier.verify().expect("Failed to verify proof"); - } - - #[test] - #[serial] - fn sha3_e2e_dory() { - DoryGlobals::reset(); - // Ensure SHA3 inline library is linked and auto-registered - #[cfg(feature = "host")] - use jolt_inlines_keccak256 as _; - // SHA3 inlines are automatically registered via #[ctor::ctor] - // when the jolt-inlines-keccak256 crate is linked (see lib.rs) - - let mut program = host::Program::new("sha3-guest"); - let (bytecode, init_memory_state, _) = program.decode(); - let inputs = postcard::to_stdvec(&[5u8; 32]).unwrap(); - let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), - init_memory_state, - 1 << 16, - ); - - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - let elf_contents_opt = program.get_elf_contents(); - let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); - let prover = RV64IMACProver::gen_from_elf( - &prover_preprocessing, - elf_contents, - &inputs, - &[], - &[], - None, - None, - ); - let io_device = prover.program_io.clone(); - let (jolt_proof, debug_info) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::new( - prover_preprocessing.shared.clone(), - prover_preprocessing.generators.to_verifier_setup(), - ); - let verifier = RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device.clone(), - None, - debug_info, - ) - .expect("Failed to create verifier"); - verifier.verify().expect("Failed to verify proof"); - assert_eq!( - io_device.inputs, inputs, - "Inputs mismatch: expected {:?}, got {:?}", - inputs, io_device.inputs - ); - let expected_output = &[ - 0xd0, 0x3, 0x5c, 0x96, 0x86, 0x6e, 0xe2, 0x2e, 0x81, 0xf5, 0xc4, 0xef, 0xbd, 0x88, - 0x33, 0xc1, 0x7e, 0xa1, 0x61, 0x10, 0x81, 0xfc, 0xd7, 0xa3, 0xdd, 0xce, 0xce, 0x7f, - 0x44, 0x72, 0x4, 0x66, - ]; - assert_eq!(io_device.outputs, expected_output, "Outputs mismatch",); - } - - #[test] - #[serial] - fn sha2_e2e_dory() { - DoryGlobals::reset(); - // Ensure SHA2 inline library is linked and auto-registered - #[cfg(feature = "host")] - use jolt_inlines_sha2 as _; - // SHA2 inlines are automatically registered via #[ctor::ctor] - // when the jolt-inlines-sha2 crate is linked (see lib.rs) - let mut program = host::Program::new("sha2-guest"); - let (bytecode, init_memory_state, _) = program.decode(); - let inputs = postcard::to_stdvec(&[5u8; 32]).unwrap(); - let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), - init_memory_state, - 1 << 16, - ); - - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - let elf_contents_opt = program.get_elf_contents(); - let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); - let prover = RV64IMACProver::gen_from_elf( - &prover_preprocessing, - elf_contents, - &inputs, - &[], - &[], - None, - None, - ); - let io_device = prover.program_io.clone(); - let (jolt_proof, debug_info) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::new( - prover_preprocessing.shared.clone(), - prover_preprocessing.generators.to_verifier_setup(), - ); - let verifier = RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device.clone(), - None, - debug_info, - ) - .expect("Failed to create verifier"); - verifier.verify().expect("Failed to verify proof"); - let expected_output = &[ - 0x28, 0x9b, 0xdf, 0x82, 0x9b, 0x4a, 0x30, 0x26, 0x7, 0x9a, 0x3e, 0xa0, 0x89, 0x73, - 0xb1, 0x97, 0x2d, 0x12, 0x4e, 0x7e, 0xaf, 0x22, 0x33, 0xc6, 0x3, 0x14, 0x3d, 0xc6, - 0x3b, 0x50, 0xd2, 0x57, - ]; - assert_eq!( - io_device.outputs, expected_output, - "Outputs mismatch: expected {:?}, got {:?}", - expected_output, io_device.outputs - ); - } - - #[test] - #[serial] - fn sha2_e2e_dory_with_unused_advice() { - DoryGlobals::reset(); - // SHA2 guest does not consume advice, but providing both trusted and untrusted advice - // should still work correctly through the full pipeline: - // - Trusted: commit in preprocessing-only context, reduce in Stage 6, batch in Stage 8 - // - Untrusted: commit at prove time, reduce in Stage 6, batch in Stage 8 - let mut program = host::Program::new("sha2-guest"); - let (bytecode, init_memory_state, _) = program.decode(); - let inputs = postcard::to_stdvec(&[5u8; 32]).unwrap(); - let trusted_advice = postcard::to_stdvec(&[7u8; 32]).unwrap(); - let untrusted_advice = postcard::to_stdvec(&[9u8; 32]).unwrap(); - - let (_, _, _, io_device) = program.trace(&inputs, &untrusted_advice, &trusted_advice); - - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), - init_memory_state, - 1 << 16, - ); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - let elf_contents = program.get_elf_contents().expect("elf contents is None"); - - let (trusted_commitment, trusted_hint) = - commit_trusted_advice_preprocessing_only(&prover_preprocessing, &trusted_advice); - - let prover = RV64IMACProver::gen_from_elf( - &prover_preprocessing, - &elf_contents, - &inputs, - &untrusted_advice, - &trusted_advice, - Some(trusted_commitment), - Some(trusted_hint), - ); - let io_device = prover.program_io.clone(); - let (jolt_proof, debug_info) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); - RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device.clone(), - Some(trusted_commitment), - debug_info, - ) - .expect("Failed to create verifier") - .verify() - .expect("Failed to verify proof"); - - // Verify output is correct (advice should not affect sha2 output) - let expected_output = &[ - 0x28, 0x9b, 0xdf, 0x82, 0x9b, 0x4a, 0x30, 0x26, 0x7, 0x9a, 0x3e, 0xa0, 0x89, 0x73, - 0xb1, 0x97, 0x2d, 0x12, 0x4e, 0x7e, 0xaf, 0x22, 0x33, 0xc6, 0x3, 0x14, 0x3d, 0xc6, - 0x3b, 0x50, 0xd2, 0x57, - ]; - assert_eq!(io_device.outputs, expected_output); - } - - #[test] - #[serial] - fn max_advice_with_small_trace() { - DoryGlobals::reset(); - // Tests that max-sized advice (4KB = 512 words) works with a minimal trace. - // With balanced dims (sigma_a=5, nu_a=4 for 512 words), the minimum padded trace - // (256 cycles -> total_vars=12) is sufficient to embed advice. - let mut program = host::Program::new("fibonacci-guest"); - let inputs = postcard::to_stdvec(&5u32).unwrap(); - let trusted_advice = vec![7u8; 4096]; - let untrusted_advice = vec![9u8; 4096]; - - let (bytecode, init_memory_state, _) = program.decode(); - let (lazy_trace, trace, final_memory_state, io_device) = - program.trace(&inputs, &untrusted_advice, &trusted_advice); - - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), - init_memory_state, - 256, - ); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - tracing::info!( - "preprocessing.memory_layout.max_trusted_advice_size: {}", - shared_preprocessing.memory_layout.max_trusted_advice_size - ); - - let (trusted_commitment, trusted_hint) = - commit_trusted_advice_preprocessing_only(&prover_preprocessing, &trusted_advice); - - let prover = RV64IMACProver::gen_from_trace( - &prover_preprocessing, - lazy_trace, - trace, - io_device, - Some(trusted_commitment), - Some(trusted_hint), - final_memory_state, - ); - - // Trace is tiny but advice is max-sized - assert!(prover.unpadded_trace_len < 512); - assert_eq!(prover.padded_trace_len, 256); - - let io_device = prover.program_io.clone(); - let (jolt_proof, debug_info) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); - RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device, - Some(trusted_commitment), - debug_info, - ) - .expect("Failed to create verifier") - .verify() - .expect("Verification failed"); - } - - #[test] - #[serial] - fn advice_e2e_dory() { - DoryGlobals::reset(); - // Tests a guest (merkle-tree) that actually consumes both trusted and untrusted advice. - let mut program = host::Program::new("merkle-tree-guest"); - let (bytecode, init_memory_state, _) = program.decode(); - - // Merkle tree with 4 leaves: input=leaf1, trusted=[leaf2, leaf3], untrusted=leaf4 - let inputs = postcard::to_stdvec(&[5u8; 32].as_slice()).unwrap(); - let untrusted_advice = postcard::to_stdvec(&[8u8; 32]).unwrap(); - let mut trusted_advice = postcard::to_stdvec(&[6u8; 32]).unwrap(); - trusted_advice.extend(postcard::to_stdvec(&[7u8; 32]).unwrap()); - - let (_, _, _, io_device) = program.trace(&inputs, &untrusted_advice, &trusted_advice); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), - init_memory_state, - 1 << 16, - ); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - let elf_contents = program.get_elf_contents().expect("elf contents is None"); - - let (trusted_commitment, trusted_hint) = - commit_trusted_advice_preprocessing_only(&prover_preprocessing, &trusted_advice); - - let prover = RV64IMACProver::gen_from_elf( - &prover_preprocessing, - &elf_contents, - &inputs, - &untrusted_advice, - &trusted_advice, - Some(trusted_commitment), - Some(trusted_hint), - ); - let io_device = prover.program_io.clone(); - let (jolt_proof, debug_info) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); - RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device.clone(), - Some(trusted_commitment), - debug_info, - ) - .expect("Failed to create verifier") - .verify() - .expect("Verification failed"); - - // Expected merkle root for leaves [5;32], [6;32], [7;32], [8;32] - let expected_output = &[ - 0xb4, 0x37, 0x0f, 0x3a, 0xb, 0x3d, 0x38, 0xa8, 0x7a, 0x6c, 0x4c, 0x46, 0x9, 0xe7, 0x83, - 0xb3, 0xcc, 0xb7, 0x1c, 0x30, 0x1f, 0xf8, 0x54, 0xd, 0xf7, 0xdd, 0xc8, 0x42, 0x32, - 0xbb, 0x16, 0xd7, - ]; - assert_eq!(io_device.outputs, expected_output); - } - - #[test] - #[serial] - fn advice_opening_point_derives_from_unified_point() { - DoryGlobals::reset(); - // Tests that advice opening points are correctly derived from the unified main opening - // point using Dory's balanced dimension policy. - // - // For a small trace (256 cycles), the advice row coordinates span both Stage 6 (cycle) - // and Stage 7 (address) challenges, verifying the two-phase reduction works correctly. - let mut program = host::Program::new("fibonacci-guest"); - let inputs = postcard::to_stdvec(&5u32).unwrap(); - let trusted_advice = postcard::to_stdvec(&[7u8; 32]).unwrap(); - let untrusted_advice = postcard::to_stdvec(&[9u8; 32]).unwrap(); - - let (bytecode, init_memory_state, _) = program.decode(); - let (lazy_trace, trace, final_memory_state, io_device) = - program.trace(&inputs, &untrusted_advice, &trusted_advice); - - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), - init_memory_state, - 1 << 16, - ); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - let (trusted_commitment, trusted_hint) = - commit_trusted_advice_preprocessing_only(&prover_preprocessing, &trusted_advice); - - let prover = RV64IMACProver::gen_from_trace( - &prover_preprocessing, - lazy_trace, - trace, - io_device, - Some(trusted_commitment), - Some(trusted_hint), - final_memory_state, - ); - - assert_eq!(prover.padded_trace_len, 256, "test expects small trace"); - - let io_device = prover.program_io.clone(); - let (jolt_proof, debug_info) = prover.prove(); - let debug_info = debug_info.expect("expected debug_info in tests"); - - // Get unified opening point and derive expected advice point - let (opening_point, _) = debug_info - .opening_accumulator - .get_committed_polynomial_opening( - CommittedPolynomial::InstructionRa(0), - SumcheckId::HammingWeightClaimReduction, - ); - let mut point_dory_le = opening_point.r.clone(); - point_dory_le.reverse(); - - let total_vars = point_dory_le.len(); - let (sigma_main, _nu_main) = DoryGlobals::balanced_sigma_nu(total_vars); - let (sigma_a, nu_a) = DoryGlobals::advice_sigma_nu_from_max_bytes( - prover_preprocessing - .shared - .memory_layout - .max_trusted_advice_size as usize, - ); - - // Build expected advice point: [col_bits[0..sigma_a] || row_bits[0..nu_a]] - let mut expected_advice_le: Vec<_> = point_dory_le[0..sigma_a].to_vec(); - expected_advice_le.extend_from_slice(&point_dory_le[sigma_main..sigma_main + nu_a]); - - // Verify both advice types derive the same opening point - for (name, kind) in [ - ("trusted", AdviceKind::Trusted), - ("untrusted", AdviceKind::Untrusted), - ] { - let get_fn = debug_info - .opening_accumulator - .get_advice_opening(kind, SumcheckId::AdviceClaimReduction); - assert!( - get_fn.is_some(), - "{name} advice opening missing for AdviceClaimReductionPhase2" - ); - let (point_be, _) = get_fn.unwrap(); - let mut point_le = point_be.r.clone(); - point_le.reverse(); - assert_eq!(point_le, expected_advice_le, "{name} advice point mismatch"); - } - - // Verify end-to-end - let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); - RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device, - Some(trusted_commitment), - Some(debug_info), - ) - .expect("Failed to create verifier") - .verify() - .expect("Verification failed"); - } - - #[test] - #[serial] - fn memory_ops_e2e_dory() { - DoryGlobals::reset(); - let mut program = host::Program::new("memory-ops-guest"); - let (bytecode, init_memory_state, _) = program.decode(); - let (_, _, _, io_device) = program.trace(&[], &[], &[]); - - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), - init_memory_state, - 1 << 16, - ); - - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - let elf_contents_opt = program.get_elf_contents(); - let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); - let prover = RV64IMACProver::gen_from_elf( - &prover_preprocessing, - elf_contents, - &[], - &[], - &[], - None, - None, - ); - let io_device = prover.program_io.clone(); - let (jolt_proof, debug_info) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::new( - prover_preprocessing.shared.clone(), - prover_preprocessing.generators.to_verifier_setup(), - ); - let verifier = RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device, - None, - debug_info, - ) - .expect("Failed to create verifier"); - verifier.verify().expect("Failed to verify proof"); - } - - #[test] - #[serial] - fn btreemap_e2e_dory() { - DoryGlobals::reset(); - let mut program = host::Program::new("btreemap-guest"); - let (bytecode, init_memory_state, _) = program.decode(); - let inputs = postcard::to_stdvec(&50u32).unwrap(); - let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), - init_memory_state, - 1 << 16, - ); - - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - let elf_contents_opt = program.get_elf_contents(); - let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); - let prover = RV64IMACProver::gen_from_elf( - &prover_preprocessing, - elf_contents, - &inputs, - &[], - &[], - None, - None, - ); - let io_device = prover.program_io.clone(); - let (jolt_proof, debug_info) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::new( - prover_preprocessing.shared.clone(), - prover_preprocessing.generators.to_verifier_setup(), - ); - let verifier = RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device, - None, - debug_info, - ) - .expect("Failed to create verifier"); - verifier.verify().expect("Failed to verify proof"); - } - - #[test] - #[serial] - fn muldiv_e2e_dory() { - DoryGlobals::reset(); - let mut program = host::Program::new("muldiv-guest"); - let (bytecode, init_memory_state, _) = program.decode(); - let inputs = postcard::to_stdvec(&[9u32, 5u32, 3u32]).unwrap(); - let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), - init_memory_state, - 1 << 16, - ); - - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - let elf_contents_opt = program.get_elf_contents(); - let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); - let prover = RV64IMACProver::gen_from_elf( - &prover_preprocessing, - elf_contents, - &[50], - &[], - &[], - None, - None, - ); - let io_device = prover.program_io.clone(); - let (jolt_proof, debug_info) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::new( - prover_preprocessing.shared.clone(), - prover_preprocessing.generators.to_verifier_setup(), - ); - let verifier = RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device, - None, - debug_info, - ) - .expect("Failed to create verifier"); - verifier.verify().expect("Failed to verify proof"); - } - - #[test] - #[serial] - #[should_panic] - fn truncated_trace() { - let mut program = host::Program::new("fibonacci-guest"); - let (bytecode, init_memory_state, _) = program.decode(); - let inputs = postcard::to_stdvec(&9u8).unwrap(); - let (lazy_trace, mut trace, final_memory_state, mut program_io) = - program.trace(&inputs, &[], &[]); - trace.truncate(100); - program_io.outputs[0] = 0; // change the output to 0 - - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - program_io.memory_layout.clone(), - init_memory_state, - 1 << 16, - ); - - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - - let prover = RV64IMACProver::gen_from_trace( - &prover_preprocessing, - lazy_trace, - trace, - program_io.clone(), - None, - None, - final_memory_state, - ); - - let (proof, _) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::new( - prover_preprocessing.shared.clone(), - prover_preprocessing.generators.to_verifier_setup(), - ); - let verifier = - RV64IMACVerifier::new(&verifier_preprocessing, proof, program_io, None, None).unwrap(); - verifier.verify().unwrap(); - } - - #[test] - #[serial] - #[should_panic] - fn malicious_trace() { - let mut program = host::Program::new("fibonacci-guest"); - let inputs = postcard::to_stdvec(&1u8).unwrap(); - let (bytecode, init_memory_state, _) = program.decode(); - let (lazy_trace, trace, final_memory_state, mut program_io) = - program.trace(&inputs, &[], &[]); - - // Since the preprocessing is done with the original memory layout, the verifier should fail - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - program_io.memory_layout.clone(), - init_memory_state, - 1 << 16, - ); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - - // change memory address of output & termination bit to the same address as input - // changes here should not be able to spoof the verifier result - program_io.memory_layout.output_start = program_io.memory_layout.input_start; - program_io.memory_layout.output_end = program_io.memory_layout.input_end; - program_io.memory_layout.termination = program_io.memory_layout.input_start; - - let prover = RV64IMACProver::gen_from_trace( - &prover_preprocessing, - lazy_trace, - trace, - program_io.clone(), - None, - None, - final_memory_state, - ); - let (proof, _) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::new( - prover_preprocessing.shared.clone(), - prover_preprocessing.generators.to_verifier_setup(), - ); - let verifier = - JoltVerifier::new(&verifier_preprocessing, proof, program_io, None, None).unwrap(); - verifier.verify().unwrap(); - } - - #[test] - #[serial] - fn fib_e2e_dory_address_major() { - DoryGlobals::reset(); - DoryGlobals::set_layout(DoryLayout::AddressMajor); - - let mut program = host::Program::new("fibonacci-guest"); - let inputs = postcard::to_stdvec(&50u32).unwrap(); - let (bytecode, init_memory_state, _) = program.decode(); - let (_, _, _, io_device) = program.trace(&inputs, &[], &[]); - - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), - init_memory_state, - 1 << 16, - ); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - let elf_contents = program.get_elf_contents().expect("elf contents is None"); - let prover = RV64IMACProver::gen_from_elf( - &prover_preprocessing, - &elf_contents, - &inputs, - &[], - &[], - None, - None, - ); - let io_device = prover.program_io.clone(); - let (proof, debug_info) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::new( - shared_preprocessing, - prover_preprocessing.generators.to_verifier_setup(), - ); - - // DoryGlobals is now initialized inside the verifier's verify_stage8 - RV64IMACVerifier::new(&verifier_preprocessing, proof, io_device, None, debug_info) - .expect("verifier creation failed") - .verify() - .expect("verification failed"); - } - - #[test] - #[serial] - fn advice_e2e_dory_address_major() { - DoryGlobals::reset(); - DoryGlobals::set_layout(DoryLayout::AddressMajor); - - // Tests a guest (merkle-tree) that actually consumes both trusted and untrusted advice. - let mut program = host::Program::new("merkle-tree-guest"); - let (bytecode, init_memory_state, _) = program.decode(); - - // Merkle tree with 4 leaves: input=leaf1, trusted=[leaf2, leaf3], untrusted=leaf4 - let inputs = postcard::to_stdvec(&[5u8; 32].as_slice()).unwrap(); - let untrusted_advice = postcard::to_stdvec(&[8u8; 32]).unwrap(); - let mut trusted_advice = postcard::to_stdvec(&[6u8; 32]).unwrap(); - trusted_advice.extend(postcard::to_stdvec(&[7u8; 32]).unwrap()); - - let (_, _, _, io_device) = program.trace(&inputs, &untrusted_advice, &trusted_advice); - let shared_preprocessing = JoltSharedPreprocessing::new( - bytecode.clone(), - io_device.memory_layout.clone(), - init_memory_state, - 1 << 16, - ); - let prover_preprocessing = JoltProverPreprocessing::new(shared_preprocessing.clone()); - let elf_contents = program.get_elf_contents().expect("elf contents is None"); - - let (trusted_commitment, trusted_hint) = - commit_trusted_advice_preprocessing_only(&prover_preprocessing, &trusted_advice); - - let prover = RV64IMACProver::gen_from_elf( - &prover_preprocessing, - &elf_contents, - &inputs, - &untrusted_advice, - &trusted_advice, - Some(trusted_commitment), - Some(trusted_hint), - ); - let io_device = prover.program_io.clone(); - let (jolt_proof, debug_info) = prover.prove(); - - let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); - RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device.clone(), - Some(trusted_commitment), - debug_info, - ) - .expect("Failed to create verifier") - .verify() - .expect("Verification failed"); - - // Expected merkle root for leaves [5;32], [6;32], [7;32], [8;32] - let expected_output = &[ - 0xb4, 0x37, 0x0f, 0x3a, 0xb, 0x3d, 0x38, 0xa8, 0x7a, 0x6c, 0x4c, 0x46, 0x9, 0xe7, 0x83, - 0xb3, 0xcc, 0xb7, 0x1c, 0x30, 0x1f, 0xf8, 0x54, 0xd, 0xf7, 0xdd, 0xc8, 0x42, 0x32, - 0xbb, 0x16, 0xd7, - ]; - assert_eq!(io_device.outputs, expected_output); - } -} diff --git a/jolt-core/src/zkvm/r1cs/evaluation.rs b/jolt-core/src/zkvm/r1cs/evaluation.rs index ffaac587fe..2cc009c776 100644 --- a/jolt-core/src/zkvm/r1cs/evaluation.rs +++ b/jolt-core/src/zkvm/r1cs/evaluation.rs @@ -52,8 +52,8 @@ use crate::utils::{ accumulation::{Acc5U, Acc6S, Acc6U, Acc7S, Acc7U, S128Sum, S192Sum}, math::s64_from_diff_u64s, }; -use crate::zkvm::bytecode::BytecodePreprocessing; use crate::zkvm::instruction::{CircuitFlags, NUM_CIRCUIT_FLAGS}; +use crate::zkvm::program::ProgramPreprocessing; use crate::zkvm::r1cs::inputs::ProductCycleInputs; use super::constraints::{ @@ -817,7 +817,7 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { /// materializing P_i. Returns `[P_0(r_cycle), P_1(r_cycle), ...]` in input order. #[tracing::instrument(skip_all, name = "R1CSEval::compute_claimed_inputs")] pub fn compute_claimed_inputs( - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, trace: &[Cycle], r_cycle: &OpeningPoint, ) -> [F; NUM_R1CS_INPUTS] { @@ -865,7 +865,7 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { for x2 in 0..eq_two_len { let e_in = eq_two[x2]; let idx = x1 * eq_two_len + x2; - let row = R1CSCycleInputs::from_trace::(bytecode_preprocessing, trace, idx); + let row = R1CSCycleInputs::from_trace::(program, trace, idx); acc_left_input.fmadd(&e_in, &row.left_input); acc_right_input.fmadd(&e_in, &row.right_input.to_i128()); diff --git a/jolt-core/src/zkvm/r1cs/inputs.rs b/jolt-core/src/zkvm/r1cs/inputs.rs index b85cd2a9e5..1b156172f1 100644 --- a/jolt-core/src/zkvm/r1cs/inputs.rs +++ b/jolt-core/src/zkvm/r1cs/inputs.rs @@ -14,10 +14,10 @@ //! (typed evaluators and claim computation). use crate::poly::opening_proof::{OpeningId, PolynomialId, SumcheckId}; -use crate::zkvm::bytecode::BytecodePreprocessing; use crate::zkvm::instruction::{ CircuitFlags, Flags, InstructionFlags, LookupQuery, NUM_CIRCUIT_FLAGS, }; +use crate::zkvm::program::ProgramPreprocessing; use crate::zkvm::witness::VirtualPolynomial; use crate::field::JoltField; @@ -266,11 +266,7 @@ pub struct R1CSCycleInputs { impl R1CSCycleInputs { /// Build directly from the execution trace and preprocessing, /// mirroring the optimized semantics used in `compute_claimed_r1cs_input_evals`. - pub fn from_trace( - bytecode_preprocessing: &BytecodePreprocessing, - trace: &[Cycle], - t: usize, - ) -> Self + pub fn from_trace(program: &ProgramPreprocessing, trace: &[Cycle], t: usize) -> Self where F: JoltField, { @@ -318,9 +314,9 @@ impl R1CSCycleInputs { }; // PCs - let pc = bytecode_preprocessing.get_pc(cycle) as u64; + let pc = program.get_pc(cycle) as u64; let next_pc = if let Some(nc) = next_cycle { - bytecode_preprocessing.get_pc(nc) as u64 + program.get_pc(nc) as u64 } else { 0u64 }; @@ -540,12 +536,12 @@ pub struct ShiftSumcheckCycleState { } impl ShiftSumcheckCycleState { - pub fn new(cycle: &Cycle, bytecode_preprocessing: &BytecodePreprocessing) -> Self { + pub fn new(cycle: &Cycle, program: &ProgramPreprocessing) -> Self { let instruction = cycle.instruction(); let circuit_flags = instruction.circuit_flags(); Self { unexpanded_pc: instruction.normalize().address as u64, - pc: bytecode_preprocessing.get_pc(cycle) as u64, + pc: program.get_pc(cycle) as u64, is_virtual: circuit_flags[CircuitFlags::VirtualInstruction], is_first_in_sequence: circuit_flags[CircuitFlags::IsFirstInSequence], is_noop: instruction.instruction_flags()[InstructionFlags::IsNoop], diff --git a/jolt-core/src/zkvm/ram/mod.rs b/jolt-core/src/zkvm/ram/mod.rs index 84c9c2ce61..c112aa6c38 100644 --- a/jolt-core/src/zkvm/ram/mod.rs +++ b/jolt-core/src/zkvm/ram/mod.rs @@ -60,14 +60,13 @@ use crate::{ utils::{accumulation::Acc6U, math::Math}, zkvm::witness::VirtualPolynomial, }; -use std::vec; - use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use common::{ constants::{BYTES_PER_INSTRUCTION, RAM_START_ADDRESS}, jolt_device::MemoryLayout, }; use rayon::prelude::*; +use std::vec; use tracer::emulator::memory::Memory; use tracer::JoltDevice; @@ -79,13 +78,42 @@ pub mod read_write_checking; pub mod val_evaluation; pub mod val_final; +/// RAM preprocessing metadata (shared between prover and verifier). +/// +/// This struct is metadata-only and does NOT contain the full program-image words. +/// The full words are stored in `ProgramImagePreprocessing` (prover-only). #[derive(Debug, Clone, CanonicalSerialize, CanonicalDeserialize)] pub struct RAMPreprocessing { + /// Minimum bytecode address (word-aligned). pub min_bytecode_address: u64, - pub bytecode_words: Vec, + /// Number of program-image words (unpadded). + pub program_image_len_words: usize, } impl RAMPreprocessing { + /// Create metadata from a `ProgramImagePreprocessing`. + pub fn from_program_image(program_image: &ProgramImagePreprocessing) -> Self { + Self { + min_bytecode_address: program_image.min_bytecode_address, + program_image_len_words: program_image.program_image_words.len(), + } + } +} + +/// Full program-image preprocessing (prover-only and full-mode verifier). +/// +/// Contains the actual u64 words that form the initial RAM program image. +/// This is O(program_size) data that the committed-mode verifier does NOT need. +#[derive(Debug, Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct ProgramImagePreprocessing { + /// Minimum bytecode address (word-aligned). + pub min_bytecode_address: u64, + /// Program-image words (little-endian packed u64 values). + pub program_image_words: Vec, +} + +impl ProgramImagePreprocessing { + /// Preprocess memory_init bytes into packed u64 words. pub fn preprocess(memory_init: Vec<(u64, u8)>) -> Self { let min_bytecode_address = memory_init .iter() @@ -101,8 +129,8 @@ impl RAMPreprocessing { + (BYTES_PER_INSTRUCTION as u64 - 1); let num_words = max_bytecode_address.next_multiple_of(8) / 8 - min_bytecode_address / 8 + 1; - let mut bytecode_words = vec![0u64; num_words as usize]; - // Convert bytes into words and populate `bytecode_words` + let mut program_image_words = vec![0u64; num_words as usize]; + // Convert bytes into words and populate `program_image_words` for chunk in memory_init.chunk_by(|(address_a, _), (address_b, _)| address_a / 8 == address_b / 8) { @@ -112,14 +140,29 @@ impl RAMPreprocessing { } let word = u64::from_le_bytes(word); let remapped_index = (chunk[0].0 / 8 - min_bytecode_address / 8) as usize; - bytecode_words[remapped_index] = word; + program_image_words[remapped_index] = word; } Self { min_bytecode_address, - bytecode_words, + program_image_words, } } + + /// Extract metadata-only `RAMPreprocessing` from this full preprocessing. + pub fn meta(&self) -> RAMPreprocessing { + RAMPreprocessing::from_program_image(self) + } + + /// Unpadded number of words. + pub fn unpadded_len_words(&self) -> usize { + self.program_image_words.len() + } + + /// Power-of-two padded length (minimum 1). + pub fn padded_len_words_pow2(&self) -> usize { + self.program_image_words.len().next_power_of_two().max(1) + } } /// Returns Some(address) if there was read/write @@ -351,6 +394,105 @@ pub fn verifier_accumulate_advice( } } +/// Accumulates staged program-image scalar contribution claims into the prover accumulator. +/// +/// These are scalar inner products: +/// - `C_rw = Σ_j ProgramWord[j] * eq(r_address_rw, start_index + j)` +/// - `C_raf = Σ_j ProgramWord[j] * eq(r_address_raf, start_index + j)` (optional) +/// +/// They are stored as *virtual* openings (not committed openings) because they are not direct +/// openings of the committed program-image polynomial. +pub fn prover_accumulate_program_image( + ram_K: usize, + min_bytecode_address: u64, + program_image_words: &[u64], + program_io: &JoltDevice, + padded_len_words: usize, + opening_accumulator: &mut ProverOpeningAccumulator, + transcript: &mut impl Transcript, + single_opening: bool, +) { + let total_vars = ram_K.log_2(); + let bytecode_start = + remap_address(min_bytecode_address, &program_io.memory_layout).unwrap() as usize; + + // Get r_address_rw from RamVal/RamReadWriteChecking (used by ValEvaluation). + let (r_rw, _) = opening_accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::RamVal, + SumcheckId::RamReadWriteChecking, + ); + let (r_address_rw, _) = r_rw.split_at(total_vars); + + // Compute C_rw using the padded program-image word vector. + let mut words = program_image_words.to_vec(); + words.resize(padded_len_words, 0u64); + let c_rw = sparse_eval_u64_block::(bytecode_start, &words, &r_address_rw.r); + + opening_accumulator.append_virtual( + transcript, + VirtualPolynomial::ProgramImageInitContributionRw, + SumcheckId::RamValEvaluation, + r_address_rw, + c_rw, + ); + + if !single_opening { + let (r_raf, _) = opening_accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::RamValFinal, + SumcheckId::RamOutputCheck, + ); + let (r_address_raf, _) = r_raf.split_at(total_vars); + let c_raf = sparse_eval_u64_block::(bytecode_start, &words, &r_address_raf.r); + opening_accumulator.append_virtual( + transcript, + VirtualPolynomial::ProgramImageInitContributionRaf, + SumcheckId::RamValFinalEvaluation, + r_address_raf, + c_raf, + ); + } +} + +/// Mirrors [`prover_accumulate_program_image`], but only populates opening points and +/// appends the already-present scalar claims to the transcript. +pub fn verifier_accumulate_program_image( + ram_K: usize, + program_io: &JoltDevice, + opening_accumulator: &mut VerifierOpeningAccumulator, + transcript: &mut impl Transcript, + single_opening: bool, +) { + let total_vars = ram_K.log_2(); + // r_address_rw from RamVal/RamReadWriteChecking. + let (r_rw, _) = opening_accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::RamVal, + SumcheckId::RamReadWriteChecking, + ); + let (r_address_rw, _) = r_rw.split_at(total_vars); + opening_accumulator.append_virtual( + transcript, + VirtualPolynomial::ProgramImageInitContributionRw, + SumcheckId::RamValEvaluation, + r_address_rw, + ); + + if !single_opening { + let (r_raf, _) = opening_accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::RamValFinal, + SumcheckId::RamOutputCheck, + ); + let (r_address_raf, _) = r_raf.split_at(total_vars); + opening_accumulator.append_virtual( + transcript, + VirtualPolynomial::ProgramImageInitContributionRaf, + SumcheckId::RamValFinalEvaluation, + r_address_raf, + ); + } + // (program_io is unused for now; retained for symmetry and future checks) + let _ = program_io; +} + /// Calculates how advice inputs contribute to the evaluation of initial_ram_state at a given random point. /// /// ## Example with Two Commitments: @@ -437,6 +579,78 @@ fn calculate_advice_memory_evaluation( } } +/// Evaluate the public portion of the initial RAM state at a random address point `r_address` +/// without materializing the full length-`ram_K` initial memory vector. +/// +/// Public initial memory consists of: +/// - the program image (`program_image_words`) placed at `min_bytecode_address` +/// - public inputs (`program_io.inputs`) placed at `memory_layout.input_start` +/// +/// This function computes: +/// \sum_k Val_init_public[k] * eq(r_address, k) +/// but only over the (contiguous) regions that can be non-zero. +pub fn eval_initial_ram_mle( + min_bytecode_address: u64, + program_image_words: &[u64], + program_io: &JoltDevice, + r_address: &[F::Challenge], +) -> F { + // Bytecode region + let bytecode_start = + remap_address(min_bytecode_address, &program_io.memory_layout).unwrap() as usize; + let mut acc = sparse_eval_u64_block::(bytecode_start, program_image_words, r_address); + + // Inputs region (packed into u64 words in little-endian) + if !program_io.inputs.is_empty() { + let input_start = remap_address( + program_io.memory_layout.input_start, + &program_io.memory_layout, + ) + .unwrap() as usize; + let input_words: Vec = program_io + .inputs + .chunks(8) + .map(|chunk| { + let mut word = [0u8; 8]; + for (i, byte) in chunk.iter().enumerate() { + word[i] = *byte; + } + u64::from_le_bytes(word) + }) + .collect(); + acc += sparse_eval_u64_block::(input_start, &input_words, r_address); + } + + acc +} + +/// Evaluate only `program_io.inputs` as part of the initial RAM state at `r_address`. +/// +/// Excludes program image, outputs, panic, and termination bits. +/// For the full IO region, see [`eval_io_mle`]. +fn eval_inputs_mle(program_io: &JoltDevice, r_address: &[F::Challenge]) -> F { + if program_io.inputs.is_empty() { + return F::zero(); + } + let input_start = remap_address( + program_io.memory_layout.input_start, + &program_io.memory_layout, + ) + .unwrap() as usize; + let input_words: Vec = program_io + .inputs + .chunks(8) + .map(|chunk| { + let mut word = [0u8; 8]; + for (i, byte) in chunk.iter().enumerate() { + word[i] = *byte; + } + u64::from_le_bytes(word) + }) + .collect(); + sparse_eval_u64_block::(input_start, &input_words, r_address) +} + /// Evaluate a shifted slice of `u64` coefficients as a multilinear polynomial at `r`. /// /// Conceptually computes: @@ -479,54 +693,6 @@ fn sparse_eval_u64_block( acc } -/// Evaluate the public portion of the initial RAM state at a random address point `r_address` -/// without materializing the full length-`ram_K` initial memory vector. -/// -/// Public initial memory consists of: -/// - the program image (`ram_preprocessing.bytecode_words`) placed at `min_bytecode_address` -/// - public inputs (`program_io.inputs`) placed at `memory_layout.input_start` -/// -/// This function computes: -/// \sum_k Val_init_public[k] * eq(r_address, k) -/// but only over the (contiguous) regions that can be non-zero. -pub fn eval_initial_ram_mle( - ram_preprocessing: &RAMPreprocessing, - program_io: &JoltDevice, - r_address: &[F::Challenge], -) -> F { - // Bytecode region - let bytecode_start = remap_address( - ram_preprocessing.min_bytecode_address, - &program_io.memory_layout, - ) - .unwrap() as usize; - let mut acc = - sparse_eval_u64_block::(bytecode_start, &ram_preprocessing.bytecode_words, r_address); - - // Inputs region (packed into u64 words in little-endian) - if !program_io.inputs.is_empty() { - let input_start = remap_address( - program_io.memory_layout.input_start, - &program_io.memory_layout, - ) - .unwrap() as usize; - let input_words: Vec = program_io - .inputs - .chunks(8) - .map(|chunk| { - let mut word = [0u8; 8]; - for (i, byte) in chunk.iter().enumerate() { - word[i] = *byte; - } - u64::from_le_bytes(word) - }) - .collect(); - acc += sparse_eval_u64_block::(input_start, &input_words, r_address); - } - - acc -} - /// Evaluate the *public IO* polynomial at a (full-RAM) address point `r_address` without /// materializing a dense IO-region vector. /// @@ -626,7 +792,8 @@ pub fn eval_io_mle(program_io: &JoltDevice, r_address: &[F::Challe /// Returns `(initial_memory_state, final_memory_state)` pub fn gen_ram_memory_states( ram_K: usize, - ram_preprocessing: &RAMPreprocessing, + min_bytecode_address: u64, + program_image_words: &[u64], program_io: &JoltDevice, final_memory: &Memory, ) -> (Vec, Vec) { @@ -634,12 +801,9 @@ pub fn gen_ram_memory_states( let mut initial_memory_state: Vec = vec![0; K]; // Copy bytecode - let mut index = remap_address( - ram_preprocessing.min_bytecode_address, - &program_io.memory_layout, - ) - .unwrap() as usize; - for word in &ram_preprocessing.bytecode_words { + let mut index = + remap_address(min_bytecode_address, &program_io.memory_layout).unwrap() as usize; + for word in program_image_words { initial_memory_state[index] = *word; index += 1; } @@ -730,17 +894,15 @@ pub fn gen_ram_memory_states( pub fn gen_ram_initial_memory_state( ram_K: usize, - ram_preprocessing: &RAMPreprocessing, + min_bytecode_address: u64, + program_image_words: &[u64], program_io: &JoltDevice, ) -> Vec { let mut initial_memory_state = vec![0; ram_K]; // Copy bytecode - let mut index = remap_address( - ram_preprocessing.min_bytecode_address, - &program_io.memory_layout, - ) - .unwrap() as usize; - for word in &ram_preprocessing.bytecode_words { + let mut index = + remap_address(min_bytecode_address, &program_io.memory_layout).unwrap() as usize; + for word in program_image_words { initial_memory_state[index] = *word; index += 1; } @@ -798,23 +960,28 @@ mod tests { let b = (rng.next_u64() & 0xff) as u8; memory_init.push((RAM_START_ADDRESS + i, b)); } - let ram_pp = RAMPreprocessing::preprocess(memory_init); + let prog_pp = ProgramImagePreprocessing::preprocess(memory_init); // Choose ram_K large enough to cover both bytecode and inputs placements. - let bytecode_start = - remap_address(ram_pp.min_bytecode_address, &program_io.memory_layout).unwrap() as usize; + let bytecode_start = remap_address(prog_pp.min_bytecode_address, &program_io.memory_layout) + .unwrap() as usize; let input_start = remap_address( program_io.memory_layout.input_start, &program_io.memory_layout, ) .unwrap() as usize; let input_words_len = program_io.inputs.len().div_ceil(8); - let needed = (bytecode_start + ram_pp.bytecode_words.len()) + let needed = (bytecode_start + prog_pp.program_image_words.len()) .max(input_start + input_words_len) .max(1); let ram_K = needed.next_power_of_two(); - let dense = gen_ram_initial_memory_state::(ram_K, &ram_pp, &program_io); + let dense = gen_ram_initial_memory_state::( + ram_K, + prog_pp.min_bytecode_address, + &prog_pp.program_image_words, + &program_io, + ); // Random evaluation point over address vars (big-endian convention). let n_vars = ram_K.log_2(); @@ -823,7 +990,12 @@ mod tests { .collect(); let dense_eval = MultilinearPolynomial::::from(dense).evaluate(&r); - let fast_eval = eval_initial_ram_mle::(&ram_pp, &program_io, &r); + let fast_eval = eval_initial_ram_mle::( + prog_pp.min_bytecode_address, + &prog_pp.program_image_words, + &program_io, + &r, + ); assert_eq!(dense_eval, fast_eval); } diff --git a/jolt-core/src/zkvm/ram/read_write_checking.rs b/jolt-core/src/zkvm/ram/read_write_checking.rs index 82c86daa04..5d9375a4ff 100644 --- a/jolt-core/src/zkvm/ram/read_write_checking.rs +++ b/jolt-core/src/zkvm/ram/read_write_checking.rs @@ -18,8 +18,8 @@ use crate::subprotocols::sumcheck_claim::{ }; use crate::subprotocols::sumcheck_prover::SumcheckInstanceProver; use crate::subprotocols::sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}; -use crate::zkvm::bytecode::BytecodePreprocessing; use crate::zkvm::config::{OneHotParams, ReadWriteConfig}; +use crate::zkvm::program::ProgramPreprocessing; use crate::{ field::JoltField, poly::{ @@ -170,7 +170,7 @@ impl RamReadWriteCheckingProver { pub fn initialize( params: RamReadWriteCheckingParams, trace: &[Cycle], - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, memory_layout: &MemoryLayout, initial_ram_state: &[u64], ) -> Self { @@ -189,12 +189,7 @@ impl RamReadWriteCheckingProver { Some(MultilinearPolynomial::from(EqPolynomial::evals(&r_prime.r))), ) }; - let inc = CommittedPolynomial::RamInc.generate_witness( - bytecode_preprocessing, - memory_layout, - trace, - None, - ); + let inc = CommittedPolynomial::RamInc.generate_witness(program, memory_layout, trace, None); let val_init: Vec<_> = initial_ram_state .par_iter() .map(|x| F::from_u64(*x)) diff --git a/jolt-core/src/zkvm/ram/val_evaluation.rs b/jolt-core/src/zkvm/ram/val_evaluation.rs index 25a58ff0ad..4ec1230d9f 100644 --- a/jolt-core/src/zkvm/ram/val_evaluation.rs +++ b/jolt-core/src/zkvm/ram/val_evaluation.rs @@ -25,9 +25,9 @@ use crate::{ transcripts::Transcript, utils::math::Math, zkvm::{ - bytecode::BytecodePreprocessing, claim_reductions::AdviceKind, - config::OneHotParams, + config::{OneHotParams, ProgramMode}, + program::{ProgramMetadata, ProgramPreprocessing}, ram::remap_address, witness::{CommittedPolynomial, VirtualPolynomial}, }, @@ -93,11 +93,23 @@ impl ValEvaluationSumcheckParams { } } + /// Create params for verifier. + /// + /// # Arguments + /// - `program_meta`: RAM preprocessing metadata + /// - `program_image_words`: Program image words (only needed in Full mode, None for Committed mode) + /// - `program_io`: Program I/O device + /// - `trace_len`: Trace length + /// - `ram_K`: RAM K parameter + /// - `program_mode`: Bytecode mode (Full or Committed) + /// - `opening_accumulator`: Verifier opening accumulator pub fn new_from_verifier( - ram_preprocessing: &super::RAMPreprocessing, + program_meta: &ProgramMetadata, + program_image_words: Option<&[u64]>, program_io: &JoltDevice, trace_len: usize, ram_K: usize, + program_mode: ProgramMode, opening_accumulator: &VerifierOpeningAccumulator, ) -> Self { let (r, _) = opening_accumulator.get_virtual_polynomial_opening( @@ -134,10 +146,28 @@ impl ValEvaluationSumcheckParams { n_memory_vars, ); - // Compute the public part of val_init evaluation (bytecode + inputs) without - // materializing the full length-K initial RAM state. - let val_init_public_eval = - super::eval_initial_ram_mle::(ram_preprocessing, program_io, &r_address.r); + // Public part of val_init: + // - Full mode: compute program-image+inputs directly using provided words. + // - Committed mode: use staged scalar program-image claim + locally computed input contribution. + let val_init_public_eval = match program_mode { + ProgramMode::Full => { + let words = program_image_words.expect("Full mode requires program_image_words"); + super::eval_initial_ram_mle::( + program_meta.min_bytecode_address, + words, + program_io, + &r_address.r, + ) + } + ProgramMode::Committed => { + let (_, prog_img_claim) = opening_accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::ProgramImageInitContributionRw, + SumcheckId::RamValEvaluation, + ); + let input_eval = super::eval_inputs_mle::(program_io, &r_address.r); + prog_img_claim + input_eval + } + }; // Combine all contributions: untrusted + trusted + public let init_eval = untrusted_contribution + trusted_contribution + val_init_public_eval; @@ -190,7 +220,7 @@ impl ValEvaluationSumcheckProver { pub fn initialize( params: ValEvaluationSumcheckParams, trace: &[Cycle], - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, memory_layout: &MemoryLayout, ) -> Self { // Compute the size-K table storing all eq(r_address, k) evaluations for @@ -213,12 +243,7 @@ impl ValEvaluationSumcheckProver { drop(_guard); drop(span); - let inc = CommittedPolynomial::RamInc.generate_witness( - bytecode_preprocessing, - memory_layout, - trace, - None, - ); + let inc = CommittedPolynomial::RamInc.generate_witness(program, memory_layout, trace, None); let lt = LtPolynomial::new(¶ms.r_cycle); Self { @@ -323,17 +348,21 @@ pub struct ValEvaluationSumcheckVerifier { impl ValEvaluationSumcheckVerifier { pub fn new( - ram_preprocessing: &super::RAMPreprocessing, + program_meta: &ProgramMetadata, + program_image_words: Option<&[u64]>, program_io: &JoltDevice, trace_len: usize, ram_K: usize, + program_mode: ProgramMode, opening_accumulator: &VerifierOpeningAccumulator, ) -> Self { let params = ValEvaluationSumcheckParams::new_from_verifier( - ram_preprocessing, + program_meta, + program_image_words, program_io, trace_len, ram_K, + program_mode, opening_accumulator, ); Self { params } diff --git a/jolt-core/src/zkvm/ram/val_final.rs b/jolt-core/src/zkvm/ram/val_final.rs index 31f7d20fdb..c6c43c00ec 100644 --- a/jolt-core/src/zkvm/ram/val_final.rs +++ b/jolt-core/src/zkvm/ram/val_final.rs @@ -18,9 +18,9 @@ use crate::{ transcripts::Transcript, utils::math::Math, zkvm::{ - bytecode::BytecodePreprocessing, claim_reductions::AdviceKind, - config::ReadWriteConfig, + config::{ProgramMode, ReadWriteConfig}, + program::{ProgramMetadata, ProgramPreprocessing}, ram::remap_address, witness::{CommittedPolynomial, VirtualPolynomial}, }, @@ -59,11 +59,24 @@ impl ValFinalSumcheckParams { } } + /// Create params for verifier. + /// + /// # Arguments + /// - `program_meta`: RAM preprocessing metadata + /// - `program_image_words`: Program image words (only needed in Full mode, None for Committed mode) + /// - `program_io`: Program I/O device + /// - `trace_len`: Trace length + /// - `ram_K`: RAM K parameter + /// - `program_mode`: Bytecode mode (Full or Committed) + /// - `opening_accumulator`: Verifier opening accumulator + /// - `rw_config`: Read/write configuration pub fn new_from_verifier( - ram_preprocessing: &super::RAMPreprocessing, + program_meta: &ProgramMetadata, + program_image_words: Option<&[u64]>, program_io: &JoltDevice, trace_len: usize, ram_K: usize, + program_mode: ProgramMode, opening_accumulator: &VerifierOpeningAccumulator, rw_config: &ReadWriteConfig, ) -> Self { @@ -108,10 +121,37 @@ impl ValFinalSumcheckParams { n_memory_vars, ); - // Compute the public part of val_init evaluation (bytecode + inputs) without - // materializing the full length-K initial RAM state. - let val_init_public_eval = - super::eval_initial_ram_mle::(ram_preprocessing, program_io, &r_address); + // Public part of val_init: + // - Full mode: compute program-image+inputs directly using provided words. + // - Committed mode: use staged scalar program-image claim + locally computed input contribution. + let val_init_public_eval = match program_mode { + ProgramMode::Full => { + let words = program_image_words.expect("Full mode requires program_image_words"); + super::eval_initial_ram_mle::( + program_meta.min_bytecode_address, + words, + program_io, + &r_address, + ) + } + ProgramMode::Committed => { + let (prog_poly, prog_sumcheck) = if rw_config.needs_single_advice_opening(log_T) { + ( + VirtualPolynomial::ProgramImageInitContributionRw, + SumcheckId::RamValEvaluation, + ) + } else { + ( + VirtualPolynomial::ProgramImageInitContributionRaf, + SumcheckId::RamValFinalEvaluation, + ) + }; + let (_, prog_img_claim) = + opening_accumulator.get_virtual_polynomial_opening(prog_poly, prog_sumcheck); + let input_eval = super::eval_inputs_mle::(program_io, &r_address); + prog_img_claim + input_eval + } + }; // Combine all contributions: untrusted + trusted + public let val_init_eval = @@ -162,7 +202,7 @@ impl ValFinalSumcheckProver { pub fn initialize( params: ValFinalSumcheckParams, trace: &[Cycle], - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, memory_layout: &MemoryLayout, ) -> Self { // Compute the size-K table storing all eq(r_address, k) evaluations for @@ -186,12 +226,7 @@ impl ValFinalSumcheckProver { drop(_guard); drop(span); - let inc = CommittedPolynomial::RamInc.generate_witness( - bytecode_preprocessing, - memory_layout, - trace, - None, - ); + let inc = CommittedPolynomial::RamInc.generate_witness(program, memory_layout, trace, None); // #[cfg(test)] // { @@ -304,18 +339,22 @@ pub struct ValFinalSumcheckVerifier { impl ValFinalSumcheckVerifier { pub fn new( - ram_preprocessing: &super::RAMPreprocessing, + program_meta: &ProgramMetadata, + program_image_words: Option<&[u64]>, program_io: &JoltDevice, trace_len: usize, ram_K: usize, + program_mode: ProgramMode, opening_accumulator: &VerifierOpeningAccumulator, rw_config: &ReadWriteConfig, ) -> Self { let params = ValFinalSumcheckParams::new_from_verifier( - ram_preprocessing, + program_meta, + program_image_words, program_io, trace_len, ram_K, + program_mode, opening_accumulator, rw_config, ); diff --git a/jolt-core/src/zkvm/registers/read_write_checking.rs b/jolt-core/src/zkvm/registers/read_write_checking.rs index c442274afc..764cb71ef7 100644 --- a/jolt-core/src/zkvm/registers/read_write_checking.rs +++ b/jolt-core/src/zkvm/registers/read_write_checking.rs @@ -1,17 +1,12 @@ use std::sync::Arc; use crate::poly::multilinear_polynomial::PolynomialEvaluation; -use crate::poly::opening_proof::PolynomialId; use crate::subprotocols::read_write_matrix::{ AddressMajorMatrixEntry, ReadWriteMatrixAddressMajor, ReadWriteMatrixCycleMajor, RegistersAddressMajorEntry, RegistersCycleMajorEntry, }; -use crate::subprotocols::sumcheck_claim::{ - CachedPointRef, ChallengePart, Claim, ClaimExpr, InputOutputClaims, SumcheckFrontend, - VerifierEvaluablePolynomial, -}; -use crate::zkvm::bytecode::BytecodePreprocessing; use crate::zkvm::config::ReadWriteConfig; +use crate::zkvm::program::ProgramPreprocessing; use crate::zkvm::witness::VirtualPolynomial; use crate::{ field::JoltField, @@ -19,13 +14,17 @@ use crate::{ eq_poly::EqPolynomial, multilinear_polynomial::{BindingOrder, MultilinearPolynomial, PolynomialBinding}, opening_proof::{ - OpeningAccumulator, OpeningPoint, ProverOpeningAccumulator, SumcheckId, + OpeningAccumulator, OpeningPoint, PolynomialId, ProverOpeningAccumulator, SumcheckId, VerifierOpeningAccumulator, BIG_ENDIAN, }, split_eq_poly::GruenSplitEqPolynomial, unipoly::UniPoly, }, subprotocols::{ + sumcheck_claim::{ + CachedPointRef, ChallengePart, Claim, ClaimExpr, InputOutputClaims, SumcheckFrontend, + VerifierEvaluablePolynomial, + }, sumcheck_prover::SumcheckInstanceProver, sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}, }, @@ -196,7 +195,7 @@ impl RegistersReadWriteCheckingProver { pub fn initialize( params: RegistersReadWriteCheckingParams, trace: Arc>, - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, memory_layout: &MemoryLayout, ) -> Self { let r_prime = ¶ms.r_cycle; @@ -214,12 +213,7 @@ impl RegistersReadWriteCheckingProver { Some(MultilinearPolynomial::from(EqPolynomial::evals(&r_prime.r))), ) }; - let inc = CommittedPolynomial::RdInc.generate_witness( - bytecode_preprocessing, - memory_layout, - &trace, - None, - ); + let inc = CommittedPolynomial::RdInc.generate_witness(program, memory_layout, &trace, None); let sparse_matrix = ReadWriteMatrixCycleMajor::<_, RegistersCycleMajorEntry>::new(&trace, params.gamma); let phase1_rounds = params.phase1_num_rounds; @@ -805,23 +799,6 @@ impl RegistersReadWriteCheckingVerifier { impl SumcheckInstanceVerifier for RegistersReadWriteCheckingVerifier { - fn input_claim(&self, accumulator: &VerifierOpeningAccumulator) -> F { - let result = self.params.input_claim(accumulator); - - #[cfg(test)] - { - let claims = Self::input_output_claims(); - let gamma_pows: Vec = - std::iter::successors(Some(F::one()), |prev| Some(*prev * self.params.gamma)) - .take(claims.claims.len()) - .collect(); - let reference_result = claims.input_claim(&gamma_pows, accumulator); - assert_eq!(result, reference_result); - } - - result - } - fn get_params(&self) -> &dyn SumcheckInstanceParams { &self.params } @@ -859,23 +836,9 @@ impl SumcheckInstanceVerifier let rs1_value_claim = rs1_ra_claim * val_claim; let rs2_value_claim = rs2_ra_claim * val_claim; - let result = EqPolynomial::mle_endian(&r_cycle, &self.params.r_cycle) + EqPolynomial::mle_endian(&r_cycle, &self.params.r_cycle) * (rd_write_value_claim - + self.params.gamma * (rs1_value_claim + self.params.gamma * rs2_value_claim)); - - #[cfg(test)] - { - let claims = Self::input_output_claims(); - let gamma_pows: Vec = - std::iter::successors(Some(F::one()), |prev| Some(*prev * self.params.gamma)) - .take(claims.claims.len()) - .collect(); - let reference_result = claims.expected_output_claim(&r_cycle, &gamma_pows, accumulator); - - assert_eq!(result, reference_result); - } - - result + + self.params.gamma * (rs1_value_claim + self.params.gamma * rs2_value_claim)) } fn cache_openings( diff --git a/jolt-core/src/zkvm/registers/val_evaluation.rs b/jolt-core/src/zkvm/registers/val_evaluation.rs index 002104552e..4fba3a4d0f 100644 --- a/jolt-core/src/zkvm/registers/val_evaluation.rs +++ b/jolt-core/src/zkvm/registers/val_evaluation.rs @@ -20,10 +20,8 @@ use crate::{ sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}, }, transcripts::Transcript, - zkvm::{ - bytecode::BytecodePreprocessing, - witness::{CommittedPolynomial, VirtualPolynomial}, - }, + zkvm::program::ProgramPreprocessing, + zkvm::witness::{CommittedPolynomial, VirtualPolynomial}, }; use allocative::Allocative; #[cfg(feature = "allocative")] @@ -106,15 +104,10 @@ impl ValEvaluationSumcheckProver { pub fn initialize( params: RegistersValEvaluationSumcheckParams, trace: &[Cycle], - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, memory_layout: &MemoryLayout, ) -> Self { - let inc = CommittedPolynomial::RdInc.generate_witness( - bytecode_preprocessing, - memory_layout, - trace, - None, - ); + let inc = CommittedPolynomial::RdInc.generate_witness(program, memory_layout, trace, None); let eq_r_address = EqPolynomial::evals(¶ms.r_address.r); let wa: Vec> = trace diff --git a/jolt-core/src/zkvm/spartan/outer.rs b/jolt-core/src/zkvm/spartan/outer.rs index dda52684ff..224072d933 100644 --- a/jolt-core/src/zkvm/spartan/outer.rs +++ b/jolt-core/src/zkvm/spartan/outer.rs @@ -32,7 +32,7 @@ use crate::utils::math::Math; #[cfg(feature = "allocative")] use crate::utils::profiling::print_data_structure_heap_usage; use crate::utils::thread::unsafe_allocate_zero_vec; -use crate::zkvm::bytecode::BytecodePreprocessing; +use crate::zkvm::program::ProgramPreprocessing; use crate::zkvm::r1cs::constraints::OUTER_FIRST_ROUND_POLY_DEGREE_BOUND; use crate::zkvm::r1cs::key::UniformSpartanKey; use crate::zkvm::r1cs::{ @@ -131,13 +131,9 @@ impl OuterUniSkipProver { pub fn initialize( params: OuterUniSkipParams, trace: &[Cycle], - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, ) -> Self { - let extended = Self::compute_univariate_skip_extended_evals( - bytecode_preprocessing, - trace, - ¶ms.tau, - ); + let extended = Self::compute_univariate_skip_extended_evals(program, trace, ¶ms.tau); let instance = Self { params, @@ -166,7 +162,7 @@ impl OuterUniSkipProver { /// \sum_{x_in'} eq(tau_in, (x_in', 0)) * Az(x_out, x_in', 0, y) * Bz(x_out, x_in', 0, y) /// + eq(tau_in, (x_in', 1)) * Az(x_out, x_in', 1, y) * Bz(x_out, x_in', 1, y) fn compute_univariate_skip_extended_evals( - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, trace: &[Cycle], tau: &[F::Challenge], ) -> [F; OUTER_UNIVARIATE_SKIP_DEGREE] { @@ -191,11 +187,8 @@ impl OuterUniSkipProver { let x_in_prime = x_in >> 1; let base_step_idx = (x_out << num_x_in_prime_bits) | x_in_prime; - let row_inputs = R1CSCycleInputs::from_trace::( - bytecode_preprocessing, - trace, - base_step_idx, - ); + let row_inputs = + R1CSCycleInputs::from_trace::(program, trace, base_step_idx); let eval = R1CSEval::::from_cycle_inputs(&row_inputs); let is_group1 = (x_in & 1) == 1; @@ -499,7 +492,7 @@ pub type OuterRemainingStreamingSumcheck = #[derive(Allocative)] pub struct OuterSharedState { #[allocative(skip)] - bytecode_preprocessing: BytecodePreprocessing, + program: ProgramPreprocessing, #[allocative(skip)] trace: Arc>, split_eq_poly: GruenSplitEqPolynomial, @@ -514,11 +507,11 @@ impl OuterSharedState { #[tracing::instrument(skip_all, name = "OuterSharedState::new")] pub fn new( trace: Arc>, - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, uni_skip_params: &OuterUniSkipParams, opening_accumulator: &ProverOpeningAccumulator, ) -> Self { - let bytecode_preprocessing = bytecode_preprocessing.clone(); + let program = program.clone(); let outer_params = OuterStreamingProverParams::new(uni_skip_params, opening_accumulator); let r0 = outer_params.r0_uniskip; @@ -546,7 +539,7 @@ impl OuterSharedState { Self { split_eq_poly, - bytecode_preprocessing, + program, trace, t_prime_poly: None, r_grid, @@ -572,7 +565,7 @@ impl OuterSharedState { offset: usize, scaled_w: &[[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]], ) { - let preprocess = &self.bytecode_preprocessing; + let preprocess = &self.program; let trace = &self.trace; debug_assert_eq!(scaled_w.len(), klen); debug_assert_eq!(grid_az.len(), jlen); @@ -933,7 +926,7 @@ impl OuterLinearStage { let selector = (full_idx & 1) == 1; let row_inputs = R1CSCycleInputs::from_trace::( - &shared.bytecode_preprocessing, + &shared.program, &shared.trace, step_idx, ); @@ -1056,7 +1049,7 @@ impl OuterLinearStage { let time_step_idx = full_idx >> 1; let row_inputs = R1CSCycleInputs::from_trace::( - &shared.bytecode_preprocessing, + &shared.program, &shared.trace, time_step_idx, ); @@ -1087,7 +1080,7 @@ impl OuterLinearStage { let selector = (full_idx & 1) == 1; let row_inputs = R1CSCycleInputs::from_trace::( - &shared.bytecode_preprocessing, + &shared.program, &shared.trace, time_step_idx, ); @@ -1168,7 +1161,7 @@ impl OuterLinearStage { let time_step_idx = full_idx >> 1; let row_inputs = R1CSCycleInputs::from_trace::( - &shared.bytecode_preprocessing, + &shared.program, &shared.trace, time_step_idx, ); @@ -1200,7 +1193,7 @@ impl OuterLinearStage { let selector = (full_idx & 1) == 1; let row_inputs = R1CSCycleInputs::from_trace::( - &shared.bytecode_preprocessing, + &shared.program, &shared.trace, time_step_idx, ); @@ -1445,11 +1438,8 @@ impl LinearSumcheckStage for OuterLinearStage { ) { let r_cycle = OuterStreamingProverParams::get_inputs_opening_point(sumcheck_challenges); - let claimed_witness_evals = R1CSEval::compute_claimed_inputs( - &shared.bytecode_preprocessing, - &shared.trace, - &r_cycle, - ); + let claimed_witness_evals = + R1CSEval::compute_claimed_inputs(&shared.program, &shared.trace, &r_cycle); for (i, input) in ALL_R1CS_INPUTS.iter().enumerate() { accumulator.append_virtual( diff --git a/jolt-core/src/zkvm/spartan/shift.rs b/jolt-core/src/zkvm/spartan/shift.rs index fda554039b..77ac0a6f63 100644 --- a/jolt-core/src/zkvm/spartan/shift.rs +++ b/jolt-core/src/zkvm/spartan/shift.rs @@ -24,8 +24,8 @@ use crate::subprotocols::sumcheck_claim::{ use crate::subprotocols::sumcheck_prover::SumcheckInstanceProver; use crate::subprotocols::sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}; use crate::transcripts::Transcript; -use crate::zkvm::bytecode::BytecodePreprocessing; use crate::zkvm::instruction::{CircuitFlags, InstructionFlags}; +use crate::zkvm::program::ProgramPreprocessing; use crate::zkvm::r1cs::inputs::ShiftSumcheckCycleState; use crate::zkvm::witness::VirtualPolynomial; use rayon::prelude::*; @@ -150,10 +150,9 @@ impl ShiftSumcheckProver { pub fn initialize( params: ShiftSumcheckParams, trace: Arc>, - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, ) -> Self { - let phase = - ShiftSumcheckPhase::Phase1(Phase1State::gen(trace, bytecode_preprocessing, ¶ms)); + let phase = ShiftSumcheckPhase::Phase1(Phase1State::gen(trace, program, ¶ms)); Self { phase, params } } } @@ -184,7 +183,7 @@ impl SumcheckInstanceProver for ShiftSumcheck sumcheck_challenges.push(r_j); self.phase = ShiftSumcheckPhase::Phase2(Phase2State::gen( &state.trace, - &state.bytecode_preprocessing, + &state.program, &sumcheck_challenges, &self.params, )); @@ -469,14 +468,14 @@ struct Phase1State { #[allocative(skip)] trace: Arc>, #[allocative(skip)] - bytecode_preprocessing: BytecodePreprocessing, + program: ProgramPreprocessing, sumcheck_challenges: Vec, } impl Phase1State { fn gen( trace: Arc>, - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, params: &ShiftSumcheckParams, ) -> Self { let EqPlusOnePrefixSuffixPoly { @@ -541,7 +540,7 @@ impl Phase1State { is_virtual, is_first_in_sequence, is_noop, - } = ShiftSumcheckCycleState::new(&trace[x], bytecode_preprocessing); + } = ShiftSumcheckCycleState::new(&trace[x], program); let mut v = F::from_u64(unexpanded_pc) + params.gamma_powers[1].mul_u64(pc); @@ -591,7 +590,7 @@ impl Phase1State { Self { prefix_suffix_pairs, trace, - bytecode_preprocessing: bytecode_preprocessing.clone(), + program: program.clone(), sumcheck_challenges: Vec::new(), } } @@ -648,7 +647,7 @@ struct Phase2State { impl Phase2State { fn gen( trace: &[Cycle], - bytecode_preprocessing: &BytecodePreprocessing, + program: &ProgramPreprocessing, sumcheck_challenges: &[F::Challenge], params: &ShiftSumcheckParams, ) -> Self { @@ -722,7 +721,7 @@ impl Phase2State { is_virtual, is_first_in_sequence, is_noop, - } = ShiftSumcheckCycleState::new(cycle, bytecode_preprocessing); + } = ShiftSumcheckCycleState::new(cycle, program); let eq_eval = eq_evals[i]; unexpanded_pc_eval_unreduced += eq_eval.mul_u64_unreduced(unexpanded_pc); pc_eval_unreduced += eq_eval.mul_u64_unreduced(pc); diff --git a/jolt-core/src/zkvm/tests.rs b/jolt-core/src/zkvm/tests.rs new file mode 100644 index 0000000000..1242c4eb88 --- /dev/null +++ b/jolt-core/src/zkvm/tests.rs @@ -0,0 +1,950 @@ +//! End-to-end test infrastructure for Jolt ZKVM. +//! +//! This module provides a unified test runner that reduces boilerplate across e2e tests. +//! Tests can be configured via `E2ETestConfig` to vary: +//! - Program (fibonacci, sha2, etc.) +//! - ProgramMode (Full vs Committed) +//! - DoryLayout (CycleMajor vs AddressMajor) +//! - Trace size +//! - Advice (trusted/untrusted) + +use std::sync::Arc; + +use ark_bn254::Fr; +use serial_test::serial; + +use crate::host; +use crate::poly::commitment::commitment_scheme::CommitmentScheme; +use crate::poly::commitment::dory::{DoryCommitmentScheme, DoryContext, DoryGlobals, DoryLayout}; +use crate::poly::multilinear_polynomial::MultilinearPolynomial; +use crate::poly::opening_proof::{OpeningAccumulator, SumcheckId}; +use crate::zkvm::bytecode::chunks::total_lanes; +use crate::zkvm::claim_reductions::AdviceKind; +use crate::zkvm::config::ProgramMode; +use crate::zkvm::program::ProgramPreprocessing; +use crate::zkvm::prover::JoltProverPreprocessing; +use crate::zkvm::ram::populate_memory_states; +use crate::zkvm::verifier::{JoltSharedPreprocessing, JoltVerifier, JoltVerifierPreprocessing}; +use crate::zkvm::witness::CommittedPolynomial; +use crate::zkvm::{RV64IMACProver, RV64IMACVerifier}; + +/// Configuration for an end-to-end test. +#[derive(Clone)] +pub struct E2ETestConfig { + /// Guest program name (e.g., "fibonacci-guest", "sha2-guest") + pub program_name: &'static str, + /// Serialized inputs to pass to the guest + pub inputs: Vec, + /// Maximum padded trace length (must be power of 2) + pub max_trace_length: usize, + /// Whether to use Committed program mode (vs Full) + pub committed_program: bool, + /// Dory layout override (None = use default CycleMajor) + pub dory_layout: Option, + /// Trusted advice bytes + pub trusted_advice: Vec, + /// Untrusted advice bytes + pub untrusted_advice: Vec, + /// Expected output bytes (None = don't verify output) + pub expected_output: Option>, +} + +impl Default for E2ETestConfig { + fn default() -> Self { + Self { + program_name: "fibonacci-guest", + inputs: postcard::to_stdvec(&100u32).unwrap(), + max_trace_length: 1 << 16, + committed_program: false, + dory_layout: None, + trusted_advice: vec![], + untrusted_advice: vec![], + expected_output: None, + } + } +} + +impl E2ETestConfig { + // ======================================================================== + // Program Constructors + // ======================================================================== + + /// Create config for fibonacci with custom input. + pub fn fibonacci(n: u32) -> Self { + Self { + inputs: postcard::to_stdvec(&n).unwrap(), + ..Default::default() + } + } + + /// Create config for sha2 (with default 32-byte input). + pub fn sha2() -> Self { + Self { + program_name: "sha2-guest", + inputs: postcard::to_stdvec(&[5u8; 32]).unwrap(), + expected_output: Some(vec![ + 0x28, 0x9b, 0xdf, 0x82, 0x9b, 0x4a, 0x30, 0x26, 0x7, 0x9a, 0x3e, 0xa0, 0x89, 0x73, + 0xb1, 0x97, 0x2d, 0x12, 0x4e, 0x7e, 0xaf, 0x22, 0x33, 0xc6, 0x3, 0x14, 0x3d, 0xc6, + 0x3b, 0x50, 0xd2, 0x57, + ]), + ..Default::default() + } + } + + /// Create config for sha3 (with default 32-byte input). + pub fn sha3() -> Self { + Self { + program_name: "sha3-guest", + inputs: postcard::to_stdvec(&[5u8; 32]).unwrap(), + expected_output: Some(vec![ + 0xd0, 0x3, 0x5c, 0x96, 0x86, 0x6e, 0xe2, 0x2e, 0x81, 0xf5, 0xc4, 0xef, 0xbd, 0x88, + 0x33, 0xc1, 0x7e, 0xa1, 0x61, 0x10, 0x81, 0xfc, 0xd7, 0xa3, 0xdd, 0xce, 0xce, 0x7f, + 0x44, 0x72, 0x4, 0x66, + ]), + ..Default::default() + } + } + + /// Create config for merkle-tree guest. + /// Default: 4 leaves with input=[5;32], trusted=[6;32,7;32], untrusted=[8;32] + pub fn merkle_tree() -> Self { + let inputs = postcard::to_stdvec(&[5u8; 32].as_slice()).unwrap(); + let untrusted_advice = postcard::to_stdvec(&[8u8; 32]).unwrap(); + let mut trusted_advice = postcard::to_stdvec(&[6u8; 32]).unwrap(); + trusted_advice.extend(postcard::to_stdvec(&[7u8; 32]).unwrap()); + + Self { + program_name: "merkle-tree-guest", + inputs, + trusted_advice, + untrusted_advice, + expected_output: Some(vec![ + 0xb4, 0x37, 0x0f, 0x3a, 0xb, 0x3d, 0x38, 0xa8, 0x7a, 0x6c, 0x4c, 0x46, 0x9, 0xe7, + 0x83, 0xb3, 0xcc, 0xb7, 0x1c, 0x30, 0x1f, 0xf8, 0x54, 0xd, 0xf7, 0xdd, 0xc8, 0x42, + 0x32, 0xbb, 0x16, 0xd7, + ]), + ..Default::default() + } + } + + /// Create config for memory-ops guest (no inputs). + pub fn memory_ops() -> Self { + Self { + program_name: "memory-ops-guest", + inputs: vec![], + ..Default::default() + } + } + + /// Create config for btreemap guest. + pub fn btreemap(n: u32) -> Self { + Self { + program_name: "btreemap-guest", + inputs: postcard::to_stdvec(&n).unwrap(), + ..Default::default() + } + } + + /// Create config for muldiv guest. + pub fn muldiv(a: u32, b: u32, c: u32) -> Self { + Self { + program_name: "muldiv-guest", + inputs: postcard::to_stdvec(&[a, b, c]).unwrap(), + ..Default::default() + } + } + + // ======================================================================== + // Builder Methods + // ======================================================================== + + /// Set committed program mode. + pub fn with_committed_program(mut self) -> Self { + self.committed_program = true; + self + } + + /// Set Dory layout. + pub fn with_dory_layout(mut self, layout: DoryLayout) -> Self { + self.dory_layout = Some(layout); + self + } + + /// Set small trace (256 cycles). + pub fn with_small_trace(mut self) -> Self { + self.max_trace_length = 256; + self + } + + /// Set custom max trace length. + #[allow(dead_code)] // API for future tests + pub fn with_max_trace_length(mut self, len: usize) -> Self { + self.max_trace_length = len; + self + } + + /// Set trusted advice bytes. + pub fn with_trusted_advice(mut self, advice: Vec) -> Self { + self.trusted_advice = advice; + self + } + + /// Set untrusted advice bytes. + pub fn with_untrusted_advice(mut self, advice: Vec) -> Self { + self.untrusted_advice = advice; + self + } + + /// Set expected output for verification. + #[allow(dead_code)] // API for future tests + pub fn expecting_output(mut self, output: Vec) -> Self { + self.expected_output = Some(output); + self + } + + /// Clear expected output (don't verify). + #[allow(dead_code)] // API for future tests + pub fn without_output_check(mut self) -> Self { + self.expected_output = None; + self + } +} + +/// Run an end-to-end test with the given configuration. +/// +/// This handles all axes of variation: +/// - Program selection +/// - Bytecode mode (Full vs Committed) +/// - Dory layout (CycleMajor vs AddressMajor) +/// - Trusted/untrusted advice (computes commitment if non-empty) +/// - Maximum padded trace length +pub fn run_e2e_test(config: E2ETestConfig) { + // Setup Dory globals + DoryGlobals::reset(); + if let Some(layout) = config.dory_layout { + DoryGlobals::set_layout(layout); + } + + // Decode and trace program + let mut program = host::Program::new(config.program_name); + let (instructions, init_memory_state, _) = program.decode(); + let (_, _, _, io_device) = program.trace( + &config.inputs, + &config.untrusted_advice, + &config.trusted_advice, + ); + + // Preprocess bytecode and program image + let program_data = Arc::new(ProgramPreprocessing::preprocess( + instructions, + init_memory_state, + )); + let shared_preprocessing = JoltSharedPreprocessing::new( + program_data.meta(), + io_device.memory_layout.clone(), + config.max_trace_length, + ); + + // Create prover preprocessing (mode-dependent) + let prover_preprocessing = if config.committed_program { + JoltProverPreprocessing::new_committed( + shared_preprocessing.clone(), + Arc::clone(&program_data), + ) + } else { + JoltProverPreprocessing::new(shared_preprocessing.clone(), Arc::clone(&program_data)) + }; + + // Verify mode is correct + assert_eq!( + prover_preprocessing.is_committed_mode(), + config.committed_program, + "Prover mode mismatch" + ); + + // Compute trusted advice commitment if advice is provided + let (trusted_commitment, trusted_hint) = if !config.trusted_advice.is_empty() { + let (c, h) = + commit_trusted_advice_preprocessing_only(&prover_preprocessing, &config.trusted_advice); + (Some(c), Some(h)) + } else { + (None, None) + }; + + // Create prover and prove + let elf_contents = program.get_elf_contents().expect("elf contents is None"); + let program_mode = if config.committed_program { + ProgramMode::Committed + } else { + ProgramMode::Full + }; + let prover = RV64IMACProver::gen_from_elf_with_program_mode( + &prover_preprocessing, + &elf_contents, + &config.inputs, + &config.untrusted_advice, + &config.trusted_advice, + trusted_commitment, + trusted_hint, + program_mode, + ); + let io_device = prover.program_io.clone(); + let (jolt_proof, debug_info) = prover.prove(); + assert_eq!(jolt_proof.program_mode, program_mode); + + // Create verifier preprocessing from prover (respects mode) + let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); + + // Verify mode propagated correctly + assert_eq!( + verifier_preprocessing.program.is_committed(), + config.committed_program, + "Verifier mode mismatch" + ); + + // Verify + let verifier = RV64IMACVerifier::new( + &verifier_preprocessing, + jolt_proof, + io_device.clone(), + trusted_commitment, + debug_info, + ) + .expect("Failed to create verifier"); + verifier.verify().expect("Verification failed"); + + // Check expected output if specified + if let Some(expected) = config.expected_output { + assert_eq!( + io_device.outputs, expected, + "Output mismatch for program '{}'", + config.program_name + ); + } +} + +/// Helper to commit trusted advice during preprocessing. +fn commit_trusted_advice_preprocessing_only( + preprocessing: &JoltProverPreprocessing, + trusted_advice_bytes: &[u8], +) -> ( + ::Commitment, + ::OpeningProofHint, +) { + let max_trusted_advice_size = preprocessing.shared.memory_layout.max_trusted_advice_size; + let mut trusted_advice_words = vec![0u64; (max_trusted_advice_size as usize) / 8]; + populate_memory_states( + 0, + trusted_advice_bytes, + Some(&mut trusted_advice_words), + None, + ); + + let poly = MultilinearPolynomial::::from(trusted_advice_words); + let advice_len = poly.len().next_power_of_two().max(1); + + let _guard = DoryGlobals::initialize_context(1, advice_len, DoryContext::TrustedAdvice, None); + let (commitment, hint) = { + let _ctx = DoryGlobals::with_context(DoryContext::TrustedAdvice); + DoryCommitmentScheme::commit(&poly, &preprocessing.generators) + }; + (commitment, hint) +} + +#[test] +#[serial] +fn fib_e2e() { + run_e2e_test(E2ETestConfig::default()); +} + +#[test] +#[serial] +fn fib_e2e_small_trace() { + run_e2e_test(E2ETestConfig::fibonacci(5).with_small_trace()); +} + +#[test] +#[serial] +fn sha2_e2e() { + #[cfg(feature = "host")] + use jolt_inlines_sha2 as _; + run_e2e_test(E2ETestConfig::sha2()); +} + +#[test] +#[serial] +fn sha3_e2e() { + #[cfg(feature = "host")] + use jolt_inlines_keccak256 as _; + run_e2e_test(E2ETestConfig::sha3()); +} + +#[test] +#[serial] +fn sha2_with_unused_advice_e2e() { + // SHA2 guest does not consume advice, but providing both trusted and untrusted advice + // should still work correctly through the full pipeline. + #[cfg(feature = "host")] + use jolt_inlines_sha2 as _; + + run_e2e_test( + E2ETestConfig::sha2() + .with_trusted_advice(postcard::to_stdvec(&[7u8; 32]).unwrap()) + .with_untrusted_advice(postcard::to_stdvec(&[9u8; 32]).unwrap()), + ); +} + +#[test] +#[serial] +fn advice_merkle_tree_e2e() { + run_e2e_test(E2ETestConfig::merkle_tree()); +} + +#[test] +#[serial] +fn memory_ops_e2e() { + run_e2e_test(E2ETestConfig::memory_ops()); +} + +#[test] +#[serial] +fn btreemap_e2e() { + run_e2e_test(E2ETestConfig::btreemap(50)); +} + +#[test] +#[serial] +fn muldiv_e2e() { + run_e2e_test(E2ETestConfig::muldiv(9, 5, 3)); +} + +#[test] +#[serial] +fn fib_e2e_address_major() { + run_e2e_test(E2ETestConfig::default().with_dory_layout(DoryLayout::AddressMajor)); +} + +#[test] +#[serial] +fn advice_merkle_tree_e2e_address_major() { + run_e2e_test(E2ETestConfig::merkle_tree().with_dory_layout(DoryLayout::AddressMajor)); +} + +// ============================================================================ +// New Tests - Committed Program Mode +// +// These tests exercise the end-to-end committed program path (bytecode + program image). +// ============================================================================ + +#[test] +#[serial] +fn fib_e2e_committed_program() { + run_e2e_test(E2ETestConfig::default().with_committed_program()); +} + +#[test] +#[serial] +fn fib_e2e_committed_program_address_major() { + run_e2e_test( + E2ETestConfig::default() + .with_committed_program() + .with_dory_layout(DoryLayout::AddressMajor), + ); +} + +#[test] +#[serial] +fn fib_e2e_committed_small_trace() { + // Committed mode with minimal trace (256 cycles). + // Tests program image commitment when trace is smaller than bytecode. + run_e2e_test( + E2ETestConfig::fibonacci(5) + .with_small_trace() + .with_committed_program(), + ); +} + +#[test] +#[serial] +fn fib_e2e_committed_small_trace_address_major() { + run_e2e_test( + E2ETestConfig::fibonacci(5) + .with_small_trace() + .with_committed_program() + .with_dory_layout(DoryLayout::AddressMajor), + ); +} + +#[test] +#[serial] +fn sha2_e2e_committed_program() { + // Larger program with committed mode (tests program image commitment with larger ELF). + #[cfg(feature = "host")] + use jolt_inlines_sha2 as _; + run_e2e_test(E2ETestConfig::sha2().with_committed_program()); +} + +#[test] +#[serial] +fn sha2_e2e_committed_program_address_major() { + #[cfg(feature = "host")] + use jolt_inlines_sha2 as _; + run_e2e_test( + E2ETestConfig::sha2() + .with_committed_program() + .with_dory_layout(DoryLayout::AddressMajor), + ); +} + +#[test] +#[serial] +fn sha3_e2e_committed_program() { + // Another larger program for committed mode coverage. + #[cfg(feature = "host")] + use jolt_inlines_keccak256 as _; + run_e2e_test(E2ETestConfig::sha3().with_committed_program()); +} + +#[test] +#[serial] +fn merkle_tree_e2e_committed_program() { + // Committed mode with both trusted and untrusted advice. + // Tests interaction of program image commitment with advice claim reductions. + run_e2e_test(E2ETestConfig::merkle_tree().with_committed_program()); +} + +#[test] +#[serial] +fn merkle_tree_e2e_committed_program_address_major() { + run_e2e_test( + E2ETestConfig::merkle_tree() + .with_committed_program() + .with_dory_layout(DoryLayout::AddressMajor), + ); +} + +#[test] +#[serial] +fn memory_ops_e2e_committed_program() { + // Memory-ops guest exercises various load/store patterns. + // Tests committed mode with diverse memory access patterns. + run_e2e_test(E2ETestConfig::memory_ops().with_committed_program()); +} + +#[test] +#[serial] +fn btreemap_e2e_committed_program() { + // BTreeMap guest has complex heap allocations. + run_e2e_test(E2ETestConfig::btreemap(50).with_committed_program()); +} + +#[test] +#[serial] +fn muldiv_e2e_committed_program() { + // Mul/div operations in committed mode. + run_e2e_test(E2ETestConfig::muldiv(9, 5, 3).with_committed_program()); +} + +#[test] +#[serial] +fn fib_e2e_committed_large_trace() { + // Larger trace length (2^17) in committed mode. + // Tests bytecode chunking with log_k_chunk=8 (256 lanes per chunk). + run_e2e_test( + E2ETestConfig::fibonacci(1000) + .with_max_trace_length(1 << 17) + .with_committed_program(), + ); +} + +#[test] +#[serial] +fn fib_e2e_committed_large_trace_address_major() { + run_e2e_test( + E2ETestConfig::fibonacci(1000) + .with_max_trace_length(1 << 17) + .with_committed_program() + .with_dory_layout(DoryLayout::AddressMajor), + ); +} + +#[test] +#[serial] +fn sha2_committed_program_with_advice() { + // SHA2 doesn't consume advice, but providing it should still work in committed mode. + // Tests that program image + bytecode + advice claim reductions all batch correctly. + #[cfg(feature = "host")] + use jolt_inlines_sha2 as _; + run_e2e_test( + E2ETestConfig::sha2() + .with_committed_program() + .with_trusted_advice(postcard::to_stdvec(&[7u8; 32]).unwrap()) + .with_untrusted_advice(postcard::to_stdvec(&[9u8; 32]).unwrap()), + ); +} + +// ============================================================================ +// New Tests - Bytecode Lane Ordering / Chunking +// ============================================================================ + +#[test] +fn bytecode_lane_chunking_counts() { + // Canonical lane spec (see bytecode-commitment-progress.md): + // 3*REGISTER_COUNT (rs1/rs2/rd) + 2 scalars + 13 circuit flags + 7 instr flags + // + 41 lookup selector + 1 raf flag = 448 (with REGISTER_COUNT=128). + assert_eq!(total_lanes(), 448); + assert_eq!(total_lanes().div_ceil(16), 28); + assert_eq!(total_lanes().div_ceil(256), 2); +} + +// ============================================================================ +// New Tests - Program Mode Detection +// ============================================================================ + +#[test] +#[serial] +fn program_mode_detection_full() { + DoryGlobals::reset(); + let mut program = host::Program::new("fibonacci-guest"); + let (instructions, init_memory_state, _) = program.decode(); + let (_, _, _, io_device) = program.trace(&[], &[], &[]); + + let program = Arc::new(ProgramPreprocessing::preprocess( + instructions, + init_memory_state, + )); + let shared = + JoltSharedPreprocessing::new(program.meta(), io_device.memory_layout.clone(), 1 << 16); + + // Full mode + let prover_full: JoltProverPreprocessing = + JoltProverPreprocessing::new(shared.clone(), Arc::clone(&program)); + assert!(!prover_full.is_committed_mode()); + assert!(prover_full.program_commitments.is_none()); + + let verifier_full = JoltVerifierPreprocessing::from(&prover_full); + assert!(verifier_full.program.is_full()); + assert!(!verifier_full.program.is_committed()); + assert!(verifier_full.program.as_full().is_ok()); + assert!(verifier_full.program.as_committed().is_err()); +} + +#[test] +#[serial] +fn program_mode_detection_committed() { + DoryGlobals::reset(); + let mut program = host::Program::new("fibonacci-guest"); + let (instructions, init_memory_state, _) = program.decode(); + let (_, _, _, io_device) = program.trace(&[], &[], &[]); + + let program_data = Arc::new(ProgramPreprocessing::preprocess( + instructions, + init_memory_state, + )); + let shared = JoltSharedPreprocessing::new( + program_data.meta(), + io_device.memory_layout.clone(), + 1 << 16, + ); + + // Committed mode + let prover_committed: JoltProverPreprocessing = + JoltProverPreprocessing::new_committed(shared.clone(), Arc::clone(&program_data)); + assert!(prover_committed.is_committed_mode()); + assert!(prover_committed.program_commitments.is_some()); + + let verifier_committed = JoltVerifierPreprocessing::from(&prover_committed); + assert!(!verifier_committed.program.is_full()); + assert!(verifier_committed.program.is_committed()); + assert!(verifier_committed.program.as_full().is_err()); + assert!(verifier_committed.program.as_committed().is_ok()); + + // Verify committed mode doesn't carry full program data + assert!( + verifier_committed.program.program_image_words().is_none(), + "Committed mode should NOT have program image words" + ); + assert!( + verifier_committed.program.instructions().is_none(), + "Committed mode should NOT have instructions" + ); + assert!( + verifier_committed.program.full().is_none(), + "Committed mode should NOT have full preprocessing" + ); + + // But it should have commitments and metadata + let trusted = verifier_committed.program.as_committed().unwrap(); + assert!( + !trusted.bytecode_commitments.is_empty(), + "Should have bytecode commitments" + ); + assert!( + trusted.bytecode_len > 0, + "Should have bytecode length metadata" + ); + assert!( + trusted.program_image_num_words > 0, + "Should have program image num words metadata" + ); +} + +// ============================================================================ +// Internal and Security Tests +// +// These tests require access to prover internals or manipulate trace/io +// directly for security testing. They cannot use E2ETestConfig. +// ============================================================================ + +#[test] +#[serial] +fn max_advice_with_small_trace() { + DoryGlobals::reset(); + // Tests that max-sized advice (4KB = 512 words) works with a minimal trace. + // With balanced dims (sigma_a=5, nu_a=4 for 512 words), the minimum padded trace + // (256 cycles -> total_vars=12) is sufficient to embed advice. + let mut program = host::Program::new("fibonacci-guest"); + let inputs = postcard::to_stdvec(&5u32).unwrap(); + let trusted_advice = vec![7u8; 4096]; + let untrusted_advice = vec![9u8; 4096]; + + let (instructions, init_memory_state, _) = program.decode(); + let (lazy_trace, trace, final_memory_state, io_device) = + program.trace(&inputs, &untrusted_advice, &trusted_advice); + + let program = Arc::new(ProgramPreprocessing::preprocess( + instructions, + init_memory_state, + )); + let shared_preprocessing = + JoltSharedPreprocessing::new(program.meta(), io_device.memory_layout.clone(), 256); + let prover_preprocessing: JoltProverPreprocessing = + JoltProverPreprocessing::new(shared_preprocessing.clone(), Arc::clone(&program)); + tracing::info!( + "preprocessing.memory_layout.max_trusted_advice_size: {}", + shared_preprocessing.memory_layout.max_trusted_advice_size + ); + + let (trusted_commitment, trusted_hint) = + commit_trusted_advice_preprocessing_only(&prover_preprocessing, &trusted_advice); + + let prover = RV64IMACProver::gen_from_trace( + &prover_preprocessing, + lazy_trace, + trace, + io_device, + Some(trusted_commitment), + Some(trusted_hint), + final_memory_state, + ); + + // Trace is tiny but advice is max-sized + assert!(prover.unpadded_trace_len < 512); + assert_eq!(prover.padded_trace_len, 256); + + let io_device = prover.program_io.clone(); + let (jolt_proof, debug_info) = prover.prove(); + + let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); + RV64IMACVerifier::new( + &verifier_preprocessing, + jolt_proof, + io_device, + Some(trusted_commitment), + debug_info, + ) + .expect("Failed to create verifier") + .verify() + .expect("Verification failed"); +} + +#[test] +#[serial] +fn advice_opening_point_derives_from_unified_point() { + DoryGlobals::reset(); + // Tests that advice opening points are correctly derived from the unified main opening + // point using Dory's balanced dimension policy. + // + // For a small trace (256 cycles), the advice row coordinates span both Stage 6 (cycle) + // and Stage 7 (address) challenges, verifying the two-phase reduction works correctly. + let mut program = host::Program::new("fibonacci-guest"); + let inputs = postcard::to_stdvec(&5u32).unwrap(); + let trusted_advice = postcard::to_stdvec(&[7u8; 32]).unwrap(); + let untrusted_advice = postcard::to_stdvec(&[9u8; 32]).unwrap(); + + let (instructions, init_memory_state, _) = program.decode(); + let (lazy_trace, trace, final_memory_state, io_device) = + program.trace(&inputs, &untrusted_advice, &trusted_advice); + + let program = Arc::new(ProgramPreprocessing::preprocess( + instructions, + init_memory_state, + )); + let shared_preprocessing = + JoltSharedPreprocessing::new(program.meta(), io_device.memory_layout.clone(), 1 << 16); + let prover_preprocessing: JoltProverPreprocessing = + JoltProverPreprocessing::new(shared_preprocessing.clone(), Arc::clone(&program)); + let (trusted_commitment, trusted_hint) = + commit_trusted_advice_preprocessing_only(&prover_preprocessing, &trusted_advice); + + let prover = RV64IMACProver::gen_from_trace( + &prover_preprocessing, + lazy_trace, + trace, + io_device, + Some(trusted_commitment), + Some(trusted_hint), + final_memory_state, + ); + + assert_eq!(prover.padded_trace_len, 256, "test expects small trace"); + + let io_device = prover.program_io.clone(); + let (jolt_proof, debug_info) = prover.prove(); + let debug_info = debug_info.expect("expected debug_info in tests"); + + // Get unified opening point and derive expected advice point + let (opening_point, _) = debug_info + .opening_accumulator + .get_committed_polynomial_opening( + CommittedPolynomial::InstructionRa(0), + SumcheckId::HammingWeightClaimReduction, + ); + let mut point_dory_le = opening_point.r.clone(); + point_dory_le.reverse(); + + let total_vars = point_dory_le.len(); + let (sigma_main, _nu_main) = DoryGlobals::balanced_sigma_nu(total_vars); + let (sigma_a, nu_a) = DoryGlobals::advice_sigma_nu_from_max_bytes( + prover_preprocessing + .shared + .memory_layout + .max_trusted_advice_size as usize, + ); + + // Build expected advice point: [col_bits[0..sigma_a] || row_bits[0..nu_a]] + let mut expected_advice_le: Vec<_> = point_dory_le[0..sigma_a].to_vec(); + expected_advice_le.extend_from_slice(&point_dory_le[sigma_main..sigma_main + nu_a]); + + // Verify both advice types derive the same opening point + for (name, kind) in [ + ("trusted", AdviceKind::Trusted), + ("untrusted", AdviceKind::Untrusted), + ] { + let get_fn = debug_info + .opening_accumulator + .get_advice_opening(kind, SumcheckId::AdviceClaimReduction); + assert!( + get_fn.is_some(), + "{name} advice opening missing for AdviceClaimReductionPhase2" + ); + let (point_be, _) = get_fn.unwrap(); + let mut point_le = point_be.r.clone(); + point_le.reverse(); + assert_eq!(point_le, expected_advice_le, "{name} advice point mismatch"); + } + + // Verify end-to-end + let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); + RV64IMACVerifier::new( + &verifier_preprocessing, + jolt_proof, + io_device, + Some(trusted_commitment), + Some(debug_info), + ) + .expect("Failed to create verifier") + .verify() + .expect("Verification failed"); +} + +#[test] +#[serial] +#[should_panic] +fn truncated_trace() { + let mut program = host::Program::new("fibonacci-guest"); + let (instructions, init_memory_state, _) = program.decode(); + let inputs = postcard::to_stdvec(&9u8).unwrap(); + let (lazy_trace, mut trace, final_memory_state, mut program_io) = + program.trace(&inputs, &[], &[]); + trace.truncate(100); + program_io.outputs[0] = 0; // change the output to 0 + + let program = Arc::new(ProgramPreprocessing::preprocess( + instructions, + init_memory_state, + )); + let shared_preprocessing = + JoltSharedPreprocessing::new(program.meta(), program_io.memory_layout.clone(), 1 << 16); + + let prover_preprocessing: JoltProverPreprocessing = + JoltProverPreprocessing::new(shared_preprocessing.clone(), Arc::clone(&program)); + + let prover = RV64IMACProver::gen_from_trace( + &prover_preprocessing, + lazy_trace, + trace, + program_io.clone(), + None, + None, + final_memory_state, + ); + + let (proof, _) = prover.prove(); + + let verifier_preprocessing = JoltVerifierPreprocessing::new_full( + prover_preprocessing.shared.clone(), + prover_preprocessing.generators.to_verifier_setup(), + Arc::clone(&prover_preprocessing.program), + ); + let verifier = + RV64IMACVerifier::new(&verifier_preprocessing, proof, program_io, None, None).unwrap(); + verifier.verify().unwrap(); +} + +#[test] +#[serial] +#[should_panic] +fn malicious_trace() { + let mut program = host::Program::new("fibonacci-guest"); + let inputs = postcard::to_stdvec(&1u8).unwrap(); + let (instructions, init_memory_state, _) = program.decode(); + let (lazy_trace, trace, final_memory_state, mut program_io) = program.trace(&inputs, &[], &[]); + + let program = Arc::new(ProgramPreprocessing::preprocess( + instructions, + init_memory_state, + )); + + // Since the preprocessing is done with the original memory layout, the verifier should fail + let shared_preprocessing = + JoltSharedPreprocessing::new(program.meta(), program_io.memory_layout.clone(), 1 << 16); + let prover_preprocessing: JoltProverPreprocessing = + JoltProverPreprocessing::new(shared_preprocessing.clone(), Arc::clone(&program)); + + // change memory address of output & termination bit to the same address as input + // changes here should not be able to spoof the verifier result + program_io.memory_layout.output_start = program_io.memory_layout.input_start; + program_io.memory_layout.output_end = program_io.memory_layout.input_end; + program_io.memory_layout.termination = program_io.memory_layout.input_start; + + let prover = RV64IMACProver::gen_from_trace( + &prover_preprocessing, + lazy_trace, + trace, + program_io.clone(), + None, + None, + final_memory_state, + ); + let (proof, _) = prover.prove(); + + let verifier_preprocessing = JoltVerifierPreprocessing::new_full( + prover_preprocessing.shared.clone(), + prover_preprocessing.generators.to_verifier_setup(), + Arc::clone(&prover_preprocessing.program), + ); + let verifier = + JoltVerifier::new(&verifier_preprocessing, proof, program_io, None, None).unwrap(); + verifier.verify().unwrap(); +} diff --git a/jolt-core/src/zkvm/verifier.rs b/jolt-core/src/zkvm/verifier.rs index 5fa2ae78fb..9f34d818a8 100644 --- a/jolt-core/src/zkvm/verifier.rs +++ b/jolt-core/src/zkvm/verifier.rs @@ -7,22 +7,31 @@ use std::sync::Arc; use crate::poly::commitment::commitment_scheme::CommitmentScheme; use crate::poly::commitment::dory::{DoryContext, DoryGlobals}; use crate::subprotocols::sumcheck::BatchedSumcheck; -use crate::zkvm::bytecode::BytecodePreprocessing; +use crate::zkvm::bytecode::chunks::total_lanes; use crate::zkvm::claim_reductions::advice::ReductionPhase; use crate::zkvm::claim_reductions::RegistersClaimReductionSumcheckVerifier; use crate::zkvm::config::OneHotParams; +use crate::zkvm::config::ProgramMode; +use crate::zkvm::program::{ + ProgramMetadata, ProgramPreprocessing, TrustedProgramCommitments, VerifierProgram, +}; #[cfg(feature = "prover")] use crate::zkvm::prover::JoltProverPreprocessing; use crate::zkvm::ram::val_final::ValFinalSumcheckVerifier; -use crate::zkvm::ram::RAMPreprocessing; +use crate::zkvm::ram::verifier_accumulate_program_image; use crate::zkvm::witness::all_committed_polynomials; use crate::zkvm::Serializable; use crate::zkvm::{ - bytecode::read_raf_checking::BytecodeReadRafSumcheckVerifier, + bytecode::read_raf_checking::{ + BytecodeReadRafAddressSumcheckVerifier, BytecodeReadRafCycleSumcheckVerifier, + BytecodeReadRafSumcheckParams, + }, claim_reductions::{ - AdviceClaimReductionVerifier, AdviceKind, HammingWeightClaimReductionVerifier, - IncClaimReductionSumcheckVerifier, InstructionLookupsClaimReductionSumcheckVerifier, - RamRaClaimReductionSumcheckVerifier, + AdviceClaimReductionVerifier, AdviceKind, BytecodeClaimReductionParams, + BytecodeClaimReductionVerifier, BytecodeReductionPhase, + HammingWeightClaimReductionVerifier, IncClaimReductionSumcheckVerifier, + InstructionLookupsClaimReductionSumcheckVerifier, ProgramImageClaimReductionParams, + ProgramImageClaimReductionVerifier, RamRaClaimReductionSumcheckVerifier, }, fiat_shamir_preamble, instruction_lookups::{ @@ -58,7 +67,10 @@ use crate::{ }, pprof_scope, subprotocols::{ - booleanity::{BooleanitySumcheckParams, BooleanitySumcheckVerifier}, + booleanity::{ + BooleanityAddressSumcheckVerifier, BooleanityCycleSumcheckVerifier, + BooleanitySumcheckParams, + }, sumcheck_verifier::SumcheckInstanceVerifier, }, transcripts::Transcript, @@ -69,7 +81,6 @@ use anyhow::Context; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use common::jolt_device::MemoryLayout; use itertools::Itertools; -use tracer::instruction::Instruction; use tracer::JoltDevice; pub struct JoltVerifier< @@ -90,6 +101,13 @@ pub struct JoltVerifier< /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). /// Cache the verifier state here between stages. advice_reduction_verifier_untrusted: Option>, + /// The bytecode claim reduction sumcheck effectively spans two stages (6b and 7). + /// Cache the verifier state here between stages. + bytecode_reduction_verifier: Option>, + /// Bytecode read RAF params, cached between Stage 6a and 6b. + bytecode_read_raf_params: Option>, + /// Booleanity params, cached between Stage 6a and 6b. + booleanity_params: Option>, pub spartan_key: UniformSpartanKey, pub one_hot_params: OneHotParams, } @@ -162,6 +180,31 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc let one_hot_params = OneHotParams::from_config(&proof.one_hot_config, proof.bytecode_K, proof.ram_K); + if proof.program_mode == ProgramMode::Committed { + let committed = preprocessing.program.as_committed()?; + if committed.log_k_chunk != proof.one_hot_config.log_k_chunk { + return Err(ProofVerifyError::InvalidBytecodeConfig(format!( + "bytecode log_k_chunk mismatch: commitments={}, proof={}", + committed.log_k_chunk, proof.one_hot_config.log_k_chunk + ))); + } + if committed.bytecode_len != preprocessing.shared.bytecode_size() { + return Err(ProofVerifyError::InvalidBytecodeConfig(format!( + "bytecode length mismatch: commitments={}, shared={}", + committed.bytecode_len, + preprocessing.shared.bytecode_size() + ))); + } + let k_chunk = 1usize << (committed.log_k_chunk as usize); + let expected_chunks = total_lanes().div_ceil(k_chunk); + if committed.bytecode_commitments.len() != expected_chunks { + return Err(ProofVerifyError::InvalidBytecodeConfig(format!( + "expected {expected_chunks} bytecode commitments, got {}", + committed.bytecode_commitments.len() + ))); + } + } + Ok(Self { trusted_advice_commitment, program_io, @@ -171,6 +214,9 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc opening_accumulator, advice_reduction_verifier_trusted: None, advice_reduction_verifier_untrusted: None, + bytecode_reduction_verifier: None, + bytecode_read_raf_params: None, + booleanity_params: None, spartan_key, one_hot_params, }) @@ -201,13 +247,22 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc self.transcript .append_serializable(trusted_advice_commitment); } + if self.proof.program_mode == ProgramMode::Committed { + let trusted = self.preprocessing.program.as_committed()?; + for commitment in &trusted.bytecode_commitments { + self.transcript.append_serializable(commitment); + } + self.transcript + .append_serializable(&trusted.program_image_commitment); + } self.verify_stage1()?; self.verify_stage2()?; self.verify_stage3()?; self.verify_stage4()?; self.verify_stage5()?; - self.verify_stage6()?; + self.verify_stage6a()?; + self.verify_stage6b()?; self.verify_stage7()?; self.verify_stage8()?; @@ -332,24 +387,41 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc .rw_config .needs_single_advice_opening(self.proof.trace_length.log_2()), ); + if self.proof.program_mode == ProgramMode::Committed { + verifier_accumulate_program_image::( + self.proof.ram_K, + &self.program_io, + &mut self.opening_accumulator, + &mut self.transcript, + self.proof + .rw_config + .needs_single_advice_opening(self.proof.trace_length.log_2()), + ); + } let registers_read_write_checking = RegistersReadWriteCheckingVerifier::new( self.proof.trace_length, &self.opening_accumulator, &mut self.transcript, &self.proof.rw_config, ); + // In Full mode, get the program image words from the preprocessing + let program_image_words = self.preprocessing.program.program_image_words(); let ram_val_evaluation = RamValEvaluationSumcheckVerifier::new( - &self.preprocessing.shared.ram, + &self.preprocessing.shared.program_meta, + program_image_words, &self.program_io, self.proof.trace_length, self.proof.ram_K, + self.proof.program_mode, &self.opening_accumulator, ); let ram_val_final = ValFinalSumcheckVerifier::new( - &self.preprocessing.shared.ram, + &self.preprocessing.shared.program_meta, + program_image_words, &self.program_io, self.proof.trace_length, self.proof.ram_K, + self.proof.program_mode, &self.opening_accumulator, &self.proof.rw_config, ); @@ -401,26 +473,65 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc Ok(()) } - fn verify_stage6(&mut self) -> Result<(), anyhow::Error> { + fn verify_stage6a(&mut self) -> Result<(), anyhow::Error> { let n_cycle_vars = self.proof.trace_length.log_2(); - let bytecode_read_raf = BytecodeReadRafSumcheckVerifier::gen( - &self.preprocessing.shared.bytecode, + let program_preprocessing = match self.proof.program_mode { + ProgramMode::Committed => { + // Ensure we have committed program commitments for committed mode. + let _ = self.preprocessing.program.as_committed()?; + None + } + ProgramMode::Full => self.preprocessing.program.full().map(|p| p.as_ref()), + }; + let bytecode_read_raf = BytecodeReadRafAddressSumcheckVerifier::new( + program_preprocessing, n_cycle_vars, &self.one_hot_params, &self.opening_accumulator, &mut self.transcript, - ); - - let ram_hamming_booleanity = - HammingBooleanitySumcheckVerifier::new(&self.opening_accumulator); + self.proof.program_mode, + )?; let booleanity_params = BooleanitySumcheckParams::new( n_cycle_vars, &self.one_hot_params, &self.opening_accumulator, &mut self.transcript, ); + let booleanity = BooleanityAddressSumcheckVerifier::new(booleanity_params); + + let instances: Vec<&dyn SumcheckInstanceVerifier> = + vec![&bytecode_read_raf, &booleanity]; - let booleanity = BooleanitySumcheckVerifier::new(booleanity_params); + let _r_stage6a = BatchedSumcheck::verify( + &self.proof.stage6a_sumcheck_proof, + instances, + &mut self.opening_accumulator, + &mut self.transcript, + ) + .context("Stage 6a")?; + + // Store params for Stage 6b + self.bytecode_read_raf_params = Some(bytecode_read_raf.into_params()); + self.booleanity_params = Some(booleanity.into_params()); + + Ok(()) + } + + fn verify_stage6b(&mut self) -> Result<(), anyhow::Error> { + // Take params cached from Stage 6a + let bytecode_read_raf_params = self + .bytecode_read_raf_params + .take() + .expect("bytecode_read_raf_params must be set by verify_stage6a"); + let booleanity_params = self + .booleanity_params + .take() + .expect("booleanity_params must be set by verify_stage6a"); + + // Initialize Stage 6b cycle verifiers from scratch (Option B). + let booleanity = BooleanityCycleSumcheckVerifier::new(booleanity_params); + let ram_hamming_booleanity = + HammingBooleanitySumcheckVerifier::new(&self.opening_accumulator); let ram_ra_virtual = RamRaVirtualSumcheckVerifier::new( self.proof.trace_length, &self.one_hot_params, @@ -438,7 +549,26 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc &mut self.transcript, ); - // Advice claim reduction (Phase 1 in Stage 6): trusted and untrusted are separate instances. + // Bytecode claim reduction (Phase 1 in Stage 6b): consumes Val_s(r_bc) from Stage 6a and + // caches an intermediate claim for Stage 7. + // + // IMPORTANT: This must be sampled *after* other Stage 6b params (e.g. lookup/inc gammas), + // to match the prover's transcript order. + if self.proof.program_mode == ProgramMode::Committed { + let bytecode_reduction_params = BytecodeClaimReductionParams::new( + &bytecode_read_raf_params, + &self.opening_accumulator, + &mut self.transcript, + ); + self.bytecode_reduction_verifier = Some(BytecodeClaimReductionVerifier::new( + bytecode_reduction_params, + )); + } else { + // Legacy mode: do not run the bytecode claim reduction. + self.bytecode_reduction_verifier = None; + } + + // Advice claim reduction (Phase 1 in Stage 6b): trusted and untrusted are separate instances. if self.trusted_advice_commitment.is_some() { self.advice_reduction_verifier_trusted = Some(AdviceClaimReductionVerifier::new( AdviceKind::Trusted, @@ -464,6 +594,40 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc )); } + // Program-image claim reduction (Stage 6b): binds staged Stage 4 scalar program-image claims + // to the trusted commitment, caching an opening of ProgramImageInit. + let program_image_reduction = if self.proof.program_mode == ProgramMode::Committed { + let trusted = self + .preprocessing + .program + .as_committed() + .expect("program commitments missing in committed mode"); + let padded_len_words = trusted.program_image_num_words; + let log_t = self.proof.trace_length.log_2(); + let m = padded_len_words.log_2(); + if m > log_t { + return Err(ProofVerifyError::InvalidBytecodeConfig(format!( + "program-image claim reduction requires m=log2(padded_len_words) <= log_T (got m={m}, log_T={log_t})" + )) + .into()); + } + let params = ProgramImageClaimReductionParams::new( + &self.program_io, + self.preprocessing.shared.min_bytecode_address(), + padded_len_words, + self.proof.ram_K, + self.proof.trace_length, + &self.proof.rw_config, + &self.opening_accumulator, + &mut self.transcript, + ); + Some(ProgramImageClaimReductionVerifier { params }) + } else { + None + }; + + let bytecode_read_raf = BytecodeReadRafCycleSumcheckVerifier::new(bytecode_read_raf_params); + let mut instances: Vec<&dyn SumcheckInstanceVerifier> = vec![ &bytecode_read_raf, &booleanity, @@ -472,20 +636,26 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc &lookups_ra_virtual, &inc_reduction, ]; + if let Some(ref bytecode) = self.bytecode_reduction_verifier { + instances.push(bytecode); + } if let Some(ref advice) = self.advice_reduction_verifier_trusted { instances.push(advice); } if let Some(ref advice) = self.advice_reduction_verifier_untrusted { instances.push(advice); } + if let Some(ref prog) = program_image_reduction { + instances.push(prog); + } - let _r_stage6 = BatchedSumcheck::verify( - &self.proof.stage6_sumcheck_proof, + let _r_stage6b = BatchedSumcheck::verify( + &self.proof.stage6b_sumcheck_proof, instances, &mut self.opening_accumulator, &mut self.transcript, ) - .context("Stage 6")?; + .context("Stage 6b")?; Ok(()) } @@ -502,6 +672,12 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc let mut instances: Vec<&dyn SumcheckInstanceVerifier> = vec![&hw_verifier]; + + if let Some(bytecode_reduction_verifier) = self.bytecode_reduction_verifier.as_mut() { + bytecode_reduction_verifier.params.borrow_mut().phase = + BytecodeReductionPhase::LaneVariables; + instances.push(bytecode_reduction_verifier); + } if let Some(advice_reduction_verifier_trusted) = self.advice_reduction_verifier_trusted.as_mut() { @@ -536,14 +712,25 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc /// Stage 8: Dory batch opening verification. fn verify_stage8(&mut self) -> Result<(), anyhow::Error> { - // Initialize DoryGlobals with the layout from the proof - // This ensures the verifier uses the same layout as the prover - let _guard = DoryGlobals::initialize_context( - 1 << self.one_hot_params.log_k_chunk, - self.proof.trace_length.next_power_of_two(), - DoryContext::Main, - Some(self.proof.dory_layout), - ); + // Initialize DoryGlobals with the layout from the proof. + // In committed mode, we must also match the Main-context sigma used to derive trusted + // bytecode commitments, otherwise Stage 8 batching will be inconsistent. + let _guard = if self.proof.program_mode == ProgramMode::Committed { + let committed = self.preprocessing.program.as_committed()?; + DoryGlobals::initialize_main_context_with_num_columns( + 1 << self.one_hot_params.log_k_chunk, + self.proof.trace_length.next_power_of_two(), + committed.bytecode_num_columns, + Some(self.proof.dory_layout), + ) + } else { + DoryGlobals::initialize_context( + 1 << self.one_hot_params.log_k_chunk, + self.proof.trace_length.next_power_of_two(), + DoryContext::Main, + Some(self.proof.dory_layout), + ) + }; // Get the unified opening point from HammingWeightClaimReduction // This contains (r_address_stage7 || r_cycle_stage6) in big-endian @@ -624,6 +811,67 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc )); } + // Bytecode chunk polynomials: committed in Bytecode context and embedded into the + // main opening point by fixing the extra cycle variables to 0. + if self.proof.program_mode == ProgramMode::Committed { + let (bytecode_point, _) = self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::BytecodeChunk(0), + SumcheckId::BytecodeClaimReduction, + ); + let log_t = opening_point.r.len() - log_k_chunk; + let log_k = bytecode_point.r.len() - log_k_chunk; + if log_k > log_t { + return Err(ProofVerifyError::InvalidBytecodeConfig(format!( + "bytecode folding requires log_T >= log_K (got log_T={log_t}, log_K={log_k})" + )) + .into()); + } + #[cfg(test)] + { + if log_k == log_t { + assert_eq!( + bytecode_point.r, opening_point.r, + "BytecodeChunk opening point must equal unified opening point when log_K == log_T" + ); + } else { + let (r_lane_main, r_cycle_main) = opening_point.split_at(log_k_chunk); + let (r_lane_bc, r_cycle_bc) = bytecode_point.split_at(log_k_chunk); + debug_assert_eq!(r_lane_main.r, r_lane_bc.r); + debug_assert_eq!(&r_cycle_main.r[(log_t - log_k)..], r_cycle_bc.r.as_slice()); + } + } + let lagrange_factor = + compute_advice_lagrange_factor::(&opening_point.r, &bytecode_point.r); + + let num_chunks = total_lanes().div_ceil(self.one_hot_params.k_chunk); + for i in 0..num_chunks { + let (_, claim) = self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::BytecodeChunk(i), + SumcheckId::BytecodeClaimReduction, + ); + polynomial_claims.push(( + CommittedPolynomial::BytecodeChunk(i), + claim * lagrange_factor, + )); + } + } + + // Program-image polynomial: opened by ProgramImageClaimReduction in Stage 6b. + // Embed into the top-left block of the main matrix (same trick as advice). + if self.proof.program_mode == ProgramMode::Committed { + let (prog_point, prog_claim) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::ProgramImageInit, + SumcheckId::ProgramImageClaimReduction, + ); + let lagrange_factor = + compute_advice_lagrange_factor::(&opening_point.r, &prog_point.r); + polynomial_claims.push(( + CommittedPolynomial::ProgramImageInit, + prog_claim * lagrange_factor, + )); + } + // 2. Sample gamma and compute powers for RLC let claims: Vec = polynomial_claims.iter().map(|(_, c)| *c).collect(); self.transcript.append_scalars(&claims); @@ -665,6 +913,27 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc } } + if self.proof.program_mode == ProgramMode::Committed { + let committed = self.preprocessing.program.as_committed()?; + for (idx, commitment) in committed.bytecode_commitments.iter().enumerate() { + commitments_map + .entry(CommittedPolynomial::BytecodeChunk(idx)) + .or_insert_with(|| commitment.clone()); + } + + // Add trusted program-image commitment if it's part of the batch. + if state + .polynomial_claims + .iter() + .any(|(p, _)| *p == CommittedPolynomial::ProgramImageInit) + { + commitments_map.insert( + CommittedPolynomial::ProgramImageInit, + committed.program_image_commitment.clone(), + ); + } + } + // Compute joint commitment: Σ γ_i · C_i let joint_commitment = self.compute_joint_commitment(&mut commitments_map, &state); @@ -675,7 +944,7 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc .map(|(gamma, claim)| *gamma * claim) .sum(); - // Verify opening + // Verify joint opening PCS::verify( &self.proof.joint_opening_proof, &self.preprocessing.generators, @@ -684,7 +953,9 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc &joint_claim, &joint_commitment, ) - .context("Stage 8") + .context("Stage 8 (joint)")?; + + Ok(()) } /// Compute joint commitment for the batch opening. @@ -712,89 +983,59 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc } } -#[derive(Debug, Clone)] +/// Shared preprocessing between prover and verifier. +/// +/// Contains O(1) metadata about the program. Does NOT contain the full program data. +/// - Full program data is in `JoltProverPreprocessing.program`. +/// - Verifier program (Full or Committed) is in `JoltVerifierPreprocessing.program`. +#[derive(Debug, Clone, CanonicalSerialize, CanonicalDeserialize)] pub struct JoltSharedPreprocessing { - pub bytecode: Arc, - pub ram: RAMPreprocessing, + /// Program metadata (bytecode size, program image info). + pub program_meta: ProgramMetadata, pub memory_layout: MemoryLayout, pub max_padded_trace_length: usize, } -impl CanonicalSerialize for JoltSharedPreprocessing { - fn serialize_with_mode( - &self, - mut writer: W, - compress: ark_serialize::Compress, - ) -> Result<(), ark_serialize::SerializationError> { - // Serialize the inner BytecodePreprocessing (not the Arc wrapper) - self.bytecode - .as_ref() - .serialize_with_mode(&mut writer, compress)?; - self.ram.serialize_with_mode(&mut writer, compress)?; - self.memory_layout - .serialize_with_mode(&mut writer, compress)?; - self.max_padded_trace_length - .serialize_with_mode(&mut writer, compress)?; - Ok(()) - } - - fn serialized_size(&self, compress: ark_serialize::Compress) -> usize { - self.bytecode.serialized_size(compress) - + self.ram.serialized_size(compress) - + self.memory_layout.serialized_size(compress) - + self.max_padded_trace_length.serialized_size(compress) - } -} - -impl CanonicalDeserialize for JoltSharedPreprocessing { - fn deserialize_with_mode( - mut reader: R, - compress: ark_serialize::Compress, - validate: ark_serialize::Validate, - ) -> Result { - let bytecode = - BytecodePreprocessing::deserialize_with_mode(&mut reader, compress, validate)?; - let ram = RAMPreprocessing::deserialize_with_mode(&mut reader, compress, validate)?; - let memory_layout = MemoryLayout::deserialize_with_mode(&mut reader, compress, validate)?; - let max_padded_trace_length = - usize::deserialize_with_mode(&mut reader, compress, validate)?; - Ok(Self { - bytecode: Arc::new(bytecode), - ram, - memory_layout, - max_padded_trace_length, - }) - } -} - -impl ark_serialize::Valid for JoltSharedPreprocessing { - fn check(&self) -> Result<(), ark_serialize::SerializationError> { - self.bytecode.check()?; - self.ram.check()?; - self.memory_layout.check() - } -} - impl JoltSharedPreprocessing { + /// Create shared preprocessing from program metadata. + /// + /// # Arguments + /// - `program_meta`: Program metadata (from `ProgramPreprocessing::meta()`) + /// - `memory_layout`: Memory layout configuration + /// - `max_padded_trace_length`: Maximum trace length for generator sizing #[tracing::instrument(skip_all, name = "JoltSharedPreprocessing::new")] pub fn new( - bytecode: Vec, + program_meta: ProgramMetadata, memory_layout: MemoryLayout, - memory_init: Vec<(u64, u8)>, max_padded_trace_length: usize, ) -> JoltSharedPreprocessing { - let bytecode = Arc::new(BytecodePreprocessing::preprocess(bytecode)); - let ram = RAMPreprocessing::preprocess(memory_init); Self { - bytecode, - ram, + program_meta, memory_layout, max_padded_trace_length, } } + + /// Bytecode size (power-of-2 padded). + /// Legacy accessor - use `program_meta.bytecode_len` directly. + pub fn bytecode_size(&self) -> usize { + self.program_meta.bytecode_len + } + + /// Minimum bytecode address. + /// Legacy accessor - use `program_meta.min_bytecode_address` directly. + pub fn min_bytecode_address(&self) -> u64 { + self.program_meta.min_bytecode_address + } + + /// Program image length (unpadded words). + /// Legacy accessor - use `program_meta.program_image_len_words` directly. + pub fn program_image_len_words(&self) -> usize { + self.program_meta.program_image_len_words + } } -#[derive(Debug, Clone, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Debug, Clone)] pub struct JoltVerifierPreprocessing where F: JoltField, @@ -802,6 +1043,69 @@ where { pub generators: PCS::VerifierSetup, pub shared: JoltSharedPreprocessing, + /// Program information for verification. + /// + /// In Full mode: contains full program preprocessing (bytecode + program image). + /// In Committed mode: contains only commitments (succinct). + pub program: VerifierProgram, +} + +impl CanonicalSerialize for JoltVerifierPreprocessing +where + F: JoltField, + PCS: CommitmentScheme, +{ + fn serialize_with_mode( + &self, + mut writer: W, + compress: ark_serialize::Compress, + ) -> Result<(), ark_serialize::SerializationError> { + self.generators.serialize_with_mode(&mut writer, compress)?; + self.shared.serialize_with_mode(&mut writer, compress)?; + self.program.serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: ark_serialize::Compress) -> usize { + self.generators.serialized_size(compress) + + self.shared.serialized_size(compress) + + self.program.serialized_size(compress) + } +} + +impl ark_serialize::Valid for JoltVerifierPreprocessing +where + F: JoltField, + PCS: CommitmentScheme, +{ + fn check(&self) -> Result<(), ark_serialize::SerializationError> { + self.generators.check()?; + self.shared.check()?; + self.program.check() + } +} + +impl CanonicalDeserialize for JoltVerifierPreprocessing +where + F: JoltField, + PCS: CommitmentScheme, +{ + fn deserialize_with_mode( + mut reader: R, + compress: ark_serialize::Compress, + validate: ark_serialize::Validate, + ) -> Result { + let generators = + PCS::VerifierSetup::deserialize_with_mode(&mut reader, compress, validate)?; + let shared = + JoltSharedPreprocessing::deserialize_with_mode(&mut reader, compress, validate)?; + let program = VerifierProgram::deserialize_with_mode(&mut reader, compress, validate)?; + Ok(Self { + generators, + shared, + program, + }) + } } impl Serializable for JoltVerifierPreprocessing @@ -835,14 +1139,39 @@ where } impl> JoltVerifierPreprocessing { - #[tracing::instrument(skip_all, name = "JoltVerifierPreprocessing::new")] - pub fn new( + /// Create verifier preprocessing in Full mode (verifier has full program). + #[tracing::instrument(skip_all, name = "JoltVerifierPreprocessing::new_full")] + pub fn new_full( + shared: JoltSharedPreprocessing, + generators: PCS::VerifierSetup, + program: Arc, + ) -> JoltVerifierPreprocessing { + Self { + generators, + shared, + program: VerifierProgram::Full(program), + } + } + + /// Create verifier preprocessing in Committed mode with trusted commitments. + /// + /// This is the "fast path" for online verification. The `TrustedProgramCommitments` + /// type guarantees (at the type level) that these commitments were derived from + /// actual program via `TrustedProgramCommitments::derive()`. + /// + /// # Trust Model + /// The caller must ensure the commitments were honestly derived (e.g., loaded from + /// a trusted file or received from trusted preprocessing). + #[tracing::instrument(skip_all, name = "JoltVerifierPreprocessing::new_committed")] + pub fn new_committed( shared: JoltSharedPreprocessing, generators: PCS::VerifierSetup, + program_commitments: TrustedProgramCommitments, ) -> JoltVerifierPreprocessing { Self { generators, - shared: shared.clone(), + shared, + program: VerifierProgram::Committed(program_commitments), } } } @@ -853,9 +1182,16 @@ impl> From<&JoltProverPreprocessi { fn from(prover_preprocessing: &JoltProverPreprocessing) -> Self { let generators = PCS::setup_verifier(&prover_preprocessing.generators); + let shared = prover_preprocessing.shared.clone(); + // Choose VerifierProgram variant based on whether prover has program commitments + let program = match &prover_preprocessing.program_commitments { + Some(commitments) => VerifierProgram::Committed(commitments.clone()), + None => VerifierProgram::Full(Arc::clone(&prover_preprocessing.program)), + }; Self { generators, - shared: prover_preprocessing.shared.clone(), + shared, + program, } } } diff --git a/jolt-core/src/zkvm/witness.rs b/jolt-core/src/zkvm/witness.rs index efcef73652..ee68f29ce9 100644 --- a/jolt-core/src/zkvm/witness.rs +++ b/jolt-core/src/zkvm/witness.rs @@ -7,7 +7,6 @@ use rayon::prelude::*; use tracer::instruction::Cycle; use crate::poly::commitment::commitment_scheme::StreamingCommitmentScheme; -use crate::zkvm::bytecode::BytecodePreprocessing; use crate::zkvm::config::OneHotParams; use crate::zkvm::instruction::InstructionFlags; use crate::zkvm::verifier::JoltSharedPreprocessing; @@ -31,6 +30,9 @@ pub enum CommittedPolynomial { InstructionRa(usize), /// One-hot ra polynomial for the bytecode instance of Shout BytecodeRa(usize), + /// Packed bytecode commitment chunk polynomial (lane chunk i). + /// This is used by BytecodeClaimReduction; commitment + batching integration is staged separately. + BytecodeChunk(usize), /// One-hot ra/wa polynomial for the RAM instance of Twist /// Note that for RAM, ra and wa are the same polynomial because /// there is at most one load or store per cycle. @@ -41,6 +43,12 @@ pub enum CommittedPolynomial { /// Untrusted advice polynomial - committed during proving, commitment in proof. /// Length cannot exceed max_trace_length. UntrustedAdvice, + /// Program image words polynomial (initial RAM image), committed in preprocessing for + /// `ProgramMode::Committed` and opened via `ProgramImageClaimReduction`. + /// + /// This polynomial is NOT streamed from the execution trace (it is provided as an "extra" + /// polynomial to the Stage 8 streaming RLC builder, similar to advice polynomials). + ProgramImageInit, } /// Returns a list of symbols representing all committed polynomials. @@ -64,6 +72,7 @@ impl CommittedPolynomial { &self, setup: &PCS::ProverSetup, preprocessing: &JoltSharedPreprocessing, + program: &crate::zkvm::program::ProgramPreprocessing, row_cycles: &[tracer::instruction::Cycle], one_hot_params: &OneHotParams, ) -> ::ChunkState @@ -108,12 +117,15 @@ impl CommittedPolynomial { let row: Vec> = row_cycles .iter() .map(|cycle| { - let pc = preprocessing.bytecode.get_pc(cycle); + let pc = program.get_pc(cycle); Some(one_hot_params.bytecode_pc_chunk(pc, *idx) as usize) }) .collect(); PCS::process_chunk_onehot(setup, one_hot_params.k_chunk, &row) } + CommittedPolynomial::BytecodeChunk(_) => { + panic!("Bytecode chunk polynomials are not stream-committed yet") + } CommittedPolynomial::RamRa(idx) => { let row: Vec> = row_cycles .iter() @@ -127,7 +139,9 @@ impl CommittedPolynomial { .collect(); PCS::process_chunk_onehot(setup, one_hot_params.k_chunk, &row) } - CommittedPolynomial::TrustedAdvice | CommittedPolynomial::UntrustedAdvice => { + CommittedPolynomial::TrustedAdvice + | CommittedPolynomial::UntrustedAdvice + | CommittedPolynomial::ProgramImageInit => { panic!("Advice polynomials should not use streaming witness generation") } } @@ -136,7 +150,7 @@ impl CommittedPolynomial { #[tracing::instrument(skip_all, name = "CommittedPolynomial::generate_witness")] pub fn generate_witness( &self, - bytecode_preprocessing: &BytecodePreprocessing, + program: &crate::zkvm::program::ProgramPreprocessing, memory_layout: &MemoryLayout, trace: &[Cycle], one_hot_params: Option<&OneHotParams>, @@ -150,7 +164,7 @@ impl CommittedPolynomial { let addresses: Vec<_> = trace .par_iter() .map(|cycle| { - let pc = bytecode_preprocessing.get_pc(cycle); + let pc = program.get_pc(cycle); Some(one_hot_params.bytecode_pc_chunk(pc, *i)) }) .collect(); @@ -159,6 +173,9 @@ impl CommittedPolynomial { one_hot_params.k_chunk, )) } + CommittedPolynomial::BytecodeChunk(_) => { + panic!("Bytecode chunk polynomials are not supported by generate_witness yet") + } CommittedPolynomial::RamRa(i) => { let one_hot_params = one_hot_params.unwrap(); let addresses: Vec<_> = trace @@ -212,7 +229,9 @@ impl CommittedPolynomial { one_hot_params.k_chunk, )) } - CommittedPolynomial::TrustedAdvice | CommittedPolynomial::UntrustedAdvice => { + CommittedPolynomial::TrustedAdvice + | CommittedPolynomial::UntrustedAdvice + | CommittedPolynomial::ProgramImageInit => { panic!("Advice polynomials should not use generate_witness") } } @@ -271,4 +290,10 @@ pub enum VirtualPolynomial { OpFlags(CircuitFlags), InstructionFlags(InstructionFlags), LookupTableFlag(usize), + BytecodeValStage(usize), + BytecodeReadRafAddrClaim, + BooleanityAddrClaim, + BytecodeClaimReductionIntermediate, + ProgramImageInitContributionRw, + ProgramImageInitContributionRaf, } diff --git a/jolt-inlines/bigint/src/multiplication/mod.rs b/jolt-inlines/bigint/src/multiplication/mod.rs index ec327f0fad..3aac420c7b 100644 --- a/jolt-inlines/bigint/src/multiplication/mod.rs +++ b/jolt-inlines/bigint/src/multiplication/mod.rs @@ -10,7 +10,6 @@ const OUTPUT_LIMBS: usize = 2 * INPUT_LIMBS; pub mod sdk; pub use sdk::*; -#[cfg(feature = "host")] pub mod exec; #[cfg(feature = "host")] pub mod sequence_builder; diff --git a/jolt-inlines/bigint/src/multiplication/sdk.rs b/jolt-inlines/bigint/src/multiplication/sdk.rs index f927a4fb27..687735524e 100644 --- a/jolt-inlines/bigint/src/multiplication/sdk.rs +++ b/jolt-inlines/bigint/src/multiplication/sdk.rs @@ -4,6 +4,18 @@ use super::{INPUT_LIMBS, OUTPUT_LIMBS}; +#[cfg(all( + not(feature = "host"), + any(target_arch = "riscv32", target_arch = "riscv64") +))] +use super::{BIGINT256_MUL_FUNCT3, BIGINT256_MUL_FUNCT7, INLINE_OPCODE}; + +#[cfg(any( + feature = "host", + not(any(target_arch = "riscv32", target_arch = "riscv64")) +))] +use crate::multiplication::exec; + /// Performs 256-bit × 256-bit multiplication /// /// # Arguments @@ -33,9 +45,11 @@ pub fn bigint256_mul(lhs: [u64; INPUT_LIMBS], rhs: [u64; INPUT_LIMBS]) -> [u64; /// - `a` and `b` must point to at least 32 bytes of readable memory /// - `result` must point to at least 64 bytes of writable memory /// - The memory regions may overlap (result can be the same as a or b) -#[cfg(not(feature = "host"))] +#[cfg(all( + not(feature = "host"), + any(target_arch = "riscv32", target_arch = "riscv64") +))] pub unsafe fn bigint256_mul_inline(a: *const u64, b: *const u64, result: *mut u64) { - use super::{BIGINT256_MUL_FUNCT3, BIGINT256_MUL_FUNCT7, INLINE_OPCODE}; core::arch::asm!( ".insn r {opcode}, {funct3}, {funct7}, {rd}, {rs1}, {rs2}", opcode = const INLINE_OPCODE, @@ -59,10 +73,11 @@ pub unsafe fn bigint256_mul_inline(a: *const u64, b: *const u64, result: *mut u6 /// - All pointers must be valid and properly aligned for u64 access (8-byte alignment) /// - `a` and `b` must point to at least 32 bytes of readable memory /// - `result` must point to at least 64 bytes of writable memory -#[cfg(feature = "host")] +#[cfg(any( + feature = "host", + not(any(target_arch = "riscv32", target_arch = "riscv64")) +))] pub unsafe fn bigint256_mul_inline(a: *const u64, b: *const u64, result: *mut u64) { - use crate::multiplication::exec; - let a_array = *(a as *const [u64; INPUT_LIMBS]); let b_array = *(b as *const [u64; INPUT_LIMBS]); let result_array = exec::bigint_mul(a_array, b_array); diff --git a/jolt-sdk/macros/src/lib.rs b/jolt-sdk/macros/src/lib.rs index 58ab22c7ec..c5e3d189c9 100644 --- a/jolt-sdk/macros/src/lib.rs +++ b/jolt-sdk/macros/src/lib.rs @@ -66,16 +66,18 @@ impl MacroBuilder { fn build(&mut self) -> TokenStream { let memory_config_fn = self.make_memory_config_fn(); let build_prover_fn = self.make_build_prover_fn(); + let build_prover_committed_fn = self.make_build_prover_committed_fn(); let build_verifier_fn = self.make_build_verifier_fn(); let analyze_fn = self.make_analyze_function(); let trace_to_file_fn = self.make_trace_to_file_func(); let compile_fn = self.make_compile_func(); + let preprocess_fn = self.make_preprocess_func(); + let preprocess_committed_fn = self.make_preprocess_committed_func(); let preprocess_shared_fn = self.make_preprocess_shared_func(); - let preprocess_prover_fn = self.make_preprocess_prover_func(); - let preprocess_verifier_fn = self.make_preprocess_verifier_func(); let verifier_preprocess_from_prover_fn = self.make_preprocess_from_prover_func(); let commit_trusted_advice_fn = self.make_commit_trusted_advice_func(); let prove_fn = self.make_prove_func(); + let prove_committed_fn = self.make_prove_committed_func(); let attributes = parse_attributes(&self.attr); let mut execute_fn = quote! {}; @@ -96,17 +98,19 @@ impl MacroBuilder { quote! { #memory_config_fn #build_prover_fn + #build_prover_committed_fn #build_verifier_fn #execute_fn #analyze_fn #trace_to_file_fn #compile_fn + #preprocess_fn + #preprocess_committed_fn #preprocess_shared_fn - #preprocess_prover_fn - #preprocess_verifier_fn #verifier_preprocess_from_prover_fn #commit_trusted_advice_fn #prove_fn + #prove_committed_fn #main_fn } .into() @@ -192,8 +196,71 @@ impl MacroBuilder { ) -> #return_type { #imports - let program = std::sync::Arc::new(program); - let preprocessing = std::sync::Arc::new(preprocessing); + let program = Arc::new(program); + let preprocessing = Arc::new(preprocessing); + + let prove_closure = move |#inputs #commitment_param_in_closure| { + let program = (*program).clone(); + let preprocessing = (*preprocessing).clone(); + #prove_fn_name(program, preprocessing, #(#all_names),* #commitment_arg_in_call) + }; + + prove_closure + } + } + } + + fn make_build_prover_committed_fn(&self) -> TokenStream2 { + let fn_name = self.get_func_name(); + let build_prover_fn_name = + Ident::new(&format!("build_prover_committed_{fn_name}"), fn_name.span()); + let prove_output_ty = self.get_prove_output_type(); + + // Include public, trusted_advice, and untrusted_advice arguments for the prover + let ordered_func_args = self.get_all_func_args_in_order(); + let all_names: Vec<_> = ordered_func_args.iter().map(|(name, _)| name).collect(); + let all_types: Vec<_> = ordered_func_args.iter().map(|(_, ty)| ty).collect(); + + let inputs_vec: Vec<_> = self.func.sig.inputs.iter().collect(); + let inputs = quote! { #(#inputs_vec),* }; + let prove_fn_name = Ident::new(&format!("prove_committed_{fn_name}"), fn_name.span()); + let imports = self.make_imports(); + + let has_trusted_advice = !self.trusted_func_args.is_empty(); + + let commitment_param_in_closure = if has_trusted_advice { + quote! { , trusted_advice_commitment: Option<::Commitment>, + trusted_advice_hint: Option<::OpeningProofHint> } + } else { + quote! {} + }; + + let commitment_arg_in_call = if has_trusted_advice { + quote! { , trusted_advice_commitment, trusted_advice_hint } + } else { + quote! {} + }; + + let return_type = if has_trusted_advice { + quote! { + impl Fn(#(#all_types),*, Option<::Commitment>, Option<::OpeningProofHint>) -> #prove_output_ty + Sync + Send + } + } else { + quote! { + impl Fn(#(#all_types),*) -> #prove_output_ty + Sync + Send + } + }; + + quote! { + #[cfg(all(not(target_arch = "wasm32"), not(feature = "guest")))] + pub fn #build_prover_fn_name( + program: jolt::host::Program, + preprocessing: jolt::JoltProverPreprocessing, + ) -> #return_type + { + #imports + let program = Arc::new(program); + let preprocessing = Arc::new(preprocessing); let prove_closure = move |#inputs #commitment_param_in_closure| { let program = (*program).clone(); @@ -253,7 +320,7 @@ impl MacroBuilder { ) -> impl Fn(#(#input_types ,)* #output_type, bool, #commitment_param_in_signature jolt::RV64IMACProof) -> bool + Sync + Send { #imports - let preprocessing = std::sync::Arc::new(preprocessing); + let preprocessing = Arc::new(preprocessing); let verify_closure = move |#(#public_inputs,)* output, panic, #commitment_param_in_closure proof: jolt::RV64IMACProof| { let preprocessing = (*preprocessing).clone(); @@ -379,7 +446,7 @@ impl MacroBuilder { #imports let mut program = Program::new(#guest_name); - let path = std::path::PathBuf::from(target_dir); + let path = PathBuf::from(target_dir); program.set_func(#fn_name_str); #set_std #set_mem_size @@ -427,7 +494,7 @@ impl MacroBuilder { } } - fn make_preprocess_shared_func(&self) -> TokenStream2 { + fn make_preprocess_func(&self) -> TokenStream2 { let attributes = parse_attributes(&self.attr); let max_trace_length = proc_macro2::Literal::u64_unsuffixed(attributes.max_trace_length); let max_input_size = proc_macro2::Literal::u64_unsuffixed(attributes.max_input_size); @@ -441,16 +508,15 @@ impl MacroBuilder { let imports = self.make_imports(); let fn_name = self.get_func_name(); - let preprocess_shared_fn_name = - Ident::new(&format!("preprocess_shared_{fn_name}"), fn_name.span()); + let preprocess_fn_name = Ident::new(&format!("preprocess_{fn_name}"), fn_name.span()); quote! { #[cfg(all(not(target_arch = "wasm32"), not(feature = "guest")))] - pub fn #preprocess_shared_fn_name(program: &mut jolt::host::Program) - -> jolt::JoltSharedPreprocessing + pub fn #preprocess_fn_name(program: &mut jolt::host::Program) + -> jolt::JoltProverPreprocessing { #imports - let (bytecode, memory_init, program_size) = program.decode(); + let (instructions, memory_init, program_size) = program.decode(); let memory_config = MemoryConfig { max_input_size: #max_input_size, max_output_size: #max_output_size, @@ -462,55 +528,103 @@ impl MacroBuilder { }; let memory_layout = MemoryLayout::new(&memory_config); - let preprocessing = JoltSharedPreprocessing::new( - bytecode, + let program_data = Arc::new(ProgramPreprocessing::preprocess(instructions, memory_init)); + let shared = JoltSharedPreprocessing::new( + program_data.meta(), memory_layout, - memory_init, #max_trace_length, ); - - preprocessing + JoltProverPreprocessing::new(shared, program_data) } } } - fn make_preprocess_prover_func(&self) -> TokenStream2 { + fn make_preprocess_committed_func(&self) -> TokenStream2 { + let attributes = parse_attributes(&self.attr); + let max_trace_length = proc_macro2::Literal::u64_unsuffixed(attributes.max_trace_length); + let max_input_size = proc_macro2::Literal::u64_unsuffixed(attributes.max_input_size); + let max_output_size = proc_macro2::Literal::u64_unsuffixed(attributes.max_output_size); + let max_untrusted_advice_size = + proc_macro2::Literal::u64_unsuffixed(attributes.max_untrusted_advice_size); + let max_trusted_advice_size = + proc_macro2::Literal::u64_unsuffixed(attributes.max_trusted_advice_size); + let stack_size = proc_macro2::Literal::u64_unsuffixed(attributes.stack_size); + let memory_size = proc_macro2::Literal::u64_unsuffixed(attributes.memory_size); let imports = self.make_imports(); let fn_name = self.get_func_name(); - let preprocess_prover_fn_name = - Ident::new(&format!("preprocess_prover_{fn_name}"), fn_name.span()); + let preprocess_fn_name = + Ident::new(&format!("preprocess_committed_{fn_name}"), fn_name.span()); quote! { #[cfg(all(not(target_arch = "wasm32"), not(feature = "guest")))] - pub fn #preprocess_prover_fn_name(shared_preprocessing: jolt::JoltSharedPreprocessing) + pub fn #preprocess_fn_name(program: &mut jolt::host::Program) -> jolt::JoltProverPreprocessing { #imports - let prover_preprocessing = JoltProverPreprocessing::new( - shared_preprocessing, - ); - prover_preprocessing + let (instructions, memory_init, program_size) = program.decode(); + let memory_config = MemoryConfig { + max_input_size: #max_input_size, + max_output_size: #max_output_size, + max_untrusted_advice_size: #max_untrusted_advice_size, + max_trusted_advice_size: #max_trusted_advice_size, + stack_size: #stack_size, + memory_size: #memory_size, + program_size: Some(program_size), + }; + let memory_layout = MemoryLayout::new(&memory_config); + + let program_data = Arc::new(ProgramPreprocessing::preprocess(instructions, memory_init)); + let shared = JoltSharedPreprocessing::new( + program_data.meta(), + memory_layout, + #max_trace_length, + ); + JoltProverPreprocessing::new_committed(shared, program_data) } } } - fn make_preprocess_verifier_func(&self) -> TokenStream2 { + fn make_preprocess_shared_func(&self) -> TokenStream2 { + let attributes = parse_attributes(&self.attr); + let max_trace_length = proc_macro2::Literal::u64_unsuffixed(attributes.max_trace_length); + let max_input_size = proc_macro2::Literal::u64_unsuffixed(attributes.max_input_size); + let max_output_size = proc_macro2::Literal::u64_unsuffixed(attributes.max_output_size); + let max_untrusted_advice_size = + proc_macro2::Literal::u64_unsuffixed(attributes.max_untrusted_advice_size); + let max_trusted_advice_size = + proc_macro2::Literal::u64_unsuffixed(attributes.max_trusted_advice_size); + let stack_size = proc_macro2::Literal::u64_unsuffixed(attributes.stack_size); + let memory_size = proc_macro2::Literal::u64_unsuffixed(attributes.memory_size); let imports = self.make_imports(); let fn_name = self.get_func_name(); - let preprocess_verifier_fn_name = - Ident::new(&format!("preprocess_verifier_{fn_name}"), fn_name.span()); + let preprocess_shared_fn_name = + Ident::new(&format!("preprocess_shared_{fn_name}"), fn_name.span()); quote! { #[cfg(all(not(target_arch = "wasm32"), not(feature = "guest")))] - pub fn #preprocess_verifier_fn_name( - shared_preprocess: jolt::JoltSharedPreprocessing, - generators: ::VerifierSetup, - ) -> jolt::JoltVerifierPreprocessing + pub fn #preprocess_shared_fn_name(program: &mut jolt::host::Program) + -> (jolt::JoltSharedPreprocessing, std::sync::Arc) { #imports - let preprocessing = JoltVerifierPreprocessing::new(shared_preprocess, generators); - preprocessing + let (instructions, memory_init, program_size) = program.decode(); + let memory_config = MemoryConfig { + max_input_size: #max_input_size, + max_output_size: #max_output_size, + max_untrusted_advice_size: #max_untrusted_advice_size, + max_trusted_advice_size: #max_trusted_advice_size, + stack_size: #stack_size, + memory_size: #memory_size, + program_size: Some(program_size), + }; + let memory_layout = MemoryLayout::new(&memory_config); + let program_data = Arc::new(ProgramPreprocessing::preprocess(instructions, memory_init)); + let shared = JoltSharedPreprocessing::new( + program_data.meta(), + memory_layout, + #max_trace_length, + ); + (shared, program_data) } } } @@ -687,12 +801,110 @@ impl MacroBuilder { let elf_contents_opt = program.get_elf_contents(); let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); - let prover = RV64IMACProver::gen_from_elf(&preprocessing, + let prover = RV64IMACProver::gen_from_elf_with_program_mode(&preprocessing, &elf_contents, &input_bytes, &untrusted_advice_bytes, &trusted_advice_bytes, #commitment_arg, + jolt::ProgramMode::Full, + ); + let io_device = prover.program_io.clone(); + let (jolt_proof, _) = prover.prove(); + + #handle_return + + (ret_val, jolt_proof, io_device) + } + } + } + + fn make_prove_committed_func(&self) -> TokenStream2 { + let prove_output_ty = self.get_prove_output_type(); + + let handle_return = match &self.func.sig.output { + ReturnType::Default => quote! { + let ret_val = (); + }, + ReturnType::Type(_, ty) => quote! { + let mut outputs = io_device.outputs.clone(); + outputs.resize(preprocessing.shared.memory_layout.max_output_size as usize, 0); + let ret_val = jolt::postcard::from_bytes::<#ty>(&outputs).unwrap(); + }, + }; + + let set_program_args = self.pub_func_args.iter().map(|(name, _)| { + quote! { + input_bytes.append(&mut jolt::postcard::to_stdvec(&#name).unwrap()) + } + }); + let set_program_untrusted_advice_args = self.untrusted_func_args.iter().map(|(name, _)| { + quote! { + untrusted_advice_bytes.append(&mut jolt::postcard::to_stdvec(&#name).unwrap()) + } + }); + let set_program_trusted_advice_args = self.trusted_func_args.iter().map(|(name, _)| { + quote! { + trusted_advice_bytes.append(&mut jolt::postcard::to_stdvec(&#name).unwrap()) + } + }); + + let fn_name = self.get_func_name(); + let inputs_vec: Vec<_> = self.func.sig.inputs.iter().collect(); + let inputs = quote! { #(#inputs_vec),* }; + let imports = self.make_imports(); + + let prove_fn_name = syn::Ident::new(&format!("prove_committed_{fn_name}"), fn_name.span()); + + let has_trusted_advice = !self.trusted_func_args.is_empty(); + + let commitment_param = if has_trusted_advice { + quote! { , trusted_advice_commitment: Option<::Commitment>, + trusted_advice_hint: Option<::OpeningProofHint> } + } else { + quote! {} + }; + + let commitment_arg = if has_trusted_advice { + quote! { trusted_advice_commitment, trusted_advice_hint } + } else { + quote! { None, None } + }; + + quote! { + #[cfg(all(not(target_arch = "wasm32"), not(feature = "guest")))] + #[allow(clippy::too_many_arguments)] + pub fn #prove_fn_name( + mut program: jolt::host::Program, + preprocessing: jolt::JoltProverPreprocessing, + #inputs + #commitment_param + ) -> #prove_output_ty { + #imports + + if !preprocessing.is_committed_mode() { + panic!( + "Committed bytecode proving requires committed preprocessing. \ + Use `preprocess_committed_*` / `JoltProverPreprocessing::new_committed`." + ); + } + + let mut input_bytes = vec![]; + #(#set_program_args;)* + let mut untrusted_advice_bytes = vec![]; + #(#set_program_untrusted_advice_args;)* + let mut trusted_advice_bytes = vec![]; + #(#set_program_trusted_advice_args;)* + + let elf_contents_opt = program.get_elf_contents(); + let elf_contents = elf_contents_opt.as_deref().expect("elf contents is None"); + let prover = RV64IMACProver::gen_from_elf_with_program_mode(&preprocessing, + &elf_contents, + &input_bytes, + &untrusted_advice_bytes, + &trusted_advice_bytes, + #commitment_arg, + jolt::ProgramMode::Committed, ); let io_device = prover.program_io.clone(); let (jolt_proof, _) = prover.prove(); @@ -890,6 +1102,8 @@ impl MacroBuilder { RV64IMACVerifier, RV64IMACProof, host::Program, + host::analyze::ProgramSummary, + ProgramPreprocessing, JoltProverPreprocessing, MemoryConfig, MemoryLayout, @@ -899,6 +1113,10 @@ impl MacroBuilder { JoltVerifierPreprocessing, JoltSharedPreprocessing }; + #[cfg(not(feature = "guest"))] + use std::sync::Arc; + #[cfg(not(feature = "guest"))] + use std::path::PathBuf; } } diff --git a/jolt-sdk/src/host_utils.rs b/jolt-sdk/src/host_utils.rs index af6c8192a6..6ced8c1afb 100644 --- a/jolt-sdk/src/host_utils.rs +++ b/jolt-sdk/src/host_utils.rs @@ -10,6 +10,8 @@ pub use jolt_core::ark_bn254::Fr as F; pub use jolt_core::field::JoltField; pub use jolt_core::guest; pub use jolt_core::poly::commitment::dory::DoryCommitmentScheme as PCS; +pub use jolt_core::zkvm::config::ProgramMode; +pub use jolt_core::zkvm::program::ProgramPreprocessing; pub use jolt_core::zkvm::{ proof_serialization::JoltProof, verifier::JoltSharedPreprocessing, verifier::JoltVerifierPreprocessing, RV64IMACProof, RV64IMACVerifier, Serializable, diff --git a/src/main.rs b/src/main.rs index 771806164e..84f4aded53 100644 --- a/src/main.rs +++ b/src/main.rs @@ -222,12 +222,9 @@ pub fn main() { let target_dir = "/tmp/jolt-guest-targets"; let mut program = guest::compile_fib(target_dir); - let shared_preprocessing = guest::preprocess_shared_fib(&mut program); - - let prover_preprocessing = guest::preprocess_prover_fib(shared_preprocessing.clone()); - let verifier_setup = prover_preprocessing.generators.to_verifier_setup(); + let prover_preprocessing = guest::preprocess_fib(&mut program); let verifier_preprocessing = - guest::preprocess_verifier_fib(shared_preprocessing, verifier_setup); + guest::verifier_preprocessing_from_prover_fib(&prover_preprocessing); let prove_fib = guest::build_prover_fib(program, prover_preprocessing); let verify_fib = guest::build_verifier_fib(verifier_preprocessing);