From 9c5fef1e364cd09d58eeabfc9f58ed9aa51d7e6e Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Mon, 2 Mar 2026 13:31:34 -0800 Subject: [PATCH 1/3] Move multiply product check into Spartan outer sumcheck with optimized extended-integer arithmetic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove the `Product` virtual polynomial and `RightLookupEqProductIfMul` R1CS constraint from the product-virtualization stage. Instead, prove the multiply instruction identity `is_mul * (rl - li*ri) = 0` directly in the Spartan outer sumcheck as an additive degree-3 (inside eq) term. Key changes: - Outer sumcheck now proves: eq(τ,x) * [Az·Bz + is_mul*(rl - li*ri)] = 0 - First product round uses fused materialization + eval: computes t(2) and t(∞) from trace using S64/S128/S160/S192 extended-integer arithmetic and FMAdd/FMAdd deferred-reduction accumulation, eliminating all field conversions (F::from_i128, F::from_u128) from the inner loop - Subsequent rounds use compact MultilinearPolynomial storage (bool/u64/u128/ i128) that binds to field elements on challenge ingestion - Rebalance R1CS groups: move RamAddrEqZeroIfNotLoadStore to second group - Merge old Stage 3 into Stage 2; reorder instances (RAM first, Spartan second) - Add gruen_poly_deg_4 for degree-4 round polynomials - Add FMAdd for Acc7S and FMAdd for S192Sum Made-with: Cursor --- jolt-core/src/poly/split_eq_poly.rs | 59 +++ .../src/subprotocols/streaming_sumcheck.rs | 8 +- jolt-core/src/utils/accumulation.rs | 22 + .../src/zkvm/bytecode/read_raf_checking.rs | 131 +++--- jolt-core/src/zkvm/claim_reductions/advice.rs | 16 +- .../zkvm/claim_reductions/hamming_weight.rs | 26 +- .../src/zkvm/claim_reductions/increments.rs | 60 +-- .../claim_reductions/instruction_lookups.rs | 144 +------ jolt-core/src/zkvm/proof_serialization.rs | 122 +++--- jolt-core/src/zkvm/prover.rs | 244 +++++------ jolt-core/src/zkvm/r1cs/constraints.rs | 27 +- jolt-core/src/zkvm/r1cs/evaluation.rs | 332 ++++++--------- jolt-core/src/zkvm/r1cs/inputs.rs | 132 ++---- jolt-core/src/zkvm/ram/ra_virtual.rs | 6 +- .../src/zkvm/spartan/instruction_input.rs | 85 ++-- jolt-core/src/zkvm/spartan/outer.rs | 397 ++++++++++++++++-- jolt-core/src/zkvm/spartan/product.rs | 88 ++-- jolt-core/src/zkvm/verifier.rs | 122 +++--- jolt-core/src/zkvm/witness.rs | 1 - 19 files changed, 1031 insertions(+), 991 deletions(-) diff --git a/jolt-core/src/poly/split_eq_poly.rs b/jolt-core/src/poly/split_eq_poly.rs index fb3d22af71..ddea34ca25 100644 --- a/jolt-core/src/poly/split_eq_poly.rs +++ b/jolt-core/src/poly/split_eq_poly.rs @@ -410,6 +410,65 @@ impl GruenSplitEqPolynomial { ]) } + /// Compute the quartic polynomial s(X) = l(X) * q(X), where l(X) is the + /// current (linear) eq polynomial and q(X) = c + dX + eX^2 + fX^3, given: + /// - c, the constant term of q (i.e. q(0)) + /// - q(2), the evaluation of q at 2 + /// - f, the leading (cubic) coefficient of q + /// - the previous round claim, s(0) + s(1) + pub fn gruen_poly_deg_4( + &self, + q_constant: F, + q_at_2: F, + q_cubic_coeff: F, + s_0_plus_s_1: F, + ) -> UniPoly { + // l(X) = a + bX (linear eq polynomial) + // q(X) = c + dX + eX^2 + fX^3 (cubic inner polynomial) + // s(X) = l(X)*q(X) is degree 4, needs evaluations at {0, 1, 2, 3, 4}. + + let eq_eval_1 = self.current_scalar + * match self.binding_order { + BindingOrder::LowToHigh => self.w[self.current_index - 1], + BindingOrder::HighToLow => self.w[self.current_index], + }; + let eq_eval_0 = self.current_scalar - eq_eval_1; + let eq_m = eq_eval_1 - eq_eval_0; + let eq_eval_2 = eq_eval_1 + eq_m; + let eq_eval_3 = eq_eval_2 + eq_m; + let eq_eval_4 = eq_eval_3 + eq_m; + + let c = q_constant; + let f = q_cubic_coeff; + + // Recover q(1) from the sumcheck identity: s(0) + s(1) = claim + let quartic_eval_0 = eq_eval_0 * c; + let quartic_eval_1 = s_0_plus_s_1 - quartic_eval_0; + let q_1 = quartic_eval_1 / eq_eval_1; + + // q(0) = c, q(1) = c+d+e+f, q(2) = c+2d+4e+8f + // Forward differences: Δ0 = q(1)-q(0) = d+e+f, Δ1 = q(2)-q(1) + // Second differences: Δ²0 = Δ1-Δ0 = 2e+6f + // Third differences: Δ³ = 6f (constant for a cubic) + let delta_0 = q_1 - c; + let delta_1 = q_at_2 - q_1; + let delta2_0 = delta_1 - delta_0; + let f6 = f + f + f + f + f + f; // 6f + // q(3) = q(2) + Δ1 + Δ²0 + Δ³ = q(2) + (Δ1 + Δ²0 + 6f) + let delta_2 = delta_1 + delta2_0 + f6; + let q_3 = q_at_2 + delta_2; + // q(4) = q(3) + Δ2 + Δ²0 + 2*Δ³ = q(3) + (delta_2 + delta2_0 + 6f + 6f) + let q_4 = q_3 + delta_2 + delta2_0 + f6 + f6; + + UniPoly::from_evals(&[ + quartic_eval_0, + quartic_eval_1, + eq_eval_2 * q_at_2, + eq_eval_3 * q_3, + eq_eval_4 * q_4, + ]) + } + /// Compute the quadratic polynomial s(X) = l(X) * q(X), where l(X) is the /// current (linear) Dao-Thaler eq polynomial and q(X) = c + dx /// - c, the constant term of q diff --git a/jolt-core/src/subprotocols/streaming_sumcheck.rs b/jolt-core/src/subprotocols/streaming_sumcheck.rs index 660520def8..927bf84cef 100644 --- a/jolt-core/src/subprotocols/streaming_sumcheck.rs +++ b/jolt-core/src/subprotocols/streaming_sumcheck.rs @@ -15,7 +15,7 @@ pub trait StreamingSumcheckWindow: Sized + MaybeAllocative + Send fn initialize(shared: &mut Self::Shared, window_size: usize) -> Self; fn compute_message( - &self, + &mut self, shared: &Self::Shared, window_size: usize, previous_claim: F, @@ -37,7 +37,7 @@ pub trait LinearSumcheckStage: Sized + MaybeAllocative + Send + Sy fn next_window(&mut self, shared: &mut Self::Shared, window_size: usize); fn compute_message( - &self, + &mut self, shared: &Self::Shared, window_size: usize, previous_claim: F, @@ -155,9 +155,9 @@ where } } - if let Some(streaming) = &mut self.streaming { + if let Some(streaming) = self.streaming.as_mut() { streaming.compute_message(&self.shared, num_unbound_vars, previous_claim) - } else if let Some(linear) = &mut self.linear { + } else if let Some(linear) = self.linear.as_mut() { linear.compute_message(&self.shared, num_unbound_vars, previous_claim) } else { unreachable!() diff --git a/jolt-core/src/utils/accumulation.rs b/jolt-core/src/utils/accumulation.rs index 97cddbcc73..0b333a5be5 100644 --- a/jolt-core/src/utils/accumulation.rs +++ b/jolt-core/src/utils/accumulation.rs @@ -569,6 +569,16 @@ impl FMAdd for Acc7S { } } +impl FMAdd for Acc7S { + #[inline(always)] + fn fmadd(&mut self, field: &F, other: &u64) { + if *other == 0 { + return; + } + self.pos += (*field).mul_u64_unreduced(*other); + } +} + impl FMAdd for Acc7S { #[inline(always)] fn fmadd(&mut self, field: &F, other: &S64) { @@ -913,6 +923,18 @@ pub struct S192Sum { } // Accumulate c (i32) * term (S64) into an S192 running sum +impl FMAdd for S192Sum { + #[inline(always)] + fn fmadd(&mut self, c: &i32, term: &u64) { + if *term == 0 { + return; + } + let c_s64 = S64::from(*c as i64); + let v_s64 = S64::from(*term); + self.sum += c_s64.mul_trunc::<1, 3>(&v_s64); + } +} + impl FMAdd for S192Sum { #[inline(always)] fn fmadd(&mut self, c: &i32, term: &S64) { diff --git a/jolt-core/src/zkvm/bytecode/read_raf_checking.rs b/jolt-core/src/zkvm/bytecode/read_raf_checking.rs index 5acb686b0b..4e3523543e 100644 --- a/jolt-core/src/zkvm/bytecode/read_raf_checking.rs +++ b/jolt-core/src/zkvm/bytecode/read_raf_checking.rs @@ -51,7 +51,8 @@ const N_STAGES: usize = 5; /// Bytecode instruction: multi-stage Read + RAF sumcheck (N_STAGES = 5). /// /// Stages virtualize different claim families (Stage1: Spartan outer; Stage2: product-virtualized -/// flags; Stage3: Shift; Stage4: Registers RW; Stage5: Registers val-eval + Instruction lookups). +/// flags + instruction input flags; Stage3: Shift flags only; Stage4: Registers RW; +/// Stage5: Registers val-eval + Instruction lookups). /// /// The input claim is a γ-weighted RLC of stage rv_claims plus RAF contributions folded into /// stages 1 and 3 via the identity polynomial. Address vars are bound in `d` chunks; cycle vars @@ -72,11 +73,11 @@ const N_STAGES: usize = 5; /// - Int(k) = 1 for all k (evaluation of the IdentityPolynomial over address variables). /// - Define per-stage Val_s(k) (address-only) as implemented by `compute_val_*`: /// * Stage1: Val_1(k) = unexpanded_pc(k) + β_1·imm(k) + Σ_t β_1^{2+t}·circuit_flag_t(k). -/// * Stage2: Val_2(k) = 1_{jump}(k) + β_2·1_{branch}(k) + β_2^2·rd_addr(k) + β_2^3·1_{write_lookup_to_rd}(k) -/// + β_2^4·1_{VirtualInstruction}(k). -/// * Stage3: Val_3(k) = imm(k) + β_3·unexpanded_pc(k) + β_3^2·1_{L_is_rs1}(k) + β_3^3·1_{L_is_pc}(k) -/// + β_3^4·1_{R_is_rs2}(k) + β_3^5·1_{R_is_imm}(k) + β_3^6·1_{IsNoop}(k) -/// + β_3^7·1_{VirtualInstruction}(k) + β_3^8·1_{IsFirstInSequence}(k). +/// * Stage2: Val_2(k) = 1_{jump}(k) + β_2·1_{branch}(k) + β_2^2·1_{rd≠0}(k) + β_2^3·1_{write_lookup_to_rd}(k) +/// + β_2^4·1_{VirtualInstruction}(k) + β_2^5·imm(k) + β_2^6·unexpanded_pc(k) +/// + β_2^7·1_{L_is_rs1}(k) + β_2^8·1_{L_is_pc}(k) + β_2^9·1_{R_is_rs2}(k) + β_2^10·1_{R_is_imm}(k). +/// * Stage3: Val_3(k) = unexpanded_pc(k) + β_3·1_{IsNoop}(k) + β_3^2·1_{VirtualInstruction}(k) +/// + β_3^3·1_{IsFirstInSequence}(k). /// * Stage4: Val_4(k) = 1_{rd=r}(k) + β_4·1_{rs1=r}(k) + β_4^2·1_{rs2=r}(k), where r is fixed by opening. /// * Stage5: Val_5(k) = 1_{rd=r}(k) + β_5·1_{¬interleaved}(k) + Σ_i β_5^{2+i}·1_{table=i}(k). /// @@ -737,8 +738,8 @@ impl BytecodeReadRafSumcheckParams { // Generate all stage-specific gamma powers upfront (order must match verifier) let stage1_gammas: Vec = transcript.challenge_scalar_powers(2 + NUM_CIRCUIT_FLAGS); - let stage2_gammas: Vec = transcript.challenge_scalar_powers(5); - let stage3_gammas: Vec = transcript.challenge_scalar_powers(9); + let stage2_gammas: Vec = transcript.challenge_scalar_powers(11); + let stage3_gammas: Vec = transcript.challenge_scalar_powers(4); let stage4_gammas: Vec = transcript.challenge_scalar_powers(3); let stage5_gammas: Vec = transcript.challenge_scalar_powers(2 + NUM_LOOKUP_TABLES); @@ -903,15 +904,7 @@ impl BytecodeReadRafSumcheckParams { *o0 = lc; } - // Stage 2 (product virtualization, de-duplicated factors) - // Val(k) = jump_flag(k) + γ·branch_flag(k) - // + γ²·is_rd_not_zero_flag(k) + γ³·write_lookup_output_to_rd_flag(k) - // where jump_flag(k) = 1 if instruction k is a jump, 0 otherwise; - // branch_flag(k) = 1 if instruction k is a branch, 0 otherwise; - // is_rd_not_zero_flag(k) = 1 if instruction k has rd != 0; - // write_lookup_output_to_rd_flag(k) = 1 if instruction k writes lookup output to rd. - // virtual_instruction(k) = 1 if instruction k is a virtual instruction. - // This Val matches the fused product sumcheck. + // Stage 2 (product virtualization + instruction input flags) { let mut lc = F::zero(); if circuit_flags[CircuitFlags::Jump] { @@ -929,38 +922,34 @@ impl BytecodeReadRafSumcheckParams { if circuit_flags[CircuitFlags::VirtualInstruction] { lc += stage2_gammas[4]; } - *o1 = lc; - } - - // Stage 3 (Shift sumcheck) - // Val(k) = imm(k) + γ·unexpanded_pc(k) - // + γ²·left_operand_is_rs1_value(k) + γ³·left_operand_is_pc(k) - // + γ⁴·right_operand_is_rs2_value(k) + γ⁵·right_operand_is_imm(k) - // + γ⁶·is_noop(k) + γ⁷·virtual_instruction(k) + γ⁸·is_first_in_sequence(k) - // This virtualizes claims output by the ShiftSumcheck. - { - let mut lc = F::from_i128(instr.operands.imm); - lc += stage3_gammas[1].mul_u64(instr.address as u64); + lc += instr.operands.imm.field_mul(stage2_gammas[5]); + lc += stage2_gammas[6].mul_u64(instr.address as u64); if instr_flags[InstructionFlags::LeftOperandIsRs1Value] { - lc += stage3_gammas[2]; + lc += stage2_gammas[7]; } if instr_flags[InstructionFlags::LeftOperandIsPC] { - lc += stage3_gammas[3]; + lc += stage2_gammas[8]; } if instr_flags[InstructionFlags::RightOperandIsRs2Value] { - lc += stage3_gammas[4]; + lc += stage2_gammas[9]; } if instr_flags[InstructionFlags::RightOperandIsImm] { - lc += stage3_gammas[5]; + lc += stage2_gammas[10]; } + *o1 = lc; + } + + // Stage 3 (Shift flags only — InstructionInput flags moved to Stage 2) + { + let mut lc = F::from_u64(instr.address as u64); if instr_flags[InstructionFlags::IsNoop] { - lc += stage3_gammas[6]; + lc += stage3_gammas[1]; } if circuit_flags[CircuitFlags::VirtualInstruction] { - lc += stage3_gammas[7]; + lc += stage3_gammas[2]; } if circuit_flags[CircuitFlags::IsFirstInSequence] { - lc += stage3_gammas[8]; + lc += stage3_gammas[3]; } *o2 = lc; } @@ -1069,45 +1058,14 @@ impl BytecodeReadRafSumcheckParams { VirtualPolynomial::OpFlags(CircuitFlags::VirtualInstruction), SumcheckId::SpartanProductVirtualization, ); - - [ - jump_claim, - branch_claim, - rd_wa_claim, - write_lookup_output_to_rd_flag_claim, - virtual_instruction_claim, - ] - .into_iter() - .zip_eq(gamma_powers) - .map(|(claim, gamma)| claim * gamma) - .sum() - } - - fn compute_rv_claim_3( - opening_accumulator: &dyn OpeningAccumulator, - gamma_powers: &[F], - ) -> F { let (_, imm_claim) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::Imm, SumcheckId::InstructionInputVirtualization, ); - let (_, spartan_shift_unexpanded_pc_claim) = opening_accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::UnexpandedPC, - SumcheckId::SpartanShift, - ); - let (_, instruction_input_unexpanded_pc_claim) = opening_accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::UnexpandedPC, - SumcheckId::InstructionInputVirtualization, - ); - - assert_eq!( - spartan_shift_unexpanded_pc_claim, - instruction_input_unexpanded_pc_claim + let (_, unexpanded_pc_claim) = opening_accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::UnexpandedPC, + SumcheckId::InstructionInputVirtualization, ); - - let unexpanded_pc_claim = spartan_shift_unexpanded_pc_claim; let (_, left_is_rs1_claim) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::InstructionFlags(InstructionFlags::LeftOperandIsRs1Value), SumcheckId::InstructionInputVirtualization, @@ -1124,6 +1082,34 @@ impl BytecodeReadRafSumcheckParams { VirtualPolynomial::InstructionFlags(InstructionFlags::RightOperandIsImm), SumcheckId::InstructionInputVirtualization, ); + + [ + jump_claim, + branch_claim, + rd_wa_claim, + write_lookup_output_to_rd_flag_claim, + virtual_instruction_claim, + imm_claim, + unexpanded_pc_claim, + left_is_rs1_claim, + left_is_pc_claim, + right_is_rs2_claim, + right_is_imm_claim, + ] + .into_iter() + .zip_eq(gamma_powers) + .map(|(claim, gamma)| claim * gamma) + .sum() + } + + fn compute_rv_claim_3( + opening_accumulator: &dyn OpeningAccumulator, + gamma_powers: &[F], + ) -> F { + let (_, unexpanded_pc_claim) = opening_accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::UnexpandedPC, + SumcheckId::SpartanShift, + ); let (_, is_noop_claim) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::InstructionFlags(InstructionFlags::IsNoop), SumcheckId::SpartanShift, @@ -1138,12 +1124,7 @@ impl BytecodeReadRafSumcheckParams { ); [ - imm_claim, unexpanded_pc_claim, - left_is_rs1_claim, - left_is_pc_claim, - right_is_rs2_claim, - right_is_imm_claim, is_noop_claim, is_virtual_claim, is_first_in_sequence_claim, diff --git a/jolt-core/src/zkvm/claim_reductions/advice.rs b/jolt-core/src/zkvm/claim_reductions/advice.rs index 09c1caa0e1..76553fceeb 100644 --- a/jolt-core/src/zkvm/claim_reductions/advice.rs +++ b/jolt-core/src/zkvm/claim_reductions/advice.rs @@ -1,8 +1,8 @@ -//! Two-phase advice claim reduction (Stage 6 cycle → Stage 7 address) +//! Two-phase advice claim reduction (Stage 5 cycle → Stage 6 address) //! //! This module generalizes the previous single-phase `AdviceClaimReduction` so that trusted and //! untrusted advice can be committed as an arbitrary Dory matrix `2^{nu_a} x 2^{sigma_a}` (balanced -//! by default), while still keeping a **single Stage 8 Dory opening** at the unified Dory point. +//! by default), while still keeping a **single Stage 7 Dory opening** at the unified Dory point. //! //! For an advice matrix embedded as the **top-left block** `2^{nu_a} x 2^{sigma_a}`, the *native* //! advice evaluation point (in Dory order, LSB-first) is: @@ -10,14 +10,14 @@ //! - `advice_rows = row_coords[0..nu_a]` //! - `advice_point = [advice_cols || advice_rows]` //! -//! In our current pipeline, `cycle` coordinates come from Stage 6 and `addr` coordinates come from -//! Stage 7. -//! - **Phase 1 (Stage 6)**: bind the cycle-derived advice coordinates and output an intermediate +//! In our current pipeline, `cycle` coordinates come from Stage 5 and `addr` coordinates come from +//! Stage 6. +//! - **Phase 1 (Stage 5)**: bind the cycle-derived advice coordinates and output an intermediate //! scalar claim `C_mid`. -//! - **Phase 2 (Stage 7)**: resume from `C_mid`, bind the address-derived advice coordinates, and -//! cache the final advice opening `AdviceMLE(advice_point)` for batching into Stage 8. +//! - **Phase 2 (Stage 6)**: resume from `C_mid`, bind the address-derived advice coordinates, and +//! cache the final advice opening `AdviceMLE(advice_point)` for batching into Stage 7. //! -//! ## Dummy-gap scaling (within Stage 6) +//! ## Dummy-gap scaling (within Stage 5) //! With cycle-major order, there may be a gap during the cycle phase where the cycle variables //! being bound in the batched sumcheck do not appear in the advice polynommial. //! diff --git a/jolt-core/src/zkvm/claim_reductions/hamming_weight.rs b/jolt-core/src/zkvm/claim_reductions/hamming_weight.rs index d40860f35a..bfce46998b 100644 --- a/jolt-core/src/zkvm/claim_reductions/hamming_weight.rs +++ b/jolt-core/src/zkvm/claim_reductions/hamming_weight.rs @@ -11,17 +11,17 @@ //! //! ## Background //! -//! After Stage 6, each ra_i one-hot polynomial has TWO claims at different address points -//! but the SAME cycle point (r_cycle_stage6): +//! After Stage 5, each ra_i one-hot polynomial has TWO claims at different address points +//! but the SAME cycle point (r_cycle_stage5): //! -//! 1. **Booleanity claim**: `ra_i(r_addr_bool, r_cycle_stage6)` -//! - From `BooleanitySumcheck` in Stage 6 +//! 1. **Booleanity claim**: `ra_i(r_addr_bool, r_cycle_stage5)` +//! - From `BooleanitySumcheck` in Stage 5 //! - r_addr_bool is shared across all ra_i and across families (instruction/bytecode/ram) //! -//! 2. **Virtualization claim**: `ra_i(r_addr_virt_i, r_cycle_stage6)` -//! - For BytecodeRa: from `BytecodeReadRaf` in Stage 6 -//! - For InstructionRa: from `InstructionRaVirtualization` in Stage 6 -//! - For RamRa: from `RamRaVirtualization` in Stage 6 +//! 2. **Virtualization claim**: `ra_i(r_addr_virt_i, r_cycle_stage5)` +//! - For BytecodeRa: from `BytecodeReadRaf` in Stage 5 +//! - For InstructionRa: from `InstructionRaVirtualization` in Stage 5 +//! - For RamRa: from `RamRaVirtualization` in Stage 5 //! - r_addr_virt_i is DIFFERENT per ra_i (each chunk has its own r_address) //! //! The HammingWeight sumcheck would normally run separately, producing its own @@ -66,9 +66,9 @@ //! //! ## After This Sumcheck //! -//! Let ρ be the challenges from this sumcheck (r_address_stage7). Each ra_i has a SINGLE opening: +//! Let ρ be the challenges from this sumcheck (r_address_stage6). Each ra_i has a SINGLE opening: //! -//! `ra_i(ρ, r_cycle_stage6)` +//! `ra_i(ρ, r_cycle_stage5)` //! //! The verifier computes expected claims using the single opening G_i(ρ): //! - HammingWeight: G_i(ρ) @@ -115,7 +115,7 @@ const DEGREE_BOUND: usize = 2; /// - Booleanity: proves Σ_k eq(r_addr_bool, k)·G_i(k) = claim_bool_i /// - Virtualization: proves Σ_k eq(r_addr_virt_i, k)·G_i(k) = claim_virt_i /// -/// After this sumcheck, each ra_i has a single opening at (ρ, r_cycle_stage6). +/// After this sumcheck, each ra_i has a single opening at (ρ, r_cycle_stage5). #[derive(Allocative, Clone)] pub struct HammingWeightClaimReductionParams { /// γ^0, γ^1, ..., γ^{3N-1} for batching (3 claims per ra polynomial) @@ -141,7 +141,7 @@ pub struct HammingWeightClaimReductionParams { } impl HammingWeightClaimReductionParams { - /// Create params by fetching claims from Stage 6 and sampling batching challenge. + /// Create params by fetching claims from Stage 5 and sampling batching challenge. /// /// Fetches: /// - HammingWeight claims (from HammingBooleanity virtual polynomial) @@ -201,7 +201,7 @@ impl HammingWeightClaimReductionParams { let mut claims_bool = Vec::with_capacity(N); let mut claims_virt = Vec::with_capacity(N); - // RAM HammingWeight factor: now in Stage 6, so shares r_cycle_stage6 + // RAM HammingWeight factor: now in Stage 5, so shares r_cycle_stage5 let ram_hw_factor = accumulator .get_virtual_polynomial_opening( VirtualPolynomial::RamHammingWeight, diff --git a/jolt-core/src/zkvm/claim_reductions/increments.rs b/jolt-core/src/zkvm/claim_reductions/increments.rs index 1f3d0ec950..a8bb476c43 100644 --- a/jolt-core/src/zkvm/claim_reductions/increments.rs +++ b/jolt-core/src/zkvm/claim_reductions/increments.rs @@ -16,8 +16,8 @@ //! So effectively RamInc has **2 distinct opening points**. //! //! 2. **RdInc**: Claims are emitted from: -//! - `RegistersReadWriteChecking` (Stage 4): opened at `s_cycle_stage4` -//! - `RegistersValEvaluation` (Stage 5): opened at `s_cycle_stage5` +//! - `RegistersReadWriteChecking` (Stage 3): opened at `s_cycle_stage3` +//! - `RegistersValEvaluation` (Stage 4): opened at `s_cycle_stage4` //! //! So RdInc has **2 distinct opening points**. //! @@ -26,20 +26,20 @@ //! Let: //! - v_1 = RamInc(r_cycle_stage2) from RamReadWriteChecking //! - v_2 = RamInc(r_cycle_stage4) from RamValCheck -//! - w_1 = RdInc(s_cycle_stage4) from RegistersReadWriteChecking -//! - w_2 = RdInc(s_cycle_stage5) from RegistersValEvaluation +//! - w_1 = RdInc(s_cycle_stage3) from RegistersReadWriteChecking +//! - w_2 = RdInc(s_cycle_stage4) from RegistersValEvaluation //! //! Input claim: //! v_1 + γ·v_2 + γ²·w_1 + γ³·w_2 //! //! Sumcheck proves (over log T rounds): //! Σ_j RamInc(j) · [eq(r_cycle_stage2, j) + γ·eq(r_cycle_stage4, j)] -//! + γ² · Σ_j RdInc(j) · [eq(s_cycle_stage4, j) + γ·eq(s_cycle_stage5, j)] +//! + γ² · Σ_j RdInc(j) · [eq(s_cycle_stage3, j) + γ·eq(s_cycle_stage4, j)] //! = v_1 + γ·v_2 + γ²·w_1 + γ³·w_2 //! //! After log T rounds with sumcheck challenges ρ, the final claim is: //! RamInc(ρ) · [eq(r_cycle_stage2, ρ) + γ·eq(r_cycle_stage4, ρ)] -//! + γ² · RdInc(ρ) · [eq(s_cycle_stage4, ρ) + γ·eq(s_cycle_stage5, ρ)] +//! + γ² · RdInc(ρ) · [eq(s_cycle_stage3, ρ) + γ·eq(s_cycle_stage4, ρ)] //! //! The verifier computes the eq terms and recovers two openings at the SAME point ρ: //! - RamInc(ρ) @@ -78,8 +78,8 @@ pub struct IncClaimReductionSumcheckParams { pub n_cycle_vars: usize, pub r_cycle_stage2: OpeningPoint, // RamInc from RamReadWriteChecking pub r_cycle_stage4: OpeningPoint, // RamInc from RamValCheck - pub s_cycle_stage4: OpeningPoint, // RdInc from RegistersReadWriteChecking - pub s_cycle_stage5: OpeningPoint, // RdInc from RegistersValEvaluation + pub s_cycle_stage3: OpeningPoint, // RdInc from RegistersReadWriteChecking + pub s_cycle_stage4: OpeningPoint, // RdInc from RegistersValEvaluation } impl IncClaimReductionSumcheckParams { @@ -100,11 +100,11 @@ impl IncClaimReductionSumcheckParams { let (r_cycle_stage4, _) = accumulator .get_committed_polynomial_opening(CommittedPolynomial::RamInc, SumcheckId::RamValCheck); - let (s_cycle_stage4, _) = accumulator.get_committed_polynomial_opening( + let (s_cycle_stage3, _) = accumulator.get_committed_polynomial_opening( CommittedPolynomial::RdInc, SumcheckId::RegistersReadWriteChecking, ); - let (s_cycle_stage5, _) = accumulator.get_committed_polynomial_opening( + let (s_cycle_stage4, _) = accumulator.get_committed_polynomial_opening( CommittedPolynomial::RdInc, SumcheckId::RegistersValEvaluation, ); @@ -114,8 +114,8 @@ impl IncClaimReductionSumcheckParams { n_cycle_vars: trace_len.log_2(), r_cycle_stage2, r_cycle_stage4, + s_cycle_stage3, s_cycle_stage4, - s_cycle_stage5, } } } @@ -266,8 +266,8 @@ struct IncClaimReductionPhase1State { // P_ram[0] = eq(r_cycle_stage2_lo, ·) // P_ram[1] = eq(r_cycle_stage4_lo, ·) P_ram: [MultilinearPolynomial; 2], - // P_rd[0] = eq(s_cycle_stage4_lo, ·) - // P_rd[1] = eq(s_cycle_stage5_lo, ·) + // P_rd[0] = eq(s_cycle_stage3_lo, ·) + // P_rd[1] = eq(s_cycle_stage4_lo, ·) P_rd: [MultilinearPolynomial; 2], // Q buffers: suffix-weighted polynomial evaluations @@ -294,20 +294,20 @@ impl IncClaimReductionPhase1State { // Big-endian: hi is first half, lo is second half let (r2_hi, r2_lo) = params.r_cycle_stage2.split_at(suffix_n_vars); let (r4_hi, r4_lo) = params.r_cycle_stage4.split_at(suffix_n_vars); + let (s3_hi, s3_lo) = params.s_cycle_stage3.split_at(suffix_n_vars); let (s4_hi, s4_lo) = params.s_cycle_stage4.split_at(suffix_n_vars); - let (s5_hi, s5_lo) = params.s_cycle_stage5.split_at(suffix_n_vars); // P buffers: prefix eq evaluations let P_ram_0 = EqPolynomial::evals(&r2_lo.r); let P_ram_1 = EqPolynomial::evals(&r4_lo.r); - let P_rd_0 = EqPolynomial::evals(&s4_lo.r); - let P_rd_1 = EqPolynomial::evals(&s5_lo.r); + let P_rd_0 = EqPolynomial::evals(&s3_lo.r); + let P_rd_1 = EqPolynomial::evals(&s4_lo.r); // Suffix eq evaluations (for computing Q) let eq_r2_hi = EqPolynomial::evals(&r2_hi.r); let eq_r4_hi = EqPolynomial::evals(&r4_hi.r); + let eq_s3_hi = EqPolynomial::evals(&s3_hi.r); let eq_s4_hi = EqPolynomial::evals(&s4_hi.r); - let eq_s5_hi = EqPolynomial::evals(&s5_hi.r); // Q buffers: sum over suffix indices let mut Q_ram_0 = unsafe_allocate_zero_vec(prefix_len); @@ -353,8 +353,8 @@ impl IncClaimReductionPhase1State { acc_ram_0.fmadd(&eq_r2_hi[x_hi], &ram_inc); acc_ram_1.fmadd(&eq_r4_hi[x_hi], &ram_inc); - acc_rd_0.fmadd(&eq_s4_hi[x_hi], &rd_inc); - acc_rd_1.fmadd(&eq_s5_hi[x_hi], &rd_inc); + acc_rd_0.fmadd(&eq_s3_hi[x_hi], &rd_inc); + acc_rd_1.fmadd(&eq_s4_hi[x_hi], &rd_inc); } q_ram_0[i] = acc_ram_0.barrett_reduce(); @@ -455,7 +455,7 @@ struct IncClaimReductionPhase2State { rd_inc: MultilinearPolynomial, // Combined eq polynomials eq_ram: MultilinearPolynomial, // eq(r_stage2, ·) + γ·eq(r_stage4, ·) - eq_rd: MultilinearPolynomial, // eq(s_stage4, ·) + γ·eq(s_stage5, ·) + eq_rd: MultilinearPolynomial, // eq(s_stage3, ·) + γ·eq(s_stage4, ·) } impl IncClaimReductionPhase2State { @@ -477,21 +477,21 @@ impl IncClaimReductionPhase2State { // Compute eq evaluations for prefix bound let (_, r2_lo) = params.r_cycle_stage2.split_at(n_vars - prefix_n_vars); let (_, r4_lo) = params.r_cycle_stage4.split_at(n_vars - prefix_n_vars); + let (_, s3_lo) = params.s_cycle_stage3.split_at(n_vars - prefix_n_vars); let (_, s4_lo) = params.s_cycle_stage4.split_at(n_vars - prefix_n_vars); - let (_, s5_lo) = params.s_cycle_stage5.split_at(n_vars - prefix_n_vars); let eq_r2_prefix = EqPolynomial::mle_endian(&r_prefix, &r2_lo); let eq_r4_prefix = EqPolynomial::mle_endian(&r_prefix, &r4_lo); + let eq_s3_prefix = EqPolynomial::mle_endian(&r_prefix, &s3_lo); let eq_s4_prefix = EqPolynomial::mle_endian(&r_prefix, &s4_lo); - let eq_s5_prefix = EqPolynomial::mle_endian(&r_prefix, &s5_lo); // Suffix eq evaluations scaled by prefix contributions let (r2_hi, _) = params.r_cycle_stage2.split_at(n_vars - prefix_n_vars); let (r4_hi, _) = params.r_cycle_stage4.split_at(n_vars - prefix_n_vars); + let (s3_hi, _) = params.s_cycle_stage3.split_at(n_vars - prefix_n_vars); let (s4_hi, _) = params.s_cycle_stage4.split_at(n_vars - prefix_n_vars); - let (s5_hi, _) = params.s_cycle_stage5.split_at(n_vars - prefix_n_vars); - // Combined eq polynomials: eq_ram = eq_r2 + γ·eq_r4, eq_rd = eq_s4 + γ·eq_s5 + // Combined eq polynomials: eq_ram = eq_r2 + γ·eq_r4, eq_rd = eq_s3 + γ·eq_s4 let (eq_ram, eq_rd) = rayon::join( || { let (eq_r2, eq_r4) = rayon::join( @@ -505,13 +505,13 @@ impl IncClaimReductionPhase2State { .collect::>() }, || { - let (eq_s4, eq_s5) = rayon::join( + let (eq_s3, eq_s4) = rayon::join( + || EqPolynomial::evals_serial(&s3_hi.r, Some(eq_s3_prefix)), || EqPolynomial::evals_serial(&s4_hi.r, Some(eq_s4_prefix)), - || EqPolynomial::evals_serial(&s5_hi.r, Some(eq_s5_prefix)), ); - eq_s4 + eq_s3 .par_iter() - .zip(eq_s5.par_iter()) + .zip(eq_s4.par_iter()) .map(|(e4, e5)| *e4 + gamma * e5) .collect::>() }, @@ -655,11 +655,11 @@ impl SumcheckInstanceVerifier // Compute eq evaluations at final point let eq_r2 = EqPolynomial::mle(&opening_point.r, &self.params.r_cycle_stage2.r); let eq_r4 = EqPolynomial::mle(&opening_point.r, &self.params.r_cycle_stage4.r); + let eq_s3 = EqPolynomial::mle(&opening_point.r, &self.params.s_cycle_stage3.r); let eq_s4 = EqPolynomial::mle(&opening_point.r, &self.params.s_cycle_stage4.r); - let eq_s5 = EqPolynomial::mle(&opening_point.r, &self.params.s_cycle_stage5.r); let eq_ram_combined = eq_r2 + gamma * eq_r4; - let eq_rd_combined = eq_s4 + gamma * eq_s5; + let eq_rd_combined = eq_s3 + gamma * eq_s4; // Fetch final claims from accumulator let (_, ram_inc_claim) = accumulator.get_committed_polynomial_opening( diff --git a/jolt-core/src/zkvm/claim_reductions/instruction_lookups.rs b/jolt-core/src/zkvm/claim_reductions/instruction_lookups.rs index d1b6f6f154..df04832d7e 100644 --- a/jolt-core/src/zkvm/claim_reductions/instruction_lookups.rs +++ b/jolt-core/src/zkvm/claim_reductions/instruction_lookups.rs @@ -1,10 +1,6 @@ use std::array; use std::sync::Arc; -use allocative::Allocative; -use ark_std::Zero; -use common::constants::XLEN; - use crate::field::JoltField; use crate::poly::eq_poly::EqPolynomial; use crate::poly::multilinear_polynomial::PolynomialBinding; @@ -21,6 +17,9 @@ use crate::utils::math::Math; use crate::utils::thread::unsafe_allocate_zero_vec; use crate::zkvm::instruction::LookupQuery; use crate::zkvm::witness::VirtualPolynomial; +use allocative::Allocative; +use ark_std::Zero; +use common::constants::XLEN; use rayon::prelude::*; use tracer::instruction::Cycle; @@ -31,8 +30,6 @@ const DEGREE_BOUND: usize = 2; pub struct InstructionLookupsClaimReductionSumcheckParams { pub gamma: F, pub gamma_sqr: F, - pub gamma_cub: F, - pub gamma_quart: F, pub n_cycle_vars: usize, pub r_spartan: OpeningPoint, } @@ -45,8 +42,6 @@ impl InstructionLookupsClaimReductionSumcheckParams { ) -> Self { let gamma = transcript.challenge_scalar::(); let gamma_sqr = gamma.square(); - let gamma_cub = gamma_sqr * gamma; - let gamma_quart = gamma_sqr.square(); let (r_spartan, _) = accumulator.get_virtual_polynomial_opening( VirtualPolynomial::LookupOutput, SumcheckId::SpartanOuter, @@ -54,8 +49,6 @@ impl InstructionLookupsClaimReductionSumcheckParams { Self { gamma, gamma_sqr, - gamma_cub, - gamma_quart, n_cycle_vars: trace_len.log_2(), r_spartan, } @@ -76,20 +69,8 @@ impl SumcheckInstanceParams for InstructionLookupsClaimReductio VirtualPolynomial::RightLookupOperand, SumcheckId::SpartanOuter, ); - let (_, left_instruction_input_claim) = accumulator.get_virtual_polynomial_opening( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::SpartanOuter, - ); - let (_, right_instruction_input_claim) = accumulator.get_virtual_polynomial_opening( - VirtualPolynomial::RightInstructionInput, - SumcheckId::SpartanOuter, - ); - lookup_output_claim - + self.gamma * left_operand_claim - + self.gamma_sqr * right_operand_claim - + self.gamma_cub * left_instruction_input_claim - + self.gamma_quart * right_instruction_input_claim + lookup_output_claim + self.gamma * left_operand_claim + self.gamma_sqr * right_operand_claim } fn degree(&self) -> usize { @@ -197,9 +178,6 @@ impl SumcheckInstanceProver let lookup_output_claim = state.lookup_output_poly.final_sumcheck_claim(); let left_lookup_operand_claim = state.left_lookup_operand_poly.final_sumcheck_claim(); let right_lookup_operand_claim = state.right_lookup_operand_poly.final_sumcheck_claim(); - let left_instruction_input_claim = state.left_instruction_input_poly.final_sumcheck_claim(); - let right_instruction_input_claim = - state.right_instruction_input_poly.final_sumcheck_claim(); accumulator.append_virtual( transcript, @@ -219,22 +197,8 @@ impl SumcheckInstanceProver transcript, VirtualPolynomial::RightLookupOperand, SumcheckId::InstructionClaimReduction, - opening_point.clone(), - right_lookup_operand_claim, - ); - accumulator.append_virtual( - transcript, - VirtualPolynomial::LeftInstructionInput, - SumcheckId::InstructionClaimReduction, - opening_point.clone(), - left_instruction_input_claim, - ); - accumulator.append_virtual( - transcript, - VirtualPolynomial::RightInstructionInput, - SumcheckId::InstructionClaimReduction, opening_point, - right_instruction_input_claim, + right_lookup_operand_claim, ); } @@ -273,8 +237,6 @@ impl InstructionLookupsPhase1State { let gamma = params.gamma; let gamma_sqr = params.gamma_sqr; - let gamma_cub = params.gamma_cub; - let gamma_quart = params.gamma_quart; const BLOCK_SIZE: usize = 32; Q.par_chunks_mut(BLOCK_SIZE) @@ -283,17 +245,12 @@ impl InstructionLookupsPhase1State { let mut q_lookup_output = [F::Unreduced::<6>::zero(); BLOCK_SIZE]; let mut q_left_lookup_operand = [F::Unreduced::<6>::zero(); BLOCK_SIZE]; let mut q_right_lookup_operand = [F::Unreduced::<7>::zero(); BLOCK_SIZE]; - let mut q_left_instruction_input = [F::Unreduced::<6>::zero(); BLOCK_SIZE]; - let mut q_right_instruction_input_pos = [F::Unreduced::<6>::zero(); BLOCK_SIZE]; - let mut q_right_instruction_input_neg = [F::Unreduced::<6>::zero(); BLOCK_SIZE]; for x_hi in 0..(1 << suffix_n_vars) { for i in 0..q_chunk.len() { let x_lo = chunk_i * BLOCK_SIZE + i; let x = x_lo + (x_hi << prefix_n_vars); let cycle = &trace[x]; - let (left_instruction_input, right_instruction_input) = - LookupQuery::::to_instruction_inputs(cycle); let (left_lookup, right_lookup) = LookupQuery::::to_lookup_operands(cycle); let lookup_output = LookupQuery::::to_lookup_output(cycle); @@ -304,29 +261,13 @@ impl InstructionLookupsPhase1State { eq_suffix_evals[x_hi].mul_u64_unreduced(left_lookup); q_right_lookup_operand[i] += eq_suffix_evals[x_hi].mul_u128_unreduced(right_lookup); - - q_left_instruction_input[i] += - eq_suffix_evals[x_hi].mul_u64_unreduced(left_instruction_input); - - let abs: u64 = right_instruction_input.unsigned_abs() as u64; - let term = eq_suffix_evals[x_hi].mul_u64_unreduced(abs); - if right_instruction_input >= 0 { - q_right_instruction_input_pos[i] += term; - } else { - q_right_instruction_input_neg[i] += term; - } } } for (i, q) in q_chunk.iter_mut().enumerate() { - let right_instruction_input = - F::from_barrett_reduce(q_right_instruction_input_pos[i]) - - F::from_barrett_reduce(q_right_instruction_input_neg[i]); *q = F::from_barrett_reduce(q_lookup_output[i]) + gamma * F::from_barrett_reduce(q_left_lookup_operand[i]) + gamma_sqr * F::from_barrett_reduce(q_right_lookup_operand[i]); - *q += gamma_cub * F::from_barrett_reduce(q_left_instruction_input[i]) - + gamma_quart * right_instruction_input; } }); @@ -372,8 +313,6 @@ struct InstructionLookupsPhase2State { lookup_output_poly: MultilinearPolynomial, left_lookup_operand_poly: MultilinearPolynomial, right_lookup_operand_poly: MultilinearPolynomial, - left_instruction_input_poly: MultilinearPolynomial, - right_instruction_input_poly: MultilinearPolynomial, eq_poly: MultilinearPolynomial, } @@ -391,14 +330,10 @@ impl InstructionLookupsPhase2State { let mut lookup_output_poly = unsafe_allocate_zero_vec(1 << n_remaining_rounds); let mut left_lookup_operand_poly = unsafe_allocate_zero_vec(1 << n_remaining_rounds); let mut right_lookup_operand_poly = unsafe_allocate_zero_vec(1 << n_remaining_rounds); - let mut left_instruction_input_poly = unsafe_allocate_zero_vec(1 << n_remaining_rounds); - let mut right_instruction_input_poly = unsafe_allocate_zero_vec(1 << n_remaining_rounds); ( &mut lookup_output_poly, &mut left_lookup_operand_poly, &mut right_lookup_operand_poly, - &mut left_instruction_input_poly, - &mut right_instruction_input_poly, trace.par_chunks(eq_evals.len()), ) .into_par_iter() @@ -407,20 +342,13 @@ impl InstructionLookupsPhase2State { lookup_output_eval, left_lookup_operand_eval, right_lookup_operand_eval, - left_instruction_input_eval, - right_instruction_input_eval, trace_chunk, )| { let mut lookup_output_eval_unreduced = F::Unreduced::<6>::zero(); let mut left_lookup_operand_eval_unreduced = F::Unreduced::<6>::zero(); let mut right_lookup_operand_eval_unreduced = F::Unreduced::<7>::zero(); - let mut left_instruction_input_eval_unreduced = F::Unreduced::<6>::zero(); - let mut right_instruction_input_pos_unreduced = F::Unreduced::<6>::zero(); - let mut right_instruction_input_neg_unreduced = F::Unreduced::<6>::zero(); for (i, cycle) in trace_chunk.iter().enumerate() { - let (left_instruction_input, right_instruction_input) = - LookupQuery::::to_instruction_inputs(cycle); let (left_lookup, right_lookup) = LookupQuery::::to_lookup_operands(cycle); let lookup_output = LookupQuery::::to_lookup_output(cycle); @@ -431,17 +359,6 @@ impl InstructionLookupsPhase2State { eq_evals[i].mul_u64_unreduced(left_lookup); right_lookup_operand_eval_unreduced += eq_evals[i].mul_u128_unreduced(right_lookup); - - left_instruction_input_eval_unreduced += - eq_evals[i].mul_u64_unreduced(left_instruction_input); - - let abs: u64 = right_instruction_input.unsigned_abs() as u64; - let term = eq_evals[i].mul_u64_unreduced(abs); - if right_instruction_input >= 0 { - right_instruction_input_pos_unreduced += term; - } else { - right_instruction_input_neg_unreduced += term; - } } *lookup_output_eval = F::from_barrett_reduce(lookup_output_eval_unreduced); @@ -449,11 +366,6 @@ impl InstructionLookupsPhase2State { F::from_barrett_reduce(left_lookup_operand_eval_unreduced); *right_lookup_operand_eval = F::from_barrett_reduce(right_lookup_operand_eval_unreduced); - *left_instruction_input_eval = - F::from_barrett_reduce(left_instruction_input_eval_unreduced); - *right_instruction_input_eval = - F::from_barrett_reduce(right_instruction_input_pos_unreduced) - - F::from_barrett_reduce(right_instruction_input_neg_unreduced); }, ); @@ -465,8 +377,6 @@ impl InstructionLookupsPhase2State { lookup_output_poly: lookup_output_poly.into(), left_lookup_operand_poly: left_lookup_operand_poly.into(), right_lookup_operand_poly: right_lookup_operand_poly.into(), - left_instruction_input_poly: left_instruction_input_poly.into(), - right_instruction_input_poly: right_instruction_input_poly.into(), eq_poly: eq_suffix_evals.into(), } } @@ -488,12 +398,6 @@ impl InstructionLookupsPhase2State { let right_lookup_operand_evals = self .right_lookup_operand_poly .sumcheck_evals_array::(j, BindingOrder::LowToHigh); - let left_instruction_input_evals = self - .left_instruction_input_poly - .sumcheck_evals_array::(j, BindingOrder::LowToHigh); - let right_instruction_input_evals = self - .right_instruction_input_poly - .sumcheck_evals_array::(j, BindingOrder::LowToHigh); let eq_evals = self .eq_poly .sumcheck_evals_array::(j, BindingOrder::LowToHigh); @@ -502,9 +406,7 @@ impl InstructionLookupsPhase2State { + eq_evals[i] * (lookup_output_evals[i] + params.gamma * left_lookup_operand_evals[i] - + params.gamma_sqr * right_lookup_operand_evals[i] - + params.gamma_cub * left_instruction_input_evals[i] - + params.gamma_quart * right_instruction_input_evals[i]) + + params.gamma_sqr * right_lookup_operand_evals[i]) }); } UniPoly::from_evals_and_hint(previous_claim, &evals) @@ -517,10 +419,6 @@ impl InstructionLookupsPhase2State { .bind_parallel(r_j, BindingOrder::LowToHigh); self.right_lookup_operand_poly .bind_parallel(r_j, BindingOrder::LowToHigh); - self.left_instruction_input_poly - .bind_parallel(r_j, BindingOrder::LowToHigh); - self.right_instruction_input_poly - .bind_parallel(r_j, BindingOrder::LowToHigh); self.eq_poly.bind_parallel(r_j, BindingOrder::LowToHigh); } } @@ -532,16 +430,12 @@ impl InstructionLookupsPhase2State { /// LookupOutput(j) /// + gamma * LeftLookupOperand(j) /// + gamma^2 * RightLookupOperand(j) -/// + gamma^3 * LeftInstructionInput(j) -/// + gamma^4 * RightInstructionInput(j) /// ) /// ``` /// /// where `r_spartan` is the randomness from the log(T) rounds of Spartan outer sumcheck (stage 1). /// -/// The purpose of this sumcheck is to aggregate instruction lookup claims into a single claim. It runs in -/// parallel with the Spartan product sumcheck. This optimization eliminates the need for a separate opening -/// of [`VirtualPolynomial::LookupOutput`] at `r_spartan`, leaving only the opening at `r_product` required. +/// Aggregates the three instruction lookup claims into a single claim. pub struct InstructionLookupsClaimReductionSumcheckVerifier { params: InstructionLookupsClaimReductionSumcheckParams, } @@ -590,21 +484,11 @@ impl SumcheckInstanceVerifier VirtualPolynomial::RightLookupOperand, SumcheckId::InstructionClaimReduction, ); - let (_, left_instruction_input_claim) = accumulator.get_virtual_polynomial_opening( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::InstructionClaimReduction, - ); - let (_, right_instruction_input_claim) = accumulator.get_virtual_polynomial_opening( - VirtualPolynomial::RightInstructionInput, - SumcheckId::InstructionClaimReduction, - ); EqPolynomial::mle(&opening_point.r, &r_spartan.r) * (lookup_output_claim + self.params.gamma * left_lookup_operand_claim - + self.params.gamma_sqr * right_lookup_operand_claim - + self.params.gamma_cub * left_instruction_input_claim - + self.params.gamma_quart * right_instruction_input_claim) + + self.params.gamma_sqr * right_lookup_operand_claim) } fn cache_openings( @@ -632,18 +516,6 @@ impl SumcheckInstanceVerifier transcript, VirtualPolynomial::RightLookupOperand, SumcheckId::InstructionClaimReduction, - opening_point.clone(), - ); - accumulator.append_virtual( - transcript, - VirtualPolynomial::LeftInstructionInput, - SumcheckId::InstructionClaimReduction, - opening_point.clone(), - ); - accumulator.append_virtual( - transcript, - VirtualPolynomial::RightInstructionInput, - SumcheckId::InstructionClaimReduction, opening_point, ); } diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index 174a256135..1f098d0bd3 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -1,5 +1,6 @@ use std::{ collections::BTreeMap, + fs::File, io::{Read, Write}, }; @@ -39,7 +40,6 @@ pub struct JoltProof, FS: Transcr pub stage4_sumcheck_proof: SumcheckInstanceProof, pub stage5_sumcheck_proof: SumcheckInstanceProof, pub stage6_sumcheck_proof: SumcheckInstanceProof, - pub stage7_sumcheck_proof: SumcheckInstanceProof, pub joint_opening_proof: PCS::Proof, pub untrusted_advice_commitment: Option, pub trace_length: usize, @@ -324,46 +324,45 @@ impl CanonicalSerialize for VirtualPolynomial { Self::RightLookupOperand => 8u8.serialize_with_mode(&mut writer, compress), Self::LeftInstructionInput => 9u8.serialize_with_mode(&mut writer, compress), Self::RightInstructionInput => 10u8.serialize_with_mode(&mut writer, compress), - Self::Product => 11u8.serialize_with_mode(&mut writer, compress), - Self::ShouldJump => 12u8.serialize_with_mode(&mut writer, compress), - Self::ShouldBranch => 13u8.serialize_with_mode(&mut writer, compress), - Self::WritePCtoRD => 14u8.serialize_with_mode(&mut writer, compress), - Self::WriteLookupOutputToRD => 15u8.serialize_with_mode(&mut writer, compress), - Self::Rd => 16u8.serialize_with_mode(&mut writer, compress), - Self::Imm => 17u8.serialize_with_mode(&mut writer, compress), - Self::Rs1Value => 18u8.serialize_with_mode(&mut writer, compress), - Self::Rs2Value => 19u8.serialize_with_mode(&mut writer, compress), - Self::RdWriteValue => 20u8.serialize_with_mode(&mut writer, compress), - Self::Rs1Ra => 21u8.serialize_with_mode(&mut writer, compress), - Self::Rs2Ra => 22u8.serialize_with_mode(&mut writer, compress), - Self::RdWa => 23u8.serialize_with_mode(&mut writer, compress), - Self::LookupOutput => 24u8.serialize_with_mode(&mut writer, compress), - Self::InstructionRaf => 25u8.serialize_with_mode(&mut writer, compress), - Self::InstructionRafFlag => 26u8.serialize_with_mode(&mut writer, compress), + Self::ShouldJump => 11u8.serialize_with_mode(&mut writer, compress), + Self::ShouldBranch => 12u8.serialize_with_mode(&mut writer, compress), + Self::WritePCtoRD => 13u8.serialize_with_mode(&mut writer, compress), + Self::WriteLookupOutputToRD => 14u8.serialize_with_mode(&mut writer, compress), + Self::Rd => 15u8.serialize_with_mode(&mut writer, compress), + Self::Imm => 16u8.serialize_with_mode(&mut writer, compress), + Self::Rs1Value => 17u8.serialize_with_mode(&mut writer, compress), + Self::Rs2Value => 18u8.serialize_with_mode(&mut writer, compress), + Self::RdWriteValue => 19u8.serialize_with_mode(&mut writer, compress), + Self::Rs1Ra => 20u8.serialize_with_mode(&mut writer, compress), + Self::Rs2Ra => 21u8.serialize_with_mode(&mut writer, compress), + Self::RdWa => 22u8.serialize_with_mode(&mut writer, compress), + Self::LookupOutput => 23u8.serialize_with_mode(&mut writer, compress), + Self::InstructionRaf => 24u8.serialize_with_mode(&mut writer, compress), + Self::InstructionRafFlag => 25u8.serialize_with_mode(&mut writer, compress), Self::InstructionRa(i) => { - 27u8.serialize_with_mode(&mut writer, compress)?; + 26u8.serialize_with_mode(&mut writer, compress)?; (u8::try_from(*i).unwrap()).serialize_with_mode(&mut writer, compress) } - Self::RegistersVal => 28u8.serialize_with_mode(&mut writer, compress), - Self::RamAddress => 29u8.serialize_with_mode(&mut writer, compress), - Self::RamRa => 30u8.serialize_with_mode(&mut writer, compress), - Self::RamReadValue => 31u8.serialize_with_mode(&mut writer, compress), - Self::RamWriteValue => 32u8.serialize_with_mode(&mut writer, compress), - Self::RamVal => 33u8.serialize_with_mode(&mut writer, compress), - Self::RamValInit => 34u8.serialize_with_mode(&mut writer, compress), - Self::RamValFinal => 35u8.serialize_with_mode(&mut writer, compress), - Self::RamHammingWeight => 36u8.serialize_with_mode(&mut writer, compress), - Self::UnivariateSkip => 37u8.serialize_with_mode(&mut writer, compress), + Self::RegistersVal => 27u8.serialize_with_mode(&mut writer, compress), + Self::RamAddress => 28u8.serialize_with_mode(&mut writer, compress), + Self::RamRa => 29u8.serialize_with_mode(&mut writer, compress), + Self::RamReadValue => 30u8.serialize_with_mode(&mut writer, compress), + Self::RamWriteValue => 31u8.serialize_with_mode(&mut writer, compress), + Self::RamVal => 32u8.serialize_with_mode(&mut writer, compress), + Self::RamValInit => 33u8.serialize_with_mode(&mut writer, compress), + Self::RamValFinal => 34u8.serialize_with_mode(&mut writer, compress), + Self::RamHammingWeight => 35u8.serialize_with_mode(&mut writer, compress), + Self::UnivariateSkip => 36u8.serialize_with_mode(&mut writer, compress), Self::OpFlags(flags) => { - 38u8.serialize_with_mode(&mut writer, compress)?; + 37u8.serialize_with_mode(&mut writer, compress)?; (u8::try_from(*flags as usize).unwrap()).serialize_with_mode(&mut writer, compress) } Self::InstructionFlags(flags) => { - 39u8.serialize_with_mode(&mut writer, compress)?; + 38u8.serialize_with_mode(&mut writer, compress)?; (u8::try_from(*flags as usize).unwrap()).serialize_with_mode(&mut writer, compress) } Self::LookupTableFlag(flag) => { - 40u8.serialize_with_mode(&mut writer, compress)?; + 39u8.serialize_with_mode(&mut writer, compress)?; (u8::try_from(*flag).unwrap()).serialize_with_mode(&mut writer, compress) } } @@ -382,7 +381,6 @@ impl CanonicalSerialize for VirtualPolynomial { | Self::RightLookupOperand | Self::LeftInstructionInput | Self::RightInstructionInput - | Self::Product | Self::ShouldJump | Self::ShouldBranch | Self::WritePCtoRD @@ -441,49 +439,48 @@ impl CanonicalDeserialize for VirtualPolynomial { 8 => Self::RightLookupOperand, 9 => Self::LeftInstructionInput, 10 => Self::RightInstructionInput, - 11 => Self::Product, - 12 => Self::ShouldJump, - 13 => Self::ShouldBranch, - 14 => Self::WritePCtoRD, - 15 => Self::WriteLookupOutputToRD, - 16 => Self::Rd, - 17 => Self::Imm, - 18 => Self::Rs1Value, - 19 => Self::Rs2Value, - 20 => Self::RdWriteValue, - 21 => Self::Rs1Ra, - 22 => Self::Rs2Ra, - 23 => Self::RdWa, - 24 => Self::LookupOutput, - 25 => Self::InstructionRaf, - 26 => Self::InstructionRafFlag, - 27 => { + 11 => Self::ShouldJump, + 12 => Self::ShouldBranch, + 13 => Self::WritePCtoRD, + 14 => Self::WriteLookupOutputToRD, + 15 => Self::Rd, + 16 => Self::Imm, + 17 => Self::Rs1Value, + 18 => Self::Rs2Value, + 19 => Self::RdWriteValue, + 20 => Self::Rs1Ra, + 21 => Self::Rs2Ra, + 22 => Self::RdWa, + 23 => Self::LookupOutput, + 24 => Self::InstructionRaf, + 25 => Self::InstructionRafFlag, + 26 => { let i = u8::deserialize_with_mode(&mut reader, compress, validate)?; Self::InstructionRa(i as usize) } - 28 => Self::RegistersVal, - 29 => Self::RamAddress, - 30 => Self::RamRa, - 31 => Self::RamReadValue, - 32 => Self::RamWriteValue, - 33 => Self::RamVal, - 34 => Self::RamValInit, - 35 => Self::RamValFinal, - 36 => Self::RamHammingWeight, - 37 => Self::UnivariateSkip, - 38 => { + 27 => Self::RegistersVal, + 28 => Self::RamAddress, + 29 => Self::RamRa, + 30 => Self::RamReadValue, + 31 => Self::RamWriteValue, + 32 => Self::RamVal, + 33 => Self::RamValInit, + 34 => Self::RamValFinal, + 35 => Self::RamHammingWeight, + 36 => Self::UnivariateSkip, + 37 => { let discriminant = u8::deserialize_with_mode(&mut reader, compress, validate)?; let flags = CircuitFlags::from_repr(discriminant) .ok_or(SerializationError::InvalidData)?; Self::OpFlags(flags) } - 39 => { + 38 => { let discriminant = u8::deserialize_with_mode(&mut reader, compress, validate)?; let flags = InstructionFlags::from_repr(discriminant) .ok_or(SerializationError::InvalidData)?; Self::InstructionFlags(flags) } - 40 => { + 39 => { let flag = u8::deserialize_with_mode(&mut reader, compress, validate)?; Self::LookupTableFlag(flag as usize) } @@ -498,7 +495,6 @@ pub fn serialize_and_print_size( file_name: &str, item: &impl CanonicalSerialize, ) -> Result<(), SerializationError> { - use std::fs::File; let mut file = File::create(file_name)?; item.serialize_compressed(&mut file)?; let file_size_bytes = file.metadata()?.len(); diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index 041eb37774..efe64a6409 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -1,5 +1,7 @@ #[cfg(test)] use crate::poly::multilinear_polynomial::PolynomialEvaluation; +use common::constants::ONEHOT_CHUNK_THRESHOLD_LOG_T; + use crate::{ subprotocols::streaming_schedule::LinearOnlySchedule, zkvm::{claim_reductions::advice::ReductionPhase, config::OneHotConfig}, @@ -17,6 +19,7 @@ use crate::poly::commitment::dory::DoryContext; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use crate::zkvm::config::ReadWriteConfig; +use crate::zkvm::ram::remap_address; use crate::zkvm::verifier::JoltSharedPreprocessing; use crate::zkvm::Serializable; @@ -143,10 +146,10 @@ pub struct JoltCpuProver< pub lazy_trace: LazyTraceIterator, pub trace: Arc>, pub advice: JoltAdvice, - /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). + /// The advice claim reduction sumcheck effectively spans two stages (5 and 6). /// Cache the prover state here between stages. advice_reduction_prover_trusted: Option>, - /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). + /// The advice claim reduction sumcheck effectively spans two stages (5 and 6). /// Cache the prover state here between stages. advice_reduction_prover_untrusted: Option>, pub unpadded_trace_len: usize, @@ -360,7 +363,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip let ram_K = trace .par_iter() .filter_map(|cycle| { - crate::zkvm::ram::remap_address( + remap_address( cycle.ram_access().address() as u64, &preprocessing.shared.memory_layout, ) @@ -368,7 +371,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip .max() .unwrap_or(0) .max( - crate::zkvm::ram::remap_address( + remap_address( preprocessing.shared.ram.min_bytecode_address, &preprocessing.shared.memory_layout, ) @@ -449,7 +452,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip let untrusted_advice_commitment = self.generate_and_commit_untrusted_advice(); self.generate_and_commit_trusted_advice(); - // Add advice hints for batched Stage 8 opening + // Add advice hints for batched Stage 7 opening if let Some(hint) = self.advice.trusted_advice_hint.take() { opening_proof_hints.insert(CommittedPolynomial::TrustedAdvice, hint); } @@ -463,9 +466,8 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip let stage4_sumcheck_proof = self.prove_stage4(); let stage5_sumcheck_proof = self.prove_stage5(); let stage6_sumcheck_proof = self.prove_stage6(); - let stage7_sumcheck_proof = self.prove_stage7(); - let joint_opening_proof = self.prove_stage8(opening_proof_hints); + let joint_opening_proof = self.prove_stage7(opening_proof_hints); #[cfg(test)] assert!( @@ -498,7 +500,6 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip stage4_sumcheck_proof, stage5_sumcheck_proof, stage6_sumcheck_proof, - stage7_sumcheck_proof, joint_opening_proof, trace_length: self.trace.len(), ram_K: self.one_hot_params.ram_k, @@ -773,17 +774,6 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip self.trace.len(), &self.rw_config, ); - let spartan_product_virtual_remainder_params = ProductVirtualRemainderParams::new( - self.trace.len(), - uni_skip_params, - &self.opening_accumulator, - ); - let instruction_claim_reduction_params = - InstructionLookupsClaimReductionSumcheckParams::new( - self.trace.len(), - &self.opening_accumulator, - &mut self.transcript, - ); let ram_raf_evaluation_params = RafEvaluationSumcheckParams::new( &self.program_io.memory_layout, &self.one_hot_params, @@ -798,6 +788,17 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip self.trace.len(), &self.rw_config, ); + let spartan_product_virtual_remainder_params = ProductVirtualRemainderParams::new( + self.trace.len(), + uni_skip_params, + &self.opening_accumulator, + ); + let instruction_claim_reduction_params = + InstructionLookupsClaimReductionSumcheckParams::new( + self.trace.len(), + &self.opening_accumulator, + &mut self.transcript, + ); let ram_read_write_checking = RamReadWriteCheckingProver::initialize( ram_read_write_checking_params, &self.trace, @@ -805,15 +806,6 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip &self.program_io.memory_layout, &self.initial_ram_state, ); - let spartan_product_virtual_remainder = ProductVirtualRemainderProver::initialize( - spartan_product_virtual_remainder_params, - Arc::clone(&self.trace), - ); - let instruction_claim_reduction = - InstructionLookupsClaimReductionSumcheckProver::initialize( - instruction_claim_reduction_params, - Arc::clone(&self.trace), - ); let ram_raf_evaluation = RamRafEvaluationSumcheckProver::initialize( ram_raf_evaluation_params, &self.trace, @@ -825,56 +817,16 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip &self.final_ram_state, &self.program_io.memory_layout, ); - - #[cfg(feature = "allocative")] - { - print_data_structure_heap_usage("RamReadWriteCheckingProver", &ram_read_write_checking); - print_data_structure_heap_usage( - "ProductVirtualRemainderProver", - &spartan_product_virtual_remainder, - ); - print_data_structure_heap_usage( - "InstructionLookupsClaimReductionSumcheckProver", - &instruction_claim_reduction, - ); - print_data_structure_heap_usage("RamRafEvaluationSumcheckProver", &ram_raf_evaluation); - print_data_structure_heap_usage("OutputSumcheckProver", &ram_output_check); - } - - let mut instances: Vec>> = vec![ - Box::new(ram_read_write_checking), - Box::new(spartan_product_virtual_remainder), - Box::new(instruction_claim_reduction), - Box::new(ram_raf_evaluation), - Box::new(ram_output_check), - ]; - - #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage2_start_flamechart.svg"); - tracing::info!("Stage 2 proving"); - let (sumcheck_proof, _r_stage2) = BatchedSumcheck::prove( - instances.iter_mut().map(|v| &mut **v as _).collect(), - &mut self.opening_accumulator, - &mut self.transcript, + let spartan_product_virtual_remainder = ProductVirtualRemainderProver::initialize( + spartan_product_virtual_remainder_params, + Arc::clone(&self.trace), ); - #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage2_end_flamechart.svg"); - drop_in_background_thread(instances); - - (first_round_proof, sumcheck_proof) - } - - #[tracing::instrument(skip_all)] - fn prove_stage3(&mut self) -> SumcheckInstanceProof { - #[cfg(not(target_arch = "wasm32"))] - print_current_memory_usage("Stage 3 baseline"); + let instruction_claim_reduction = + InstructionLookupsClaimReductionSumcheckProver::initialize( + instruction_claim_reduction_params, + Arc::clone(&self.trace), + ); - // Initialization params - let spartan_shift_params = ShiftSumcheckParams::new( - self.trace.len().log_2(), - &self.opening_accumulator, - &mut self.transcript, - ); let spartan_instruction_input_params = InstructionInputParams::new(&self.opening_accumulator, &mut self.transcript); let spartan_registers_claim_reduction_params = RegistersClaimReductionSumcheckParams::new( @@ -882,13 +834,6 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip &self.opening_accumulator, &mut self.transcript, ); - - // Initialize - let spartan_shift = ShiftSumcheckProver::initialize( - spartan_shift_params, - Arc::clone(&self.trace), - &self.preprocessing.shared.bytecode, - ); let spartan_instruction_input = InstructionInputSumcheckProver::initialize( spartan_instruction_input_params, &self.trace, @@ -901,7 +846,17 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip #[cfg(feature = "allocative")] { - print_data_structure_heap_usage("ShiftSumcheckProver", &spartan_shift); + print_data_structure_heap_usage("RamReadWriteCheckingProver", &ram_read_write_checking); + print_data_structure_heap_usage("RamRafEvaluationSumcheckProver", &ram_raf_evaluation); + print_data_structure_heap_usage("OutputSumcheckProver", &ram_output_check); + print_data_structure_heap_usage( + "ProductVirtualRemainderProver", + &spartan_product_virtual_remainder, + ); + print_data_structure_heap_usage( + "InstructionLookupsClaimReductionSumcheckProver", + &instruction_claim_reduction, + ); print_data_structure_heap_usage( "InstructionInputSumcheckProver", &spartan_instruction_input, @@ -913,30 +868,40 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip } let mut instances: Vec>> = vec![ - Box::new(spartan_shift), + Box::new(ram_read_write_checking), + Box::new(ram_raf_evaluation), + Box::new(ram_output_check), + Box::new(spartan_product_virtual_remainder), + Box::new(instruction_claim_reduction), Box::new(spartan_instruction_input), Box::new(spartan_registers_claim_reduction), ]; #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage3_start_flamechart.svg"); - tracing::info!("Stage 3 proving"); - let (sumcheck_proof, _r_stage3) = BatchedSumcheck::prove( + write_boxed_instance_flamegraph_svg(&instances, "stage2_start_flamechart.svg"); + tracing::info!("Stage 2 proving"); + let (sumcheck_proof, _r_stage2) = BatchedSumcheck::prove( instances.iter_mut().map(|v| &mut **v as _).collect(), &mut self.opening_accumulator, &mut self.transcript, ); #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage3_end_flamechart.svg"); + write_boxed_instance_flamegraph_svg(&instances, "stage2_end_flamechart.svg"); drop_in_background_thread(instances); - sumcheck_proof + (first_round_proof, sumcheck_proof) } #[tracing::instrument(skip_all)] - fn prove_stage4(&mut self) -> SumcheckInstanceProof { + fn prove_stage3(&mut self) -> SumcheckInstanceProof { #[cfg(not(target_arch = "wasm32"))] - print_current_memory_usage("Stage 4 baseline"); + print_current_memory_usage("Stage 3 baseline"); + + let spartan_shift_params = ShiftSumcheckParams::new( + self.trace.len().log_2(), + &self.opening_accumulator, + &mut self.transcript, + ); let registers_read_write_checking_params = RegistersReadWriteCheckingParams::new( self.trace.len(), @@ -952,7 +917,6 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip &mut self.opening_accumulator, &mut self.transcript, ); - // Domain-separate the batching challenge. self.transcript.append_bytes(b"ram_val_check_gamma", &[]); let ram_val_check_gamma: F = self.transcript.challenge_scalar::(); let ram_val_check_params = RamValCheckSumcheckParams::new_from_prover( @@ -963,6 +927,11 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip ram_val_check_gamma, ); + let spartan_shift = ShiftSumcheckProver::initialize( + spartan_shift_params, + Arc::clone(&self.trace), + &self.preprocessing.shared.bytecode, + ); let registers_read_write_checking = RegistersReadWriteCheckingProver::initialize( registers_read_write_checking_params, self.trace.clone(), @@ -978,6 +947,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip #[cfg(feature = "allocative")] { + print_data_structure_heap_usage("ShiftSumcheckProver", &spartan_shift); print_data_structure_heap_usage( "RegistersReadWriteCheckingProver", ®isters_read_write_checking, @@ -986,29 +956,30 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip } let mut instances: Vec>> = vec![ + Box::new(spartan_shift), Box::new(registers_read_write_checking), Box::new(ram_val_check), ]; #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage4_start_flamechart.svg"); - tracing::info!("Stage 4 proving"); - let (sumcheck_proof, _r_stage4) = BatchedSumcheck::prove( + write_boxed_instance_flamegraph_svg(&instances, "stage3_start_flamechart.svg"); + tracing::info!("Stage 3 proving"); + let (sumcheck_proof, _r_stage3) = BatchedSumcheck::prove( instances.iter_mut().map(|v| &mut **v as _).collect(), &mut self.opening_accumulator, &mut self.transcript, ); #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage4_end_flamechart.svg"); + write_boxed_instance_flamegraph_svg(&instances, "stage3_end_flamechart.svg"); drop_in_background_thread(instances); sumcheck_proof } #[tracing::instrument(skip_all)] - fn prove_stage5(&mut self) -> SumcheckInstanceProof { + fn prove_stage4(&mut self) -> SumcheckInstanceProof { #[cfg(not(target_arch = "wasm32"))] - print_current_memory_usage("Stage 5 baseline"); + print_current_memory_usage("Stage 4 baseline"); // Initialization params (same order as batch) let lookups_read_raf_params = InstructionReadRafSumcheckParams::new( self.trace.len().log_2(), @@ -1059,24 +1030,24 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip ]; #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage5_start_flamechart.svg"); - tracing::info!("Stage 5 proving"); - let (sumcheck_proof, _r_stage5) = BatchedSumcheck::prove( + write_boxed_instance_flamegraph_svg(&instances, "stage4_start_flamechart.svg"); + tracing::info!("Stage 4 proving"); + let (sumcheck_proof, _r_stage4) = BatchedSumcheck::prove( instances.iter_mut().map(|v| &mut **v as _).collect(), &mut self.opening_accumulator, &mut self.transcript, ); #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage5_end_flamechart.svg"); + write_boxed_instance_flamegraph_svg(&instances, "stage4_end_flamechart.svg"); drop_in_background_thread(instances); sumcheck_proof } #[tracing::instrument(skip_all)] - fn prove_stage6(&mut self) -> SumcheckInstanceProof { + fn prove_stage5(&mut self) -> SumcheckInstanceProof { #[cfg(not(target_arch = "wasm32"))] - print_current_memory_usage("Stage 6 baseline"); + print_current_memory_usage("Stage 5 baseline"); let bytecode_read_raf_params = BytecodeReadRafSumcheckParams::gen( &self.preprocessing.shared.bytecode, @@ -1112,7 +1083,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip &mut self.transcript, ); - // Advice claim reduction (Phase 1 in Stage 6): trusted and untrusted are separate instances. + // Advice claim reduction (Phase 1 in Stage 5): trusted and untrusted are separate instances. if self.advice.trusted_advice_polynomial.is_some() { let trusted_advice_params = AdviceClaimReductionParams::new( AdviceKind::Trusted, @@ -1120,7 +1091,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip self.trace.len(), &self.opening_accumulator, ); - // Note: We clone the advice polynomial here because Stage 8 needs the original polynomial + // Note: We clone the advice polynomial here because Stage 7 needs the original polynomial // A future optimization could use Arc with copy-on-write. self.advice_reduction_prover_trusted = { let poly = self @@ -1142,7 +1113,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip self.trace.len(), &self.opening_accumulator, ); - // Note: We clone the advice polynomial here because Stage 8 needs the original polynomial + // Note: We clone the advice polynomial here because Stage 7 needs the original polynomial // A future optimization could use Arc with copy-on-write. self.advice_reduction_prover_untrusted = { let poly = self @@ -1218,15 +1189,15 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip } #[cfg(feature = "allocative")] - write_instance_flamegraph_svg(&instances, "stage6_start_flamechart.svg"); - tracing::info!("Stage 6 proving"); - let (sumcheck_proof, _r_stage6) = BatchedSumcheck::prove( + write_instance_flamegraph_svg(&instances, "stage5_start_flamechart.svg"); + tracing::info!("Stage 5 proving"); + let (sumcheck_proof, _r_stage5) = BatchedSumcheck::prove( instances.iter_mut().map(|v| &mut **v as _).collect(), &mut self.opening_accumulator, &mut self.transcript, ); #[cfg(feature = "allocative")] - write_instance_flamegraph_svg(&instances, "stage6_end_flamechart.svg"); + write_instance_flamegraph_svg(&instances, "stage5_end_flamechart.svg"); drop_in_background_thread(bytecode_read_raf); drop_in_background_thread(booleanity); drop_in_background_thread(ram_hamming_booleanity); @@ -1237,9 +1208,9 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip sumcheck_proof } - /// Stage 7: HammingWeight + ClaimReduction sumcheck (only log_k_chunk rounds). + /// Stage 6: HammingWeight + ClaimReduction sumcheck (only log_k_chunk rounds). #[tracing::instrument(skip_all)] - fn prove_stage7(&mut self) -> SumcheckInstanceProof { + fn prove_stage6(&mut self) -> SumcheckInstanceProof { // Create params and prover for HammingWeightClaimReduction // (r_cycle and r_addr_bool are extracted from Booleanity opening internally) let hw_params = HammingWeightClaimReductionParams::new( @@ -1257,7 +1228,7 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip #[cfg(feature = "allocative")] print_data_structure_heap_usage("HammingWeightClaimReductionProver", &hw_prover); - // Run Stage 7 batched sumcheck (address rounds only). + // Run Stage 6 batched sumcheck (address rounds only). // Includes HammingWeightClaimReduction plus address phase of advice reduction instances (if needed). let mut instances: Vec>> = vec![Box::new(hw_prover)]; @@ -1290,28 +1261,28 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip } #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage7_start_flamechart.svg"); - tracing::info!("Stage 7 proving"); + write_boxed_instance_flamegraph_svg(&instances, "stage6_start_flamechart.svg"); + tracing::info!("Stage 6 proving"); let (sumcheck_proof, _) = BatchedSumcheck::prove( instances.iter_mut().map(|v| &mut **v as _).collect(), &mut self.opening_accumulator, &mut self.transcript, ); #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage7_end_flamechart.svg"); + write_boxed_instance_flamegraph_svg(&instances, "stage6_end_flamechart.svg"); drop_in_background_thread(instances); sumcheck_proof } - /// Stage 8: Dory batch opening proof. + /// Stage 7: Dory batch opening proof. /// Builds streaming RLC polynomial directly from trace (no witness regeneration needed). #[tracing::instrument(skip_all)] - fn prove_stage8( + fn prove_stage7( &mut self, opening_proof_hints: HashMap, ) -> PCS::Proof { - tracing::info!("Stage 8 proving (Dory batch opening)"); + tracing::info!("Stage 7 proving (Dory batch opening)"); let _guard = DoryGlobals::initialize_context( self.one_hot_params.k_chunk, @@ -1321,19 +1292,19 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip ); // Get the unified opening point from HammingWeightClaimReduction - // This contains (r_address_stage7 || r_cycle_stage6) in big-endian + // This contains (r_address_stage6 || r_cycle_stage5) in big-endian let (opening_point, _) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::InstructionRa(0), SumcheckId::HammingWeightClaimReduction, ); let log_k_chunk = self.one_hot_params.log_k_chunk; - let r_address_stage7 = &opening_point.r[..log_k_chunk]; + let r_address_stage6 = &opening_point.r[..log_k_chunk]; // 1. Collect all (polynomial, claim) pairs let mut polynomial_claims = Vec::new(); - // Dense polynomials: RamInc and RdInc (from IncClaimReduction in Stage 6) - // These are at r_cycle_stage6 only (length log_T) + // Dense polynomials: RamInc and RdInc (from IncClaimReduction in Stage 5) + // These are at r_cycle_stage5 only (length log_T) let (_ram_inc_point, ram_inc_claim) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::RamInc, @@ -1348,16 +1319,16 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip #[cfg(test)] { // Verify that Inc openings are at the same point as r_cycle from HammingWeightClaimReduction - let r_cycle_stage6 = &opening_point.r[log_k_chunk..]; + let r_cycle_stage5 = &opening_point.r[log_k_chunk..]; debug_assert_eq!( _ram_inc_point.r.as_slice(), - r_cycle_stage6, + r_cycle_stage5, "RamInc opening point should match r_cycle from HammingWeightClaimReduction" ); debug_assert_eq!( _rd_inc_point.r.as_slice(), - r_cycle_stage6, + r_cycle_stage5, "RdInc opening point should match r_cycle from HammingWeightClaimReduction" ); } @@ -1365,13 +1336,13 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip // Apply Lagrange factor for dense polys: ∏_{i, ProofTranscrip polynomial_claims.push((CommittedPolynomial::RamRa(i), claim)); } - // Advice polynomials: TrustedAdvice and UntrustedAdvice (from AdviceClaimReduction in Stage 6) + // Advice polynomials: TrustedAdvice and UntrustedAdvice (from AdviceClaimReduction in Stage 5) // These are committed with smaller dimensions, so we apply Lagrange factors to embed // them in the top-left block of the main Dory matrix. if let Some((advice_point, advice_claim)) = self @@ -1529,7 +1500,6 @@ where shared: JoltSharedPreprocessing, // max_trace_length: usize, ) -> JoltProverPreprocessing { - use common::constants::ONEHOT_CHUNK_THRESHOLD_LOG_T; let max_T: usize = shared.max_padded_trace_length.next_power_of_two(); let max_log_T = max_T.log_2(); // Use the maximum possible log_k_chunk for generator setup @@ -1852,8 +1822,8 @@ mod tests { DoryGlobals::reset(); // SHA2 guest does not consume advice, but providing both trusted and untrusted advice // should still work correctly through the full pipeline: - // - Trusted: commit in preprocessing-only context, reduce in Stage 6, batch in Stage 8 - // - Untrusted: commit at prove time, reduce in Stage 6, batch in Stage 8 + // - Trusted: commit in preprocessing-only context, reduce in Stage 5, batch in Stage 7 + // - Untrusted: commit at prove time, reduce in Stage 5, batch in Stage 7 let mut program = host::Program::new("sha2-guest"); let (bytecode, init_memory_state, _) = program.decode(); let inputs = postcard::to_stdvec(&[5u8; 32]).unwrap(); @@ -2038,8 +2008,8 @@ mod tests { // Tests that advice opening points are correctly derived from the unified main opening // point using Dory's balanced dimension policy. // - // For a small trace (256 cycles), the advice row coordinates span both Stage 6 (cycle) - // and Stage 7 (address) challenges, verifying the two-phase reduction works correctly. + // For a small trace (256 cycles), the advice row coordinates span both Stage 5 (cycle) + // and Stage 6 (address) challenges, verifying the two-phase reduction works correctly. let mut program = host::Program::new("fibonacci-guest"); let inputs = postcard::to_stdvec(&5u32).unwrap(); let trusted_advice = postcard::to_stdvec(&[7u8; 32]).unwrap(); @@ -2394,7 +2364,7 @@ mod tests { prover_preprocessing.generators.to_verifier_setup(), ); - // DoryGlobals is now initialized inside the verifier's verify_stage8 + // DoryGlobals is now initialized inside the verifier's verify_stage7 RV64IMACVerifier::new(&verifier_preprocessing, proof, io_device, None, debug_info) .expect("verifier creation failed") .verify() diff --git a/jolt-core/src/zkvm/r1cs/constraints.rs b/jolt-core/src/zkvm/r1cs/constraints.rs index c000f04570..5c69edf023 100644 --- a/jolt-core/src/zkvm/r1cs/constraints.rs +++ b/jolt-core/src/zkvm/r1cs/constraints.rs @@ -162,7 +162,6 @@ pub enum R1CSConstraintLabel { LeftLookupEqLeftInputOtherwise, RightLookupAdd, RightLookupSub, - RightLookupEqProductIfMul, RightLookupEqRightInputOtherwise, AssertLookupOne, RdWriteEqLookupIfWriteLookupToRd, @@ -301,11 +300,6 @@ pub static R1CS_CONSTRAINTS: [NamedR1CSConstraint; NUM_R1CS_CONSTRAINTS] = [ if { { JoltR1CSInputs::OpFlags(CircuitFlags::SubtractOperands) } } => ( { JoltR1CSInputs::RightLookupOperand } ) == ( { JoltR1CSInputs::LeftInstructionInput } - { JoltR1CSInputs::RightInstructionInput } + { 0x10000000000000000i128 } ) ), - r1cs_eq_conditional!( - label: R1CSConstraintLabel::RightLookupEqProductIfMul, - if { { JoltR1CSInputs::OpFlags(CircuitFlags::MultiplyOperands) } } - => ( { JoltR1CSInputs::RightLookupOperand } ) == ( { JoltR1CSInputs::Product } ) - ), // if !(AddOperands || SubtractOperands || MultiplyOperands || Advice) { // assert!(RightLookupOperand == RightInstructionInput) // } @@ -495,7 +489,6 @@ const fn complement_first_group_labels() -> [R1CSConstraintLabel; NUM_REMAINING_ /// First group: 10 boolean-guarded eq constraints, where Bz is around 64 bits pub const R1CS_CONSTRAINTS_FIRST_GROUP_LABELS: [R1CSConstraintLabel; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE] = [ - R1CSConstraintLabel::RamAddrEqZeroIfNotLoadStore, R1CSConstraintLabel::RamReadEqRamWriteIfLoad, R1CSConstraintLabel::RamReadEqRdWriteIfLoad, R1CSConstraintLabel::Rs2EqRamWriteIfStore, @@ -520,8 +513,8 @@ pub static R1CS_CONSTRAINTS_FIRST_GROUP: [NamedR1CSConstraint; OUTER_UNIVARIATE_ pub static R1CS_CONSTRAINTS_SECOND_GROUP: [NamedR1CSConstraint; NUM_REMAINING_R1CS_CONSTRAINTS] = filter_r1cs_constraints(&R1CS_CONSTRAINTS_SECOND_GROUP_LABELS); -/// Domain sizing for product-virtualization univariate-skip (size-5 window) -pub const NUM_PRODUCT_VIRTUAL: usize = 5; +/// Domain sizing for product-virtualization univariate-skip (size-4 window) +pub const NUM_PRODUCT_VIRTUAL: usize = 4; pub const PRODUCT_VIRTUAL_UNIVARIATE_SKIP_DOMAIN_SIZE: usize = NUM_PRODUCT_VIRTUAL; pub const PRODUCT_VIRTUAL_UNIVARIATE_SKIP_DEGREE: usize = NUM_PRODUCT_VIRTUAL - 1; pub const PRODUCT_VIRTUAL_UNIVARIATE_SKIP_EXTENDED_DOMAIN_SIZE: usize = @@ -534,7 +527,6 @@ pub const PRODUCT_VIRTUAL_FIRST_ROUND_POLY_DEGREE_BOUND: usize = #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, EnumCount, EnumIter)] pub enum ProductConstraintLabel { - Instruction, WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, @@ -565,14 +557,7 @@ pub struct ProductConstraint { /// Canonical list of the product constraints in the same order as /// `PRODUCT_VIRTUAL_TERMS` used by the product virtualization stage. pub const PRODUCT_CONSTRAINTS: [ProductConstraint; NUM_PRODUCT_CONSTRAINTS] = [ - // 0: Product = LeftInstructionInput · RightInstructionInput - ProductConstraint { - label: ProductConstraintLabel::Instruction, - left: ProductFactorExpr::Var(VirtualPolynomial::LeftInstructionInput), - right: ProductFactorExpr::Var(VirtualPolynomial::RightInstructionInput), - output: VirtualPolynomial::Product, - }, - // 1: WriteLookupOutputToRD = IsRdNotZero · OpFlags(WriteLookupOutputToRD) + // 0: WriteLookupOutputToRD = IsRdNotZero · OpFlags(WriteLookupOutputToRD) ProductConstraint { label: ProductConstraintLabel::WriteLookupOutputToRD, left: ProductFactorExpr::Var(VirtualPolynomial::InstructionFlags( @@ -583,7 +568,7 @@ pub const PRODUCT_CONSTRAINTS: [ProductConstraint; NUM_PRODUCT_CONSTRAINTS] = [ )), output: VirtualPolynomial::WriteLookupOutputToRD, }, - // 2: WritePCtoRD = IsRdNotZero · OpFlags(Jump) + // 1: WritePCtoRD = IsRdNotZero · OpFlags(Jump) ProductConstraint { label: ProductConstraintLabel::WritePCtoRD, left: ProductFactorExpr::Var(VirtualPolynomial::InstructionFlags( @@ -592,7 +577,7 @@ pub const PRODUCT_CONSTRAINTS: [ProductConstraint; NUM_PRODUCT_CONSTRAINTS] = [ right: ProductFactorExpr::Var(VirtualPolynomial::OpFlags(CircuitFlags::Jump)), output: VirtualPolynomial::WritePCtoRD, }, - // 3: ShouldBranch = LookupOutput · InstructionFlags(Branch) + // 2: ShouldBranch = LookupOutput · InstructionFlags(Branch) ProductConstraint { label: ProductConstraintLabel::ShouldBranch, left: ProductFactorExpr::Var(VirtualPolynomial::LookupOutput), @@ -601,7 +586,7 @@ pub const PRODUCT_CONSTRAINTS: [ProductConstraint; NUM_PRODUCT_CONSTRAINTS] = [ )), output: VirtualPolynomial::ShouldBranch, }, - // 4: ShouldJump = OpFlags(Jump) · (1 − NextIsNoop) + // 3: ShouldJump = OpFlags(Jump) · (1 − NextIsNoop) ProductConstraint { label: ProductConstraintLabel::ShouldJump, left: ProductFactorExpr::Var(VirtualPolynomial::OpFlags(CircuitFlags::Jump)), diff --git a/jolt-core/src/zkvm/r1cs/evaluation.rs b/jolt-core/src/zkvm/r1cs/evaluation.rs index 1b204d7a13..0a4e204fa6 100644 --- a/jolt-core/src/zkvm/r1cs/evaluation.rs +++ b/jolt-core/src/zkvm/r1cs/evaluation.rs @@ -128,7 +128,6 @@ pub(crate) const PRODUCT_VIRTUAL_COEFFS_PER_J: [[i32; /// Boolean guards for the first group (univariate-skip base window) #[derive(Clone, Copy, Debug)] pub struct AzFirstGroup { - pub not_load_store: bool, // !(Load || Store) pub load_a: bool, // Load pub load_b: bool, // Load pub store: bool, // Store @@ -150,31 +149,29 @@ impl AzFirstGroup { w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], acc: &mut Acc5U, ) { - acc.fmadd(&w[0], &self.not_load_store); - acc.fmadd(&w[1], &self.load_a); - acc.fmadd(&w[2], &self.load_b); - acc.fmadd(&w[3], &self.store); - acc.fmadd(&w[4], &self.add_sub_mul); - acc.fmadd(&w[5], &self.not_add_sub_mul); - acc.fmadd(&w[6], &self.assert_flag); - acc.fmadd(&w[7], &self.should_jump); - acc.fmadd(&w[8], &self.virtual_instr_not_last); - acc.fmadd(&w[9], &self.must_start_sequence); + acc.fmadd(&w[0], &self.load_a); + acc.fmadd(&w[1], &self.load_b); + acc.fmadd(&w[2], &self.store); + acc.fmadd(&w[3], &self.add_sub_mul); + acc.fmadd(&w[4], &self.not_add_sub_mul); + acc.fmadd(&w[5], &self.assert_flag); + acc.fmadd(&w[6], &self.should_jump); + acc.fmadd(&w[7], &self.virtual_instr_not_last); + acc.fmadd(&w[8], &self.must_start_sequence); } } /// Magnitudes for the first group (kept small: bool/u64/S64) #[derive(Clone, Copy, Debug)] pub struct BzFirstGroup { - pub ram_addr: u64, // RamAddress - 0 - pub ram_read_minus_ram_write: S64, // RamRead - RamWrite - pub ram_read_minus_rd_write: S64, // RamRead - RdWrite - pub rs2_minus_ram_write: S64, // Rs2 - RamWrite - pub left_lookup: u64, // LeftLookup - 0 - pub left_lookup_minus_left_input: S64, // LeftLookup - LeftInstructionInput - pub lookup_output_minus_one: S64, // LookupOutput - 1 - pub next_unexp_pc_minus_lookup_output: S64, // NextUnexpandedPC - LookupOutput - pub next_pc_minus_pc_plus_one: S64, // NextPC - (PC + 1) + pub ram_read_minus_ram_write: S64, // RamRead - RamWrite + pub ram_read_minus_rd_write: S64, // RamRead - RdWrite + pub rs2_minus_ram_write: S64, // Rs2 - RamWrite + pub left_lookup: u64, // LeftLookup - 0 + pub left_lookup_minus_left_input: S64, // LeftLookup - LeftInstructionInput + pub lookup_output_minus_one: S64, // LookupOutput - 1 + pub next_unexp_pc_minus_lookup_output: S64, // NextUnexpandedPC - LookupOutput + pub next_pc_minus_pc_plus_one: S64, // NextPC - (PC + 1) pub one_minus_do_not_update_unexpanded_pc: bool, // 1 - DoNotUpdateUnexpandedPC } @@ -188,16 +185,15 @@ impl BzFirstGroup { w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], acc: &mut Acc6S, ) { - acc.fmadd(&w[0], &self.ram_addr); - acc.fmadd(&w[1], &self.ram_read_minus_ram_write); - acc.fmadd(&w[2], &self.ram_read_minus_rd_write); - acc.fmadd(&w[3], &self.rs2_minus_ram_write); - acc.fmadd(&w[4], &self.left_lookup); - acc.fmadd(&w[5], &self.left_lookup_minus_left_input); - acc.fmadd(&w[6], &self.lookup_output_minus_one); - acc.fmadd(&w[7], &self.next_unexp_pc_minus_lookup_output); - acc.fmadd(&w[8], &self.next_pc_minus_pc_plus_one); - acc.fmadd(&w[9], &self.one_minus_do_not_update_unexpanded_pc); + acc.fmadd(&w[0], &self.ram_read_minus_ram_write); + acc.fmadd(&w[1], &self.ram_read_minus_rd_write); + acc.fmadd(&w[2], &self.rs2_minus_ram_write); + acc.fmadd(&w[3], &self.left_lookup); + acc.fmadd(&w[4], &self.left_lookup_minus_left_input); + acc.fmadd(&w[5], &self.lookup_output_minus_one); + acc.fmadd(&w[6], &self.next_unexp_pc_minus_lookup_output); + acc.fmadd(&w[7], &self.next_pc_minus_pc_plus_one); + acc.fmadd(&w[8], &self.one_minus_do_not_update_unexpanded_pc); } } @@ -205,9 +201,9 @@ impl BzFirstGroup { #[derive(Clone, Copy, Debug)] pub struct AzSecondGroup { pub load_or_store: bool, // Load || Store + pub not_load_store: bool, // !(Load || Store) pub add: bool, // Add pub sub: bool, // Sub - pub mul: bool, // Mul pub not_add_sub_mul_advice: bool, // !(Add || Sub || Mul || Advice) pub write_lookup_to_rd: bool, // write_lookup_output_to_rd_addr (Rd != 0) pub write_pc_to_rd: bool, // write_pc_to_rd_addr (Rd != 0) @@ -226,9 +222,9 @@ impl AzSecondGroup { acc: &mut Acc5U, ) { acc.fmadd(&w[0], &self.load_or_store); - acc.fmadd(&w[1], &self.add); - acc.fmadd(&w[2], &self.sub); - acc.fmadd(&w[3], &self.mul); + acc.fmadd(&w[1], &self.not_load_store); + acc.fmadd(&w[2], &self.add); + acc.fmadd(&w[3], &self.sub); acc.fmadd(&w[4], &self.not_add_sub_mul_advice); acc.fmadd(&w[5], &self.write_lookup_to_rd); acc.fmadd(&w[6], &self.write_pc_to_rd); @@ -241,9 +237,9 @@ impl AzSecondGroup { #[derive(Clone, Copy, Debug)] pub struct BzSecondGroup { pub ram_addr_minus_rs1_plus_imm: i128, // RamAddress - (Rs1 + Imm) + pub ram_addr: u64, // RamAddress - 0 pub right_lookup_minus_add_result: S160, // RightLookup - (Left + Right) pub right_lookup_minus_sub_result: S160, // RightLookup - (Left - Right + 2^64) - pub right_lookup_minus_product: S160, // RightLookup - Product pub right_lookup_minus_right_input: S160, // RightLookup - RightInput pub rd_write_minus_lookup_output: S64, // RdWrite - LookupOutput pub rd_write_minus_pc_plus_const: S64, // RdWrite - (UnexpandedPC + const) @@ -262,9 +258,9 @@ impl BzSecondGroup { acc: &mut Acc7S, ) { acc.fmadd(&w[0], &self.ram_addr_minus_rs1_plus_imm); - acc.fmadd(&w[1], &self.right_lookup_minus_add_result); - acc.fmadd(&w[2], &self.right_lookup_minus_sub_result); - acc.fmadd(&w[3], &self.right_lookup_minus_product); + acc.fmadd(&w[1], &self.ram_addr); + acc.fmadd(&w[2], &self.right_lookup_minus_add_result); + acc.fmadd(&w[3], &self.right_lookup_minus_sub_result); acc.fmadd(&w[4], &self.right_lookup_minus_right_input); acc.fmadd(&w[5], &self.rd_write_minus_lookup_output); acc.fmadd(&w[6], &self.rd_write_minus_pc_plus_const); @@ -301,7 +297,6 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { let inline_seq = flags[CircuitFlags::VirtualInstruction]; AzFirstGroup { - not_load_store: !(ld || st), load_a: ld, load_b: ld, store: st, @@ -317,7 +312,6 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { #[inline] pub fn eval_bz_first_group(&self) -> BzFirstGroup { BzFirstGroup { - ram_addr: self.row.ram_addr, ram_read_minus_ram_write: s64_from_diff_u64s( self.row.ram_read_value, self.row.ram_write_value, @@ -353,16 +347,15 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { pub fn az_at_r_first_group(&self, w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]) -> F { let az = self.eval_az_first_group(); let mut acc: Acc5U = Acc5U::zero(); - acc.fmadd(&w[0], &az.not_load_store); - acc.fmadd(&w[1], &az.load_a); - acc.fmadd(&w[2], &az.load_b); - acc.fmadd(&w[3], &az.store); - acc.fmadd(&w[4], &az.add_sub_mul); - acc.fmadd(&w[5], &az.not_add_sub_mul); - acc.fmadd(&w[6], &az.assert_flag); - acc.fmadd(&w[7], &az.should_jump); - acc.fmadd(&w[8], &az.virtual_instr_not_last); - acc.fmadd(&w[9], &az.must_start_sequence); + acc.fmadd(&w[0], &az.load_a); + acc.fmadd(&w[1], &az.load_b); + acc.fmadd(&w[2], &az.store); + acc.fmadd(&w[3], &az.add_sub_mul); + acc.fmadd(&w[4], &az.not_add_sub_mul); + acc.fmadd(&w[5], &az.assert_flag); + acc.fmadd(&w[6], &az.should_jump); + acc.fmadd(&w[7], &az.virtual_instr_not_last); + acc.fmadd(&w[8], &az.must_start_sequence); acc.barrett_reduce() } @@ -370,16 +363,15 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { pub fn bz_at_r_first_group(&self, w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]) -> F { let bz = self.eval_bz_first_group(); let mut acc: Acc6S = Acc6S::zero(); - acc.fmadd(&w[0], &bz.ram_addr); - acc.fmadd(&w[1], &bz.ram_read_minus_ram_write); - acc.fmadd(&w[2], &bz.ram_read_minus_rd_write); - acc.fmadd(&w[3], &bz.rs2_minus_ram_write); - acc.fmadd(&w[4], &bz.left_lookup); - acc.fmadd(&w[5], &bz.left_lookup_minus_left_input); - acc.fmadd(&w[6], &bz.lookup_output_minus_one); - acc.fmadd(&w[7], &bz.next_unexp_pc_minus_lookup_output); - acc.fmadd(&w[8], &bz.next_pc_minus_pc_plus_one); - acc.fmadd(&w[9], &bz.one_minus_do_not_update_unexpanded_pc); + acc.fmadd(&w[0], &bz.ram_read_minus_ram_write); + acc.fmadd(&w[1], &bz.ram_read_minus_rd_write); + acc.fmadd(&w[2], &bz.rs2_minus_ram_write); + acc.fmadd(&w[3], &bz.left_lookup); + acc.fmadd(&w[4], &bz.left_lookup_minus_left_input); + acc.fmadd(&w[5], &bz.lookup_output_minus_one); + acc.fmadd(&w[6], &bz.next_unexp_pc_minus_lookup_output); + acc.fmadd(&w[7], &bz.next_pc_minus_pc_plus_one); + acc.fmadd(&w[8], &bz.one_minus_do_not_update_unexpanded_pc); acc.barrett_reduce() } @@ -412,73 +404,66 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { let mut bz_eval_s128: S128Sum = S128Sum::zero(); let c0_i32 = coeffs_i32[0]; - if az.not_load_store { + if az.load_a { az_eval_i32 += c0_i32; } else { - bz_eval_s128.fmadd(&c0_i32, &bz.ram_addr); + bz_eval_s128.fmadd(&c0_i32, &bz.ram_read_minus_ram_write); } let c1_i32 = coeffs_i32[1]; - if az.load_a { + if az.load_b { az_eval_i32 += c1_i32; } else { - bz_eval_s128.fmadd(&c1_i32, &bz.ram_read_minus_ram_write); + bz_eval_s128.fmadd(&c1_i32, &bz.ram_read_minus_rd_write); } let c2_i32 = coeffs_i32[2]; - if az.load_b { + if az.store { az_eval_i32 += c2_i32; } else { - bz_eval_s128.fmadd(&c2_i32, &bz.ram_read_minus_rd_write); + bz_eval_s128.fmadd(&c2_i32, &bz.rs2_minus_ram_write); } let c3_i32 = coeffs_i32[3]; - if az.store { + if az.add_sub_mul { az_eval_i32 += c3_i32; } else { - bz_eval_s128.fmadd(&c3_i32, &bz.rs2_minus_ram_write); + bz_eval_s128.fmadd(&c3_i32, &bz.left_lookup); } let c4_i32 = coeffs_i32[4]; - if az.add_sub_mul { + if az.not_add_sub_mul { az_eval_i32 += c4_i32; } else { - bz_eval_s128.fmadd(&c4_i32, &bz.left_lookup); + bz_eval_s128.fmadd(&c4_i32, &bz.left_lookup_minus_left_input); } let c5_i32 = coeffs_i32[5]; - if az.not_add_sub_mul { + if az.assert_flag { az_eval_i32 += c5_i32; } else { - bz_eval_s128.fmadd(&c5_i32, &bz.left_lookup_minus_left_input); + bz_eval_s128.fmadd(&c5_i32, &bz.lookup_output_minus_one); } let c6_i32 = coeffs_i32[6]; - if az.assert_flag { + if az.should_jump { az_eval_i32 += c6_i32; } else { - bz_eval_s128.fmadd(&c6_i32, &bz.lookup_output_minus_one); + bz_eval_s128.fmadd(&c6_i32, &bz.next_unexp_pc_minus_lookup_output); } let c7_i32 = coeffs_i32[7]; - if az.should_jump { + if az.virtual_instr_not_last { az_eval_i32 += c7_i32; } else { - bz_eval_s128.fmadd(&c7_i32, &bz.next_unexp_pc_minus_lookup_output); + bz_eval_s128.fmadd(&c7_i32, &bz.next_pc_minus_pc_plus_one); } let c8_i32 = coeffs_i32[8]; - if az.virtual_instr_not_last { - az_eval_i32 += c8_i32; - } else { - bz_eval_s128.fmadd(&c8_i32, &bz.next_pc_minus_pc_plus_one); - } - - let c9_i32 = coeffs_i32[9]; if az.must_start_sequence { - az_eval_i32 += c9_i32; + az_eval_i32 += c8_i32; } else { - bz_eval_s128.fmadd(&c9_i32, &bz.one_minus_do_not_update_unexpanded_pc); + bz_eval_s128.fmadd(&c8_i32, &bz.one_minus_do_not_update_unexpanded_pc); } let az_eval_s64 = S64::from_i64(az_eval_i32 as i64); @@ -503,37 +488,36 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { pub fn assert_constraints_first_group(&self) { let az = self.eval_az_first_group(); let bz = self.eval_bz_first_group(); - self.assert_constraint_first_group(0, az.not_load_store, bz.ram_addr == 0); self.assert_constraint_first_group( - 1, + 0, az.load_a, bz.ram_read_minus_ram_write.to_i128() == 0, ); - self.assert_constraint_first_group(2, az.load_b, bz.ram_read_minus_rd_write.to_i128() == 0); - self.assert_constraint_first_group(3, az.store, bz.rs2_minus_ram_write.to_i128() == 0); - self.assert_constraint_first_group(4, az.add_sub_mul, bz.left_lookup == 0); + self.assert_constraint_first_group(1, az.load_b, bz.ram_read_minus_rd_write.to_i128() == 0); + self.assert_constraint_first_group(2, az.store, bz.rs2_minus_ram_write.to_i128() == 0); + self.assert_constraint_first_group(3, az.add_sub_mul, bz.left_lookup == 0); self.assert_constraint_first_group( - 5, + 4, az.not_add_sub_mul, bz.left_lookup_minus_left_input.to_i128() == 0, ); self.assert_constraint_first_group( - 6, + 5, az.assert_flag, bz.lookup_output_minus_one.to_i128() == 0, ); self.assert_constraint_first_group( - 7, + 6, az.should_jump, bz.next_unexp_pc_minus_lookup_output.to_i128() == 0, ); self.assert_constraint_first_group( - 8, + 7, az.virtual_instr_not_last, bz.next_pc_minus_pc_plus_one.to_i128() == 0, ); self.assert_constraint_first_group( - 9, + 8, az.must_start_sequence, !bz.one_minus_do_not_update_unexpanded_pc, ); @@ -554,9 +538,9 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { AzSecondGroup { load_or_store: (flags[CircuitFlags::Load] || flags[CircuitFlags::Store]), + not_load_store: !(flags[CircuitFlags::Load] || flags[CircuitFlags::Store]), add: flags[CircuitFlags::AddOperands], sub: flags[CircuitFlags::SubtractOperands], - mul: flags[CircuitFlags::MultiplyOperands], not_add_sub_mul_advice, write_lookup_to_rd: self.row.write_lookup_output_to_rd_addr, write_pc_to_rd: self.row.write_pc_to_rd_addr, @@ -584,8 +568,6 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { S160::from(self.row.right_lookup) - S160::from(right_add_expected); let right_lookup_minus_sub_result = S160::from(self.row.right_lookup) - S160::from(right_sub_expected); - let right_lookup_minus_product = - S160::from(self.row.right_lookup) - S160::from(self.row.product); let right_lookup_minus_right_input = S160::from(self.row.right_lookup) - S160::from(self.row.right_input); @@ -620,9 +602,9 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { BzSecondGroup { ram_addr_minus_rs1_plus_imm, + ram_addr: self.row.ram_addr, right_lookup_minus_add_result, right_lookup_minus_sub_result, - right_lookup_minus_product, right_lookup_minus_right_input, rd_write_minus_lookup_output, rd_write_minus_pc_plus_const, @@ -632,14 +614,13 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { } #[inline] - pub fn az_at_r_second_group(&self, _w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]) -> F { - let w = _w; + pub fn az_at_r_second_group(&self, w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]) -> F { let az = self.eval_az_second_group(); let mut acc: Acc5U = Acc5U::zero(); acc.fmadd(&w[0], &az.load_or_store); - acc.fmadd(&w[1], &az.add); - acc.fmadd(&w[2], &az.sub); - acc.fmadd(&w[3], &az.mul); + acc.fmadd(&w[1], &az.not_load_store); + acc.fmadd(&w[2], &az.add); + acc.fmadd(&w[3], &az.sub); acc.fmadd(&w[4], &az.not_add_sub_mul_advice); acc.fmadd(&w[5], &az.write_lookup_to_rd); acc.fmadd(&w[6], &az.write_pc_to_rd); @@ -649,14 +630,13 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { } #[inline] - pub fn bz_at_r_second_group(&self, _w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]) -> F { - let w = _w; + pub fn bz_at_r_second_group(&self, w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]) -> F { let bz = self.eval_bz_second_group(); let mut acc: Acc7S = Acc7S::zero(); acc.fmadd(&w[0], &bz.ram_addr_minus_rs1_plus_imm); - acc.fmadd(&w[1], &bz.right_lookup_minus_add_result); - acc.fmadd(&w[2], &bz.right_lookup_minus_sub_result); - acc.fmadd(&w[3], &bz.right_lookup_minus_product); + acc.fmadd(&w[1], &bz.ram_addr); + acc.fmadd(&w[2], &bz.right_lookup_minus_add_result); + acc.fmadd(&w[3], &bz.right_lookup_minus_sub_result); acc.fmadd(&w[4], &bz.right_lookup_minus_right_input); acc.fmadd(&w[5], &bz.rd_write_minus_lookup_output); acc.fmadd(&w[6], &bz.rd_write_minus_pc_plus_const); @@ -701,24 +681,24 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { } let c1 = coeffs_i32[1]; - if az.add { + if az.not_load_store { az_eval_i32 += c1; } else { - bz_eval_s192.fmadd(&c1, &bz.right_lookup_minus_add_result); + bz_eval_s192.fmadd(&c1, &bz.ram_addr); } let c2 = coeffs_i32[2]; - if az.sub { + if az.add { az_eval_i32 += c2; } else { - bz_eval_s192.fmadd(&c2, &bz.right_lookup_minus_sub_result); + bz_eval_s192.fmadd(&c2, &bz.right_lookup_minus_add_result); } let c3 = coeffs_i32[3]; - if az.mul { + if az.sub { az_eval_i32 += c3; } else { - bz_eval_s192.fmadd(&c3, &bz.right_lookup_minus_product); + bz_eval_s192.fmadd(&c3, &bz.right_lookup_minus_sub_result); } let c4 = coeffs_i32[4]; @@ -783,9 +763,9 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { az.load_or_store, bz.ram_addr_minus_rs1_plus_imm == 0i128, ); - self.assert_constraint_second_group(1, az.add, bz.right_lookup_minus_add_result.is_zero()); - self.assert_constraint_second_group(2, az.sub, bz.right_lookup_minus_sub_result.is_zero()); - self.assert_constraint_second_group(3, az.mul, bz.right_lookup_minus_product.is_zero()); + self.assert_constraint_second_group(1, az.not_load_store, bz.ram_addr == 0); + self.assert_constraint_second_group(2, az.add, bz.right_lookup_minus_add_result.is_zero()); + self.assert_constraint_second_group(3, az.sub, bz.right_lookup_minus_sub_result.is_zero()); self.assert_constraint_second_group( 4, az.not_add_sub_mul_advice, @@ -837,7 +817,6 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { // If S128 => 7 limbs signed let mut acc_left_input: Acc6U = Acc6U::zero(); let mut acc_right_input: Acc6S = Acc6S::zero(); - let mut acc_product: Acc7S = Acc7S::zero(); let mut acc_wl_left: Acc5U = Acc5U::zero(); let mut acc_wp_left: Acc5U = Acc5U::zero(); let mut acc_sb_right: Acc5U = Acc5U::zero(); @@ -869,7 +848,6 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { acc_left_input.fmadd(&e_in, &row.left_input); acc_right_input.fmadd(&e_in, &row.right_input.to_i128()); - acc_product.fmadd(&e_in, &row.product); acc_wl_left.fmadd(&e_in, &(row.write_lookup_output_to_rd_addr as u64)); acc_wp_left.fmadd(&e_in, &(row.write_pc_to_rd_addr as u64)); @@ -903,8 +881,6 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { eq1_val.mul_unreduced::<9>(acc_left_input.barrett_reduce()); out_unr[JoltR1CSInputs::RightInstructionInput.to_index()] = eq1_val.mul_unreduced::<9>(acc_right_input.barrett_reduce()); - out_unr[JoltR1CSInputs::Product.to_index()] = - eq1_val.mul_unreduced::<9>(acc_product.barrett_reduce()); out_unr[JoltR1CSInputs::WriteLookupOutputToRD.to_index()] = eq1_val.mul_unreduced::<9>(acc_wl_left.barrett_reduce()); out_unr[JoltR1CSInputs::WritePCtoRD.to_index()] = @@ -971,71 +947,61 @@ pub struct ProductVirtualEval; impl ProductVirtualEval { /// Compute both fused left and right factors at r0 weights for a single cycle row. - /// Expected order of weights: [Instruction, WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, ShouldJump] + /// Expected order of weights: [WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, ShouldJump] #[inline] pub fn fused_left_right_at_r( row: &ProductCycleInputs, weights_at_r0: &[F], ) -> (F, F) { - // Left: u64/u8/bool let mut left_acc: Acc6U = Acc6U::zero(); - left_acc.fmadd(&weights_at_r0[0], &row.instruction_left_input); + left_acc.fmadd(&weights_at_r0[0], &row.is_rd_not_zero); left_acc.fmadd(&weights_at_r0[1], &row.is_rd_not_zero); - left_acc.fmadd(&weights_at_r0[2], &row.is_rd_not_zero); - left_acc.fmadd(&weights_at_r0[3], &row.should_branch_lookup_output); - left_acc.fmadd(&weights_at_r0[4], &row.jump_flag); + left_acc.fmadd(&weights_at_r0[2], &row.should_branch_lookup_output); + left_acc.fmadd(&weights_at_r0[3], &row.jump_flag); - // Right: i128/bool let mut right_acc: Acc6S = Acc6S::zero(); - right_acc.fmadd(&weights_at_r0[0], &row.instruction_right_input); - right_acc.fmadd(&weights_at_r0[1], &row.write_lookup_output_to_rd_flag); - right_acc.fmadd(&weights_at_r0[2], &row.jump_flag); - right_acc.fmadd(&weights_at_r0[3], &row.should_branch_flag); - right_acc.fmadd(&weights_at_r0[4], &row.not_next_noop); + right_acc.fmadd(&weights_at_r0[0], &row.write_lookup_output_to_rd_flag); + right_acc.fmadd(&weights_at_r0[1], &row.jump_flag); + right_acc.fmadd(&weights_at_r0[2], &row.should_branch_flag); + right_acc.fmadd(&weights_at_r0[3], &row.not_next_noop); (left_acc.barrett_reduce(), right_acc.barrett_reduce()) } /// Compute the fused left·right product at the j-th extended uniskip target for product virtualization. - /// Uses precomputed integer Lagrange coefficients over the size-5 base window and returns an S256 product. + /// Uses precomputed integer Lagrange coefficients over the size-4 base window and returns an S256 product. #[inline] pub fn extended_fused_product_at_j(row: &ProductCycleInputs, j: usize) -> S256 { let c: &[i32; PRODUCT_VIRTUAL_UNIVARIATE_SKIP_DOMAIN_SIZE] = &PRODUCT_VIRTUAL_COEFFS_PER_J[j]; - // Weighted components lifted to i128 let mut left_w: [i128; NUM_PRODUCT_VIRTUAL] = [0; NUM_PRODUCT_VIRTUAL]; let mut right_w: [i128; NUM_PRODUCT_VIRTUAL] = [0; NUM_PRODUCT_VIRTUAL]; - // 0: Instruction (LeftInstructionInput × RightInstructionInput) - left_w[0] = (c[0] as i128) * (row.instruction_left_input as i128); - right_w[0] = (c[0] as i128) * row.instruction_right_input; - - // 1: WriteLookupOutputToRD (IsRdNotZero × WriteLookupOutputToRD_flag) - left_w[1] = if row.is_rd_not_zero { c[1] as i128 } else { 0 }; - right_w[1] = if row.write_lookup_output_to_rd_flag { - c[1] as i128 + // 0: WriteLookupOutputToRD (IsRdNotZero × WriteLookupOutputToRD_flag) + left_w[0] = if row.is_rd_not_zero { c[0] as i128 } else { 0 }; + right_w[0] = if row.write_lookup_output_to_rd_flag { + c[0] as i128 } else { 0 }; - // 2: WritePCtoRD (IsRdNotZero × Jump_flag) - left_w[2] = if row.is_rd_not_zero { c[2] as i128 } else { 0 }; - right_w[2] = if row.jump_flag { c[2] as i128 } else { 0 }; + // 1: WritePCtoRD (IsRdNotZero × Jump_flag) + left_w[1] = if row.is_rd_not_zero { c[1] as i128 } else { 0 }; + right_w[1] = if row.jump_flag { c[1] as i128 } else { 0 }; - // 3: ShouldBranch (LookupOutput × Branch_flag) - left_w[3] = (c[3] as i128) * (row.should_branch_lookup_output as i128); - right_w[3] = if row.should_branch_flag { - c[3] as i128 + // 2: ShouldBranch (LookupOutput × Branch_flag) + left_w[2] = (c[2] as i128) * (row.should_branch_lookup_output as i128); + right_w[2] = if row.should_branch_flag { + c[2] as i128 } else { 0 }; - // 4: ShouldJump (Jump_flag × (1 − NextIsNoop)) - left_w[4] = if row.jump_flag { c[4] as i128 } else { 0 }; - right_w[4] = if row.not_next_noop { c[4] as i128 } else { 0 }; + // 3: ShouldJump (Jump_flag × (1 − NextIsNoop)) + left_w[3] = if row.jump_flag { c[3] as i128 } else { 0 }; + right_w[3] = if row.not_next_noop { c[3] as i128 } else { 0 }; - // Fuse in i128, then multiply as S128×S128 → S256 let mut left_sum: i128 = 0; let mut right_sum: i128 = 0; let mut i = 0; @@ -1049,22 +1015,20 @@ impl ProductVirtualEval { left_s128.mul_trunc::<2, 4>(&right_s128) } - /// Compute z(r_cycle) for the 9 de-duplicated factor polynomials used by Product Virtualization. + /// Compute z(r_cycle) for the 7 de-duplicated factor polynomials used by Product Virtualization. /// Order of outputs matches PRODUCT_UNIQUE_FACTOR_VIRTUALS: - /// 0: LeftInstructionInput (u64) - /// 1: RightInstructionInput (i128) - /// 2: IsRdNotZero (bool) - /// 3: OpFlags(WriteLookupOutputToRD) (bool) - /// 4: OpFlags(Jump) (bool) - /// 5: LookupOutput (u64) - /// 6: InstructionFlags(Branch) (bool) - /// 7: NextIsNoop (bool) - /// 8: OpFlags(VirtualInstruction) (bool) — not a product factor, opened for downstream stages + /// 0: IsRdNotZero (bool) + /// 1: OpFlags(WriteLookupOutputToRD) (bool) + /// 2: OpFlags(Jump) (bool) + /// 3: LookupOutput (u64) + /// 4: InstructionFlags(Branch) (bool) + /// 5: NextIsNoop (bool) + /// 6: OpFlags(VirtualInstruction) (bool) — not a product factor, opened for downstream stages #[tracing::instrument(skip_all, name = "ProductVirtualEval::compute_claimed_factors")] pub fn compute_claimed_factors( trace: &[tracer::instruction::Cycle], r_cycle: &OpeningPoint, - ) -> [F; 9] { + ) -> [F; 7] { let m = r_cycle.len() / 2; let (r2, r1) = r_cycle.split_at_r(m); let (eq_one, eq_two) = rayon::join(|| EqPolynomial::evals(r2), || EqPolynomial::evals(r1)); @@ -1076,9 +1040,6 @@ impl ProductVirtualEval { .map(|x1| { let eq1_val = eq_one[x1]; - // Accumulators for 9 outputs - let mut acc_left_u64: Acc6U = Acc6U::zero(); - let mut acc_right_i128: Acc6S = Acc6S::zero(); let mut acc_rd_zero_flag: Acc5U = Acc5U::zero(); let mut acc_wl_flag: Acc5U = Acc5U::zero(); let mut acc_jump_flag: Acc5U = Acc5U::zero(); @@ -1092,42 +1053,29 @@ impl ProductVirtualEval { let idx = x1 * eq_two_len + x2; let row = ProductCycleInputs::from_trace::(trace, idx); - // 0: LeftInstructionInput (u64) - acc_left_u64.fmadd(&e_in, &row.instruction_left_input); - // 1: RightInstructionInput (i128) - acc_right_i128.fmadd(&e_in, &row.instruction_right_input); - // 2: IsRdNotZero (bool) acc_rd_zero_flag.fmadd(&e_in, &(row.is_rd_not_zero)); - // 3: OpFlags(WriteLookupOutputToRD) (bool) acc_wl_flag.fmadd(&e_in, &row.write_lookup_output_to_rd_flag); - // 4: OpFlags(Jump) (bool) acc_jump_flag.fmadd(&e_in, &row.jump_flag); - // 5: LookupOutput (u64) acc_lookup_output.fmadd(&e_in, &row.should_branch_lookup_output); - // 6: InstructionFlags(Branch) (bool) acc_branch_flag.fmadd(&e_in, &row.should_branch_flag); - // 7: NextIsNoop (bool) = !not_next_noop acc_next_is_noop.fmadd(&e_in, &(!row.not_next_noop)); - // 8: OpFlags(VirtualInstruction) (bool) acc_virtual_instr_flag.fmadd(&e_in, &row.virtual_instruction_flag); } - let mut out_unr = [F::Unreduced::<9>::zero(); 9]; - out_unr[0] = eq1_val.mul_unreduced::<9>(acc_left_u64.barrett_reduce()); - out_unr[1] = eq1_val.mul_unreduced::<9>(acc_right_i128.barrett_reduce()); - out_unr[2] = eq1_val.mul_unreduced::<9>(acc_rd_zero_flag.barrett_reduce()); - out_unr[3] = eq1_val.mul_unreduced::<9>(acc_wl_flag.barrett_reduce()); - out_unr[4] = eq1_val.mul_unreduced::<9>(acc_jump_flag.barrett_reduce()); - out_unr[5] = eq1_val.mul_unreduced::<9>(acc_lookup_output.barrett_reduce()); - out_unr[6] = eq1_val.mul_unreduced::<9>(acc_branch_flag.barrett_reduce()); - out_unr[7] = eq1_val.mul_unreduced::<9>(acc_next_is_noop.barrett_reduce()); - out_unr[8] = eq1_val.mul_unreduced::<9>(acc_virtual_instr_flag.barrett_reduce()); + let mut out_unr = [F::Unreduced::<9>::zero(); 7]; + out_unr[0] = eq1_val.mul_unreduced::<9>(acc_rd_zero_flag.barrett_reduce()); + out_unr[1] = eq1_val.mul_unreduced::<9>(acc_wl_flag.barrett_reduce()); + out_unr[2] = eq1_val.mul_unreduced::<9>(acc_jump_flag.barrett_reduce()); + out_unr[3] = eq1_val.mul_unreduced::<9>(acc_lookup_output.barrett_reduce()); + out_unr[4] = eq1_val.mul_unreduced::<9>(acc_branch_flag.barrett_reduce()); + out_unr[5] = eq1_val.mul_unreduced::<9>(acc_next_is_noop.barrett_reduce()); + out_unr[6] = eq1_val.mul_unreduced::<9>(acc_virtual_instr_flag.barrett_reduce()); out_unr }) .reduce( - || [F::Unreduced::<9>::zero(); 9], + || [F::Unreduced::<9>::zero(); 7], |mut acc, item| { - for i in 0..9 { + for i in 0..7 { acc[i] += item[i]; } acc diff --git a/jolt-core/src/zkvm/r1cs/inputs.rs b/jolt-core/src/zkvm/r1cs/inputs.rs index ef18f3d968..04a7a2213b 100644 --- a/jolt-core/src/zkvm/r1cs/inputs.rs +++ b/jolt-core/src/zkvm/r1cs/inputs.rs @@ -21,7 +21,7 @@ use crate::zkvm::instruction::{ use crate::zkvm::witness::VirtualPolynomial; use crate::field::JoltField; -use ark_ff::biginteger::{S128, S64}; +use ark_ff::biginteger::S64; use common::constants::XLEN; use std::fmt::Debug; use tracer::instruction::Cycle; @@ -44,7 +44,6 @@ pub enum JoltR1CSInputs { RightInstructionInput, // (instruction input) LeftLookupOperand, // (instruction raf) RightLookupOperand, // (instruction raf) - Product, // (product virtualization) WriteLookupOutputToRD, // (product virtualization) WritePCtoRD, // (product virtualization) ShouldBranch, // (product virtualization) @@ -60,10 +59,9 @@ pub enum JoltR1CSInputs { pub const NUM_R1CS_INPUTS: usize = ALL_R1CS_INPUTS.len(); /// This const serves to define a canonical ordering over inputs (and thus indices /// for each input). This is needed for sumcheck. -pub const ALL_R1CS_INPUTS: [JoltR1CSInputs; 37] = [ +pub const ALL_R1CS_INPUTS: [JoltR1CSInputs; 36] = [ JoltR1CSInputs::LeftInstructionInput, JoltR1CSInputs::RightInstructionInput, - JoltR1CSInputs::Product, JoltR1CSInputs::WriteLookupOutputToRD, JoltR1CSInputs::WritePCtoRD, JoltR1CSInputs::ShouldBranch, @@ -117,41 +115,40 @@ impl JoltR1CSInputs { match self { JoltR1CSInputs::LeftInstructionInput => 0, JoltR1CSInputs::RightInstructionInput => 1, - JoltR1CSInputs::Product => 2, - JoltR1CSInputs::WriteLookupOutputToRD => 3, - JoltR1CSInputs::WritePCtoRD => 4, - JoltR1CSInputs::ShouldBranch => 5, - JoltR1CSInputs::PC => 6, - JoltR1CSInputs::UnexpandedPC => 7, - JoltR1CSInputs::Imm => 8, - JoltR1CSInputs::RamAddress => 9, - JoltR1CSInputs::Rs1Value => 10, - JoltR1CSInputs::Rs2Value => 11, - JoltR1CSInputs::RdWriteValue => 12, - JoltR1CSInputs::RamReadValue => 13, - JoltR1CSInputs::RamWriteValue => 14, - JoltR1CSInputs::LeftLookupOperand => 15, - JoltR1CSInputs::RightLookupOperand => 16, - JoltR1CSInputs::NextUnexpandedPC => 17, - JoltR1CSInputs::NextPC => 18, - JoltR1CSInputs::NextIsVirtual => 19, - JoltR1CSInputs::NextIsFirstInSequence => 20, - JoltR1CSInputs::LookupOutput => 21, - JoltR1CSInputs::ShouldJump => 22, - JoltR1CSInputs::OpFlags(CircuitFlags::AddOperands) => 23, - JoltR1CSInputs::OpFlags(CircuitFlags::SubtractOperands) => 24, - JoltR1CSInputs::OpFlags(CircuitFlags::MultiplyOperands) => 25, - JoltR1CSInputs::OpFlags(CircuitFlags::Load) => 26, - JoltR1CSInputs::OpFlags(CircuitFlags::Store) => 27, - JoltR1CSInputs::OpFlags(CircuitFlags::Jump) => 28, - JoltR1CSInputs::OpFlags(CircuitFlags::WriteLookupOutputToRD) => 29, - JoltR1CSInputs::OpFlags(CircuitFlags::VirtualInstruction) => 30, - JoltR1CSInputs::OpFlags(CircuitFlags::Assert) => 31, - JoltR1CSInputs::OpFlags(CircuitFlags::DoNotUpdateUnexpandedPC) => 32, - JoltR1CSInputs::OpFlags(CircuitFlags::Advice) => 33, - JoltR1CSInputs::OpFlags(CircuitFlags::IsCompressed) => 34, - JoltR1CSInputs::OpFlags(CircuitFlags::IsFirstInSequence) => 35, - JoltR1CSInputs::OpFlags(CircuitFlags::IsLastInSequence) => 36, + JoltR1CSInputs::WriteLookupOutputToRD => 2, + JoltR1CSInputs::WritePCtoRD => 3, + JoltR1CSInputs::ShouldBranch => 4, + JoltR1CSInputs::PC => 5, + JoltR1CSInputs::UnexpandedPC => 6, + JoltR1CSInputs::Imm => 7, + JoltR1CSInputs::RamAddress => 8, + JoltR1CSInputs::Rs1Value => 9, + JoltR1CSInputs::Rs2Value => 10, + JoltR1CSInputs::RdWriteValue => 11, + JoltR1CSInputs::RamReadValue => 12, + JoltR1CSInputs::RamWriteValue => 13, + JoltR1CSInputs::LeftLookupOperand => 14, + JoltR1CSInputs::RightLookupOperand => 15, + JoltR1CSInputs::NextUnexpandedPC => 16, + JoltR1CSInputs::NextPC => 17, + JoltR1CSInputs::NextIsVirtual => 18, + JoltR1CSInputs::NextIsFirstInSequence => 19, + JoltR1CSInputs::LookupOutput => 20, + JoltR1CSInputs::ShouldJump => 21, + JoltR1CSInputs::OpFlags(CircuitFlags::AddOperands) => 22, + JoltR1CSInputs::OpFlags(CircuitFlags::SubtractOperands) => 23, + JoltR1CSInputs::OpFlags(CircuitFlags::MultiplyOperands) => 24, + JoltR1CSInputs::OpFlags(CircuitFlags::Load) => 25, + JoltR1CSInputs::OpFlags(CircuitFlags::Store) => 26, + JoltR1CSInputs::OpFlags(CircuitFlags::Jump) => 27, + JoltR1CSInputs::OpFlags(CircuitFlags::WriteLookupOutputToRD) => 28, + JoltR1CSInputs::OpFlags(CircuitFlags::VirtualInstruction) => 29, + JoltR1CSInputs::OpFlags(CircuitFlags::Assert) => 30, + JoltR1CSInputs::OpFlags(CircuitFlags::DoNotUpdateUnexpandedPC) => 31, + JoltR1CSInputs::OpFlags(CircuitFlags::Advice) => 32, + JoltR1CSInputs::OpFlags(CircuitFlags::IsCompressed) => 33, + JoltR1CSInputs::OpFlags(CircuitFlags::IsFirstInSequence) => 34, + JoltR1CSInputs::OpFlags(CircuitFlags::IsLastInSequence) => 35, } } } @@ -170,7 +167,6 @@ impl From<&JoltR1CSInputs> for VirtualPolynomial { JoltR1CSInputs::RamWriteValue => VirtualPolynomial::RamWriteValue, JoltR1CSInputs::LeftLookupOperand => VirtualPolynomial::LeftLookupOperand, JoltR1CSInputs::RightLookupOperand => VirtualPolynomial::RightLookupOperand, - JoltR1CSInputs::Product => VirtualPolynomial::Product, JoltR1CSInputs::NextUnexpandedPC => VirtualPolynomial::NextUnexpandedPC, JoltR1CSInputs::NextPC => VirtualPolynomial::NextPC, JoltR1CSInputs::LookupOutput => VirtualPolynomial::LookupOutput, @@ -205,10 +201,6 @@ pub struct R1CSCycleInputs { /// Right instruction input as signed-magnitude `S64`. /// Typically `Imm` or `Rs2Value` with exact integer semantics. pub right_input: S64, - /// Signed-magnitude `S128` product consistent with the `Product` witness. - /// Computed from `left_input` × `right_input` using the same truncation semantics as the witness. - pub product: S128, - /// Left lookup operand (u64) for the instruction lookup query. /// Matches `LeftLookupOperand` virtual polynomial semantics. pub left_lookup: u64, @@ -290,17 +282,13 @@ impl R1CSCycleInputs { None }; - // Instruction inputs and product let (left_input, right_i128) = LookupQuery::::to_instruction_inputs(cycle); - let left_s64: S64 = S64::from_u64(left_input); let right_mag = right_i128.unsigned_abs(); debug_assert!( right_mag <= u64::MAX as u128, "RightInstructionInput overflow at row {t}: |{right_i128}| > 2^64-1" ); let right_input = S64::from_u64_with_sign(right_mag as u64, right_i128 >= 0); - let right_s128: S128 = S128::from_i128(right_i128); - let product: S128 = left_s64.mul_trunc::<2, 2>(&right_s128); // Lookup operands and output let (left_lookup, right_lookup) = LookupQuery::::to_lookup_operands(cycle); @@ -374,7 +362,6 @@ impl R1CSCycleInputs { Self { left_input, right_input, - product, left_lookup, right_lookup, lookup_output, @@ -416,7 +403,6 @@ impl R1CSCycleInputs { JoltR1CSInputs::RightInstructionInput => self.right_input.to_i128(), JoltR1CSInputs::LeftLookupOperand => self.left_lookup as i128, JoltR1CSInputs::RightLookupOperand => self.right_lookup as i128, - JoltR1CSInputs::Product => self.product.to_i128().expect("product too large for i128"), JoltR1CSInputs::WriteLookupOutputToRD => self.write_lookup_output_to_rd_addr as i128, JoltR1CSInputs::WritePCtoRD => self.write_pc_to_rd_addr as i128, JoltR1CSInputs::ShouldBranch => self.should_branch as i128, @@ -434,18 +420,14 @@ impl R1CSCycleInputs { /// Canonical, de-duplicated list of product-virtual factor polynomials used by /// the Product Virtualization stage. /// Order: -/// 0: LeftInstructionInput -/// 1: RightInstructionInput -/// 2: InstructionFlags(IsRdNotZero) -/// 3: OpFlags(WriteLookupOutputToRD) -/// 4: OpFlags(Jump) -/// 5: LookupOutput -/// 6: InstructionFlags(Branch) -/// 7: NextIsNoop -/// 8: OpFlags(VirtualInstruction) — not a product factor, but opened here for downstream stages -pub const PRODUCT_UNIQUE_FACTOR_VIRTUALS: [VirtualPolynomial; 9] = [ - VirtualPolynomial::LeftInstructionInput, - VirtualPolynomial::RightInstructionInput, +/// 0: InstructionFlags(IsRdNotZero) +/// 1: OpFlags(WriteLookupOutputToRD) +/// 2: OpFlags(Jump) +/// 3: LookupOutput +/// 4: InstructionFlags(Branch) +/// 5: NextIsNoop +/// 6: OpFlags(VirtualInstruction) — not a product factor, but opened here for downstream stages +pub const PRODUCT_UNIQUE_FACTOR_VIRTUALS: [VirtualPolynomial; 7] = [ VirtualPolynomial::InstructionFlags(InstructionFlags::IsRdNotZero), VirtualPolynomial::OpFlags(CircuitFlags::WriteLookupOutputToRD), VirtualPolynomial::OpFlags(CircuitFlags::Jump), @@ -455,21 +437,13 @@ pub const PRODUCT_UNIQUE_FACTOR_VIRTUALS: [VirtualPolynomial; 9] = [ VirtualPolynomial::OpFlags(CircuitFlags::VirtualInstruction), ]; -/// Minimal, unified view for the Product-virtualization round: the 6 product pairs -/// (left, right) materialized from the trace for a single cycle. -/// Total size is small; we keep primitive representations that match witness generation. +/// Minimal, unified view for the Product-virtualization round: the product pairs +/// materialized from the trace for a single cycle. #[derive(Clone, Debug)] pub struct ProductCycleInputs { - // 16-byte aligned - /// Instruction: LeftInstructionInput × RightInstructionInput (right input as i128) - pub instruction_right_input: i128, - - // 8-byte aligned - pub instruction_left_input: u64, /// ShouldBranch: LookupOutput × Branch_flag (left side) pub should_branch_lookup_output: u64, - // 1-byte fields /// WriteLookupOutputToRD right flag (boolean) pub write_lookup_output_to_rd_flag: bool, /// Jump flag used by both WritePCtoRD (right) and ShouldJump (left) @@ -497,37 +471,24 @@ impl ProductCycleInputs { let flags_view = instr.circuit_flags(); let instruction_flags = instr.instruction_flags(); - // Instruction inputs - let (left_input, right_input) = LookupQuery::::to_instruction_inputs(cycle); - - // Lookup output let lookup_output = LookupQuery::::to_lookup_output(cycle); - // Jump and Branch flags let jump_flag = flags_view[CircuitFlags::Jump]; let branch_flag = instruction_flags[InstructionFlags::Branch]; - // Next-is-noop and its complement (1 - NextIsNoop) let not_next_noop = { if t + 1 < len { !trace[t + 1].instruction().instruction_flags()[InstructionFlags::IsNoop] } else { - // Needs final not_next_noop to be false for the shift sumcheck - // (since EqPlusOne does not do overflow) false } }; let is_rd_not_zero = instruction_flags[InstructionFlags::IsRdNotZero]; - - // WriteLookupOutputToRD flag let write_lookup_output_to_rd_flag = flags_view[CircuitFlags::WriteLookupOutputToRD]; - let virtual_instruction_flag = flags_view[CircuitFlags::VirtualInstruction]; Self { - instruction_left_input: left_input, - instruction_right_input: right_input, write_lookup_output_to_rd_flag, should_branch_lookup_output: lookup_output, should_branch_flag: branch_flag, @@ -601,7 +562,6 @@ mod tests { } (JoltR1CSInputs::LeftLookupOperand, JoltR1CSInputs::LeftLookupOperand) => true, (JoltR1CSInputs::RightLookupOperand, JoltR1CSInputs::RightLookupOperand) => true, - (JoltR1CSInputs::Product, JoltR1CSInputs::Product) => true, (JoltR1CSInputs::WriteLookupOutputToRD, JoltR1CSInputs::WriteLookupOutputToRD) => { true } diff --git a/jolt-core/src/zkvm/ram/ra_virtual.rs b/jolt-core/src/zkvm/ram/ra_virtual.rs index 646d9eb484..ff51af0d59 100644 --- a/jolt-core/src/zkvm/ram/ra_virtual.rs +++ b/jolt-core/src/zkvm/ram/ra_virtual.rs @@ -5,10 +5,10 @@ //! //! ## Input //! -//! From RA reduction sumcheck (Stage 5), we receive a single claim: +//! From RA reduction sumcheck (Stage 4), we receive a single claim: //! //! ```text -//! ra(r_address_stage2, r_cycle_stage5) = ra_claim_stage5 +//! ra(r_address_stage2, r_cycle_stage4) = ra_claim_stage4 //! ``` //! //! ## Identity @@ -16,7 +16,7 @@ //! We prove the following sumcheck identity over `c ∈ {0,1}^{log_T}`: //! //! ```text -//! Σ_c eq(r_cycle_stage5, c) · Π_{i=0}^{d-1} ra_i(r_address_stage2_i, c) = ra_claim_stage5 +//! Σ_c eq(r_cycle_stage4, c) · Π_{i=0}^{d-1} ra_i(r_address_stage2_i, c) = ra_claim_stage4 //! ``` //! //! where: diff --git a/jolt-core/src/zkvm/spartan/instruction_input.rs b/jolt-core/src/zkvm/spartan/instruction_input.rs index fa0a81d9d0..00730c7a59 100644 --- a/jolt-core/src/zkvm/spartan/instruction_input.rs +++ b/jolt-core/src/zkvm/spartan/instruction_input.rs @@ -36,7 +36,7 @@ const DEGREE_BOUND: usize = 3; #[derive(Allocative, Clone)] pub struct InstructionInputParams { - pub r_cycle_stage_2: OpeningPoint, + pub r_spartan: OpeningPoint, pub gamma: F, } @@ -45,15 +45,12 @@ impl InstructionInputParams { opening_accumulator: &dyn OpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { - let (r_cycle_stage_2, _) = opening_accumulator.get_virtual_polynomial_opening( + let (r_spartan, _) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::LeftInstructionInput, - SumcheckId::SpartanProductVirtualization, + SumcheckId::SpartanOuter, ); let gamma = transcript.challenge_scalar(); - Self { - r_cycle_stage_2, - gamma, - } + Self { r_spartan, gamma } } } @@ -63,40 +60,20 @@ impl SumcheckInstanceParams for InstructionInputParams { } fn num_rounds(&self) -> usize { - self.r_cycle_stage_2.len() + self.r_spartan.len() } fn input_claim(&self, accumulator: &dyn OpeningAccumulator) -> F { - let (r_left_claim_instruction, left_claim_instruction) = accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::InstructionClaimReduction, - ); - let (r_right_claim_instruction, right_claim_instruction) = accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::RightInstructionInput, - SumcheckId::InstructionClaimReduction, - ); - - let (r_left_claim_stage_2, left_claim_stage_2) = accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::SpartanProductVirtualization, - ); - let (r_right_claim_stage_2, right_claim_stage_2) = accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::RightInstructionInput, - SumcheckId::SpartanProductVirtualization, - ); - - // Soundness: InstructionClaimReduction and SpartanProductVirtualization must produce - // the same claims at the same opening points. - assert_eq!(r_left_claim_instruction, r_left_claim_stage_2); - assert_eq!(left_claim_instruction, left_claim_stage_2); - assert_eq!(r_right_claim_instruction, r_right_claim_stage_2); - assert_eq!(right_claim_instruction, right_claim_stage_2); + let (_, left_claim) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::LeftInstructionInput, + SumcheckId::SpartanOuter, + ); + let (_, right_claim) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::RightInstructionInput, + SumcheckId::SpartanOuter, + ); - right_claim_stage_2 + self.gamma * left_claim_stage_2 + right_claim + self.gamma * left_claim } fn normalize_opening_point( @@ -118,7 +95,7 @@ pub struct InstructionInputSumcheckProver { rs2_value_poly: MultilinearPolynomial, imm_poly: MultilinearPolynomial, unexpanded_pc_poly: MultilinearPolynomial, - eq_r_cycle_stage_2: GruenSplitEqPolynomial, + eq_r_spartan: GruenSplitEqPolynomial, pub params: InstructionInputParams, } @@ -176,8 +153,8 @@ impl InstructionInputSumcheckProver { }, ); - let eq_r_cycle_stage_2 = - GruenSplitEqPolynomial::new(¶ms.r_cycle_stage_2.r, BindingOrder::LowToHigh); + let eq_r_spartan = + GruenSplitEqPolynomial::new(¶ms.r_spartan.r, BindingOrder::LowToHigh); Self { left_is_rs1_poly: left_is_rs1_poly.into(), @@ -188,7 +165,7 @@ impl InstructionInputSumcheckProver { rs2_value_poly: rs2_value_poly.into(), imm_poly: imm_poly.into(), unexpanded_pc_poly: unexpanded_pc_poly.into(), - eq_r_cycle_stage_2, + eq_r_spartan, params, } } @@ -204,7 +181,7 @@ impl SumcheckInstanceProver #[tracing::instrument(skip_all, name = "InstructionInputSumcheckProver::compute_message")] fn compute_message(&mut self, _round: usize, previous_claim: F) -> UniPoly { let [eval_at_0, eval_at_inf] = self - .eq_r_cycle_stage_2 + .eq_r_spartan .par_fold_out_in( || [F::Unreduced::<9>::zero(); 2], |inner, j, _x_in, e_in| { @@ -276,7 +253,7 @@ impl SumcheckInstanceProver ) .map(|x| F::from_montgomery_reduce::<9>(x)); - self.eq_r_cycle_stage_2 + self.eq_r_spartan .gruen_poly_deg_3(eval_at_0, eval_at_inf, previous_claim) } @@ -291,7 +268,7 @@ impl SumcheckInstanceProver rs2_value_poly, imm_poly, unexpanded_pc_poly, - eq_r_cycle_stage_2, + eq_r_spartan, params: _, } = self; left_is_rs1_poly.bind_parallel(r_j, BindingOrder::LowToHigh); @@ -302,7 +279,7 @@ impl SumcheckInstanceProver rs2_value_poly.bind_parallel(r_j, BindingOrder::LowToHigh); imm_poly.bind_parallel(r_j, BindingOrder::LowToHigh); unexpanded_pc_poly.bind_parallel(r_j, BindingOrder::LowToHigh); - eq_r_cycle_stage_2.bind(r_j); + eq_r_spartan.bind(r_j); } fn cache_openings( @@ -390,7 +367,7 @@ impl SumcheckInstanceProver /// ``` /// /// Note: -/// - `r_cycle_stage_2` is the randomness from instruction product sumcheck (stage 2). +/// - `r_spartan` is the randomness from Spartan outer sumcheck (stage 1). pub struct InstructionInputSumcheckVerifier { params: InstructionInputParams, } @@ -436,7 +413,7 @@ impl SumcheckInstanceVerifier ) -> F { let r = self.params.normalize_opening_point(sumcheck_challenges); - let eq_eval_at_r_cycle_stage_2 = EqPolynomial::mle_endian(&r, &self.params.r_cycle_stage_2); + let eq_eval_at_r_spartan = EqPolynomial::mle_endian(&r, &self.params.r_spartan); let (_, rs1_value_eval) = accumulator.get_virtual_polynomial_opening( VirtualPolynomial::Rs1Value, @@ -476,7 +453,7 @@ impl SumcheckInstanceVerifier let right_instruction_input = right_is_rs2_eval * rs2_value_eval + right_is_imm_eval * imm_eval; - let result = eq_eval_at_r_cycle_stage_2 + let result = eq_eval_at_r_spartan * (right_instruction_input + self.params.gamma * left_instruction_input); #[cfg(test)] @@ -573,24 +550,24 @@ impl SumcheckFrontend for InstructionInputSumcheckVerifier { let left_instruction_input_eval = left_is_rs1 * rs1_value + left_is_pc * unexpanded_pc; let right_instruction_input_eval = right_is_rs2 * rs2_value + right_is_imm * imm; - let eq_r_stage2 = VerifierEvaluablePolynomial::Eq(CachedPointRef { + let eq_r_spartan = VerifierEvaluablePolynomial::Eq(CachedPointRef { opening: PolynomialId::Virtual(VirtualPolynomial::LeftInstructionInput), - sumcheck: SumcheckId::SpartanProductVirtualization, + sumcheck: SumcheckId::SpartanOuter, part: ChallengePart::Cycle, }); InputOutputClaims { claims: vec![ Claim { - input_sumcheck_id: SumcheckId::SpartanProductVirtualization, + input_sumcheck_id: SumcheckId::SpartanOuter, input_claim_expr: right_instruction_input, - batching_poly: eq_r_stage2, + batching_poly: eq_r_spartan, expected_output_claim_expr: right_instruction_input_eval, }, Claim { - input_sumcheck_id: SumcheckId::SpartanProductVirtualization, + input_sumcheck_id: SumcheckId::SpartanOuter, input_claim_expr: left_instruction_input, - batching_poly: eq_r_stage2, + batching_poly: eq_r_spartan, expected_output_claim_expr: left_instruction_input_eval, }, ], diff --git a/jolt-core/src/zkvm/spartan/outer.rs b/jolt-core/src/zkvm/spartan/outer.rs index dda52684ff..8657f7213d 100644 --- a/jolt-core/src/zkvm/spartan/outer.rs +++ b/jolt-core/src/zkvm/spartan/outer.rs @@ -2,6 +2,7 @@ use std::marker::PhantomData; use std::sync::Arc; use allocative::Allocative; +use ark_ff::biginteger::{S128, S160, S192}; use ark_std::Zero; use rayon::prelude::*; use tracer::instruction::Cycle; @@ -11,7 +12,7 @@ use crate::field::{FMAdd, JoltField, MontgomeryReduce}; use crate::poly::dense_mlpoly::DensePolynomial; use crate::poly::eq_poly::EqPolynomial; use crate::poly::lagrange_poly::LagrangePolynomial; -use crate::poly::multilinear_polynomial::{BindingOrder, PolynomialBinding}; +use crate::poly::multilinear_polynomial::{BindingOrder, MultilinearPolynomial, PolynomialBinding}; use crate::poly::multiquadratic_poly::MultiquadraticPolynomial; use crate::poly::opening_proof::{ OpeningAccumulator, OpeningPoint, ProverOpeningAccumulator, SumcheckId, @@ -28,11 +29,12 @@ use crate::subprotocols::univariate_skip::build_uniskip_first_round_poly; use crate::transcripts::Transcript; use crate::utils::accumulation::{Acc5U, Acc6S, Acc7S, Acc8S}; use crate::utils::expanding_table::ExpandingTable; -use crate::utils::math::Math; +use crate::utils::math::{s64_from_diff_u64s, Math}; #[cfg(feature = "allocative")] use crate::utils::profiling::print_data_structure_heap_usage; use crate::utils::thread::unsafe_allocate_zero_vec; use crate::zkvm::bytecode::BytecodePreprocessing; +use crate::zkvm::instruction::CircuitFlags; use crate::zkvm::r1cs::constraints::OUTER_FIRST_ROUND_POLY_DEGREE_BOUND; use crate::zkvm::r1cs::key::UniformSpartanKey; use crate::zkvm::r1cs::{ @@ -41,44 +43,52 @@ use crate::zkvm::r1cs::{ OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE, OUTER_UNIVARIATE_SKIP_EXTENDED_DOMAIN_SIZE, }, evaluation::R1CSEval, - inputs::{R1CSCycleInputs, ALL_R1CS_INPUTS}, + inputs::{JoltR1CSInputs, R1CSCycleInputs, ALL_R1CS_INPUTS}, }; use crate::zkvm::witness::VirtualPolynomial; /// Degree bound of the sumcheck round polynomials for [`OuterRemainingSumcheckVerifier`]. -const OUTER_REMAINING_DEGREE_BOUND: usize = 3; +/// Degree 4 because the product constraint `is_mul*(rl - li*ri)` is degree 3 inside eq. +const OUTER_REMAINING_DEGREE_BOUND: usize = 4; // this represents the index position in multi-quadratic poly array // This should actually be d where degree is the degree of the streaming data structure // For example : MultiQuadratic has d=2; for cubic this would be 3 etc. const INFINITY: usize = 2; // Spartan Outer sumcheck -// (with univariate-skip first round on Z, and no Cz term given all eq conditional constraints) // -// We define a univariate in Z first-round polynomial -// s1(Y) := L(τ_high, Y) · Σ_{x_out ∈ {0,1}^{m_out}} Σ_{x_in ∈ {0,1}^{m_in}} -// E_out(r_out, x_out) · E_in(r_in, x_in) · -// [ Az(x_out, x_in, Y) · Bz(x_out, x_in, Y) ], -// where L(τ_high, Y) is the Lagrange basis polynomial over the univariate-skip -// base domain evaluated at τ_high, and Az(·,·,Y), Bz(·,·,Y) are the -// per-row univariate polynomials in Y induced by the R1CS row (split into two -// internal groups in code, but algebraically composing to Az·Bz at Y). -// The prover sends s1(Y) via univariate-skip by evaluating t1(Y) := Σ Σ E_out·E_in·(Az·Bz) -// on an extended grid Y ∈ {−D..D} outside the base window, interpolating t1, -// multiplying by L(τ_high, Y) to obtain s1, and the verifier samples r0. +// Proves the combined R1CS + product constraint: // -// Subsequent outer rounds bind the cycle variables r_tail = (r1, r2, …) using -// a streaming first cycle-bit round followed by linear-time rounds: -// • Streaming round (after r0): compute -// t(0) = Σ_{x_out} E_out · Σ_{x_in} E_in · (Az(0)·Bz(0)) -// t(∞) = Σ_{x_out} E_out · Σ_{x_in} E_in · ((Az(1)−Az(0))·(Bz(1)−Bz(0))) -// send a cubic built from these endpoints, and bind cached coefficients by r1. -// • Remaining rounds: reuse bound coefficients to compute the same endpoints -// in linear time for each subsequent bit and bind by r_i. +// Σ_x eq(τ, x) · [ Az(x)·Bz(x) + is_mul(x)·(rl(x) − li(x)·ri(x)) ] = 0 // -// Final check (verifier): with r = [r0 || r_tail] and outer binding order from -// the top, evaluate Eq_τ(τ, r) and verify -// L(τ_high, r_high) · Eq_τ(τ, r) · (Az(r) · Bz(r)). +// where Az·Bz encodes the R1CS constraints (no Cz term — all constraints are +// eq-conditional) and the product term enforces the multiply-instruction +// identity rl = li·ri on rows where is_mul = 1. +// +// The sumcheck proceeds in three phases: +// +// 1. Univariate-skip first round (on constraint index Z): +// Only Az·Bz participates. The product term is independent of Z. +// s1(Z) = L(τ_high, Z) · Σ_x eq(τ_low, x) · Σ_y eq(r_group, y) · Az(x,y,Z)·Bz(x,y,Z) +// The prover evaluates t1(Z) = s1(Z)/L(τ_high,Z) on the extended grid, +// sends s1 via univariate-skip, and the verifier samples r0. +// The scaling factor L(τ_high, r0) is baked into the split-eq polynomial +// via `new_with_scaling`, so it multiplies all subsequent terms uniformly. +// +// 2. Group-variable round (first cycle-variable round, binding the constraint +// group selector): +// Only Az·Bz participates. The product term is constant in the group +// variable and skipped (degree-3 round polynomial via gruen_poly_deg_3). +// +// 3. Remaining cycle-variable rounds: +// Both Az·Bz and the product term participate. Each round produces a +// degree-4 round polynomial: cubic Az·Bz via gruen_poly_deg_3 plus the +// cubic product term via gruen_poly_deg_4, combined additively. +// The product term's claim starts at 0 (for a valid witness) and is +// tracked separately in prev_claim_product. +// +// Final check (verifier): +// L(τ_high, r0) · eq(τ_low, r_cycle) · [ Az(r)·Bz(r) + product(r) ] #[derive(Allocative, Clone)] pub struct OuterUniSkipParams { @@ -418,10 +428,18 @@ impl SumcheckInstanceVerifier // Randomness used to bind the rows of R1CS matrices A,B. let rx_constr = &[sumcheck_challenges[0], self.params.r0]; // Compute sum_y A(rx_constr, y)*z(y) * sum_y B(rx_constr, y)*z(y). - let inner_sum_prod = self + let azbz = self .key .evaluate_inner_sum_product_at_point(rx_constr, r1cs_input_evals); + // Product constraint: is_mul * (right_lookup - left_input * right_input) + let is_mul_eval = + r1cs_input_evals[JoltR1CSInputs::OpFlags(CircuitFlags::MultiplyOperands).to_index()]; + let rl_eval = r1cs_input_evals[JoltR1CSInputs::RightLookupOperand.to_index()]; + let li_eval = r1cs_input_evals[JoltR1CSInputs::LeftInstructionInput.to_index()]; + let ri_eval = r1cs_input_evals[JoltR1CSInputs::RightInstructionInput.to_index()]; + let product_term = is_mul_eval * (rl_eval - li_eval * ri_eval); + let tau = &self.params.tau; let tau_high = &tau[tau.len() - 1]; let tau_high_bound_r0 = LagrangePolynomial::::lagrange_kernel::< @@ -432,7 +450,7 @@ impl SumcheckInstanceVerifier let r_tail_reversed: Vec = sumcheck_challenges.iter().rev().copied().collect(); let tau_bound_r_tail_reversed = EqPolynomial::mle(tau_low, &r_tail_reversed); - tau_high_bound_r0 * tau_bound_r_tail_reversed * inner_sum_prod + tau_high_bound_r0 * tau_bound_r_tail_reversed * (azbz + product_term) } fn cache_openings( @@ -508,6 +526,7 @@ pub struct OuterSharedState { #[allocative(skip)] lagrange_evals_r0: [F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], pub params: OuterStreamingProverParams, + prev_claim_product: F, } impl OuterSharedState { @@ -552,6 +571,7 @@ impl OuterSharedState { r_grid, params: outer_params, lagrange_evals_r0: lagrange_evals_r, + prev_claim_product: F::zero(), } } @@ -801,15 +821,16 @@ impl StreamingSumcheckWindow for OuterStreamingWindow { #[tracing::instrument(skip_all, name = "OuterStreamingWindow::compute_message")] fn compute_message( - &self, + &mut self, shared: &Self::Shared, window_size: usize, previous_claim: F, ) -> UniPoly { + let prev_claim_azbz = previous_claim - shared.prev_claim_product; let (t_prime_0, t_prime_inf) = shared.compute_t_evals(window_size); shared .split_eq_poly - .gruen_poly_deg_3(t_prime_0, t_prime_inf, previous_claim) + .gruen_poly_deg_3(t_prime_0, t_prime_inf, prev_claim_azbz) } #[tracing::instrument(skip_all, name = "OuterStreamingWindow::ingest_challenge")] @@ -828,9 +849,220 @@ impl StreamingSumcheckWindow for OuterStreamingWindow { pub struct OuterLinearStage { az: DensePolynomial, bz: DensePolynomial, + product_is_mul: Option>, + product_rl: Option>, + product_li: Option>, + product_ri: Option>, + prev_round_poly_product: Option>, } impl OuterLinearStage { + /// Fused materialization + first-round product eval computation. + /// + /// In a single pass over the trace, simultaneously: + /// 1. Materializes compact `MultilinearPolynomial`s (bool/u128/u64/i128) for + /// subsequent rounds + /// 2. Computes the first product round evals `(t2, t_inf)` using small-scalar + /// arithmetic with deferred reduction. `t0 = 0` is known since + /// `prev_claim_product = 0` on the first round (valid witness). + #[tracing::instrument( + skip_all, + name = "OuterLinearStage::fused_materialise_and_first_product_evals" + )] + fn fused_materialise_and_first_product_evals( + shared: &OuterSharedState, + ) -> ( + F, + F, + MultilinearPolynomial, + MultilinearPolynomial, + MultilinearPolynomial, + MultilinearPolynomial, + ) { + let num_cycles = shared.trace.len(); + let mut is_mul_vec: Vec = vec![false; num_cycles]; + let mut rl_vec: Vec = unsafe_allocate_zero_vec(num_cycles); + let mut li_vec: Vec = unsafe_allocate_zero_vec(num_cycles); + let mut ri_vec: Vec = unsafe_allocate_zero_vec(num_cycles); + + let e_out = shared.split_eq_poly.E_out_current(); + let e_in = shared.split_eq_poly.E_in_current(); + let num_x_in = e_in.len(); + let chunk_size = 2 * num_x_in; + + let (t2_unr, tinf_unr) = is_mul_vec + .par_chunks_mut(chunk_size) + .zip(rl_vec.par_chunks_mut(chunk_size)) + .zip(li_vec.par_chunks_mut(chunk_size)) + .zip(ri_vec.par_chunks_mut(chunk_size)) + .enumerate() + .fold( + || (F::Unreduced::<9>::zero(), F::Unreduced::<9>::zero()), + |(mut acc2, mut acc_inf), (x_out, (((m_chunk, r_chunk), l_chunk), i_chunk))| { + let mut inner2 = Acc7S::::zero(); + let mut inner_inf = Acc7S::::zero(); + + for x_in in 0..num_x_in { + let g = x_out * num_x_in + x_in; + let idx_lo = 2 * g; + let idx_hi = idx_lo + 1; + + let row_lo = R1CSCycleInputs::from_trace::( + &shared.bytecode_preprocessing, + &shared.trace, + idx_lo, + ); + let row_hi = R1CSCycleInputs::from_trace::( + &shared.bytecode_preprocessing, + &shared.trace, + idx_hi, + ); + + let m0 = row_lo.flags[CircuitFlags::MultiplyOperands]; + let m1 = row_hi.flags[CircuitFlags::MultiplyOperands]; + + let off_lo = 2 * x_in; + let off_hi = off_lo + 1; + m_chunk[off_lo] = m0; + m_chunk[off_hi] = m1; + r_chunk[off_lo] = row_lo.right_lookup; + r_chunk[off_hi] = row_hi.right_lookup; + l_chunk[off_lo] = row_lo.left_input; + l_chunk[off_hi] = row_hi.left_input; + i_chunk[off_lo] = row_lo.right_input.to_i128(); + i_chunk[off_hi] = row_hi.right_input.to_i128(); + + if !m0 && !m1 { + continue; + } + + let dm: i8 = (m1 as i8) - (m0 as i8); + let dl_i128 = + (row_hi.left_input as i128) - (row_lo.left_input as i128); + let di_i128 = + row_hi.right_input.to_i128() - row_lo.right_input.to_i128(); + + if dm != 0 { + let dl_s64 = + s64_from_diff_u64s(row_hi.left_input, row_lo.left_input); + let di_s128 = S128::from(di_i128); + let dl_di: S192 = dl_s64.mul_trunc::<2, 3>(&di_s128); + let p_inf = if dm == 1 { dl_di.neg() } else { dl_di }; + inner_inf.fmadd(&e_in[x_in], &p_inf); + } + + let m2_val: i8 = (m0 as i8) + 2 * dm; + if m2_val != 0 { + let l2 = (row_lo.left_input as i128) + 2 * dl_i128; + let i2 = row_lo.right_input.to_i128() + 2 * di_i128; + let l2_s160 = S160::from(l2); + let i2_s160 = S160::from(i2); + let li_prod = &l2_s160 * &i2_s160; + let r2 = S160::from_sum_u128( + row_hi.right_lookup, + row_hi.right_lookup, + ) - S160::from(row_lo.right_lookup); + let val = r2 - li_prod; + let p2 = match m2_val { + 1 => val, + 2 => val + val, + _ => -val, + }; + inner2.fmadd(&e_in[x_in], &p2); + } + } + + let e_out_val = e_out[x_out]; + let red2: F = inner2.barrett_reduce(); + let red_inf: F = inner_inf.barrett_reduce(); + acc2 += e_out_val.mul_unreduced::<9>(red2); + acc_inf += e_out_val.mul_unreduced::<9>(red_inf); + + (acc2, acc_inf) + }, + ) + .reduce( + || (F::Unreduced::<9>::zero(), F::Unreduced::<9>::zero()), + |a, b| (a.0 + b.0, a.1 + b.1), + ); + + let t2 = F::from_montgomery_reduce::<9>(t2_unr); + let t_inf = F::from_montgomery_reduce::<9>(tinf_unr); + + ( + t2, + t_inf, + is_mul_vec.into(), + rl_vec.into(), + li_vec.into(), + ri_vec.into(), + ) + } + + #[tracing::instrument(skip_all, name = "OuterLinearStage::compute_product_t_evals")] + fn compute_product_t_evals(&self, shared: &OuterSharedState) -> (F, F, F) { + let is_mul = self.product_is_mul.as_ref().unwrap(); + let rl = self.product_rl.as_ref().unwrap(); + let li = self.product_li.as_ref().unwrap(); + let ri = self.product_ri.as_ref().unwrap(); + + let [t0, t2, tinf] = shared + .split_eq_poly + .par_fold_out_in( + || [F::Unreduced::<9>::zero(); 3], + |inner, g, _x_in, e_in| { + let m0 = is_mul.get_bound_coeff(2 * g); + let m1 = is_mul.get_bound_coeff(2 * g + 1); + if m0.is_zero() && m1.is_zero() { + return; + } + + let r0 = rl.get_bound_coeff(2 * g); + let r1 = rl.get_bound_coeff(2 * g + 1); + let l0 = li.get_bound_coeff(2 * g); + let l1 = li.get_bound_coeff(2 * g + 1); + let i0 = ri.get_bound_coeff(2 * g); + let i1 = ri.get_bound_coeff(2 * g + 1); + + let p0 = m0 * (r0 - l0 * i0); + + let dm = m1 - m0; + let dr = r1 - r0; + let dl = l1 - l0; + let di = i1 - i0; + + let m2 = m0 + dm + dm; + let r2 = r0 + dr + dr; + let l2 = l0 + dl + dl; + let i2 = i0 + di + di; + let p2 = m2 * (r2 - l2 * i2); + + let p_inf = -(dm * dl * di); + + inner[0] += e_in.mul_unreduced::<9>(p0); + inner[1] += e_in.mul_unreduced::<9>(p2); + inner[2] += e_in.mul_unreduced::<9>(p_inf); + }, + |_x_out, e_out, inner| { + let mut outer = [F::Unreduced::<9>::zero(); 3]; + for k in 0..3 { + let red = F::from_montgomery_reduce::<9>(inner[k]); + outer[k] = e_out.mul_unreduced::<9>(red); + } + outer + }, + |mut a, b| { + for k in 0..3 { + a[k] += b[k]; + } + a + }, + ) + .map(F::from_montgomery_reduce::<9>); + + (t0, t2, tinf) + } + #[tracing::instrument( skip_all, name = "OuterLinearStage::fused_materialise_polynomials_general_with_multiquadratic" @@ -1400,7 +1632,15 @@ impl LinearSumcheckStage for OuterLinearStage { Self::fused_materialise_polynomials_round_zero(shared, window_size) }; - Self { az, bz } + Self { + az, + bz, + product_is_mul: None, + product_rl: None, + product_li: None, + product_ri: None, + prev_round_poly_product: None, + } } #[tracing::instrument(skip_all, name = "OuterLinearStage::next_window")] @@ -1410,19 +1650,58 @@ impl LinearSumcheckStage for OuterLinearStage { #[tracing::instrument(skip_all, name = "OuterLinearStage::compute_message")] fn compute_message( - &self, + &mut self, shared: &Self::Shared, window_size: usize, previous_claim: F, ) -> UniPoly { + let prev_claim_product = shared.prev_claim_product; + let prev_claim_azbz = previous_claim - prev_claim_product; + let (t_prime_0, t_prime_inf) = shared.compute_t_evals(window_size); - shared - .split_eq_poly - .gruen_poly_deg_3(t_prime_0, t_prime_inf, previous_claim) + let round_poly_azbz = + shared + .split_eq_poly + .gruen_poly_deg_3(t_prime_0, t_prime_inf, prev_claim_azbz); + + let is_group_variable_round = shared.split_eq_poly.num_challenges() == 0; + if is_group_variable_round { + self.prev_round_poly_product = None; + round_poly_azbz + } else { + let (t_prod_0, t_prod_2, t_prod_inf) = if self.product_is_mul.is_none() { + let (t2, t_inf, is_mul, rl, li, ri) = + Self::fused_materialise_and_first_product_evals(shared); + self.product_is_mul = Some(is_mul); + self.product_rl = Some(rl); + self.product_li = Some(li); + self.product_ri = Some(ri); + (F::zero(), t2, t_inf) + } else { + self.compute_product_t_evals(shared) + }; + + let round_poly_product = shared.split_eq_poly.gruen_poly_deg_4( + t_prod_0, + t_prod_2, + t_prod_inf, + prev_claim_product, + ); + self.prev_round_poly_product = Some(round_poly_product.clone()); + + let mut combined = round_poly_azbz; + combined += &round_poly_product; + combined + } } #[tracing::instrument(skip_all, name = "OuterLinearStage::ingest_challenge")] fn ingest_challenge(&mut self, shared: &mut Self::Shared, r_j: F::Challenge, _round: usize) { + if let Some(ref poly) = self.prev_round_poly_product { + shared.prev_claim_product = poly.evaluate(&r_j); + } + self.prev_round_poly_product = None; + shared.split_eq_poly.bind(r_j); if let Some(t_prime_poly) = shared.t_prime_poly.as_mut() { @@ -1430,8 +1709,50 @@ impl LinearSumcheckStage for OuterLinearStage { } rayon::join( - || self.az.bind_parallel(r_j, BindingOrder::LowToHigh), - || self.bz.bind_parallel(r_j, BindingOrder::LowToHigh), + || { + rayon::join( + || self.az.bind_parallel(r_j, BindingOrder::LowToHigh), + || self.bz.bind_parallel(r_j, BindingOrder::LowToHigh), + ) + }, + || { + if self.product_is_mul.is_some() { + rayon::join( + || { + rayon::join( + || { + self.product_is_mul + .as_mut() + .unwrap() + .bind_parallel(r_j, BindingOrder::LowToHigh) + }, + || { + self.product_rl + .as_mut() + .unwrap() + .bind_parallel(r_j, BindingOrder::LowToHigh) + }, + ) + }, + || { + rayon::join( + || { + self.product_li + .as_mut() + .unwrap() + .bind_parallel(r_j, BindingOrder::LowToHigh) + }, + || { + self.product_ri + .as_mut() + .unwrap() + .bind_parallel(r_j, BindingOrder::LowToHigh) + }, + ) + }, + ); + } + }, ); } diff --git a/jolt-core/src/zkvm/spartan/product.rs b/jolt-core/src/zkvm/spartan/product.rs index 3876bd75ec..da86e1c4a5 100644 --- a/jolt-core/src/zkvm/spartan/product.rs +++ b/jolt-core/src/zkvm/spartan/product.rs @@ -44,27 +44,20 @@ use tracer::instruction::Cycle; // We define a "combined" left and right polynomial // Left(x, y) = \sum_i L(y, i) * Left_i(x), // Right(x, y) = \sum_i R(y, i) * Right_i(x), -// where Left_i(x) = one of the five left polynomials, Right_i(x) = one of the five right polynomials -// Indexing is over i \in {-2, -1, 0, 1, 2}, though this gets mapped to the 0th, 1st, ..., 4th polynomial +// where Left_i(x) = one of the four left polynomials, Right_i(x) = one of the four right polynomials +// Indexing is over i \in {-1, 0, 1, 2} (mapped to the 0th, 1st, 2nd, 3rd polynomial) // -// We also need to define the combined claim: -// claim(y) = \sum_i L(y, i) * claim_i, -// where claim_i is the claim of the i-th product virtualization sumcheck +// The four product constraints are: +// WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, ShouldJump +// (The former Instruction product constraint is now handled directly in the outer sumcheck.) +// +// claim(y) = \sum_i L(y, i) * claim_i // -// The product virtualization sumcheck is then: // \sum_y L(tau_high, y) * \sum_x eq(tau_low, x) * Left(x, y) * Right(x, y) // = claim(tau_high) // // Final claim is: // L(tau_high, r0) * Eq(tau_low, r_tail^rev) * Left(r_tail, r0) * Right(r_tail, r0) -// -// After this, we also need to check the consistency of the Left and Right evaluations with the -// claimed evaluations of the factor polynomials. This is done in the ProductVirtualInner sumcheck. -// -// TODO (Quang): this is essentially Spartan with non-zero claims. We should unify this with Spartan outer/inner. -// Only complication is to generalize the splitting strategy -// (i.e. Spartan outer currently does uni skip for half of the constraints, -// whereas here we do it for all of them) /// Degree of the sumcheck round polynomials for [`ProductVirtualRemainderVerifier`]. const PRODUCT_VIRTUAL_REMAINDER_DEGREE: usize = 3; @@ -73,11 +66,11 @@ const PRODUCT_VIRTUAL_REMAINDER_DEGREE: usize = 3; pub struct ProductVirtualUniSkipParams { /// τ = [τ_low || τ_high] /// - τ_low: the cycle-point r_cycle carried from Spartan outer (length = num_cycle_vars) - /// - τ_high: the univariate-skip binding point sampled for the size-5 domain (length = 1) + /// - τ_high: the univariate-skip binding point sampled for the size-4 domain (length = 1) /// Ordering matches outer: variables are MSB→LSB with τ_high last pub tau: Vec, - /// Base evaluations (claims) for the five product terms at the base domain - /// Order: [Product, WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, ShouldJump] + /// Base evaluations (claims) for the four product terms at the base domain + /// Order: [WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, ShouldJump] pub base_evals: [F; NUM_PRODUCT_VIRTUAL], } @@ -86,9 +79,8 @@ impl ProductVirtualUniSkipParams { opening_accumulator: &dyn OpeningAccumulator, transcript: &mut T, ) -> Self { - // Reuse r_cycle from Stage 1 (outer) for τ_low, and sample τ_high let r_cycle = opening_accumulator - .get_virtual_polynomial_opening(VirtualPolynomial::Product, SumcheckId::SpartanOuter) + .get_virtual_polynomial_opening(PRODUCT_CONSTRAINTS[0].output, SumcheckId::SpartanOuter) .0 .r; let tau_high = transcript.challenge_scalar_optimized::(); @@ -173,20 +165,13 @@ impl ProductVirtualUniSkipProver { /// t1(z) = Σ_{x_out} E_out[x_out] · Σ_{x_in} E_in[x_in] · left_z(x) · right_z(x), /// where x is the concatenation of (x_out || x_in) in MSB→LSB order. /// - /// Lagrange fusion per target z on (current) extended window {−4,−3,3,4}: - /// - Compute c[0..4] = LagrangeHelper::shift_coeffs_i32(shift(z)) using the same shifted-kernel - /// as outer.rs (indices correspond to the 5 base points). - /// - Define fused values at this z by linearly combining the 5 product witnesses with c: + /// Lagrange fusion per target z on the extended window: + /// - Compute c[0..3] = LagrangeHelper::shift_coeffs_i32(shift(z)) using the same shifted-kernel + /// as outer.rs (indices correspond to the 4 base points). + /// - Define fused values at this z by linearly combining the 4 product witnesses with c: /// left_z(x) = Σ_i c[i] · Left_i(x) /// right_z(x) = Σ_i c[i] · Right_i^eff(x) - /// with Right_4^eff(x) = 1 − NextIsNoop(x) for the ShouldJump term only. - /// - /// Small-value lifting rules for integer accumulation before converting to the field: - /// - Instruction: LeftInstructionInput is u64 → lift to i128; RightInstructionInput is S64 → i128. - /// - WriteLookupOutputToRD: IsRdNotZero is bool/u8 → i32; flag is bool/u8 → i32. - /// - WritePCtoRD: IsRdNotZero is bool/u8 → i32; Jump flag is bool/u8 → i32. - /// - ShouldBranch: LookupOutput is u64 → i128; Branch flag is bool/u8 → i32. - /// - ShouldJump: Jump flag (left) is bool/u8 → i32; Right^eff = (1 − NextIsNoop) is bool/u8 → i32. + /// with Right_3^eff(x) = 1 − NextIsNoop(x) for the ShouldJump term only. fn compute_univariate_skip_extended_evals( trace: &[Cycle], tau: &[F::Challenge], @@ -405,17 +390,17 @@ impl SumcheckInstanceParams for ProductVirtualRemainderParams SumcheckInstanceVerifier accumulator: &VerifierOpeningAccumulator, sumcheck_challenges: &[F::Challenge], ) -> F { - // Lagrange weights at r0 let w = LagrangePolynomial::::evals::< F::Challenge, PRODUCT_VIRTUAL_UNIVARIATE_SKIP_DOMAIN_SIZE, >(&self.params.r0); - // Fetch factor claims - let l_inst = accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::SpartanProductVirtualization, - ) - .1; - let r_inst = accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::RightInstructionInput, - SumcheckId::SpartanProductVirtualization, - ) - .1; let is_rd_not_zero = accumulator .get_virtual_polynomial_opening( VirtualPolynomial::InstructionFlags(InstructionFlags::IsRdNotZero), @@ -724,18 +695,11 @@ impl SumcheckInstanceVerifier ) .1; - let fused_left = w[0] * l_inst - + w[1] * is_rd_not_zero - + w[2] * is_rd_not_zero - + w[3] * lookup_out - + w[4] * j_flag; - let fused_right = w[0] * r_inst - + w[1] * wl_flag - + w[2] * j_flag - + w[3] * branch_flag - + w[4] * (F::one() - next_is_noop); - - // Multiply by L(τ_high, r0) and Eq(τ_low, r_tail^rev) + let fused_left = + w[0] * is_rd_not_zero + w[1] * is_rd_not_zero + w[2] * lookup_out + w[3] * j_flag; + let fused_right = + w[0] * wl_flag + w[1] * j_flag + w[2] * branch_flag + w[3] * (F::one() - next_is_noop); + let tau_high = &self.params.tau[self.params.tau.len() - 1]; let tau_high_bound_r0 = LagrangePolynomial::::lagrange_kernel::< F::Challenge, diff --git a/jolt-core/src/zkvm/verifier.rs b/jolt-core/src/zkvm/verifier.rs index aa4709c4ba..1337bde9ff 100644 --- a/jolt-core/src/zkvm/verifier.rs +++ b/jolt-core/src/zkvm/verifier.rs @@ -13,7 +13,7 @@ use crate::zkvm::claim_reductions::RegistersClaimReductionSumcheckVerifier; use crate::zkvm::config::OneHotParams; #[cfg(feature = "prover")] use crate::zkvm::prover::JoltProverPreprocessing; -use crate::zkvm::ram::RAMPreprocessing; +use crate::zkvm::ram::{gen_ram_initial_memory_state, RAMPreprocessing}; use crate::zkvm::witness::all_committed_polynomials; use crate::zkvm::Serializable; use crate::zkvm::{ @@ -82,10 +82,10 @@ pub struct JoltVerifier< pub preprocessing: &'a JoltVerifierPreprocessing, pub transcript: ProofTranscript, pub opening_accumulator: VerifierOpeningAccumulator, - /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). + /// The advice claim reduction sumcheck effectively spans two stages (5 and 6). /// Cache the verifier state here between stages. advice_reduction_verifier_trusted: Option>, - /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). + /// The advice claim reduction sumcheck effectively spans two stages (5 and 6). /// Cache the verifier state here between stages. advice_reduction_verifier_untrusted: Option>, pub spartan_key: UniformSpartanKey, @@ -217,7 +217,6 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc self.verify_stage5()?; self.verify_stage6()?; self.verify_stage7()?; - self.verify_stage8()?; Ok(()) } @@ -265,18 +264,6 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc &self.proof.rw_config, ); - let spartan_product_virtual_remainder = ProductVirtualRemainderVerifier::new( - self.proof.trace_length, - uni_skip_params, - &self.opening_accumulator, - ); - - let instruction_claim_reduction = InstructionLookupsClaimReductionSumcheckVerifier::new( - self.proof.trace_length, - &self.opening_accumulator, - &mut self.transcript, - ); - let ram_raf_evaluation = RamRafEvaluationSumcheckVerifier::new( &self.program_io.memory_layout, &self.one_hot_params, @@ -293,29 +280,18 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc &self.proof.rw_config, ); - let _r_stage2 = BatchedSumcheck::verify( - &self.proof.stage2_sumcheck_proof, - vec![ - &ram_read_write_checking, - &spartan_product_virtual_remainder, - &instruction_claim_reduction, - &ram_raf_evaluation, - &ram_output_check, - ], - &mut self.opening_accumulator, - &mut self.transcript, - ) - .context("Stage 2")?; - - Ok(()) - } + let spartan_product_virtual_remainder = ProductVirtualRemainderVerifier::new( + self.proof.trace_length, + uni_skip_params, + &self.opening_accumulator, + ); - fn verify_stage3(&mut self) -> Result<(), anyhow::Error> { - let spartan_shift = ShiftSumcheckVerifier::new( - self.proof.trace_length.log_2(), + let instruction_claim_reduction = InstructionLookupsClaimReductionSumcheckVerifier::new( + self.proof.trace_length, &self.opening_accumulator, &mut self.transcript, ); + let spartan_instruction_input = InstructionInputSumcheckVerifier::new(&self.opening_accumulator, &mut self.transcript); let spartan_registers_claim_reduction = RegistersClaimReductionSumcheckVerifier::new( @@ -324,22 +300,32 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc &mut self.transcript, ); - let _r_stage3 = BatchedSumcheck::verify( - &self.proof.stage3_sumcheck_proof, + let _r_stage2 = BatchedSumcheck::verify( + &self.proof.stage2_sumcheck_proof, vec![ - &spartan_shift, + &ram_read_write_checking, + &ram_raf_evaluation, + &ram_output_check, + &spartan_product_virtual_remainder, + &instruction_claim_reduction, &spartan_instruction_input, &spartan_registers_claim_reduction, ], &mut self.opening_accumulator, &mut self.transcript, ) - .context("Stage 3")?; + .context("Stage 2")?; Ok(()) } - fn verify_stage4(&mut self) -> Result<(), anyhow::Error> { + fn verify_stage3(&mut self) -> Result<(), anyhow::Error> { + let spartan_shift = ShiftSumcheckVerifier::new( + self.proof.trace_length.log_2(), + &self.opening_accumulator, + &mut self.transcript, + ); + let registers_read_write_checking = RegistersReadWriteCheckingVerifier::new( self.proof.trace_length, &self.opening_accumulator, @@ -354,10 +340,9 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc &mut self.opening_accumulator, &mut self.transcript, ); - // Domain-separate the batching challenge. self.transcript.append_bytes(b"ram_val_check_gamma", &[]); let ram_val_check_gamma: F = self.transcript.challenge_scalar::(); - let initial_ram_state = crate::zkvm::ram::gen_ram_initial_memory_state::( + let initial_ram_state = gen_ram_initial_memory_state::( self.proof.ram_K, &self.preprocessing.shared.ram, &self.program_io, @@ -372,21 +357,22 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc &self.opening_accumulator, ); - let _r_stage4 = BatchedSumcheck::verify( - &self.proof.stage4_sumcheck_proof, + let _r_stage3 = BatchedSumcheck::verify( + &self.proof.stage3_sumcheck_proof, vec![ - ®isters_read_write_checking as &dyn SumcheckInstanceVerifier, + &spartan_shift as &dyn SumcheckInstanceVerifier, + ®isters_read_write_checking, &ram_val_check, ], &mut self.opening_accumulator, &mut self.transcript, ) - .context("Stage 4")?; + .context("Stage 3")?; Ok(()) } - fn verify_stage5(&mut self) -> Result<(), anyhow::Error> { + fn verify_stage4(&mut self) -> Result<(), anyhow::Error> { let n_cycle_vars = self.proof.trace_length.log_2(); let lookups_read_raf = InstructionReadRafSumcheckVerifier::new( @@ -404,8 +390,8 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc let registers_val_evaluation = RegistersValEvaluationSumcheckVerifier::new(&self.opening_accumulator); - let _r_stage5 = BatchedSumcheck::verify( - &self.proof.stage5_sumcheck_proof, + let _r_stage4 = BatchedSumcheck::verify( + &self.proof.stage4_sumcheck_proof, vec![ &lookups_read_raf, &ram_ra_reduction, @@ -414,12 +400,12 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc &mut self.opening_accumulator, &mut self.transcript, ) - .context("Stage 5")?; + .context("Stage 4")?; Ok(()) } - fn verify_stage6(&mut self) -> Result<(), anyhow::Error> { + fn verify_stage5(&mut self) -> Result<(), anyhow::Error> { let n_cycle_vars = self.proof.trace_length.log_2(); let bytecode_read_raf = BytecodeReadRafSumcheckVerifier::gen( &self.preprocessing.shared.bytecode, @@ -456,7 +442,7 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc &mut self.transcript, ); - // Advice claim reduction (Phase 1 in Stage 6): trusted and untrusted are separate instances. + // Advice claim reduction (Phase 1 in Stage 5): trusted and untrusted are separate instances. if self.trusted_advice_commitment.is_some() { self.advice_reduction_verifier_trusted = Some(AdviceClaimReductionVerifier::new( AdviceKind::Trusted, @@ -489,19 +475,19 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc instances.push(advice); } - let _r_stage6 = BatchedSumcheck::verify( - &self.proof.stage6_sumcheck_proof, + let _r_stage5 = BatchedSumcheck::verify( + &self.proof.stage5_sumcheck_proof, instances, &mut self.opening_accumulator, &mut self.transcript, ) - .context("Stage 6")?; + .context("Stage 5")?; Ok(()) } - /// Stage 7: HammingWeight claim reduction verification. - fn verify_stage7(&mut self) -> Result<(), anyhow::Error> { + /// Stage 6: HammingWeight claim reduction verification. + fn verify_stage6(&mut self) -> Result<(), anyhow::Error> { // Create verifier for HammingWeightClaimReduction // (r_cycle and r_addr_bool are extracted from Booleanity opening internally) let hw_verifier = HammingWeightClaimReductionVerifier::new( @@ -533,19 +519,19 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc } } - let _r_address_stage7 = BatchedSumcheck::verify( - &self.proof.stage7_sumcheck_proof, + let _r_address_stage6 = BatchedSumcheck::verify( + &self.proof.stage6_sumcheck_proof, instances, &mut self.opening_accumulator, &mut self.transcript, ) - .context("Stage 7")?; + .context("Stage 6")?; Ok(()) } - /// Stage 8: Dory batch opening verification. - fn verify_stage8(&mut self) -> Result<(), anyhow::Error> { + /// Stage 7: Dory batch opening verification. + fn verify_stage7(&mut self) -> Result<(), anyhow::Error> { // Initialize DoryGlobals with the layout from the proof // This ensures the verifier uses the same layout as the prover let _guard = DoryGlobals::initialize_context( @@ -556,18 +542,18 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc ); // Get the unified opening point from HammingWeightClaimReduction - // This contains (r_address_stage7 || r_cycle_stage6) in big-endian + // This contains (r_address_stage6 || r_cycle_stage5) in big-endian let (opening_point, _) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::InstructionRa(0), SumcheckId::HammingWeightClaimReduction, ); let log_k_chunk = self.one_hot_params.log_k_chunk; - let r_address_stage7 = &opening_point.r[..log_k_chunk]; + let r_address_stage6 = &opening_point.r[..log_k_chunk]; // 1. Collect all (polynomial, claim) pairs let mut polynomial_claims = Vec::new(); - // Dense polynomials: RamInc and RdInc (from IncClaimReduction in Stage 6) + // Dense polynomials: RamInc and RdInc (from IncClaimReduction in Stage 5) let (_, ram_inc_claim) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::RamInc, SumcheckId::IncClaimReduction, @@ -579,7 +565,7 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc // Apply Lagrange factor for dense polys // Note: r_address is in big-endian, Lagrange factor uses ∏(1 - r_i) - let lagrange_factor: F = r_address_stage7.iter().map(|r| F::one() - *r).product(); + let lagrange_factor: F = r_address_stage6.iter().map(|r| F::one() - *r).product(); polynomial_claims.push((CommittedPolynomial::RamInc, ram_inc_claim * lagrange_factor)); polynomial_claims.push((CommittedPolynomial::RdInc, rd_inc_claim * lagrange_factor)); @@ -607,7 +593,7 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc polynomial_claims.push((CommittedPolynomial::RamRa(i), claim)); } - // Advice polynomials: TrustedAdvice and UntrustedAdvice (from AdviceClaimReduction in Stage 6) + // Advice polynomials: TrustedAdvice and UntrustedAdvice (from AdviceClaimReduction in Stage 5) // These are committed with smaller dimensions, so we apply Lagrange factors to embed // them in the top-left block of the main Dory matrix. if let Some((advice_point, advice_claim)) = self @@ -694,7 +680,7 @@ impl<'a, F: JoltField, PCS: CommitmentScheme, ProofTranscript: Transc &joint_claim, &joint_commitment, ) - .context("Stage 8") + .context("Stage 7") } /// Compute joint commitment for the batch opening. diff --git a/jolt-core/src/zkvm/witness.rs b/jolt-core/src/zkvm/witness.rs index efcef73652..f0538afcf4 100644 --- a/jolt-core/src/zkvm/witness.rs +++ b/jolt-core/src/zkvm/witness.rs @@ -241,7 +241,6 @@ pub enum VirtualPolynomial { RightLookupOperand, LeftInstructionInput, RightInstructionInput, - Product, ShouldJump, ShouldBranch, WritePCtoRD, From 5b10ef6456b0d7751db22718c98e66b8ed7beacc Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Tue, 3 Mar 2026 15:30:49 -0800 Subject: [PATCH 2/3] refactor: clean up imports and naming across changed files Hoist mid-file and in-function `use` statements to top-level import blocks, replace fully qualified paths (3+ segments) with short names, and add proper cfg gates for conditionally-used imports. Made-with: Cursor --- jolt-core/src/poly/split_eq_poly.rs | 2 -- jolt-core/src/zkvm/prover.rs | 27 ++++++++++++++------------ jolt-core/src/zkvm/r1cs/constraints.rs | 22 ++++++++++----------- jolt-core/src/zkvm/r1cs/inputs.rs | 8 ++++---- jolt-core/src/zkvm/spartan/outer.rs | 5 ++--- jolt-core/src/zkvm/verifier.rs | 26 ++++++++++++------------- 6 files changed, 43 insertions(+), 47 deletions(-) diff --git a/jolt-core/src/poly/split_eq_poly.rs b/jolt-core/src/poly/split_eq_poly.rs index b76f4a4d37..fe5a7b61c0 100644 --- a/jolt-core/src/poly/split_eq_poly.rs +++ b/jolt-core/src/poly/split_eq_poly.rs @@ -854,8 +854,6 @@ mod tests { /// Verify that evals_cached returns [1] at index 0 (eq over 0 vars). #[test] fn evals_cached_starts_with_one() { - use crate::poly::eq_poly::EqPolynomial; - let mut rng = test_rng(); for num_vars in 1..=10 { let w: Vec<::Challenge> = diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index 8bea893dbc..f43e13f26e 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -18,6 +18,8 @@ use crate::poly::commitment::dory::bind_opening_inputs; #[cfg(feature = "zk")] use crate::poly::commitment::dory::bind_opening_inputs_zk; use crate::poly::commitment::dory::DoryContext; +#[cfg(not(feature = "zk"))] +use crate::zkvm::proof_serialization::Claims; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use crate::zkvm::config::ReadWriteConfig; @@ -145,9 +147,10 @@ use crate::poly::commitment::pedersen::PedersenGenerators; use crate::poly::lagrange_poly::LagrangeHelper; #[cfg(feature = "zk")] use crate::subprotocols::blindfold::{ - pedersen_generator_count_for_r1cs, BakedPublicInputs, BlindFoldProof, BlindFoldProver, - BlindFoldWitness, ExtraConstraintWitness, FinalOutputWitness, RelaxedR1CSInstance, - RoundWitness, StageConfig, StageWitness, VerifierR1CSBuilder, + pedersen_generator_count_for_r1cs, BakedPublicInputs, BlindFoldAccumulator, BlindFoldProof, + BlindFoldProver, BlindFoldWitness, ExtraConstraintWitness, FinalOutputWitness, + OpeningProofData, RelaxedR1CSInstance, RoundWitness, StageConfig, StageWitness, + VerifierR1CSBuilder, }; #[cfg(feature = "zk")] use crate::subprotocols::blindfold::{InputClaimConstraint, OutputClaimConstraint, ValueSource}; @@ -192,7 +195,7 @@ pub struct JoltCpuProver< pub pedersen_generators: PedersenGenerators, pub rw_config: ReadWriteConfig, #[cfg(feature = "zk")] - blindfold_accumulator: crate::subprotocols::blindfold::BlindFoldAccumulator, + blindfold_accumulator: BlindFoldAccumulator, #[cfg(not(feature = "zk"))] _curve: std::marker::PhantomData, } @@ -471,7 +474,7 @@ where #[cfg(feature = "zk")] pedersen_generators, #[cfg(feature = "zk")] - blindfold_accumulator: crate::subprotocols::blindfold::BlindFoldAccumulator::new(), + blindfold_accumulator: BlindFoldAccumulator::new(), #[cfg(not(feature = "zk"))] _curve: std::marker::PhantomData, } @@ -529,8 +532,7 @@ where let blindfold_proof = self.prove_blindfold(&joint_opening_proof); #[cfg(not(feature = "zk"))] - let opening_claims = - crate::zkvm::proof_serialization::Claims(self.opening_accumulator.openings.clone()); + let opening_claims = Claims(self.opening_accumulator.openings.clone()); #[cfg(test)] assert!( @@ -1992,14 +1994,13 @@ where { let y_com: C::G1 = PCS::eval_commitment(&proof).expect("ZK proof must have y_com"); bind_opening_inputs_zk::(&mut self.transcript, &opening_point.r, &y_com); - self.blindfold_accumulator.set_opening_proof_data( - crate::subprotocols::blindfold::OpeningProofData { + self.blindfold_accumulator + .set_opening_proof_data(OpeningProofData { opening_ids, constraint_coeffs, joint_claim, y_blinding: _y_blinding.expect("ZK mode requires y_blinding"), - }, - ); + }); } #[cfg(not(feature = "zk"))] { @@ -2121,6 +2122,8 @@ mod tests { multilinear_polynomial::MultilinearPolynomial, opening_proof::{OpeningAccumulator, SumcheckId}, }; + #[cfg(feature = "zk")] + use crate::subprotocols::blindfold::StageWitness; use crate::zkvm::claim_reductions::AdviceKind; use crate::zkvm::verifier::JoltSharedPreprocessing; use crate::zkvm::witness::CommittedPolynomial; @@ -2136,7 +2139,7 @@ mod tests { #[cfg(feature = "zk")] fn round_commitment_data( gens: &PedersenGenerators, - stages: &[crate::subprotocols::blindfold::StageWitness], + stages: &[StageWitness], rng: &mut R, ) -> (Vec, Vec>, Vec) { let mut commitments = Vec::new(); diff --git a/jolt-core/src/zkvm/r1cs/constraints.rs b/jolt-core/src/zkvm/r1cs/constraints.rs index 5c69edf023..5722ec08ec 100644 --- a/jolt-core/src/zkvm/r1cs/constraints.rs +++ b/jolt-core/src/zkvm/r1cs/constraints.rs @@ -43,6 +43,12 @@ use strum::EnumCount as EnumCountTrait; use strum_macros::{EnumCount, EnumIter}; pub use super::ops::{Term, LC}; +#[cfg(test)] +use crate::field::JoltField; +#[cfg(test)] +use crate::poly::multilinear_polynomial::MultilinearPolynomial; +#[cfg(test)] +use std::fmt::Write as _; /// A single R1CS constraint row #[derive(Clone, Copy, Debug)] @@ -59,14 +65,12 @@ impl R1CSConstraint { } #[cfg(test)] - pub fn pretty_fmt( + pub fn pretty_fmt( &self, f: &mut String, - flattened_polynomials: &[crate::poly::multilinear_polynomial::MultilinearPolynomial], + flattened_polynomials: &[MultilinearPolynomial], step_index: usize, ) -> std::fmt::Result { - use std::fmt::Write as _; - self.a.pretty_fmt(f)?; write!(f, " ⋅ ")?; self.b.pretty_fmt(f)?; @@ -107,8 +111,6 @@ impl R1CSConstraint { f: &mut String, row: &super::inputs::R1CSCycleInputs, ) -> std::fmt::Result { - use std::fmt::Write as _; - self.a.pretty_fmt(f)?; write!(f, " ⋅ ")?; self.b.pretty_fmt(f)?; @@ -181,14 +183,12 @@ pub struct NamedR1CSConstraint { impl NamedR1CSConstraint { #[cfg(test)] - pub fn pretty_fmt( + pub fn pretty_fmt( &self, f: &mut String, - flattened_polynomials: &[crate::poly::multilinear_polynomial::MultilinearPolynomial], + flattened_polynomials: &[MultilinearPolynomial], step_index: usize, ) -> std::fmt::Result { - use std::fmt::Write as _; - writeln!(f, "[{:?}]", self.label)?; self.cons.pretty_fmt(f, flattened_polynomials, step_index) } @@ -199,8 +199,6 @@ impl NamedR1CSConstraint { f: &mut String, row: &super::inputs::R1CSCycleInputs, ) -> std::fmt::Result { - use std::fmt::Write as _; - writeln!(f, "[{:?}]", self.label)?; self.cons.pretty_fmt_with_row(f, row) } diff --git a/jolt-core/src/zkvm/r1cs/inputs.rs b/jolt-core/src/zkvm/r1cs/inputs.rs index ab4dcdf817..ede64fa208 100644 --- a/jolt-core/src/zkvm/r1cs/inputs.rs +++ b/jolt-core/src/zkvm/r1cs/inputs.rs @@ -24,7 +24,7 @@ use crate::field::JoltField; use ark_ff::biginteger::S64; use common::constants::XLEN; use std::fmt::Debug; -use tracer::instruction::Cycle; +use tracer::instruction::{Cycle, RAMAccess}; use strum::IntoEnumIterator; @@ -302,9 +302,9 @@ impl R1CSCycleInputs { // RAM let ram_addr = cycle.ram_access().address() as u64; let (ram_read_value, ram_write_value) = match cycle.ram_access() { - tracer::instruction::RAMAccess::Read(r) => (r.value, r.value), - tracer::instruction::RAMAccess::Write(w) => (w.pre_value, w.post_value), - tracer::instruction::RAMAccess::NoOp => (0u64, 0u64), + RAMAccess::Read(r) => (r.value, r.value), + RAMAccess::Write(w) => (w.pre_value, w.post_value), + RAMAccess::NoOp => (0u64, 0u64), }; // PCs diff --git a/jolt-core/src/zkvm/spartan/outer.rs b/jolt-core/src/zkvm/spartan/outer.rs index d04a0fde0b..f8e7ca90c5 100644 --- a/jolt-core/src/zkvm/spartan/outer.rs +++ b/jolt-core/src/zkvm/spartan/outer.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "zk")] +use std::collections::BTreeSet; use std::marker::PhantomData; use std::sync::Arc; @@ -441,7 +443,6 @@ impl SumcheckInstanceParams for OuterRemainingSumcheckParams // Build structural template by iterating R1CS constraints to find // which input indices appear in A-sides and B-sides. - use std::collections::BTreeSet; let mut a_indices = BTreeSet::new(); let mut b_indices = BTreeSet::new(); let mut a_has_const = false; @@ -578,8 +579,6 @@ impl SumcheckInstanceParams for OuterRemainingSumcheckParams #[cfg(feature = "zk")] fn output_constraint_challenge_values(&self, sumcheck_challenges: &[F::Challenge]) -> Vec { - use std::collections::BTreeSet; - let r_stream = sumcheck_challenges[0]; // Lagrange weights at r0 diff --git a/jolt-core/src/zkvm/verifier.rs b/jolt-core/src/zkvm/verifier.rs index 413a9413cd..12bedd274d 100644 --- a/jolt-core/src/zkvm/verifier.rs +++ b/jolt-core/src/zkvm/verifier.rs @@ -4,6 +4,11 @@ use std::io::{Read, Write}; use std::path::Path; use std::sync::Arc; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use common::jolt_device::MemoryLayout; +use tracer::instruction::Instruction; +use tracer::JoltDevice; + use crate::curve::JoltCurve; use crate::poly::commitment::commitment_scheme::{CommitmentScheme, ZkEvalCommitment}; #[cfg(feature = "zk")] @@ -11,6 +16,8 @@ use crate::poly::commitment::dory::bind_opening_inputs_zk; use crate::poly::commitment::dory::{bind_opening_inputs, DoryContext, DoryGlobals}; #[cfg(feature = "zk")] use crate::poly::lagrange_poly::LagrangeHelper; +#[cfg(not(feature = "zk"))] +use crate::poly::opening_proof::{OpeningPoint, BIG_ENDIAN}; #[cfg(feature = "zk")] use crate::subprotocols::blindfold::{ pedersen_generator_count_for_r1cs, BakedPublicInputs, BlindFoldVerifier, @@ -184,11 +191,6 @@ fn scale_batching_coefficients( }) .collect() } -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use common::jolt_device::MemoryLayout; -use tracer::instruction::Instruction; -use tracer::JoltDevice; - pub struct JoltVerifier< 'a, F: JoltField, @@ -268,7 +270,6 @@ where #[cfg(not(feature = "zk"))] { - use crate::poly::opening_proof::{OpeningPoint, BIG_ENDIAN}; for (id, (_, claim)) in &proof.opening_claims.0 { let dummy_point = OpeningPoint::::new(vec![]); opening_accumulator @@ -1003,14 +1004,11 @@ where let num_rounds = proof.num_rounds(); for round_idx in 0..num_rounds { let poly_degree = match proof { - crate::subprotocols::sumcheck::SumcheckInstanceProof::Clear(std_proof) => { - std_proof.compressed_polys[round_idx] - .coeffs_except_linear_term - .len() - } - crate::subprotocols::sumcheck::SumcheckInstanceProof::Zk(zk_proof) => { - zk_proof.poly_degrees[round_idx] - } + SumcheckInstanceProof::Clear(std_proof) => std_proof.compressed_polys + [round_idx] + .coeffs_except_linear_term + .len(), + SumcheckInstanceProof::Zk(zk_proof) => zk_proof.poly_degrees[round_idx], }; let starts_new_chain = round_idx == 0; let config = if starts_new_chain { From 2d4d0b58057a14fcf118b98bba2764be3eac6b06 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Tue, 3 Mar 2026 17:15:57 -0800 Subject: [PATCH 3/3] fix(z3-verifier): remove Product from R1CS inputs and VirtualPolynomial match Product was removed from R1CS inputs (hoisted to outer sumcheck) and VirtualPolynomial::Product variant was deleted. Update z3-verifier to match: drop &self.product from r1cs_inputs() array and remove the Product match arm from virtpoly_to_int(). Made-with: Cursor --- z3-verifier/src/cpu_constraints.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/z3-verifier/src/cpu_constraints.rs b/z3-verifier/src/cpu_constraints.rs index 7b2f8a43c4..e51aeea6e8 100644 --- a/z3-verifier/src/cpu_constraints.rs +++ b/z3-verifier/src/cpu_constraints.rs @@ -152,7 +152,6 @@ impl JoltState { [ &self.left_input, &self.right_input, - &self.product, &self.write_lookup_output_to_rd, &self.write_pc_to_rd, &self.should_branch, @@ -213,7 +212,6 @@ impl JoltState { match poly { VirtualPolynomial::LeftInstructionInput => &self.left_input, VirtualPolynomial::RightInstructionInput => &self.right_input, - VirtualPolynomial::Product => &self.product, VirtualPolynomial::InstructionFlags(InstructionFlags::IsRdNotZero) => { &self.instruction_flags[InstructionFlags::IsRdNotZero as usize] }