diff --git a/jolt-core/src/poly/split_eq_poly.rs b/jolt-core/src/poly/split_eq_poly.rs index f93f17af60..fe5a7b61c0 100644 --- a/jolt-core/src/poly/split_eq_poly.rs +++ b/jolt-core/src/poly/split_eq_poly.rs @@ -410,6 +410,65 @@ impl GruenSplitEqPolynomial { ]) } + /// Compute the quartic polynomial s(X) = l(X) * q(X), where l(X) is the + /// current (linear) eq polynomial and q(X) = c + dX + eX^2 + fX^3, given: + /// - c, the constant term of q (i.e. q(0)) + /// - q(2), the evaluation of q at 2 + /// - f, the leading (cubic) coefficient of q + /// - the previous round claim, s(0) + s(1) + pub fn gruen_poly_deg_4( + &self, + q_constant: F, + q_at_2: F, + q_cubic_coeff: F, + s_0_plus_s_1: F, + ) -> UniPoly { + // l(X) = a + bX (linear eq polynomial) + // q(X) = c + dX + eX^2 + fX^3 (cubic inner polynomial) + // s(X) = l(X)*q(X) is degree 4, needs evaluations at {0, 1, 2, 3, 4}. + + let eq_eval_1 = self.current_scalar + * match self.binding_order { + BindingOrder::LowToHigh => self.w[self.current_index - 1], + BindingOrder::HighToLow => self.w[self.current_index], + }; + let eq_eval_0 = self.current_scalar - eq_eval_1; + let eq_m = eq_eval_1 - eq_eval_0; + let eq_eval_2 = eq_eval_1 + eq_m; + let eq_eval_3 = eq_eval_2 + eq_m; + let eq_eval_4 = eq_eval_3 + eq_m; + + let c = q_constant; + let f = q_cubic_coeff; + + // Recover q(1) from the sumcheck identity: s(0) + s(1) = claim + let quartic_eval_0 = eq_eval_0 * c; + let quartic_eval_1 = s_0_plus_s_1 - quartic_eval_0; + let q_1 = quartic_eval_1 / eq_eval_1; + + // q(0) = c, q(1) = c+d+e+f, q(2) = c+2d+4e+8f + // Forward differences: Δ0 = q(1)-q(0) = d+e+f, Δ1 = q(2)-q(1) + // Second differences: Δ²0 = Δ1-Δ0 = 2e+6f + // Third differences: Δ³ = 6f (constant for a cubic) + let delta_0 = q_1 - c; + let delta_1 = q_at_2 - q_1; + let delta2_0 = delta_1 - delta_0; + let f6 = f + f + f + f + f + f; // 6f + // q(3) = q(2) + Δ1 + Δ²0 + Δ³ = q(2) + (Δ1 + Δ²0 + 6f) + let delta_2 = delta_1 + delta2_0 + f6; + let q_3 = q_at_2 + delta_2; + // q(4) = q(3) + Δ2 + Δ²0 + 2*Δ³ = q(3) + (delta_2 + delta2_0 + 6f + 6f) + let q_4 = q_3 + delta_2 + delta2_0 + f6 + f6; + + UniPoly::from_evals(&[ + quartic_eval_0, + quartic_eval_1, + eq_eval_2 * q_at_2, + eq_eval_3 * q_3, + eq_eval_4 * q_4, + ]) + } + /// Compute the quadratic polynomial s(X) = l(X) * q(X), where l(X) is the /// current (linear) Dao-Thaler eq polynomial and q(X) = c + dx /// - c, the constant term of q @@ -795,8 +854,6 @@ mod tests { /// Verify that evals_cached returns [1] at index 0 (eq over 0 vars). #[test] fn evals_cached_starts_with_one() { - use crate::poly::eq_poly::EqPolynomial; - let mut rng = test_rng(); for num_vars in 1..=10 { let w: Vec<::Challenge> = diff --git a/jolt-core/src/subprotocols/streaming_sumcheck.rs b/jolt-core/src/subprotocols/streaming_sumcheck.rs index a6a51230f5..4e26319a23 100644 --- a/jolt-core/src/subprotocols/streaming_sumcheck.rs +++ b/jolt-core/src/subprotocols/streaming_sumcheck.rs @@ -17,7 +17,7 @@ pub trait StreamingSumcheckWindow: Sized + MaybeAllocative + Send fn initialize(shared: &mut Self::Shared, window_size: usize) -> Self; fn compute_message( - &self, + &mut self, shared: &Self::Shared, window_size: usize, previous_claim: F, @@ -39,7 +39,7 @@ pub trait LinearSumcheckStage: Sized + MaybeAllocative + Send + Sy fn next_window(&mut self, shared: &mut Self::Shared, window_size: usize); fn compute_message( - &self, + &mut self, shared: &Self::Shared, window_size: usize, previous_claim: F, @@ -170,9 +170,9 @@ where } } - if let Some(streaming) = &mut self.streaming { + if let Some(streaming) = self.streaming.as_mut() { streaming.compute_message(&self.shared, num_unbound_vars, previous_claim) - } else if let Some(linear) = &mut self.linear { + } else if let Some(linear) = self.linear.as_mut() { linear.compute_message(&self.shared, num_unbound_vars, previous_claim) } else { unreachable!() diff --git a/jolt-core/src/utils/accumulation.rs b/jolt-core/src/utils/accumulation.rs index 8f4bb142bc..d71d154e20 100644 --- a/jolt-core/src/utils/accumulation.rs +++ b/jolt-core/src/utils/accumulation.rs @@ -583,6 +583,16 @@ impl FMAdd for WideAccumS { } } +impl FMAdd for WideAccumS { + #[inline(always)] + fn fmadd(&mut self, field: &F, other: &u64) { + if *other == 0 { + return; + } + self.pos += (*field).mul_u64_unreduced(*other); + } +} + impl BarrettReduce for WideAccumS { #[inline(always)] fn barrett_reduce(&self) -> F { @@ -817,6 +827,18 @@ pub struct S192Sum { pub sum: S192, } +impl FMAdd for S192Sum { + #[inline(always)] + fn fmadd(&mut self, c: &i32, term: &u64) { + if *term == 0 { + return; + } + let c_s64 = S64::from(*c as i64); + let v_s64 = S64::from(*term); + self.sum += c_s64.mul_trunc::<1, 3>(&v_s64); + } +} + impl FMAdd for S192Sum { #[inline(always)] fn fmadd(&mut self, c: &i32, term: &S64) { diff --git a/jolt-core/src/zkvm/bytecode/read_raf_checking.rs b/jolt-core/src/zkvm/bytecode/read_raf_checking.rs index d128f1fe5c..38c4c057e3 100644 --- a/jolt-core/src/zkvm/bytecode/read_raf_checking.rs +++ b/jolt-core/src/zkvm/bytecode/read_raf_checking.rs @@ -57,7 +57,8 @@ const N_STAGES: usize = 5; /// Bytecode instruction: multi-stage Read + RAF sumcheck (N_STAGES = 5). /// /// Stages virtualize different claim families (Stage1: Spartan outer; Stage2: product-virtualized -/// flags; Stage3: Shift; Stage4: Registers RW; Stage5: Registers val-eval + Instruction lookups). +/// flags + instruction input flags; Stage3: Shift flags only; Stage4: Registers RW; +/// Stage5: Registers val-eval + Instruction lookups). /// /// The input claim is a γ-weighted RLC of stage rv_claims plus RAF contributions folded into /// stages 1 and 3 via the identity polynomial. Address vars are bound in `d` chunks; cycle vars @@ -78,11 +79,11 @@ const N_STAGES: usize = 5; /// - Int(k) = 1 for all k (evaluation of the IdentityPolynomial over address variables). /// - Define per-stage Val_s(k) (address-only) as implemented by `compute_val_*`: /// * Stage1: Val_1(k) = unexpanded_pc(k) + β_1·imm(k) + Σ_t β_1^{2+t}·circuit_flag_t(k). -/// * Stage2: Val_2(k) = 1_{jump}(k) + β_2·1_{branch}(k) + β_2^2·rd_addr(k) + β_2^3·1_{write_lookup_to_rd}(k) -/// + β_2^4·1_{VirtualInstruction}(k). -/// * Stage3: Val_3(k) = imm(k) + β_3·unexpanded_pc(k) + β_3^2·1_{L_is_rs1}(k) + β_3^3·1_{L_is_pc}(k) -/// + β_3^4·1_{R_is_rs2}(k) + β_3^5·1_{R_is_imm}(k) + β_3^6·1_{IsNoop}(k) -/// + β_3^7·1_{VirtualInstruction}(k) + β_3^8·1_{IsFirstInSequence}(k). +/// * Stage2: Val_2(k) = 1_{jump}(k) + β_2·1_{branch}(k) + β_2^2·1_{rd≠0}(k) + β_2^3·1_{write_lookup_to_rd}(k) +/// + β_2^4·1_{VirtualInstruction}(k) + β_2^5·imm(k) + β_2^6·unexpanded_pc(k) +/// + β_2^7·1_{L_is_rs1}(k) + β_2^8·1_{L_is_pc}(k) + β_2^9·1_{R_is_rs2}(k) + β_2^10·1_{R_is_imm}(k). +/// * Stage3: Val_3(k) = unexpanded_pc(k) + β_3·1_{IsNoop}(k) + β_3^2·1_{VirtualInstruction}(k) +/// + β_3^3·1_{IsFirstInSequence}(k). /// * Stage4: Val_4(k) = 1_{rd=r}(k) + β_4·1_{rs1=r}(k) + β_4^2·1_{rs2=r}(k), where r is fixed by opening. /// * Stage5: Val_5(k) = 1_{rd=r}(k) + β_5·1_{¬interleaved}(k) + Σ_i β_5^{2+i}·1_{table=i}(k). /// @@ -750,8 +751,8 @@ impl BytecodeReadRafSumcheckParams { // Generate all stage-specific gamma powers upfront (order must match verifier) let stage1_gammas: Vec = transcript.challenge_scalar_powers(2 + NUM_CIRCUIT_FLAGS); - let stage2_gammas: Vec = transcript.challenge_scalar_powers(5); - let stage3_gammas: Vec = transcript.challenge_scalar_powers(9); + let stage2_gammas: Vec = transcript.challenge_scalar_powers(11); + let stage3_gammas: Vec = transcript.challenge_scalar_powers(4); let stage4_gammas: Vec = transcript.challenge_scalar_powers(3); let stage5_gammas: Vec = transcript.challenge_scalar_powers(2 + NUM_LOOKUP_TABLES); @@ -923,15 +924,7 @@ impl BytecodeReadRafSumcheckParams { *o0 = lc; } - // Stage 2 (product virtualization, de-duplicated factors) - // Val(k) = jump_flag(k) + γ·branch_flag(k) - // + γ²·is_rd_not_zero_flag(k) + γ³·write_lookup_output_to_rd_flag(k) - // where jump_flag(k) = 1 if instruction k is a jump, 0 otherwise; - // branch_flag(k) = 1 if instruction k is a branch, 0 otherwise; - // is_rd_not_zero_flag(k) = 1 if instruction k has rd != 0; - // write_lookup_output_to_rd_flag(k) = 1 if instruction k writes lookup output to rd. - // virtual_instruction(k) = 1 if instruction k is a virtual instruction. - // This Val matches the fused product sumcheck. + // Stage 2 (product virtualization + instruction input flags) { let mut lc = F::zero(); if circuit_flags[CircuitFlags::Jump] { @@ -949,38 +942,34 @@ impl BytecodeReadRafSumcheckParams { if circuit_flags[CircuitFlags::VirtualInstruction] { lc += stage2_gammas[4]; } - *o1 = lc; - } - - // Stage 3 (Shift sumcheck) - // Val(k) = imm(k) + γ·unexpanded_pc(k) - // + γ²·left_operand_is_rs1_value(k) + γ³·left_operand_is_pc(k) - // + γ⁴·right_operand_is_rs2_value(k) + γ⁵·right_operand_is_imm(k) - // + γ⁶·is_noop(k) + γ⁷·virtual_instruction(k) + γ⁸·is_first_in_sequence(k) - // This virtualizes claims output by the ShiftSumcheck. - { - let mut lc = F::from_i128(instr.operands.imm); - lc += stage3_gammas[1].mul_u64(instr.address as u64); + lc += instr.operands.imm.field_mul(stage2_gammas[5]); + lc += stage2_gammas[6].mul_u64(instr.address as u64); if instr_flags[InstructionFlags::LeftOperandIsRs1Value] { - lc += stage3_gammas[2]; + lc += stage2_gammas[7]; } if instr_flags[InstructionFlags::LeftOperandIsPC] { - lc += stage3_gammas[3]; + lc += stage2_gammas[8]; } if instr_flags[InstructionFlags::RightOperandIsRs2Value] { - lc += stage3_gammas[4]; + lc += stage2_gammas[9]; } if instr_flags[InstructionFlags::RightOperandIsImm] { - lc += stage3_gammas[5]; + lc += stage2_gammas[10]; } + *o1 = lc; + } + + // Stage 3 (Shift flags only — InstructionInput flags moved to Stage 2) + { + let mut lc = F::from_u64(instr.address as u64); if instr_flags[InstructionFlags::IsNoop] { - lc += stage3_gammas[6]; + lc += stage3_gammas[1]; } if circuit_flags[CircuitFlags::VirtualInstruction] { - lc += stage3_gammas[7]; + lc += stage3_gammas[2]; } if circuit_flags[CircuitFlags::IsFirstInSequence] { - lc += stage3_gammas[8]; + lc += stage3_gammas[3]; } *o2 = lc; } @@ -1089,45 +1078,14 @@ impl BytecodeReadRafSumcheckParams { VirtualPolynomial::OpFlags(CircuitFlags::VirtualInstruction), SumcheckId::SpartanProductVirtualization, ); - - [ - jump_claim, - branch_claim, - rd_wa_claim, - write_lookup_output_to_rd_flag_claim, - virtual_instruction_claim, - ] - .into_iter() - .zip_eq(gamma_powers) - .map(|(claim, gamma)| claim * gamma) - .sum() - } - - fn compute_rv_claim_3( - opening_accumulator: &dyn OpeningAccumulator, - gamma_powers: &[F], - ) -> F { let (_, imm_claim) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::Imm, SumcheckId::InstructionInputVirtualization, ); - let (_, spartan_shift_unexpanded_pc_claim) = opening_accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::UnexpandedPC, - SumcheckId::SpartanShift, - ); - let (_, instruction_input_unexpanded_pc_claim) = opening_accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::UnexpandedPC, - SumcheckId::InstructionInputVirtualization, - ); - - assert_eq!( - spartan_shift_unexpanded_pc_claim, - instruction_input_unexpanded_pc_claim + let (_, unexpanded_pc_claim) = opening_accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::UnexpandedPC, + SumcheckId::InstructionInputVirtualization, ); - - let unexpanded_pc_claim = spartan_shift_unexpanded_pc_claim; let (_, left_is_rs1_claim) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::InstructionFlags(InstructionFlags::LeftOperandIsRs1Value), SumcheckId::InstructionInputVirtualization, @@ -1144,6 +1102,34 @@ impl BytecodeReadRafSumcheckParams { VirtualPolynomial::InstructionFlags(InstructionFlags::RightOperandIsImm), SumcheckId::InstructionInputVirtualization, ); + + [ + jump_claim, + branch_claim, + rd_wa_claim, + write_lookup_output_to_rd_flag_claim, + virtual_instruction_claim, + imm_claim, + unexpanded_pc_claim, + left_is_rs1_claim, + left_is_pc_claim, + right_is_rs2_claim, + right_is_imm_claim, + ] + .into_iter() + .zip_eq(gamma_powers) + .map(|(claim, gamma)| claim * gamma) + .sum() + } + + fn compute_rv_claim_3( + opening_accumulator: &dyn OpeningAccumulator, + gamma_powers: &[F], + ) -> F { + let (_, unexpanded_pc_claim) = opening_accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::UnexpandedPC, + SumcheckId::SpartanShift, + ); let (_, is_noop_claim) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::InstructionFlags(InstructionFlags::IsNoop), SumcheckId::SpartanShift, @@ -1158,12 +1144,7 @@ impl BytecodeReadRafSumcheckParams { ); [ - imm_claim, unexpanded_pc_claim, - left_is_rs1_claim, - left_is_pc_claim, - right_is_rs2_claim, - right_is_imm_claim, is_noop_claim, is_virtual_claim, is_first_in_sequence_claim, @@ -1252,8 +1233,8 @@ impl SumcheckInstanceParams for BytecodeReadRafSumcheckParams SumcheckInstanceParams for BytecodeReadRafSumcheckParams SumcheckInstanceParams for BytecodeReadRafSumcheckParams SumcheckInstanceParams for BytecodeReadRafSumcheckParams SumcheckInstanceParams for BytecodeReadRafSumcheckParams SumcheckInstanceParams for BytecodeReadRafSumcheckParams { /// γ^0, γ^1, ..., γ^{3N-1} for batching (3 claims per ra polynomial) @@ -147,7 +147,7 @@ pub struct HammingWeightClaimReductionParams { } impl HammingWeightClaimReductionParams { - /// Create params by fetching claims from Stage 6 and sampling batching challenge. + /// Create params by fetching claims from Stage 5 and sampling batching challenge. /// /// Fetches: /// - HammingWeight claims (from HammingBooleanity virtual polynomial) @@ -207,7 +207,7 @@ impl HammingWeightClaimReductionParams { let mut claims_bool = Vec::with_capacity(N); let mut claims_virt = Vec::with_capacity(N); - // RAM HammingWeight factor: now in Stage 6, so shares r_cycle_stage6 + // RAM HammingWeight factor: now in Stage 5, so shares r_cycle_stage5 let ram_hw_factor = accumulator .get_virtual_polynomial_opening( VirtualPolynomial::RamHammingWeight, diff --git a/jolt-core/src/zkvm/claim_reductions/increments.rs b/jolt-core/src/zkvm/claim_reductions/increments.rs index 6a9415822c..a42a995b55 100644 --- a/jolt-core/src/zkvm/claim_reductions/increments.rs +++ b/jolt-core/src/zkvm/claim_reductions/increments.rs @@ -16,8 +16,8 @@ //! So effectively RamInc has **2 distinct opening points**. //! //! 2. **RdInc**: Claims are emitted from: -//! - `RegistersReadWriteChecking` (Stage 4): opened at `s_cycle_stage4` -//! - `RegistersValEvaluation` (Stage 5): opened at `s_cycle_stage5` +//! - `RegistersReadWriteChecking` (Stage 3): opened at `s_cycle_stage3` +//! - `RegistersValEvaluation` (Stage 4): opened at `s_cycle_stage4` //! //! So RdInc has **2 distinct opening points**. //! @@ -26,20 +26,20 @@ //! Let: //! - v_1 = RamInc(r_cycle_stage2) from RamReadWriteChecking //! - v_2 = RamInc(r_cycle_stage4) from RamValCheck -//! - w_1 = RdInc(s_cycle_stage4) from RegistersReadWriteChecking -//! - w_2 = RdInc(s_cycle_stage5) from RegistersValEvaluation +//! - w_1 = RdInc(s_cycle_stage3) from RegistersReadWriteChecking +//! - w_2 = RdInc(s_cycle_stage4) from RegistersValEvaluation //! //! Input claim: //! v_1 + γ·v_2 + γ²·w_1 + γ³·w_2 //! //! Sumcheck proves (over log T rounds): //! Σ_j RamInc(j) · [eq(r_cycle_stage2, j) + γ·eq(r_cycle_stage4, j)] -//! + γ² · Σ_j RdInc(j) · [eq(s_cycle_stage4, j) + γ·eq(s_cycle_stage5, j)] +//! + γ² · Σ_j RdInc(j) · [eq(s_cycle_stage3, j) + γ·eq(s_cycle_stage4, j)] //! = v_1 + γ·v_2 + γ²·w_1 + γ³·w_2 //! //! After log T rounds with sumcheck challenges ρ, the final claim is: //! RamInc(ρ) · [eq(r_cycle_stage2, ρ) + γ·eq(r_cycle_stage4, ρ)] -//! + γ² · RdInc(ρ) · [eq(s_cycle_stage4, ρ) + γ·eq(s_cycle_stage5, ρ)] +//! + γ² · RdInc(ρ) · [eq(s_cycle_stage3, ρ) + γ·eq(s_cycle_stage4, ρ)] //! //! The verifier computes the eq terms and recovers two openings at the SAME point ρ: //! - RamInc(ρ) @@ -82,8 +82,8 @@ pub struct IncClaimReductionSumcheckParams { pub n_cycle_vars: usize, pub r_cycle_stage2: OpeningPoint, // RamInc from RamReadWriteChecking pub r_cycle_stage4: OpeningPoint, // RamInc from RamValCheck - pub s_cycle_stage4: OpeningPoint, // RdInc from RegistersReadWriteChecking - pub s_cycle_stage5: OpeningPoint, // RdInc from RegistersValEvaluation + pub s_cycle_stage3: OpeningPoint, // RdInc from RegistersReadWriteChecking + pub s_cycle_stage4: OpeningPoint, // RdInc from RegistersValEvaluation } impl IncClaimReductionSumcheckParams { @@ -104,11 +104,11 @@ impl IncClaimReductionSumcheckParams { let (r_cycle_stage4, _) = accumulator .get_committed_polynomial_opening(CommittedPolynomial::RamInc, SumcheckId::RamValCheck); - let (s_cycle_stage4, _) = accumulator.get_committed_polynomial_opening( + let (s_cycle_stage3, _) = accumulator.get_committed_polynomial_opening( CommittedPolynomial::RdInc, SumcheckId::RegistersReadWriteChecking, ); - let (s_cycle_stage5, _) = accumulator.get_committed_polynomial_opening( + let (s_cycle_stage4, _) = accumulator.get_committed_polynomial_opening( CommittedPolynomial::RdInc, SumcheckId::RegistersValEvaluation, ); @@ -118,8 +118,8 @@ impl IncClaimReductionSumcheckParams { n_cycle_vars: trace_len.log_2(), r_cycle_stage2, r_cycle_stage4, + s_cycle_stage3, s_cycle_stage4, - s_cycle_stage5, } } } @@ -204,11 +204,11 @@ impl SumcheckInstanceParams for IncClaimReductionSumcheckParams let eq_r2: F = EqPolynomial::mle(&opening_point.r, &self.r_cycle_stage2.r); let eq_r4: F = EqPolynomial::mle(&opening_point.r, &self.r_cycle_stage4.r); + let eq_s3: F = EqPolynomial::mle(&opening_point.r, &self.s_cycle_stage3.r); let eq_s4: F = EqPolynomial::mle(&opening_point.r, &self.s_cycle_stage4.r); - let eq_s5: F = EqPolynomial::mle(&opening_point.r, &self.s_cycle_stage5.r); let eq_ram_combined = eq_r2 + gamma * eq_r4; - let eq_rd_combined = eq_s4 + gamma * eq_s5; + let eq_rd_combined = eq_s3 + gamma * eq_s4; vec![eq_ram_combined, gamma_sqr * eq_rd_combined] } @@ -317,8 +317,8 @@ struct IncClaimReductionPhase1State { // P_ram[0] = eq(r_cycle_stage2_lo, ·) // P_ram[1] = eq(r_cycle_stage4_lo, ·) P_ram: [MultilinearPolynomial; 2], - // P_rd[0] = eq(s_cycle_stage4_lo, ·) - // P_rd[1] = eq(s_cycle_stage5_lo, ·) + // P_rd[0] = eq(s_cycle_stage3_lo, ·) + // P_rd[1] = eq(s_cycle_stage4_lo, ·) P_rd: [MultilinearPolynomial; 2], // Q buffers: suffix-weighted polynomial evaluations @@ -345,20 +345,20 @@ impl IncClaimReductionPhase1State { // Big-endian: hi is first half, lo is second half let (r2_hi, r2_lo) = params.r_cycle_stage2.split_at(suffix_n_vars); let (r4_hi, r4_lo) = params.r_cycle_stage4.split_at(suffix_n_vars); + let (s3_hi, s3_lo) = params.s_cycle_stage3.split_at(suffix_n_vars); let (s4_hi, s4_lo) = params.s_cycle_stage4.split_at(suffix_n_vars); - let (s5_hi, s5_lo) = params.s_cycle_stage5.split_at(suffix_n_vars); // P buffers: prefix eq evaluations let P_ram_0 = EqPolynomial::evals(&r2_lo.r); let P_ram_1 = EqPolynomial::evals(&r4_lo.r); - let P_rd_0 = EqPolynomial::evals(&s4_lo.r); - let P_rd_1 = EqPolynomial::evals(&s5_lo.r); + let P_rd_0 = EqPolynomial::evals(&s3_lo.r); + let P_rd_1 = EqPolynomial::evals(&s4_lo.r); // Suffix eq evaluations (for computing Q) let eq_r2_hi = EqPolynomial::evals(&r2_hi.r); let eq_r4_hi = EqPolynomial::evals(&r4_hi.r); + let eq_s3_hi = EqPolynomial::evals(&s3_hi.r); let eq_s4_hi = EqPolynomial::evals(&s4_hi.r); - let eq_s5_hi = EqPolynomial::evals(&s5_hi.r); // Q buffers: sum over suffix indices let mut Q_ram_0 = unsafe_allocate_zero_vec(prefix_len); @@ -404,8 +404,8 @@ impl IncClaimReductionPhase1State { acc_ram_0.fmadd(&eq_r2_hi[x_hi], &ram_inc); acc_ram_1.fmadd(&eq_r4_hi[x_hi], &ram_inc); - acc_rd_0.fmadd(&eq_s4_hi[x_hi], &rd_inc); - acc_rd_1.fmadd(&eq_s5_hi[x_hi], &rd_inc); + acc_rd_0.fmadd(&eq_s3_hi[x_hi], &rd_inc); + acc_rd_1.fmadd(&eq_s4_hi[x_hi], &rd_inc); } q_ram_0[i] = acc_ram_0.barrett_reduce(); @@ -506,7 +506,7 @@ struct IncClaimReductionPhase2State { rd_inc: MultilinearPolynomial, // Combined eq polynomials eq_ram: MultilinearPolynomial, // eq(r_stage2, ·) + γ·eq(r_stage4, ·) - eq_rd: MultilinearPolynomial, // eq(s_stage4, ·) + γ·eq(s_stage5, ·) + eq_rd: MultilinearPolynomial, // eq(s_stage3, ·) + γ·eq(s_stage4, ·) } impl IncClaimReductionPhase2State { @@ -528,21 +528,21 @@ impl IncClaimReductionPhase2State { // Compute eq evaluations for prefix bound let (_, r2_lo) = params.r_cycle_stage2.split_at(n_vars - prefix_n_vars); let (_, r4_lo) = params.r_cycle_stage4.split_at(n_vars - prefix_n_vars); + let (_, s3_lo) = params.s_cycle_stage3.split_at(n_vars - prefix_n_vars); let (_, s4_lo) = params.s_cycle_stage4.split_at(n_vars - prefix_n_vars); - let (_, s5_lo) = params.s_cycle_stage5.split_at(n_vars - prefix_n_vars); let eq_r2_prefix = EqPolynomial::mle_endian(&r_prefix, &r2_lo); let eq_r4_prefix = EqPolynomial::mle_endian(&r_prefix, &r4_lo); + let eq_s3_prefix = EqPolynomial::mle_endian(&r_prefix, &s3_lo); let eq_s4_prefix = EqPolynomial::mle_endian(&r_prefix, &s4_lo); - let eq_s5_prefix = EqPolynomial::mle_endian(&r_prefix, &s5_lo); // Suffix eq evaluations scaled by prefix contributions let (r2_hi, _) = params.r_cycle_stage2.split_at(n_vars - prefix_n_vars); let (r4_hi, _) = params.r_cycle_stage4.split_at(n_vars - prefix_n_vars); + let (s3_hi, _) = params.s_cycle_stage3.split_at(n_vars - prefix_n_vars); let (s4_hi, _) = params.s_cycle_stage4.split_at(n_vars - prefix_n_vars); - let (s5_hi, _) = params.s_cycle_stage5.split_at(n_vars - prefix_n_vars); - // Combined eq polynomials: eq_ram = eq_r2 + γ·eq_r4, eq_rd = eq_s4 + γ·eq_s5 + // Combined eq polynomials: eq_ram = eq_r2 + γ·eq_r4, eq_rd = eq_s3 + γ·eq_s4 let (eq_ram, eq_rd) = rayon::join( || { let (eq_r2, eq_r4) = rayon::join( @@ -556,13 +556,13 @@ impl IncClaimReductionPhase2State { .collect::>() }, || { - let (eq_s4, eq_s5) = rayon::join( + let (eq_s3, eq_s4) = rayon::join( + || EqPolynomial::evals_serial(&s3_hi.r, Some(eq_s3_prefix)), || EqPolynomial::evals_serial(&s4_hi.r, Some(eq_s4_prefix)), - || EqPolynomial::evals_serial(&s5_hi.r, Some(eq_s5_prefix)), ); - eq_s4 + eq_s3 .par_iter() - .zip(eq_s5.par_iter()) + .zip(eq_s4.par_iter()) .map(|(e4, e5)| *e4 + gamma * e5) .collect::>() }, @@ -706,11 +706,11 @@ impl SumcheckInstanceVerifier // Compute eq evaluations at final point let eq_r2 = EqPolynomial::mle(&opening_point.r, &self.params.r_cycle_stage2.r); let eq_r4 = EqPolynomial::mle(&opening_point.r, &self.params.r_cycle_stage4.r); + let eq_s3 = EqPolynomial::mle(&opening_point.r, &self.params.s_cycle_stage3.r); let eq_s4 = EqPolynomial::mle(&opening_point.r, &self.params.s_cycle_stage4.r); - let eq_s5 = EqPolynomial::mle(&opening_point.r, &self.params.s_cycle_stage5.r); let eq_ram_combined = eq_r2 + gamma * eq_r4; - let eq_rd_combined = eq_s4 + gamma * eq_s5; + let eq_rd_combined = eq_s3 + gamma * eq_s4; // Fetch final claims from accumulator let (_, ram_inc_claim) = accumulator.get_committed_polynomial_opening( diff --git a/jolt-core/src/zkvm/claim_reductions/instruction_lookups.rs b/jolt-core/src/zkvm/claim_reductions/instruction_lookups.rs index cd5696dbbf..8038b66aa0 100644 --- a/jolt-core/src/zkvm/claim_reductions/instruction_lookups.rs +++ b/jolt-core/src/zkvm/claim_reductions/instruction_lookups.rs @@ -1,10 +1,6 @@ use std::array; use std::sync::Arc; -use allocative::Allocative; -use ark_std::Zero; -use common::constants::XLEN; - use crate::field::JoltField; use crate::poly::eq_poly::EqPolynomial; use crate::poly::multilinear_polynomial::PolynomialBinding; @@ -25,6 +21,9 @@ use crate::utils::math::Math; use crate::utils::thread::unsafe_allocate_zero_vec; use crate::zkvm::instruction::LookupQuery; use crate::zkvm::witness::VirtualPolynomial; +use allocative::Allocative; +use ark_std::Zero; +use common::constants::XLEN; use rayon::prelude::*; use tracer::instruction::Cycle; @@ -35,8 +34,6 @@ const DEGREE_BOUND: usize = 2; pub struct InstructionLookupsClaimReductionSumcheckParams { pub gamma: F, pub gamma_sqr: F, - pub gamma_cub: F, - pub gamma_quart: F, pub n_cycle_vars: usize, pub r_spartan: OpeningPoint, } @@ -49,8 +46,6 @@ impl InstructionLookupsClaimReductionSumcheckParams { ) -> Self { let gamma = transcript.challenge_scalar::(); let gamma_sqr = gamma.square(); - let gamma_cub = gamma_sqr * gamma; - let gamma_quart = gamma_sqr.square(); let (r_spartan, _) = accumulator.get_virtual_polynomial_opening( VirtualPolynomial::LookupOutput, SumcheckId::SpartanOuter, @@ -58,8 +53,6 @@ impl InstructionLookupsClaimReductionSumcheckParams { Self { gamma, gamma_sqr, - gamma_cub, - gamma_quart, n_cycle_vars: trace_len.log_2(), r_spartan, } @@ -80,20 +73,8 @@ impl SumcheckInstanceParams for InstructionLookupsClaimReductio VirtualPolynomial::RightLookupOperand, SumcheckId::SpartanOuter, ); - let (_, left_instruction_input_claim) = accumulator.get_virtual_polynomial_opening( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::SpartanOuter, - ); - let (_, right_instruction_input_claim) = accumulator.get_virtual_polynomial_opening( - VirtualPolynomial::RightInstructionInput, - SumcheckId::SpartanOuter, - ); - lookup_output_claim - + self.gamma * left_operand_claim - + self.gamma_sqr * right_operand_claim - + self.gamma_cub * left_instruction_input_claim - + self.gamma_quart * right_instruction_input_claim + lookup_output_claim + self.gamma * left_operand_claim + self.gamma_sqr * right_operand_claim } fn degree(&self) -> usize { @@ -123,14 +104,6 @@ impl SumcheckInstanceParams for InstructionLookupsClaimReductio VirtualPolynomial::RightLookupOperand, SumcheckId::SpartanOuter, ), - OpeningId::virt( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::SpartanOuter, - ), - OpeningId::virt( - VirtualPolynomial::RightInstructionInput, - SumcheckId::SpartanOuter, - ), ]) } @@ -139,7 +112,7 @@ impl SumcheckInstanceParams for InstructionLookupsClaimReductio &self, _accumulator: &dyn OpeningAccumulator, ) -> Vec { - vec![self.gamma, self.gamma_sqr, self.gamma_cub, self.gamma_quart] + vec![self.gamma, self.gamma_sqr] } #[cfg(feature = "zk")] @@ -157,14 +130,6 @@ impl SumcheckInstanceParams for InstructionLookupsClaimReductio VirtualPolynomial::RightLookupOperand, SumcheckId::InstructionClaimReduction, ), - OpeningId::virt( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::InstructionClaimReduction, - ), - OpeningId::virt( - VirtualPolynomial::RightInstructionInput, - SumcheckId::InstructionClaimReduction, - ), ])) } @@ -172,13 +137,7 @@ impl SumcheckInstanceParams for InstructionLookupsClaimReductio fn output_constraint_challenge_values(&self, sumcheck_challenges: &[F::Challenge]) -> Vec { let opening_point = self.normalize_opening_point(sumcheck_challenges); let eq_eval = EqPolynomial::::mle(&opening_point.r, &self.r_spartan.r); - vec![ - eq_eval, - eq_eval * self.gamma, - eq_eval * self.gamma_sqr, - eq_eval * self.gamma_cub, - eq_eval * self.gamma_quart, - ] + vec![eq_eval, eq_eval * self.gamma, eq_eval * self.gamma_sqr] } } @@ -270,9 +229,6 @@ impl SumcheckInstanceProver let lookup_output_claim = state.lookup_output_poly.final_sumcheck_claim(); let left_lookup_operand_claim = state.left_lookup_operand_poly.final_sumcheck_claim(); let right_lookup_operand_claim = state.right_lookup_operand_poly.final_sumcheck_claim(); - let left_instruction_input_claim = state.left_instruction_input_poly.final_sumcheck_claim(); - let right_instruction_input_claim = - state.right_instruction_input_poly.final_sumcheck_claim(); accumulator.append_virtual( VirtualPolynomial::LookupOutput, @@ -289,20 +245,8 @@ impl SumcheckInstanceProver accumulator.append_virtual( VirtualPolynomial::RightLookupOperand, SumcheckId::InstructionClaimReduction, - opening_point.clone(), - right_lookup_operand_claim, - ); - accumulator.append_virtual( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::InstructionClaimReduction, - opening_point.clone(), - left_instruction_input_claim, - ); - accumulator.append_virtual( - VirtualPolynomial::RightInstructionInput, - SumcheckId::InstructionClaimReduction, opening_point, - right_instruction_input_claim, + right_lookup_operand_claim, ); } @@ -341,8 +285,6 @@ impl InstructionLookupsPhase1State { let gamma = params.gamma; let gamma_sqr = params.gamma_sqr; - let gamma_cub = params.gamma_cub; - let gamma_quart = params.gamma_quart; const BLOCK_SIZE: usize = 32; Q.par_chunks_mut(BLOCK_SIZE) @@ -351,17 +293,12 @@ impl InstructionLookupsPhase1State { let mut q_lookup_output = [F::UnreducedMulU128::zero(); BLOCK_SIZE]; let mut q_left_lookup_operand = [F::UnreducedMulU128::zero(); BLOCK_SIZE]; let mut q_right_lookup_operand = [F::UnreducedMulU128Accum::zero(); BLOCK_SIZE]; - let mut q_left_instruction_input = [F::UnreducedMulU128::zero(); BLOCK_SIZE]; - let mut q_right_instruction_input_pos = [F::UnreducedMulU128::zero(); BLOCK_SIZE]; - let mut q_right_instruction_input_neg = [F::UnreducedMulU128::zero(); BLOCK_SIZE]; for x_hi in 0..(1 << suffix_n_vars) { for i in 0..q_chunk.len() { let x_lo = chunk_i * BLOCK_SIZE + i; let x = x_lo + (x_hi << prefix_n_vars); let cycle = &trace[x]; - let (left_instruction_input, right_instruction_input) = - LookupQuery::::to_instruction_inputs(cycle); let (left_lookup, right_lookup) = LookupQuery::::to_lookup_operands(cycle); let lookup_output = LookupQuery::::to_lookup_output(cycle); @@ -372,29 +309,13 @@ impl InstructionLookupsPhase1State { eq_suffix_evals[x_hi].mul_u64_unreduced(left_lookup); q_right_lookup_operand[i] += eq_suffix_evals[x_hi].mul_u128_unreduced(right_lookup); - - q_left_instruction_input[i] += - eq_suffix_evals[x_hi].mul_u64_unreduced(left_instruction_input); - - let abs: u64 = right_instruction_input.unsigned_abs() as u64; - let term = eq_suffix_evals[x_hi].mul_u64_unreduced(abs); - if right_instruction_input >= 0 { - q_right_instruction_input_pos[i] += term; - } else { - q_right_instruction_input_neg[i] += term; - } } } for (i, q) in q_chunk.iter_mut().enumerate() { - let right_instruction_input = - F::reduce_mul_u128(q_right_instruction_input_pos[i]) - - F::reduce_mul_u128(q_right_instruction_input_neg[i]); *q = F::reduce_mul_u128(q_lookup_output[i]) + gamma * F::reduce_mul_u128(q_left_lookup_operand[i]) + gamma_sqr * F::reduce_mul_u128_accum(q_right_lookup_operand[i]); - *q += gamma_cub * F::reduce_mul_u128(q_left_instruction_input[i]) - + gamma_quart * right_instruction_input; } }); @@ -440,8 +361,6 @@ struct InstructionLookupsPhase2State { lookup_output_poly: MultilinearPolynomial, left_lookup_operand_poly: MultilinearPolynomial, right_lookup_operand_poly: MultilinearPolynomial, - left_instruction_input_poly: MultilinearPolynomial, - right_instruction_input_poly: MultilinearPolynomial, eq_poly: MultilinearPolynomial, } @@ -459,14 +378,10 @@ impl InstructionLookupsPhase2State { let mut lookup_output_poly = unsafe_allocate_zero_vec(1 << n_remaining_rounds); let mut left_lookup_operand_poly = unsafe_allocate_zero_vec(1 << n_remaining_rounds); let mut right_lookup_operand_poly = unsafe_allocate_zero_vec(1 << n_remaining_rounds); - let mut left_instruction_input_poly = unsafe_allocate_zero_vec(1 << n_remaining_rounds); - let mut right_instruction_input_poly = unsafe_allocate_zero_vec(1 << n_remaining_rounds); ( &mut lookup_output_poly, &mut left_lookup_operand_poly, &mut right_lookup_operand_poly, - &mut left_instruction_input_poly, - &mut right_instruction_input_poly, trace.par_chunks(eq_evals.len()), ) .into_par_iter() @@ -475,20 +390,13 @@ impl InstructionLookupsPhase2State { lookup_output_eval, left_lookup_operand_eval, right_lookup_operand_eval, - left_instruction_input_eval, - right_instruction_input_eval, trace_chunk, )| { let mut lookup_output_eval_unreduced = F::UnreducedMulU128::zero(); let mut left_lookup_operand_eval_unreduced = F::UnreducedMulU128::zero(); let mut right_lookup_operand_eval_unreduced = F::UnreducedMulU128Accum::zero(); - let mut left_instruction_input_eval_unreduced = F::UnreducedMulU128::zero(); - let mut right_instruction_input_pos_unreduced = F::UnreducedMulU128::zero(); - let mut right_instruction_input_neg_unreduced = F::UnreducedMulU128::zero(); for (i, cycle) in trace_chunk.iter().enumerate() { - let (left_instruction_input, right_instruction_input) = - LookupQuery::::to_instruction_inputs(cycle); let (left_lookup, right_lookup) = LookupQuery::::to_lookup_operands(cycle); let lookup_output = LookupQuery::::to_lookup_output(cycle); @@ -499,17 +407,6 @@ impl InstructionLookupsPhase2State { eq_evals[i].mul_u64_unreduced(left_lookup); right_lookup_operand_eval_unreduced += eq_evals[i].mul_u128_unreduced(right_lookup); - - left_instruction_input_eval_unreduced += - eq_evals[i].mul_u64_unreduced(left_instruction_input); - - let abs: u64 = right_instruction_input.unsigned_abs() as u64; - let term = eq_evals[i].mul_u64_unreduced(abs); - if right_instruction_input >= 0 { - right_instruction_input_pos_unreduced += term; - } else { - right_instruction_input_neg_unreduced += term; - } } *lookup_output_eval = F::reduce_mul_u128(lookup_output_eval_unreduced); @@ -517,11 +414,6 @@ impl InstructionLookupsPhase2State { F::reduce_mul_u128(left_lookup_operand_eval_unreduced); *right_lookup_operand_eval = F::reduce_mul_u128_accum(right_lookup_operand_eval_unreduced); - *left_instruction_input_eval = - F::reduce_mul_u128(left_instruction_input_eval_unreduced); - *right_instruction_input_eval = - F::reduce_mul_u128(right_instruction_input_pos_unreduced) - - F::reduce_mul_u128(right_instruction_input_neg_unreduced); }, ); @@ -533,8 +425,6 @@ impl InstructionLookupsPhase2State { lookup_output_poly: lookup_output_poly.into(), left_lookup_operand_poly: left_lookup_operand_poly.into(), right_lookup_operand_poly: right_lookup_operand_poly.into(), - left_instruction_input_poly: left_instruction_input_poly.into(), - right_instruction_input_poly: right_instruction_input_poly.into(), eq_poly: eq_suffix_evals.into(), } } @@ -556,12 +446,6 @@ impl InstructionLookupsPhase2State { let right_lookup_operand_evals = self .right_lookup_operand_poly .sumcheck_evals_array::(j, BindingOrder::LowToHigh); - let left_instruction_input_evals = self - .left_instruction_input_poly - .sumcheck_evals_array::(j, BindingOrder::LowToHigh); - let right_instruction_input_evals = self - .right_instruction_input_poly - .sumcheck_evals_array::(j, BindingOrder::LowToHigh); let eq_evals = self .eq_poly .sumcheck_evals_array::(j, BindingOrder::LowToHigh); @@ -570,9 +454,7 @@ impl InstructionLookupsPhase2State { + eq_evals[i] * (lookup_output_evals[i] + params.gamma * left_lookup_operand_evals[i] - + params.gamma_sqr * right_lookup_operand_evals[i] - + params.gamma_cub * left_instruction_input_evals[i] - + params.gamma_quart * right_instruction_input_evals[i]) + + params.gamma_sqr * right_lookup_operand_evals[i]) }); } UniPoly::from_evals_and_hint(previous_claim, &evals) @@ -585,10 +467,6 @@ impl InstructionLookupsPhase2State { .bind_parallel(r_j, BindingOrder::LowToHigh); self.right_lookup_operand_poly .bind_parallel(r_j, BindingOrder::LowToHigh); - self.left_instruction_input_poly - .bind_parallel(r_j, BindingOrder::LowToHigh); - self.right_instruction_input_poly - .bind_parallel(r_j, BindingOrder::LowToHigh); self.eq_poly.bind_parallel(r_j, BindingOrder::LowToHigh); } } @@ -600,16 +478,12 @@ impl InstructionLookupsPhase2State { /// LookupOutput(j) /// + gamma * LeftLookupOperand(j) /// + gamma^2 * RightLookupOperand(j) -/// + gamma^3 * LeftInstructionInput(j) -/// + gamma^4 * RightInstructionInput(j) /// ) /// ``` /// /// where `r_spartan` is the randomness from the log(T) rounds of Spartan outer sumcheck (stage 1). /// -/// The purpose of this sumcheck is to aggregate instruction lookup claims into a single claim. It runs in -/// parallel with the Spartan product sumcheck. This optimization eliminates the need for a separate opening -/// of [`VirtualPolynomial::LookupOutput`] at `r_spartan`, leaving only the opening at `r_product` required. +/// Aggregates the three instruction lookup claims into a single claim. pub struct InstructionLookupsClaimReductionSumcheckVerifier { params: InstructionLookupsClaimReductionSumcheckParams, } @@ -658,21 +532,11 @@ impl SumcheckInstanceVerifier VirtualPolynomial::RightLookupOperand, SumcheckId::InstructionClaimReduction, ); - let (_, left_instruction_input_claim) = accumulator.get_virtual_polynomial_opening( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::InstructionClaimReduction, - ); - let (_, right_instruction_input_claim) = accumulator.get_virtual_polynomial_opening( - VirtualPolynomial::RightInstructionInput, - SumcheckId::InstructionClaimReduction, - ); EqPolynomial::mle(&opening_point.r, &r_spartan.r) * (lookup_output_claim + self.params.gamma * left_lookup_operand_claim - + self.params.gamma_sqr * right_lookup_operand_claim - + self.params.gamma_cub * left_instruction_input_claim - + self.params.gamma_quart * right_instruction_input_claim) + + self.params.gamma_sqr * right_lookup_operand_claim) } fn cache_openings( @@ -696,16 +560,6 @@ impl SumcheckInstanceVerifier accumulator.append_virtual( VirtualPolynomial::RightLookupOperand, SumcheckId::InstructionClaimReduction, - opening_point.clone(), - ); - accumulator.append_virtual( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::InstructionClaimReduction, - opening_point.clone(), - ); - accumulator.append_virtual( - VirtualPolynomial::RightInstructionInput, - SumcheckId::InstructionClaimReduction, opening_point, ); } diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index a8058963e4..bd3e20b78f 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -1,6 +1,9 @@ #[cfg(not(feature = "zk"))] use std::collections::BTreeMap; -use std::io::{Read, Write}; +use std::{ + fs::File, + io::{Read, Write}, +}; use ark_serialize::{ CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate, @@ -43,7 +46,6 @@ pub struct JoltProof, pub stage5_sumcheck_proof: SumcheckInstanceProof, pub stage6_sumcheck_proof: SumcheckInstanceProof, - pub stage7_sumcheck_proof: SumcheckInstanceProof, #[cfg(feature = "zk")] pub blindfold_proof: BlindFoldProof, pub joint_opening_proof: PCS::Proof, @@ -333,46 +335,45 @@ impl CanonicalSerialize for VirtualPolynomial { Self::RightLookupOperand => 8u8.serialize_with_mode(&mut writer, compress), Self::LeftInstructionInput => 9u8.serialize_with_mode(&mut writer, compress), Self::RightInstructionInput => 10u8.serialize_with_mode(&mut writer, compress), - Self::Product => 11u8.serialize_with_mode(&mut writer, compress), - Self::ShouldJump => 12u8.serialize_with_mode(&mut writer, compress), - Self::ShouldBranch => 13u8.serialize_with_mode(&mut writer, compress), - Self::WritePCtoRD => 14u8.serialize_with_mode(&mut writer, compress), - Self::WriteLookupOutputToRD => 15u8.serialize_with_mode(&mut writer, compress), - Self::Rd => 16u8.serialize_with_mode(&mut writer, compress), - Self::Imm => 17u8.serialize_with_mode(&mut writer, compress), - Self::Rs1Value => 18u8.serialize_with_mode(&mut writer, compress), - Self::Rs2Value => 19u8.serialize_with_mode(&mut writer, compress), - Self::RdWriteValue => 20u8.serialize_with_mode(&mut writer, compress), - Self::Rs1Ra => 21u8.serialize_with_mode(&mut writer, compress), - Self::Rs2Ra => 22u8.serialize_with_mode(&mut writer, compress), - Self::RdWa => 23u8.serialize_with_mode(&mut writer, compress), - Self::LookupOutput => 24u8.serialize_with_mode(&mut writer, compress), - Self::InstructionRaf => 25u8.serialize_with_mode(&mut writer, compress), - Self::InstructionRafFlag => 26u8.serialize_with_mode(&mut writer, compress), + Self::ShouldJump => 11u8.serialize_with_mode(&mut writer, compress), + Self::ShouldBranch => 12u8.serialize_with_mode(&mut writer, compress), + Self::WritePCtoRD => 13u8.serialize_with_mode(&mut writer, compress), + Self::WriteLookupOutputToRD => 14u8.serialize_with_mode(&mut writer, compress), + Self::Rd => 15u8.serialize_with_mode(&mut writer, compress), + Self::Imm => 16u8.serialize_with_mode(&mut writer, compress), + Self::Rs1Value => 17u8.serialize_with_mode(&mut writer, compress), + Self::Rs2Value => 18u8.serialize_with_mode(&mut writer, compress), + Self::RdWriteValue => 19u8.serialize_with_mode(&mut writer, compress), + Self::Rs1Ra => 20u8.serialize_with_mode(&mut writer, compress), + Self::Rs2Ra => 21u8.serialize_with_mode(&mut writer, compress), + Self::RdWa => 22u8.serialize_with_mode(&mut writer, compress), + Self::LookupOutput => 23u8.serialize_with_mode(&mut writer, compress), + Self::InstructionRaf => 24u8.serialize_with_mode(&mut writer, compress), + Self::InstructionRafFlag => 25u8.serialize_with_mode(&mut writer, compress), Self::InstructionRa(i) => { - 27u8.serialize_with_mode(&mut writer, compress)?; + 26u8.serialize_with_mode(&mut writer, compress)?; (u8::try_from(*i).unwrap()).serialize_with_mode(&mut writer, compress) } - Self::RegistersVal => 28u8.serialize_with_mode(&mut writer, compress), - Self::RamAddress => 29u8.serialize_with_mode(&mut writer, compress), - Self::RamRa => 30u8.serialize_with_mode(&mut writer, compress), - Self::RamReadValue => 31u8.serialize_with_mode(&mut writer, compress), - Self::RamWriteValue => 32u8.serialize_with_mode(&mut writer, compress), - Self::RamVal => 33u8.serialize_with_mode(&mut writer, compress), - Self::RamValInit => 34u8.serialize_with_mode(&mut writer, compress), - Self::RamValFinal => 35u8.serialize_with_mode(&mut writer, compress), - Self::RamHammingWeight => 36u8.serialize_with_mode(&mut writer, compress), - Self::UnivariateSkip => 37u8.serialize_with_mode(&mut writer, compress), + Self::RegistersVal => 27u8.serialize_with_mode(&mut writer, compress), + Self::RamAddress => 28u8.serialize_with_mode(&mut writer, compress), + Self::RamRa => 29u8.serialize_with_mode(&mut writer, compress), + Self::RamReadValue => 30u8.serialize_with_mode(&mut writer, compress), + Self::RamWriteValue => 31u8.serialize_with_mode(&mut writer, compress), + Self::RamVal => 32u8.serialize_with_mode(&mut writer, compress), + Self::RamValInit => 33u8.serialize_with_mode(&mut writer, compress), + Self::RamValFinal => 34u8.serialize_with_mode(&mut writer, compress), + Self::RamHammingWeight => 35u8.serialize_with_mode(&mut writer, compress), + Self::UnivariateSkip => 36u8.serialize_with_mode(&mut writer, compress), Self::OpFlags(flags) => { - 38u8.serialize_with_mode(&mut writer, compress)?; + 37u8.serialize_with_mode(&mut writer, compress)?; (u8::try_from(*flags as usize).unwrap()).serialize_with_mode(&mut writer, compress) } Self::InstructionFlags(flags) => { - 39u8.serialize_with_mode(&mut writer, compress)?; + 38u8.serialize_with_mode(&mut writer, compress)?; (u8::try_from(*flags as usize).unwrap()).serialize_with_mode(&mut writer, compress) } Self::LookupTableFlag(flag) => { - 40u8.serialize_with_mode(&mut writer, compress)?; + 39u8.serialize_with_mode(&mut writer, compress)?; (u8::try_from(*flag).unwrap()).serialize_with_mode(&mut writer, compress) } } @@ -391,7 +392,6 @@ impl CanonicalSerialize for VirtualPolynomial { | Self::RightLookupOperand | Self::LeftInstructionInput | Self::RightInstructionInput - | Self::Product | Self::ShouldJump | Self::ShouldBranch | Self::WritePCtoRD @@ -450,49 +450,48 @@ impl CanonicalDeserialize for VirtualPolynomial { 8 => Self::RightLookupOperand, 9 => Self::LeftInstructionInput, 10 => Self::RightInstructionInput, - 11 => Self::Product, - 12 => Self::ShouldJump, - 13 => Self::ShouldBranch, - 14 => Self::WritePCtoRD, - 15 => Self::WriteLookupOutputToRD, - 16 => Self::Rd, - 17 => Self::Imm, - 18 => Self::Rs1Value, - 19 => Self::Rs2Value, - 20 => Self::RdWriteValue, - 21 => Self::Rs1Ra, - 22 => Self::Rs2Ra, - 23 => Self::RdWa, - 24 => Self::LookupOutput, - 25 => Self::InstructionRaf, - 26 => Self::InstructionRafFlag, - 27 => { + 11 => Self::ShouldJump, + 12 => Self::ShouldBranch, + 13 => Self::WritePCtoRD, + 14 => Self::WriteLookupOutputToRD, + 15 => Self::Rd, + 16 => Self::Imm, + 17 => Self::Rs1Value, + 18 => Self::Rs2Value, + 19 => Self::RdWriteValue, + 20 => Self::Rs1Ra, + 21 => Self::Rs2Ra, + 22 => Self::RdWa, + 23 => Self::LookupOutput, + 24 => Self::InstructionRaf, + 25 => Self::InstructionRafFlag, + 26 => { let i = u8::deserialize_with_mode(&mut reader, compress, validate)?; Self::InstructionRa(i as usize) } - 28 => Self::RegistersVal, - 29 => Self::RamAddress, - 30 => Self::RamRa, - 31 => Self::RamReadValue, - 32 => Self::RamWriteValue, - 33 => Self::RamVal, - 34 => Self::RamValInit, - 35 => Self::RamValFinal, - 36 => Self::RamHammingWeight, - 37 => Self::UnivariateSkip, - 38 => { + 27 => Self::RegistersVal, + 28 => Self::RamAddress, + 29 => Self::RamRa, + 30 => Self::RamReadValue, + 31 => Self::RamWriteValue, + 32 => Self::RamVal, + 33 => Self::RamValInit, + 34 => Self::RamValFinal, + 35 => Self::RamHammingWeight, + 36 => Self::UnivariateSkip, + 37 => { let discriminant = u8::deserialize_with_mode(&mut reader, compress, validate)?; let flags = CircuitFlags::from_repr(discriminant) .ok_or(SerializationError::InvalidData)?; Self::OpFlags(flags) } - 39 => { + 38 => { let discriminant = u8::deserialize_with_mode(&mut reader, compress, validate)?; let flags = InstructionFlags::from_repr(discriminant) .ok_or(SerializationError::InvalidData)?; Self::InstructionFlags(flags) } - 40 => { + 39 => { let flag = u8::deserialize_with_mode(&mut reader, compress, validate)?; Self::LookupTableFlag(flag as usize) } @@ -507,7 +506,6 @@ pub fn serialize_and_print_size( file_name: &str, item: &impl CanonicalSerialize, ) -> Result<(), SerializationError> { - use std::fs::File; let mut file = File::create(file_name)?; item.serialize_compressed(&mut file)?; let file_size_bytes = file.metadata()?.len(); diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index b30ec0c105..c5060d7dbb 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -1,3 +1,5 @@ +use common::constants::ONEHOT_CHUNK_THRESHOLD_LOG_T; + #[cfg(feature = "zk")] use crate::zkvm::stage8_opening_ids; use crate::zkvm::{claim_reductions::advice::ReductionPhase, config::OneHotConfig}; @@ -16,9 +18,12 @@ use crate::poly::commitment::dory::bind_opening_inputs; #[cfg(feature = "zk")] use crate::poly::commitment::dory::bind_opening_inputs_zk; use crate::poly::commitment::dory::DoryContext; +#[cfg(not(feature = "zk"))] +use crate::zkvm::proof_serialization::Claims; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use crate::zkvm::config::ReadWriteConfig; +use crate::zkvm::ram::remap_address; use crate::zkvm::verifier::JoltSharedPreprocessing; use crate::zkvm::Serializable; @@ -142,9 +147,10 @@ use crate::poly::commitment::pedersen::PedersenGenerators; use crate::poly::lagrange_poly::LagrangeHelper; #[cfg(feature = "zk")] use crate::subprotocols::blindfold::{ - pedersen_generator_count_for_r1cs, BakedPublicInputs, BlindFoldProof, BlindFoldProver, - BlindFoldWitness, ExtraConstraintWitness, FinalOutputWitness, RelaxedR1CSInstance, - RoundWitness, StageConfig, StageWitness, VerifierR1CSBuilder, + pedersen_generator_count_for_r1cs, BakedPublicInputs, BlindFoldAccumulator, BlindFoldProof, + BlindFoldProver, BlindFoldWitness, ExtraConstraintWitness, FinalOutputWitness, + OpeningProofData, RelaxedR1CSInstance, RoundWitness, StageConfig, StageWitness, + VerifierR1CSBuilder, }; #[cfg(feature = "zk")] use crate::subprotocols::blindfold::{InputClaimConstraint, OutputClaimConstraint, ValueSource}; @@ -171,10 +177,10 @@ pub struct JoltCpuProver< pub lazy_trace: LazyTraceIterator, pub trace: Arc>, pub advice: JoltAdvice, - /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). + /// The advice claim reduction sumcheck effectively spans two stages (5 and 6). /// Cache the prover state here between stages. advice_reduction_prover_trusted: Option>, - /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). + /// The advice claim reduction sumcheck effectively spans two stages (5 and 6). /// Cache the prover state here between stages. advice_reduction_prover_untrusted: Option>, pub unpadded_trace_len: usize, @@ -189,7 +195,7 @@ pub struct JoltCpuProver< pub pedersen_generators: PedersenGenerators, pub rw_config: ReadWriteConfig, #[cfg(feature = "zk")] - blindfold_accumulator: crate::subprotocols::blindfold::BlindFoldAccumulator, + blindfold_accumulator: BlindFoldAccumulator, #[cfg(not(feature = "zk"))] _curve: std::marker::PhantomData, } @@ -401,7 +407,7 @@ where let ram_K = trace .par_iter() .filter_map(|cycle| { - crate::zkvm::ram::remap_address( + remap_address( cycle.ram_access().address() as u64, &preprocessing.shared.memory_layout, ) @@ -409,7 +415,7 @@ where .max() .unwrap_or(0) .max( - crate::zkvm::ram::remap_address( + remap_address( preprocessing.shared.ram.min_bytecode_address, &preprocessing.shared.memory_layout, ) @@ -469,7 +475,7 @@ where #[cfg(feature = "zk")] pedersen_generators, #[cfg(feature = "zk")] - blindfold_accumulator: crate::subprotocols::blindfold::BlindFoldAccumulator::new(), + blindfold_accumulator: BlindFoldAccumulator::new(), #[cfg(not(feature = "zk"))] _curve: std::marker::PhantomData, } @@ -503,7 +509,7 @@ where let untrusted_advice_commitment = self.generate_and_commit_untrusted_advice(); self.generate_and_commit_trusted_advice(); - // Add advice hints for batched Stage 8 opening + // Add advice hints for batched Stage 7 opening if let Some(hint) = self.advice.trusted_advice_hint.take() { opening_proof_hints.insert(CommittedPolynomial::TrustedAdvice, hint); } @@ -519,19 +525,15 @@ where let (stage4_sumcheck_proof, r_stage4) = self.prove_stage4(); let (stage5_sumcheck_proof, r_stage5) = self.prove_stage5(); let (stage6_sumcheck_proof, r_stage6) = self.prove_stage6(); - let (stage7_sumcheck_proof, r_stage7) = self.prove_stage7(); - let _sumcheck_challenges = [ - r_stage1, r_stage2, r_stage3, r_stage4, r_stage5, r_stage6, r_stage7, - ]; + let _sumcheck_challenges = [r_stage1, r_stage2, r_stage3, r_stage4, r_stage5, r_stage6]; - let joint_opening_proof = self.prove_stage8(opening_proof_hints); + let joint_opening_proof = self.prove_stage7(opening_proof_hints); #[cfg(feature = "zk")] let blindfold_proof = self.prove_blindfold(&joint_opening_proof); #[cfg(not(feature = "zk"))] - let opening_claims = - crate::zkvm::proof_serialization::Claims(self.opening_accumulator.openings.clone()); + let opening_claims = Claims(self.opening_accumulator.openings.clone()); #[cfg(test)] assert!( @@ -563,7 +565,6 @@ where stage4_sumcheck_proof, stage5_sumcheck_proof, stage6_sumcheck_proof, - stage7_sumcheck_proof, #[cfg(feature = "zk")] blindfold_proof, joint_opening_proof, @@ -893,17 +894,6 @@ where self.trace.len(), &self.rw_config, ); - let spartan_product_virtual_remainder_params = ProductVirtualRemainderParams::new( - self.trace.len(), - uni_skip_params, - &self.opening_accumulator, - ); - let instruction_claim_reduction_params = - InstructionLookupsClaimReductionSumcheckParams::new( - self.trace.len(), - &self.opening_accumulator, - &mut self.transcript, - ); let ram_raf_evaluation_params = RafEvaluationSumcheckParams::new( &self.program_io.memory_layout, &self.one_hot_params, @@ -918,6 +908,17 @@ where self.trace.len(), &self.rw_config, ); + let spartan_product_virtual_remainder_params = ProductVirtualRemainderParams::new( + self.trace.len(), + uni_skip_params, + &self.opening_accumulator, + ); + let instruction_claim_reduction_params = + InstructionLookupsClaimReductionSumcheckParams::new( + self.trace.len(), + &self.opening_accumulator, + &mut self.transcript, + ); let ram_read_write_checking = RamReadWriteCheckingProver::initialize( ram_read_write_checking_params, &self.trace, @@ -925,15 +926,6 @@ where &self.program_io.memory_layout, &self.initial_ram_state, ); - let spartan_product_virtual_remainder = ProductVirtualRemainderProver::initialize( - spartan_product_virtual_remainder_params, - Arc::clone(&self.trace), - ); - let instruction_claim_reduction = - InstructionLookupsClaimReductionSumcheckProver::initialize( - instruction_claim_reduction_params, - Arc::clone(&self.trace), - ); let ram_raf_evaluation = RamRafEvaluationSumcheckProver::initialize( ram_raf_evaluation_params, &self.trace, @@ -945,60 +937,16 @@ where &self.final_ram_state, &self.program_io.memory_layout, ); - - #[cfg(feature = "allocative")] - { - print_data_structure_heap_usage("RamReadWriteCheckingProver", &ram_read_write_checking); - print_data_structure_heap_usage( - "ProductVirtualRemainderProver", - &spartan_product_virtual_remainder, - ); - print_data_structure_heap_usage( - "InstructionLookupsClaimReductionSumcheckProver", - &instruction_claim_reduction, + let spartan_product_virtual_remainder = ProductVirtualRemainderProver::initialize( + spartan_product_virtual_remainder_params, + Arc::clone(&self.trace), + ); + let instruction_claim_reduction = + InstructionLookupsClaimReductionSumcheckProver::initialize( + instruction_claim_reduction_params, + Arc::clone(&self.trace), ); - print_data_structure_heap_usage("RamRafEvaluationSumcheckProver", &ram_raf_evaluation); - print_data_structure_heap_usage("OutputSumcheckProver", &ram_output_check); - } - - let mut instances: Vec>> = vec![ - Box::new(ram_read_write_checking), - Box::new(spartan_product_virtual_remainder), - Box::new(instruction_claim_reduction), - Box::new(ram_raf_evaluation), - Box::new(ram_output_check), - ]; - - #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage2_start_flamechart.svg"); - tracing::info!("Stage 2 proving"); - - let (sumcheck_proof, r_stage2, _initial_claim) = - self.prove_batched_sumcheck(instances.iter_mut().map(|v| &mut **v as _).collect()); - - #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage2_end_flamechart.svg"); - drop_in_background_thread(instances); - - (first_round_proof, sumcheck_proof, r_stage2) - } - #[tracing::instrument(skip_all)] - fn prove_stage3( - &mut self, - ) -> ( - SumcheckInstanceProof, - Vec, - ) { - #[cfg(not(target_arch = "wasm32"))] - print_current_memory_usage("Stage 3 baseline"); - - // Initialization params - let spartan_shift_params = ShiftSumcheckParams::new( - self.trace.len().log_2(), - &self.opening_accumulator, - &mut self.transcript, - ); let spartan_instruction_input_params = InstructionInputParams::new(&self.opening_accumulator, &mut self.transcript); let spartan_registers_claim_reduction_params = RegistersClaimReductionSumcheckParams::new( @@ -1006,12 +954,6 @@ where &self.opening_accumulator, &mut self.transcript, ); - - let spartan_shift = ShiftSumcheckProver::initialize( - spartan_shift_params, - Arc::clone(&self.trace), - &self.preprocessing.shared.bytecode, - ); let spartan_instruction_input = InstructionInputSumcheckProver::initialize( spartan_instruction_input_params, &self.trace, @@ -1024,7 +966,17 @@ where #[cfg(feature = "allocative")] { - print_data_structure_heap_usage("ShiftSumcheckProver", &spartan_shift); + print_data_structure_heap_usage("RamReadWriteCheckingProver", &ram_read_write_checking); + print_data_structure_heap_usage("RamRafEvaluationSumcheckProver", &ram_raf_evaluation); + print_data_structure_heap_usage("OutputSumcheckProver", &ram_output_check); + print_data_structure_heap_usage( + "ProductVirtualRemainderProver", + &spartan_product_virtual_remainder, + ); + print_data_structure_heap_usage( + "InstructionLookupsClaimReductionSumcheckProver", + &instruction_claim_reduction, + ); print_data_structure_heap_usage( "InstructionInputSumcheckProver", &spartan_instruction_input, @@ -1036,32 +988,44 @@ where } let mut instances: Vec>> = vec![ - Box::new(spartan_shift), + Box::new(ram_read_write_checking), + Box::new(ram_raf_evaluation), + Box::new(ram_output_check), + Box::new(spartan_product_virtual_remainder), + Box::new(instruction_claim_reduction), Box::new(spartan_instruction_input), Box::new(spartan_registers_claim_reduction), ]; #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage3_start_flamechart.svg"); - tracing::info!("Stage 3 proving"); + write_boxed_instance_flamegraph_svg(&instances, "stage2_start_flamechart.svg"); + tracing::info!("Stage 2 proving"); - let (sumcheck_proof, r_stage3, _initial_claim) = + let (sumcheck_proof, r_stage2, _initial_claim) = self.prove_batched_sumcheck(instances.iter_mut().map(|v| &mut **v as _).collect()); + #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage3_end_flamechart.svg"); + write_boxed_instance_flamegraph_svg(&instances, "stage2_end_flamechart.svg"); drop_in_background_thread(instances); - (sumcheck_proof, r_stage3) + (first_round_proof, sumcheck_proof, r_stage2) } + #[tracing::instrument(skip_all)] - fn prove_stage4( + fn prove_stage3( &mut self, ) -> ( SumcheckInstanceProof, Vec, ) { #[cfg(not(target_arch = "wasm32"))] - print_current_memory_usage("Stage 4 baseline"); + print_current_memory_usage("Stage 3 baseline"); + + let spartan_shift_params = ShiftSumcheckParams::new( + self.trace.len().log_2(), + &self.opening_accumulator, + &mut self.transcript, + ); let registers_read_write_checking_params = RegistersReadWriteCheckingParams::new( self.trace.len(), @@ -1076,7 +1040,6 @@ where &self.one_hot_params, &mut self.opening_accumulator, ); - // Domain-separate the batching challenge. self.transcript.append_bytes(b"ram_val_check_gamma", &[]); let ram_val_check_gamma: F = self.transcript.challenge_scalar::(); let ram_val_check_params = RamValCheckSumcheckParams::new_from_prover( @@ -1089,6 +1052,11 @@ where &self.program_io, ); + let spartan_shift = ShiftSumcheckProver::initialize( + spartan_shift_params, + Arc::clone(&self.trace), + &self.preprocessing.shared.bytecode, + ); let registers_read_write_checking = RegistersReadWriteCheckingProver::initialize( registers_read_write_checking_params, self.trace.clone(), @@ -1104,6 +1072,7 @@ where #[cfg(feature = "allocative")] { + print_data_structure_heap_usage("ShiftSumcheckProver", &spartan_shift); print_data_structure_heap_usage( "RegistersReadWriteCheckingProver", ®isters_read_write_checking, @@ -1112,32 +1081,33 @@ where } let mut instances: Vec>> = vec![ + Box::new(spartan_shift), Box::new(registers_read_write_checking), Box::new(ram_val_check), ]; #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage4_start_flamechart.svg"); - tracing::info!("Stage 4 proving"); + write_boxed_instance_flamegraph_svg(&instances, "stage3_start_flamechart.svg"); + tracing::info!("Stage 3 proving"); - let (sumcheck_proof, r_stage4, _initial_claim) = + let (sumcheck_proof, r_stage3, _initial_claim) = self.prove_batched_sumcheck(instances.iter_mut().map(|v| &mut **v as _).collect()); #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage4_end_flamechart.svg"); + write_boxed_instance_flamegraph_svg(&instances, "stage3_end_flamechart.svg"); drop_in_background_thread(instances); - (sumcheck_proof, r_stage4) + (sumcheck_proof, r_stage3) } - #[tracing::instrument(skip_all)] - fn prove_stage5( + fn prove_stage4( &mut self, ) -> ( SumcheckInstanceProof, Vec, ) { #[cfg(not(target_arch = "wasm32"))] - print_current_memory_usage("Stage 5 baseline"); + print_current_memory_usage("Stage 4 baseline"); + // Initialization params (same order as batch) let lookups_read_raf_params = InstructionReadRafSumcheckParams::new( self.trace.len().log_2(), @@ -1188,27 +1158,27 @@ where ]; #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage5_start_flamechart.svg"); - tracing::info!("Stage 5 proving"); + write_boxed_instance_flamegraph_svg(&instances, "stage4_start_flamechart.svg"); + tracing::info!("Stage 4 proving"); - let (sumcheck_proof, r_stage5, _initial_claim) = + let (sumcheck_proof, r_stage4, _initial_claim) = self.prove_batched_sumcheck(instances.iter_mut().map(|v| &mut **v as _).collect()); #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage5_end_flamechart.svg"); + write_boxed_instance_flamegraph_svg(&instances, "stage4_end_flamechart.svg"); drop_in_background_thread(instances); - (sumcheck_proof, r_stage5) + (sumcheck_proof, r_stage4) } #[tracing::instrument(skip_all)] - fn prove_stage6( + fn prove_stage5( &mut self, ) -> ( SumcheckInstanceProof, Vec, ) { #[cfg(not(target_arch = "wasm32"))] - print_current_memory_usage("Stage 6 baseline"); + print_current_memory_usage("Stage 5 baseline"); let bytecode_read_raf_params = BytecodeReadRafSumcheckParams::gen( &self.preprocessing.shared.bytecode, @@ -1244,7 +1214,7 @@ where &mut self.transcript, ); - // Advice claim reduction (Phase 1 in Stage 6): trusted and untrusted are separate instances. + // Advice claim reduction (Phase 1 in Stage 5): trusted and untrusted are separate instances. if self.advice.trusted_advice_polynomial.is_some() { let trusted_advice_params = AdviceClaimReductionParams::new( AdviceKind::Trusted, @@ -1252,7 +1222,7 @@ where self.trace.len(), &self.opening_accumulator, ); - // Note: We clone the advice polynomial here because Stage 8 needs the original polynomial + // Note: We clone the advice polynomial here because Stage 7 needs the original polynomial // A future optimization could use Arc with copy-on-write. self.advice_reduction_prover_trusted = { let poly = self @@ -1274,7 +1244,7 @@ where self.trace.len(), &self.opening_accumulator, ); - // Note: We clone the advice polynomial here because Stage 8 needs the original polynomial + // Note: We clone the advice polynomial here because Stage 7 needs the original polynomial // A future optimization could use Arc with copy-on-write. self.advice_reduction_prover_untrusted = { let poly = self @@ -1353,13 +1323,13 @@ where } #[cfg(feature = "allocative")] - write_instance_flamegraph_svg(&instances, "stage6_start_flamechart.svg"); - tracing::info!("Stage 6 proving"); + write_instance_flamegraph_svg(&instances, "stage5_start_flamechart.svg"); + tracing::info!("Stage 5 proving"); - let (sumcheck_proof, r_stage6, _initial_claim) = + let (sumcheck_proof, r_stage5, _initial_claim) = self.prove_batched_sumcheck(instances.iter_mut().map(|v| &mut **v as _).collect()); #[cfg(feature = "allocative")] - write_instance_flamegraph_svg(&instances, "stage6_end_flamechart.svg"); + write_instance_flamegraph_svg(&instances, "stage5_end_flamechart.svg"); drop_in_background_thread(bytecode_read_raf); drop_in_background_thread(booleanity); drop_in_background_thread(ram_hamming_booleanity); @@ -1370,7 +1340,7 @@ where self.advice_reduction_prover_trusted = advice_trusted; self.advice_reduction_prover_untrusted = advice_untrusted; - (sumcheck_proof, r_stage6) + (sumcheck_proof, r_stage5) } #[tracing::instrument(skip_all)] @@ -1392,8 +1362,8 @@ where let zk_stages = self.blindfold_accumulator.take_stage_data(); assert_eq!( zk_stages.len(), - 7, - "Expected 7 ZK stages, got {}", + 6, + "Expected 6 ZK stages, got {}", zk_stages.len() ); @@ -1750,9 +1720,9 @@ where prover.prove(&real_instance, &real_witness, &z, &mut blindfold_transcript) } - /// Stage 7: HammingWeight + ClaimReduction sumcheck (only log_k_chunk rounds). + /// Stage 6: HammingWeight + ClaimReduction sumcheck (only log_k_chunk rounds). #[tracing::instrument(skip_all)] - fn prove_stage7( + fn prove_stage6( &mut self, ) -> ( SumcheckInstanceProof, @@ -1775,7 +1745,7 @@ where #[cfg(feature = "allocative")] print_data_structure_heap_usage("HammingWeightClaimReductionProver", &hw_prover); - // Run Stage 7 batched sumcheck (address rounds only). + // Run Stage 6 batched sumcheck (address rounds only). // Includes HammingWeightClaimReduction plus address phase of advice reduction instances (if needed). let mut instances: Vec>> = vec![Box::new(hw_prover)]; @@ -1808,26 +1778,26 @@ where } #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage7_start_flamechart.svg"); - tracing::info!("Stage 7 proving"); + write_boxed_instance_flamegraph_svg(&instances, "stage6_start_flamechart.svg"); + tracing::info!("Stage 6 proving"); - let (sumcheck_proof, r_stage7, _initial_claim) = + let (sumcheck_proof, r_stage6, _initial_claim) = self.prove_batched_sumcheck(instances.iter_mut().map(|v| &mut **v as _).collect()); #[cfg(feature = "allocative")] - write_boxed_instance_flamegraph_svg(&instances, "stage7_end_flamechart.svg"); + write_boxed_instance_flamegraph_svg(&instances, "stage6_end_flamechart.svg"); drop_in_background_thread(instances); - (sumcheck_proof, r_stage7) + (sumcheck_proof, r_stage6) } - /// Stage 8: Dory batch opening proof. + /// Stage 7: Dory batch opening proof. /// Builds streaming RLC polynomial directly from trace (no witness regeneration needed). #[tracing::instrument(skip_all)] - fn prove_stage8( + fn prove_stage7( &mut self, opening_proof_hints: HashMap, ) -> PCS::Proof { - tracing::info!("Stage 8 proving (Dory batch opening)"); + tracing::info!("Stage 7 proving (Dory batch opening)"); let _guard = DoryGlobals::initialize_context( self.one_hot_params.k_chunk, @@ -1837,39 +1807,57 @@ where ); // Get the unified opening point from HammingWeightClaimReduction - // This contains (r_address_stage7 || r_cycle_stage6) in big-endian + // This contains (r_address_stage6 || r_cycle_stage5) in big-endian let (opening_point, _) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::InstructionRa(0), SumcheckId::HammingWeightClaimReduction, ); let log_k_chunk = self.one_hot_params.log_k_chunk; - let r_address_stage7 = &opening_point.r[..log_k_chunk]; + let r_address_stage6 = &opening_point.r[..log_k_chunk]; let mut polynomial_claims = Vec::new(); let mut scaling_factors = Vec::new(); - // Dense polynomials: RamInc and RdInc (from IncClaimReduction in Stage 6) - // at r_cycle_stage6 only (length log_T) - let (_, ram_inc_claim) = self.opening_accumulator.get_committed_polynomial_opening( - CommittedPolynomial::RamInc, - SumcheckId::IncClaimReduction, - ); - let (_, rd_inc_claim) = self.opening_accumulator.get_committed_polynomial_opening( - CommittedPolynomial::RdInc, - SumcheckId::IncClaimReduction, - ); + // Dense polynomials: RamInc and RdInc (from IncClaimReduction in Stage 5) + // These are at r_cycle_stage5 only (length log_T) + let (_ram_inc_point, ram_inc_claim) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::RamInc, + SumcheckId::IncClaimReduction, + ); + let (_rd_inc_point, rd_inc_claim) = + self.opening_accumulator.get_committed_polynomial_opening( + CommittedPolynomial::RdInc, + SumcheckId::IncClaimReduction, + ); + + #[cfg(test)] + { + let r_cycle_stage5 = &opening_point.r[log_k_chunk..]; + + debug_assert_eq!( + _ram_inc_point.r.as_slice(), + r_cycle_stage5, + "RamInc opening point should match r_cycle from HammingWeightClaimReduction" + ); + debug_assert_eq!( + _rd_inc_point.r.as_slice(), + r_cycle_stage5, + "RdInc opening point should match r_cycle from HammingWeightClaimReduction" + ); + } // Dense polynomials are zero-padded in the Dory matrix, so their evaluation // includes a factor eq(r_addr, 0) = ∏(1 − r_addr_i). - let lagrange_factor: F = EqPolynomial::zero_selector(r_address_stage7); + let lagrange_factor: F = EqPolynomial::zero_selector(r_address_stage6); polynomial_claims.push((CommittedPolynomial::RamInc, ram_inc_claim * lagrange_factor)); scaling_factors.push(lagrange_factor); polynomial_claims.push((CommittedPolynomial::RdInc, rd_inc_claim * lagrange_factor)); scaling_factors.push(lagrange_factor); // Sparse polynomials: all RA polys (from HammingWeightClaimReduction) - // These are at (r_address_stage7, r_cycle_stage6) + // These are at (r_address_stage6, r_cycle_stage5) for i in 0..self.one_hot_params.instruction_d { let (_, claim) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::InstructionRa(i), @@ -1895,7 +1883,7 @@ where scaling_factors.push(F::one()); } - // Advice polynomials: TrustedAdvice and UntrustedAdvice (from AdviceClaimReduction in Stage 6) + // Advice polynomials: TrustedAdvice and UntrustedAdvice (from AdviceClaimReduction in Stage 5) // These are committed with smaller dimensions, so we apply Lagrange factors to embed // them in the top-left block of the main Dory matrix. #[cfg(feature = "zk")] @@ -2006,14 +1994,13 @@ where { let y_com: C::G1 = PCS::eval_commitment(&proof).expect("ZK proof must have y_com"); bind_opening_inputs_zk::(&mut self.transcript, &opening_point.r, &y_com); - self.blindfold_accumulator.set_opening_proof_data( - crate::subprotocols::blindfold::OpeningProofData { + self.blindfold_accumulator + .set_opening_proof_data(OpeningProofData { opening_ids, constraint_coeffs, joint_claim, y_blinding: _y_blinding.expect("ZK mode requires y_blinding"), - }, - ); + }); } #[cfg(not(feature = "zk"))] { @@ -2071,7 +2058,6 @@ where { #[tracing::instrument(skip_all, name = "JoltProverPreprocessing::gen")] pub fn new(shared: JoltSharedPreprocessing) -> JoltProverPreprocessing { - use common::constants::ONEHOT_CHUNK_THRESHOLD_LOG_T; let max_T: usize = shared.max_padded_trace_length.next_power_of_two(); let max_log_T = max_T.log_2(); let max_log_k_chunk = if max_log_T < ONEHOT_CHUNK_THRESHOLD_LOG_T { @@ -2140,6 +2126,8 @@ mod tests { multilinear_polynomial::MultilinearPolynomial, opening_proof::{OpeningAccumulator, SumcheckId}, }; + #[cfg(feature = "zk")] + use crate::subprotocols::blindfold::StageWitness; use crate::zkvm::claim_reductions::AdviceKind; use crate::zkvm::verifier::JoltSharedPreprocessing; use crate::zkvm::witness::CommittedPolynomial; @@ -2155,7 +2143,7 @@ mod tests { #[cfg(feature = "zk")] fn round_commitment_data( gens: &PedersenGenerators, - stages: &[crate::subprotocols::blindfold::StageWitness], + stages: &[StageWitness], rng: &mut R, ) -> (Vec, Vec>, Vec) { let mut commitments = Vec::new(); @@ -2423,8 +2411,8 @@ mod tests { DoryGlobals::reset(); // SHA2 guest does not consume advice, but providing both trusted and untrusted advice // should still work correctly through the full pipeline: - // - Trusted: commit in preprocessing-only context, reduce in Stage 6, batch in Stage 8 - // - Untrusted: commit at prove time, reduce in Stage 6, batch in Stage 8 + // - Trusted: commit in preprocessing-only context, reduce in Stage 5, batch in Stage 7 + // - Untrusted: commit at prove time, reduce in Stage 5, batch in Stage 7 let mut program = host::Program::new("sha2-guest"); let (bytecode, init_memory_state, _) = program.decode(); let inputs = postcard::to_stdvec(&[5u8; 32]).unwrap(); @@ -2609,8 +2597,8 @@ mod tests { // Tests that advice opening points are correctly derived from the unified main opening // point using Dory's balanced dimension policy. // - // For a small trace (256 cycles), the advice row coordinates span both Stage 6 (cycle) - // and Stage 7 (address) challenges, verifying the two-phase reduction works correctly. + // For a small trace (256 cycles), the advice row coordinates span both Stage 5 (cycle) + // and Stage 6 (address) challenges, verifying the two-phase reduction works correctly. let mut program = host::Program::new("fibonacci-guest"); let inputs = postcard::to_stdvec(&5u32).unwrap(); let trusted_advice = postcard::to_stdvec(&[7u8; 32]).unwrap(); @@ -2969,9 +2957,9 @@ mod tests { ); let (jolt_proof, _) = prover.prove(); - println!("\n=== BlindFold R1CS Satisfaction Test (All 7 Stages) ===\n"); + println!("\n=== BlindFold R1CS Satisfaction Test (All 6 Stages) ===\n"); - // Process all 7 stages and verify each one + // Process all 6 stages and verify each one let stage_proofs: Vec<(&str, &SumcheckInstanceProof)> = vec![ ("Stage 1 (Spartan Outer)", &jolt_proof.stage1_sumcheck_proof), ( @@ -2979,15 +2967,14 @@ mod tests { &jolt_proof.stage2_sumcheck_proof, ), ("Stage 3 (Instruction)", &jolt_proof.stage3_sumcheck_proof), - ("Stage 4 (Registers+RAM)", &jolt_proof.stage4_sumcheck_proof), - ("Stage 5 (Value+Lookup)", &jolt_proof.stage5_sumcheck_proof), + ("Stage 4 (Value+Lookup)", &jolt_proof.stage4_sumcheck_proof), ( - "Stage 6 (OneHot+Hamming)", - &jolt_proof.stage6_sumcheck_proof, + "Stage 5 (OneHot+Hamming)", + &jolt_proof.stage5_sumcheck_proof, ), ( - "Stage 7 (HammingWeight+ClaimReduction)", - &jolt_proof.stage7_sumcheck_proof, + "Stage 6 (HammingWeight+ClaimReduction)", + &jolt_proof.stage6_sumcheck_proof, ), ]; @@ -3290,7 +3277,7 @@ mod tests { let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); - // DoryGlobals is now initialized inside the verifier's verify_stage8 + // DoryGlobals is now initialized inside the verifier's verify_stage7 RV64IMACVerifier::new(&verifier_preprocessing, proof, io_device, None, debug_info) .expect("verifier creation failed") .verify() diff --git a/jolt-core/src/zkvm/r1cs/constraints.rs b/jolt-core/src/zkvm/r1cs/constraints.rs index c000f04570..5722ec08ec 100644 --- a/jolt-core/src/zkvm/r1cs/constraints.rs +++ b/jolt-core/src/zkvm/r1cs/constraints.rs @@ -43,6 +43,12 @@ use strum::EnumCount as EnumCountTrait; use strum_macros::{EnumCount, EnumIter}; pub use super::ops::{Term, LC}; +#[cfg(test)] +use crate::field::JoltField; +#[cfg(test)] +use crate::poly::multilinear_polynomial::MultilinearPolynomial; +#[cfg(test)] +use std::fmt::Write as _; /// A single R1CS constraint row #[derive(Clone, Copy, Debug)] @@ -59,14 +65,12 @@ impl R1CSConstraint { } #[cfg(test)] - pub fn pretty_fmt( + pub fn pretty_fmt( &self, f: &mut String, - flattened_polynomials: &[crate::poly::multilinear_polynomial::MultilinearPolynomial], + flattened_polynomials: &[MultilinearPolynomial], step_index: usize, ) -> std::fmt::Result { - use std::fmt::Write as _; - self.a.pretty_fmt(f)?; write!(f, " ⋅ ")?; self.b.pretty_fmt(f)?; @@ -107,8 +111,6 @@ impl R1CSConstraint { f: &mut String, row: &super::inputs::R1CSCycleInputs, ) -> std::fmt::Result { - use std::fmt::Write as _; - self.a.pretty_fmt(f)?; write!(f, " ⋅ ")?; self.b.pretty_fmt(f)?; @@ -162,7 +164,6 @@ pub enum R1CSConstraintLabel { LeftLookupEqLeftInputOtherwise, RightLookupAdd, RightLookupSub, - RightLookupEqProductIfMul, RightLookupEqRightInputOtherwise, AssertLookupOne, RdWriteEqLookupIfWriteLookupToRd, @@ -182,14 +183,12 @@ pub struct NamedR1CSConstraint { impl NamedR1CSConstraint { #[cfg(test)] - pub fn pretty_fmt( + pub fn pretty_fmt( &self, f: &mut String, - flattened_polynomials: &[crate::poly::multilinear_polynomial::MultilinearPolynomial], + flattened_polynomials: &[MultilinearPolynomial], step_index: usize, ) -> std::fmt::Result { - use std::fmt::Write as _; - writeln!(f, "[{:?}]", self.label)?; self.cons.pretty_fmt(f, flattened_polynomials, step_index) } @@ -200,8 +199,6 @@ impl NamedR1CSConstraint { f: &mut String, row: &super::inputs::R1CSCycleInputs, ) -> std::fmt::Result { - use std::fmt::Write as _; - writeln!(f, "[{:?}]", self.label)?; self.cons.pretty_fmt_with_row(f, row) } @@ -301,11 +298,6 @@ pub static R1CS_CONSTRAINTS: [NamedR1CSConstraint; NUM_R1CS_CONSTRAINTS] = [ if { { JoltR1CSInputs::OpFlags(CircuitFlags::SubtractOperands) } } => ( { JoltR1CSInputs::RightLookupOperand } ) == ( { JoltR1CSInputs::LeftInstructionInput } - { JoltR1CSInputs::RightInstructionInput } + { 0x10000000000000000i128 } ) ), - r1cs_eq_conditional!( - label: R1CSConstraintLabel::RightLookupEqProductIfMul, - if { { JoltR1CSInputs::OpFlags(CircuitFlags::MultiplyOperands) } } - => ( { JoltR1CSInputs::RightLookupOperand } ) == ( { JoltR1CSInputs::Product } ) - ), // if !(AddOperands || SubtractOperands || MultiplyOperands || Advice) { // assert!(RightLookupOperand == RightInstructionInput) // } @@ -495,7 +487,6 @@ const fn complement_first_group_labels() -> [R1CSConstraintLabel; NUM_REMAINING_ /// First group: 10 boolean-guarded eq constraints, where Bz is around 64 bits pub const R1CS_CONSTRAINTS_FIRST_GROUP_LABELS: [R1CSConstraintLabel; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE] = [ - R1CSConstraintLabel::RamAddrEqZeroIfNotLoadStore, R1CSConstraintLabel::RamReadEqRamWriteIfLoad, R1CSConstraintLabel::RamReadEqRdWriteIfLoad, R1CSConstraintLabel::Rs2EqRamWriteIfStore, @@ -520,8 +511,8 @@ pub static R1CS_CONSTRAINTS_FIRST_GROUP: [NamedR1CSConstraint; OUTER_UNIVARIATE_ pub static R1CS_CONSTRAINTS_SECOND_GROUP: [NamedR1CSConstraint; NUM_REMAINING_R1CS_CONSTRAINTS] = filter_r1cs_constraints(&R1CS_CONSTRAINTS_SECOND_GROUP_LABELS); -/// Domain sizing for product-virtualization univariate-skip (size-5 window) -pub const NUM_PRODUCT_VIRTUAL: usize = 5; +/// Domain sizing for product-virtualization univariate-skip (size-4 window) +pub const NUM_PRODUCT_VIRTUAL: usize = 4; pub const PRODUCT_VIRTUAL_UNIVARIATE_SKIP_DOMAIN_SIZE: usize = NUM_PRODUCT_VIRTUAL; pub const PRODUCT_VIRTUAL_UNIVARIATE_SKIP_DEGREE: usize = NUM_PRODUCT_VIRTUAL - 1; pub const PRODUCT_VIRTUAL_UNIVARIATE_SKIP_EXTENDED_DOMAIN_SIZE: usize = @@ -534,7 +525,6 @@ pub const PRODUCT_VIRTUAL_FIRST_ROUND_POLY_DEGREE_BOUND: usize = #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, EnumCount, EnumIter)] pub enum ProductConstraintLabel { - Instruction, WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, @@ -565,14 +555,7 @@ pub struct ProductConstraint { /// Canonical list of the product constraints in the same order as /// `PRODUCT_VIRTUAL_TERMS` used by the product virtualization stage. pub const PRODUCT_CONSTRAINTS: [ProductConstraint; NUM_PRODUCT_CONSTRAINTS] = [ - // 0: Product = LeftInstructionInput · RightInstructionInput - ProductConstraint { - label: ProductConstraintLabel::Instruction, - left: ProductFactorExpr::Var(VirtualPolynomial::LeftInstructionInput), - right: ProductFactorExpr::Var(VirtualPolynomial::RightInstructionInput), - output: VirtualPolynomial::Product, - }, - // 1: WriteLookupOutputToRD = IsRdNotZero · OpFlags(WriteLookupOutputToRD) + // 0: WriteLookupOutputToRD = IsRdNotZero · OpFlags(WriteLookupOutputToRD) ProductConstraint { label: ProductConstraintLabel::WriteLookupOutputToRD, left: ProductFactorExpr::Var(VirtualPolynomial::InstructionFlags( @@ -583,7 +566,7 @@ pub const PRODUCT_CONSTRAINTS: [ProductConstraint; NUM_PRODUCT_CONSTRAINTS] = [ )), output: VirtualPolynomial::WriteLookupOutputToRD, }, - // 2: WritePCtoRD = IsRdNotZero · OpFlags(Jump) + // 1: WritePCtoRD = IsRdNotZero · OpFlags(Jump) ProductConstraint { label: ProductConstraintLabel::WritePCtoRD, left: ProductFactorExpr::Var(VirtualPolynomial::InstructionFlags( @@ -592,7 +575,7 @@ pub const PRODUCT_CONSTRAINTS: [ProductConstraint; NUM_PRODUCT_CONSTRAINTS] = [ right: ProductFactorExpr::Var(VirtualPolynomial::OpFlags(CircuitFlags::Jump)), output: VirtualPolynomial::WritePCtoRD, }, - // 3: ShouldBranch = LookupOutput · InstructionFlags(Branch) + // 2: ShouldBranch = LookupOutput · InstructionFlags(Branch) ProductConstraint { label: ProductConstraintLabel::ShouldBranch, left: ProductFactorExpr::Var(VirtualPolynomial::LookupOutput), @@ -601,7 +584,7 @@ pub const PRODUCT_CONSTRAINTS: [ProductConstraint; NUM_PRODUCT_CONSTRAINTS] = [ )), output: VirtualPolynomial::ShouldBranch, }, - // 4: ShouldJump = OpFlags(Jump) · (1 − NextIsNoop) + // 3: ShouldJump = OpFlags(Jump) · (1 − NextIsNoop) ProductConstraint { label: ProductConstraintLabel::ShouldJump, left: ProductFactorExpr::Var(VirtualPolynomial::OpFlags(CircuitFlags::Jump)), diff --git a/jolt-core/src/zkvm/r1cs/evaluation.rs b/jolt-core/src/zkvm/r1cs/evaluation.rs index 807593488a..9dfcb4bd64 100644 --- a/jolt-core/src/zkvm/r1cs/evaluation.rs +++ b/jolt-core/src/zkvm/r1cs/evaluation.rs @@ -128,7 +128,6 @@ pub(crate) const PRODUCT_VIRTUAL_COEFFS_PER_J: [[i32; /// Boolean guards for the first group (univariate-skip base window) #[derive(Clone, Copy, Debug)] pub struct AzFirstGroup { - pub not_load_store: bool, // !(Load || Store) pub load_a: bool, // Load pub load_b: bool, // Load pub store: bool, // Store @@ -150,31 +149,29 @@ impl AzFirstGroup { w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], acc: &mut SmallAccumU, ) { - acc.fmadd(&w[0], &self.not_load_store); - acc.fmadd(&w[1], &self.load_a); - acc.fmadd(&w[2], &self.load_b); - acc.fmadd(&w[3], &self.store); - acc.fmadd(&w[4], &self.add_sub_mul); - acc.fmadd(&w[5], &self.not_add_sub_mul); - acc.fmadd(&w[6], &self.assert_flag); - acc.fmadd(&w[7], &self.should_jump); - acc.fmadd(&w[8], &self.virtual_instr_not_last); - acc.fmadd(&w[9], &self.must_start_sequence); + acc.fmadd(&w[0], &self.load_a); + acc.fmadd(&w[1], &self.load_b); + acc.fmadd(&w[2], &self.store); + acc.fmadd(&w[3], &self.add_sub_mul); + acc.fmadd(&w[4], &self.not_add_sub_mul); + acc.fmadd(&w[5], &self.assert_flag); + acc.fmadd(&w[6], &self.should_jump); + acc.fmadd(&w[7], &self.virtual_instr_not_last); + acc.fmadd(&w[8], &self.must_start_sequence); } } /// Magnitudes for the first group (kept small: bool/u64/S64) #[derive(Clone, Copy, Debug)] pub struct BzFirstGroup { - pub ram_addr: u64, // RamAddress - 0 - pub ram_read_minus_ram_write: S64, // RamRead - RamWrite - pub ram_read_minus_rd_write: S64, // RamRead - RdWrite - pub rs2_minus_ram_write: S64, // Rs2 - RamWrite - pub left_lookup: u64, // LeftLookup - 0 - pub left_lookup_minus_left_input: S64, // LeftLookup - LeftInstructionInput - pub lookup_output_minus_one: S64, // LookupOutput - 1 - pub next_unexp_pc_minus_lookup_output: S64, // NextUnexpandedPC - LookupOutput - pub next_pc_minus_pc_plus_one: S64, // NextPC - (PC + 1) + pub ram_read_minus_ram_write: S64, // RamRead - RamWrite + pub ram_read_minus_rd_write: S64, // RamRead - RdWrite + pub rs2_minus_ram_write: S64, // Rs2 - RamWrite + pub left_lookup: u64, // LeftLookup - 0 + pub left_lookup_minus_left_input: S64, // LeftLookup - LeftInstructionInput + pub lookup_output_minus_one: S64, // LookupOutput - 1 + pub next_unexp_pc_minus_lookup_output: S64, // NextUnexpandedPC - LookupOutput + pub next_pc_minus_pc_plus_one: S64, // NextPC - (PC + 1) pub one_minus_do_not_update_unexpanded_pc: bool, // 1 - DoNotUpdateUnexpandedPC } @@ -188,16 +185,15 @@ impl BzFirstGroup { w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], acc: &mut MedAccumS, ) { - acc.fmadd(&w[0], &self.ram_addr); - acc.fmadd(&w[1], &self.ram_read_minus_ram_write); - acc.fmadd(&w[2], &self.ram_read_minus_rd_write); - acc.fmadd(&w[3], &self.rs2_minus_ram_write); - acc.fmadd(&w[4], &self.left_lookup); - acc.fmadd(&w[5], &self.left_lookup_minus_left_input); - acc.fmadd(&w[6], &self.lookup_output_minus_one); - acc.fmadd(&w[7], &self.next_unexp_pc_minus_lookup_output); - acc.fmadd(&w[8], &self.next_pc_minus_pc_plus_one); - acc.fmadd(&w[9], &self.one_minus_do_not_update_unexpanded_pc); + acc.fmadd(&w[0], &self.ram_read_minus_ram_write); + acc.fmadd(&w[1], &self.ram_read_minus_rd_write); + acc.fmadd(&w[2], &self.rs2_minus_ram_write); + acc.fmadd(&w[3], &self.left_lookup); + acc.fmadd(&w[4], &self.left_lookup_minus_left_input); + acc.fmadd(&w[5], &self.lookup_output_minus_one); + acc.fmadd(&w[6], &self.next_unexp_pc_minus_lookup_output); + acc.fmadd(&w[7], &self.next_pc_minus_pc_plus_one); + acc.fmadd(&w[8], &self.one_minus_do_not_update_unexpanded_pc); } } @@ -205,9 +201,9 @@ impl BzFirstGroup { #[derive(Clone, Copy, Debug)] pub struct AzSecondGroup { pub load_or_store: bool, // Load || Store + pub not_load_store: bool, // !(Load || Store) pub add: bool, // Add pub sub: bool, // Sub - pub mul: bool, // Mul pub not_add_sub_mul_advice: bool, // !(Add || Sub || Mul || Advice) pub write_lookup_to_rd: bool, // write_lookup_output_to_rd_addr (Rd != 0) pub write_pc_to_rd: bool, // write_pc_to_rd_addr (Rd != 0) @@ -226,9 +222,9 @@ impl AzSecondGroup { acc: &mut SmallAccumU, ) { acc.fmadd(&w[0], &self.load_or_store); - acc.fmadd(&w[1], &self.add); - acc.fmadd(&w[2], &self.sub); - acc.fmadd(&w[3], &self.mul); + acc.fmadd(&w[1], &self.not_load_store); + acc.fmadd(&w[2], &self.add); + acc.fmadd(&w[3], &self.sub); acc.fmadd(&w[4], &self.not_add_sub_mul_advice); acc.fmadd(&w[5], &self.write_lookup_to_rd); acc.fmadd(&w[6], &self.write_pc_to_rd); @@ -241,9 +237,9 @@ impl AzSecondGroup { #[derive(Clone, Copy, Debug)] pub struct BzSecondGroup { pub ram_addr_minus_rs1_plus_imm: i128, // RamAddress - (Rs1 + Imm) + pub ram_addr: u64, // RamAddress - 0 pub right_lookup_minus_add_result: S160, // RightLookup - (Left + Right) pub right_lookup_minus_sub_result: S160, // RightLookup - (Left - Right + 2^64) - pub right_lookup_minus_product: S160, // RightLookup - Product pub right_lookup_minus_right_input: S160, // RightLookup - RightInput pub rd_write_minus_lookup_output: S64, // RdWrite - LookupOutput pub rd_write_minus_pc_plus_const: S64, // RdWrite - (UnexpandedPC + const) @@ -262,9 +258,9 @@ impl BzSecondGroup { acc: &mut WideAccumS, ) { acc.fmadd(&w[0], &self.ram_addr_minus_rs1_plus_imm); - acc.fmadd(&w[1], &self.right_lookup_minus_add_result); - acc.fmadd(&w[2], &self.right_lookup_minus_sub_result); - acc.fmadd(&w[3], &self.right_lookup_minus_product); + acc.fmadd(&w[1], &self.ram_addr); + acc.fmadd(&w[2], &self.right_lookup_minus_add_result); + acc.fmadd(&w[3], &self.right_lookup_minus_sub_result); acc.fmadd(&w[4], &self.right_lookup_minus_right_input); acc.fmadd(&w[5], &self.rd_write_minus_lookup_output); acc.fmadd(&w[6], &self.rd_write_minus_pc_plus_const); @@ -301,7 +297,6 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { let inline_seq = flags[CircuitFlags::VirtualInstruction]; AzFirstGroup { - not_load_store: !(ld || st), load_a: ld, load_b: ld, store: st, @@ -317,7 +312,6 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { #[inline] pub fn eval_bz_first_group(&self) -> BzFirstGroup { BzFirstGroup { - ram_addr: self.row.ram_addr, ram_read_minus_ram_write: s64_from_diff_u64s( self.row.ram_read_value, self.row.ram_write_value, @@ -353,16 +347,15 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { pub fn az_at_r_first_group(&self, w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]) -> F { let az = self.eval_az_first_group(); let mut acc: SmallAccumU = SmallAccumU::zero(); - acc.fmadd(&w[0], &az.not_load_store); - acc.fmadd(&w[1], &az.load_a); - acc.fmadd(&w[2], &az.load_b); - acc.fmadd(&w[3], &az.store); - acc.fmadd(&w[4], &az.add_sub_mul); - acc.fmadd(&w[5], &az.not_add_sub_mul); - acc.fmadd(&w[6], &az.assert_flag); - acc.fmadd(&w[7], &az.should_jump); - acc.fmadd(&w[8], &az.virtual_instr_not_last); - acc.fmadd(&w[9], &az.must_start_sequence); + acc.fmadd(&w[0], &az.load_a); + acc.fmadd(&w[1], &az.load_b); + acc.fmadd(&w[2], &az.store); + acc.fmadd(&w[3], &az.add_sub_mul); + acc.fmadd(&w[4], &az.not_add_sub_mul); + acc.fmadd(&w[5], &az.assert_flag); + acc.fmadd(&w[6], &az.should_jump); + acc.fmadd(&w[7], &az.virtual_instr_not_last); + acc.fmadd(&w[8], &az.must_start_sequence); acc.barrett_reduce() } @@ -370,16 +363,15 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { pub fn bz_at_r_first_group(&self, w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]) -> F { let bz = self.eval_bz_first_group(); let mut acc: MedAccumS = MedAccumS::zero(); - acc.fmadd(&w[0], &bz.ram_addr); - acc.fmadd(&w[1], &bz.ram_read_minus_ram_write); - acc.fmadd(&w[2], &bz.ram_read_minus_rd_write); - acc.fmadd(&w[3], &bz.rs2_minus_ram_write); - acc.fmadd(&w[4], &bz.left_lookup); - acc.fmadd(&w[5], &bz.left_lookup_minus_left_input); - acc.fmadd(&w[6], &bz.lookup_output_minus_one); - acc.fmadd(&w[7], &bz.next_unexp_pc_minus_lookup_output); - acc.fmadd(&w[8], &bz.next_pc_minus_pc_plus_one); - acc.fmadd(&w[9], &bz.one_minus_do_not_update_unexpanded_pc); + acc.fmadd(&w[0], &bz.ram_read_minus_ram_write); + acc.fmadd(&w[1], &bz.ram_read_minus_rd_write); + acc.fmadd(&w[2], &bz.rs2_minus_ram_write); + acc.fmadd(&w[3], &bz.left_lookup); + acc.fmadd(&w[4], &bz.left_lookup_minus_left_input); + acc.fmadd(&w[5], &bz.lookup_output_minus_one); + acc.fmadd(&w[6], &bz.next_unexp_pc_minus_lookup_output); + acc.fmadd(&w[7], &bz.next_pc_minus_pc_plus_one); + acc.fmadd(&w[8], &bz.one_minus_do_not_update_unexpanded_pc); acc.barrett_reduce() } @@ -412,73 +404,66 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { let mut bz_eval_s128: S128Sum = S128Sum::zero(); let c0_i32 = coeffs_i32[0]; - if az.not_load_store { + if az.load_a { az_eval_i32 += c0_i32; } else { - bz_eval_s128.fmadd(&c0_i32, &bz.ram_addr); + bz_eval_s128.fmadd(&c0_i32, &bz.ram_read_minus_ram_write); } let c1_i32 = coeffs_i32[1]; - if az.load_a { + if az.load_b { az_eval_i32 += c1_i32; } else { - bz_eval_s128.fmadd(&c1_i32, &bz.ram_read_minus_ram_write); + bz_eval_s128.fmadd(&c1_i32, &bz.ram_read_minus_rd_write); } let c2_i32 = coeffs_i32[2]; - if az.load_b { + if az.store { az_eval_i32 += c2_i32; } else { - bz_eval_s128.fmadd(&c2_i32, &bz.ram_read_minus_rd_write); + bz_eval_s128.fmadd(&c2_i32, &bz.rs2_minus_ram_write); } let c3_i32 = coeffs_i32[3]; - if az.store { + if az.add_sub_mul { az_eval_i32 += c3_i32; } else { - bz_eval_s128.fmadd(&c3_i32, &bz.rs2_minus_ram_write); + bz_eval_s128.fmadd(&c3_i32, &bz.left_lookup); } let c4_i32 = coeffs_i32[4]; - if az.add_sub_mul { + if az.not_add_sub_mul { az_eval_i32 += c4_i32; } else { - bz_eval_s128.fmadd(&c4_i32, &bz.left_lookup); + bz_eval_s128.fmadd(&c4_i32, &bz.left_lookup_minus_left_input); } let c5_i32 = coeffs_i32[5]; - if az.not_add_sub_mul { + if az.assert_flag { az_eval_i32 += c5_i32; } else { - bz_eval_s128.fmadd(&c5_i32, &bz.left_lookup_minus_left_input); + bz_eval_s128.fmadd(&c5_i32, &bz.lookup_output_minus_one); } let c6_i32 = coeffs_i32[6]; - if az.assert_flag { + if az.should_jump { az_eval_i32 += c6_i32; } else { - bz_eval_s128.fmadd(&c6_i32, &bz.lookup_output_minus_one); + bz_eval_s128.fmadd(&c6_i32, &bz.next_unexp_pc_minus_lookup_output); } let c7_i32 = coeffs_i32[7]; - if az.should_jump { + if az.virtual_instr_not_last { az_eval_i32 += c7_i32; } else { - bz_eval_s128.fmadd(&c7_i32, &bz.next_unexp_pc_minus_lookup_output); + bz_eval_s128.fmadd(&c7_i32, &bz.next_pc_minus_pc_plus_one); } let c8_i32 = coeffs_i32[8]; - if az.virtual_instr_not_last { - az_eval_i32 += c8_i32; - } else { - bz_eval_s128.fmadd(&c8_i32, &bz.next_pc_minus_pc_plus_one); - } - - let c9_i32 = coeffs_i32[9]; if az.must_start_sequence { - az_eval_i32 += c9_i32; + az_eval_i32 += c8_i32; } else { - bz_eval_s128.fmadd(&c9_i32, &bz.one_minus_do_not_update_unexpanded_pc); + bz_eval_s128.fmadd(&c8_i32, &bz.one_minus_do_not_update_unexpanded_pc); } let az_eval_s64 = S64::from_i64(az_eval_i32 as i64); @@ -503,37 +488,36 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { pub fn assert_constraints_first_group(&self) { let az = self.eval_az_first_group(); let bz = self.eval_bz_first_group(); - self.assert_constraint_first_group(0, az.not_load_store, bz.ram_addr == 0); self.assert_constraint_first_group( - 1, + 0, az.load_a, bz.ram_read_minus_ram_write.to_i128() == 0, ); - self.assert_constraint_first_group(2, az.load_b, bz.ram_read_minus_rd_write.to_i128() == 0); - self.assert_constraint_first_group(3, az.store, bz.rs2_minus_ram_write.to_i128() == 0); - self.assert_constraint_first_group(4, az.add_sub_mul, bz.left_lookup == 0); + self.assert_constraint_first_group(1, az.load_b, bz.ram_read_minus_rd_write.to_i128() == 0); + self.assert_constraint_first_group(2, az.store, bz.rs2_minus_ram_write.to_i128() == 0); + self.assert_constraint_first_group(3, az.add_sub_mul, bz.left_lookup == 0); self.assert_constraint_first_group( - 5, + 4, az.not_add_sub_mul, bz.left_lookup_minus_left_input.to_i128() == 0, ); self.assert_constraint_first_group( - 6, + 5, az.assert_flag, bz.lookup_output_minus_one.to_i128() == 0, ); self.assert_constraint_first_group( - 7, + 6, az.should_jump, bz.next_unexp_pc_minus_lookup_output.to_i128() == 0, ); self.assert_constraint_first_group( - 8, + 7, az.virtual_instr_not_last, bz.next_pc_minus_pc_plus_one.to_i128() == 0, ); self.assert_constraint_first_group( - 9, + 8, az.must_start_sequence, !bz.one_minus_do_not_update_unexpanded_pc, ); @@ -554,9 +538,9 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { AzSecondGroup { load_or_store: (flags[CircuitFlags::Load] || flags[CircuitFlags::Store]), + not_load_store: !(flags[CircuitFlags::Load] || flags[CircuitFlags::Store]), add: flags[CircuitFlags::AddOperands], sub: flags[CircuitFlags::SubtractOperands], - mul: flags[CircuitFlags::MultiplyOperands], not_add_sub_mul_advice, write_lookup_to_rd: self.row.write_lookup_output_to_rd_addr, write_pc_to_rd: self.row.write_pc_to_rd_addr, @@ -584,8 +568,6 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { S160::from(self.row.right_lookup) - S160::from(right_add_expected); let right_lookup_minus_sub_result = S160::from(self.row.right_lookup) - S160::from(right_sub_expected); - let right_lookup_minus_product = - S160::from(self.row.right_lookup) - S160::from(self.row.product); let right_lookup_minus_right_input = S160::from(self.row.right_lookup) - S160::from(self.row.right_input); @@ -620,9 +602,9 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { BzSecondGroup { ram_addr_minus_rs1_plus_imm, + ram_addr: self.row.ram_addr, right_lookup_minus_add_result, right_lookup_minus_sub_result, - right_lookup_minus_product, right_lookup_minus_right_input, rd_write_minus_lookup_output, rd_write_minus_pc_plus_const, @@ -632,14 +614,13 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { } #[inline] - pub fn az_at_r_second_group(&self, _w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]) -> F { - let w = _w; + pub fn az_at_r_second_group(&self, w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]) -> F { let az = self.eval_az_second_group(); let mut acc: SmallAccumU = SmallAccumU::zero(); acc.fmadd(&w[0], &az.load_or_store); - acc.fmadd(&w[1], &az.add); - acc.fmadd(&w[2], &az.sub); - acc.fmadd(&w[3], &az.mul); + acc.fmadd(&w[1], &az.not_load_store); + acc.fmadd(&w[2], &az.add); + acc.fmadd(&w[3], &az.sub); acc.fmadd(&w[4], &az.not_add_sub_mul_advice); acc.fmadd(&w[5], &az.write_lookup_to_rd); acc.fmadd(&w[6], &az.write_pc_to_rd); @@ -649,14 +630,13 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { } #[inline] - pub fn bz_at_r_second_group(&self, _w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]) -> F { - let w = _w; + pub fn bz_at_r_second_group(&self, w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]) -> F { let bz = self.eval_bz_second_group(); let mut acc = WideAccumS::::zero(); acc.fmadd(&w[0], &bz.ram_addr_minus_rs1_plus_imm); - acc.fmadd(&w[1], &bz.right_lookup_minus_add_result); - acc.fmadd(&w[2], &bz.right_lookup_minus_sub_result); - acc.fmadd(&w[3], &bz.right_lookup_minus_product); + acc.fmadd(&w[1], &bz.ram_addr); + acc.fmadd(&w[2], &bz.right_lookup_minus_add_result); + acc.fmadd(&w[3], &bz.right_lookup_minus_sub_result); acc.fmadd(&w[4], &bz.right_lookup_minus_right_input); acc.fmadd(&w[5], &bz.rd_write_minus_lookup_output); acc.fmadd(&w[6], &bz.rd_write_minus_pc_plus_const); @@ -701,24 +681,24 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { } let c1 = coeffs_i32[1]; - if az.add { + if az.not_load_store { az_eval_i32 += c1; } else { - bz_eval_s192.fmadd(&c1, &bz.right_lookup_minus_add_result); + bz_eval_s192.fmadd(&c1, &bz.ram_addr); } let c2 = coeffs_i32[2]; - if az.sub { + if az.add { az_eval_i32 += c2; } else { - bz_eval_s192.fmadd(&c2, &bz.right_lookup_minus_sub_result); + bz_eval_s192.fmadd(&c2, &bz.right_lookup_minus_add_result); } let c3 = coeffs_i32[3]; - if az.mul { + if az.sub { az_eval_i32 += c3; } else { - bz_eval_s192.fmadd(&c3, &bz.right_lookup_minus_product); + bz_eval_s192.fmadd(&c3, &bz.right_lookup_minus_sub_result); } let c4 = coeffs_i32[4]; @@ -783,9 +763,9 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { az.load_or_store, bz.ram_addr_minus_rs1_plus_imm == 0i128, ); - self.assert_constraint_second_group(1, az.add, bz.right_lookup_minus_add_result.is_zero()); - self.assert_constraint_second_group(2, az.sub, bz.right_lookup_minus_sub_result.is_zero()); - self.assert_constraint_second_group(3, az.mul, bz.right_lookup_minus_product.is_zero()); + self.assert_constraint_second_group(1, az.not_load_store, bz.ram_addr == 0); + self.assert_constraint_second_group(2, az.add, bz.right_lookup_minus_add_result.is_zero()); + self.assert_constraint_second_group(3, az.sub, bz.right_lookup_minus_sub_result.is_zero()); self.assert_constraint_second_group( 4, az.not_add_sub_mul_advice, @@ -837,7 +817,6 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { // If S128 => 7 limbs signed let mut acc_left_input: MedAccumU = MedAccumU::zero(); let mut acc_right_input: MedAccumS = MedAccumS::zero(); - let mut acc_product = WideAccumS::::zero(); let mut acc_wl_left: SmallAccumU = SmallAccumU::zero(); let mut acc_wp_left: SmallAccumU = SmallAccumU::zero(); let mut acc_sb_right: SmallAccumU = SmallAccumU::zero(); @@ -870,7 +849,6 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { acc_left_input.fmadd(&e_in, &row.left_input); acc_right_input.fmadd(&e_in, &row.right_input.to_i128()); - acc_product.fmadd(&e_in, &row.product); acc_wl_left.fmadd(&e_in, &(row.write_lookup_output_to_rd_addr as u64)); acc_wp_left.fmadd(&e_in, &(row.write_pc_to_rd_addr as u64)); @@ -904,8 +882,6 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { eq1_val.mul_to_product_accum(acc_left_input.barrett_reduce()); out_unr[JoltR1CSInputs::RightInstructionInput.to_index()] = eq1_val.mul_to_product_accum(acc_right_input.barrett_reduce()); - out_unr[JoltR1CSInputs::Product.to_index()] = - eq1_val.mul_to_product_accum(acc_product.barrett_reduce()); out_unr[JoltR1CSInputs::WriteLookupOutputToRD.to_index()] = eq1_val.mul_to_product_accum(acc_wl_left.barrett_reduce()); out_unr[JoltR1CSInputs::WritePCtoRD.to_index()] = @@ -972,71 +948,61 @@ pub struct ProductVirtualEval; impl ProductVirtualEval { /// Compute both fused left and right factors at r0 weights for a single cycle row. - /// Expected order of weights: [Instruction, WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, ShouldJump] + /// Expected order of weights: [WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, ShouldJump] #[inline] pub fn fused_left_right_at_r( row: &ProductCycleInputs, weights_at_r0: &[F], ) -> (F, F) { - // Left: u64/u8/bool let mut left_acc: MedAccumU = MedAccumU::zero(); - left_acc.fmadd(&weights_at_r0[0], &row.instruction_left_input); + left_acc.fmadd(&weights_at_r0[0], &row.is_rd_not_zero); left_acc.fmadd(&weights_at_r0[1], &row.is_rd_not_zero); - left_acc.fmadd(&weights_at_r0[2], &row.is_rd_not_zero); - left_acc.fmadd(&weights_at_r0[3], &row.should_branch_lookup_output); - left_acc.fmadd(&weights_at_r0[4], &row.jump_flag); + left_acc.fmadd(&weights_at_r0[2], &row.should_branch_lookup_output); + left_acc.fmadd(&weights_at_r0[3], &row.jump_flag); - // Right: i128/bool let mut right_acc: MedAccumS = MedAccumS::zero(); - right_acc.fmadd(&weights_at_r0[0], &row.instruction_right_input); - right_acc.fmadd(&weights_at_r0[1], &row.write_lookup_output_to_rd_flag); - right_acc.fmadd(&weights_at_r0[2], &row.jump_flag); - right_acc.fmadd(&weights_at_r0[3], &row.should_branch_flag); - right_acc.fmadd(&weights_at_r0[4], &row.not_next_noop); + right_acc.fmadd(&weights_at_r0[0], &row.write_lookup_output_to_rd_flag); + right_acc.fmadd(&weights_at_r0[1], &row.jump_flag); + right_acc.fmadd(&weights_at_r0[2], &row.should_branch_flag); + right_acc.fmadd(&weights_at_r0[3], &row.not_next_noop); (left_acc.barrett_reduce(), right_acc.barrett_reduce()) } /// Compute the fused left·right product at the j-th extended uniskip target for product virtualization. - /// Uses precomputed integer Lagrange coefficients over the size-5 base window and returns an S256 product. + /// Uses precomputed integer Lagrange coefficients over the size-4 base window and returns an S256 product. #[inline] pub fn extended_fused_product_at_j(row: &ProductCycleInputs, j: usize) -> S256 { let c: &[i32; PRODUCT_VIRTUAL_UNIVARIATE_SKIP_DOMAIN_SIZE] = &PRODUCT_VIRTUAL_COEFFS_PER_J[j]; - // Weighted components lifted to i128 let mut left_w: [i128; NUM_PRODUCT_VIRTUAL] = [0; NUM_PRODUCT_VIRTUAL]; let mut right_w: [i128; NUM_PRODUCT_VIRTUAL] = [0; NUM_PRODUCT_VIRTUAL]; - // 0: Instruction (LeftInstructionInput × RightInstructionInput) - left_w[0] = (c[0] as i128) * (row.instruction_left_input as i128); - right_w[0] = (c[0] as i128) * row.instruction_right_input; - - // 1: WriteLookupOutputToRD (IsRdNotZero × WriteLookupOutputToRD_flag) - left_w[1] = if row.is_rd_not_zero { c[1] as i128 } else { 0 }; - right_w[1] = if row.write_lookup_output_to_rd_flag { - c[1] as i128 + // 0: WriteLookupOutputToRD (IsRdNotZero × WriteLookupOutputToRD_flag) + left_w[0] = if row.is_rd_not_zero { c[0] as i128 } else { 0 }; + right_w[0] = if row.write_lookup_output_to_rd_flag { + c[0] as i128 } else { 0 }; - // 2: WritePCtoRD (IsRdNotZero × Jump_flag) - left_w[2] = if row.is_rd_not_zero { c[2] as i128 } else { 0 }; - right_w[2] = if row.jump_flag { c[2] as i128 } else { 0 }; + // 1: WritePCtoRD (IsRdNotZero × Jump_flag) + left_w[1] = if row.is_rd_not_zero { c[1] as i128 } else { 0 }; + right_w[1] = if row.jump_flag { c[1] as i128 } else { 0 }; - // 3: ShouldBranch (LookupOutput × Branch_flag) - left_w[3] = (c[3] as i128) * (row.should_branch_lookup_output as i128); - right_w[3] = if row.should_branch_flag { - c[3] as i128 + // 2: ShouldBranch (LookupOutput × Branch_flag) + left_w[2] = (c[2] as i128) * (row.should_branch_lookup_output as i128); + right_w[2] = if row.should_branch_flag { + c[2] as i128 } else { 0 }; - // 4: ShouldJump (Jump_flag × (1 − NextIsNoop)) - left_w[4] = if row.jump_flag { c[4] as i128 } else { 0 }; - right_w[4] = if row.not_next_noop { c[4] as i128 } else { 0 }; + // 3: ShouldJump (Jump_flag × (1 − NextIsNoop)) + left_w[3] = if row.jump_flag { c[3] as i128 } else { 0 }; + right_w[3] = if row.not_next_noop { c[3] as i128 } else { 0 }; - // Fuse in i128, then multiply as S128×S128 → S256 let mut left_sum: i128 = 0; let mut right_sum: i128 = 0; let mut i = 0; @@ -1050,22 +1016,20 @@ impl ProductVirtualEval { left_s128.mul_trunc::<2, 4>(&right_s128) } - /// Compute z(r_cycle) for the 9 de-duplicated factor polynomials used by Product Virtualization. + /// Compute z(r_cycle) for the 7 de-duplicated factor polynomials used by Product Virtualization. /// Order of outputs matches PRODUCT_UNIQUE_FACTOR_VIRTUALS: - /// 0: LeftInstructionInput (u64) - /// 1: RightInstructionInput (i128) - /// 2: IsRdNotZero (bool) - /// 3: OpFlags(WriteLookupOutputToRD) (bool) - /// 4: OpFlags(Jump) (bool) - /// 5: LookupOutput (u64) - /// 6: InstructionFlags(Branch) (bool) - /// 7: NextIsNoop (bool) - /// 8: OpFlags(VirtualInstruction) (bool) — not a product factor, opened for downstream stages + /// 0: IsRdNotZero (bool) + /// 1: OpFlags(WriteLookupOutputToRD) (bool) + /// 2: OpFlags(Jump) (bool) + /// 3: LookupOutput (u64) + /// 4: InstructionFlags(Branch) (bool) + /// 5: NextIsNoop (bool) + /// 6: OpFlags(VirtualInstruction) (bool) — not a product factor, opened for downstream stages #[tracing::instrument(skip_all, name = "ProductVirtualEval::compute_claimed_factors")] pub fn compute_claimed_factors( trace: &[tracer::instruction::Cycle], r_cycle: &OpeningPoint, - ) -> [F; 9] { + ) -> [F; 7] { let m = r_cycle.len() / 2; let (r2, r1) = r_cycle.split_at_r(m); let (eq_one, eq_two) = rayon::join(|| EqPolynomial::evals(r2), || EqPolynomial::evals(r1)); @@ -1077,9 +1041,6 @@ impl ProductVirtualEval { .map(|x1| { let eq1_val = eq_one[x1]; - // Accumulators for 9 outputs - let mut acc_left_u64: MedAccumU = MedAccumU::zero(); - let mut acc_right_i128: MedAccumS = MedAccumS::zero(); let mut acc_rd_zero_flag: SmallAccumU = SmallAccumU::zero(); let mut acc_wl_flag: SmallAccumU = SmallAccumU::zero(); let mut acc_jump_flag: SmallAccumU = SmallAccumU::zero(); @@ -1093,42 +1054,29 @@ impl ProductVirtualEval { let idx = x1 * eq_two_len + x2; let row = ProductCycleInputs::from_trace::(trace, idx); - // 0: LeftInstructionInput (u64) - acc_left_u64.fmadd(&e_in, &row.instruction_left_input); - // 1: RightInstructionInput (i128) - acc_right_i128.fmadd(&e_in, &row.instruction_right_input); - // 2: IsRdNotZero (bool) acc_rd_zero_flag.fmadd(&e_in, &(row.is_rd_not_zero)); - // 3: OpFlags(WriteLookupOutputToRD) (bool) acc_wl_flag.fmadd(&e_in, &row.write_lookup_output_to_rd_flag); - // 4: OpFlags(Jump) (bool) acc_jump_flag.fmadd(&e_in, &row.jump_flag); - // 5: LookupOutput (u64) acc_lookup_output.fmadd(&e_in, &row.should_branch_lookup_output); - // 6: InstructionFlags(Branch) (bool) acc_branch_flag.fmadd(&e_in, &row.should_branch_flag); - // 7: NextIsNoop (bool) = !not_next_noop acc_next_is_noop.fmadd(&e_in, &(!row.not_next_noop)); - // 8: OpFlags(VirtualInstruction) (bool) acc_virtual_instr_flag.fmadd(&e_in, &row.virtual_instruction_flag); } - let mut out_unr = [F::UnreducedProductAccum::zero(); 9]; - out_unr[0] = eq1_val.mul_to_product_accum(acc_left_u64.barrett_reduce()); - out_unr[1] = eq1_val.mul_to_product_accum(acc_right_i128.barrett_reduce()); - out_unr[2] = eq1_val.mul_to_product_accum(acc_rd_zero_flag.barrett_reduce()); - out_unr[3] = eq1_val.mul_to_product_accum(acc_wl_flag.barrett_reduce()); - out_unr[4] = eq1_val.mul_to_product_accum(acc_jump_flag.barrett_reduce()); - out_unr[5] = eq1_val.mul_to_product_accum(acc_lookup_output.barrett_reduce()); - out_unr[6] = eq1_val.mul_to_product_accum(acc_branch_flag.barrett_reduce()); - out_unr[7] = eq1_val.mul_to_product_accum(acc_next_is_noop.barrett_reduce()); - out_unr[8] = eq1_val.mul_to_product_accum(acc_virtual_instr_flag.barrett_reduce()); + let mut out_unr = [F::UnreducedProductAccum::zero(); 7]; + out_unr[0] = eq1_val.mul_to_product_accum(acc_rd_zero_flag.barrett_reduce()); + out_unr[1] = eq1_val.mul_to_product_accum(acc_wl_flag.barrett_reduce()); + out_unr[2] = eq1_val.mul_to_product_accum(acc_jump_flag.barrett_reduce()); + out_unr[3] = eq1_val.mul_to_product_accum(acc_lookup_output.barrett_reduce()); + out_unr[4] = eq1_val.mul_to_product_accum(acc_branch_flag.barrett_reduce()); + out_unr[5] = eq1_val.mul_to_product_accum(acc_next_is_noop.barrett_reduce()); + out_unr[6] = eq1_val.mul_to_product_accum(acc_virtual_instr_flag.barrett_reduce()); out_unr }) .reduce( - || [F::UnreducedProductAccum::zero(); 9], + || [F::UnreducedProductAccum::zero(); 7], |mut acc, item| { - for i in 0..9 { + for i in 0..7 { acc[i] += item[i]; } acc diff --git a/jolt-core/src/zkvm/r1cs/inputs.rs b/jolt-core/src/zkvm/r1cs/inputs.rs index 92dda3f1c5..ede64fa208 100644 --- a/jolt-core/src/zkvm/r1cs/inputs.rs +++ b/jolt-core/src/zkvm/r1cs/inputs.rs @@ -21,10 +21,10 @@ use crate::zkvm::instruction::{ use crate::zkvm::witness::VirtualPolynomial; use crate::field::JoltField; -use ark_ff::biginteger::{S128, S64}; +use ark_ff::biginteger::S64; use common::constants::XLEN; use std::fmt::Debug; -use tracer::instruction::Cycle; +use tracer::instruction::{Cycle, RAMAccess}; use strum::IntoEnumIterator; @@ -44,7 +44,6 @@ pub enum JoltR1CSInputs { RightInstructionInput, // (instruction input) LeftLookupOperand, // (instruction raf) RightLookupOperand, // (instruction raf) - Product, // (product virtualization) WriteLookupOutputToRD, // (product virtualization) WritePCtoRD, // (product virtualization) ShouldBranch, // (product virtualization) @@ -60,10 +59,9 @@ pub enum JoltR1CSInputs { pub const NUM_R1CS_INPUTS: usize = ALL_R1CS_INPUTS.len(); /// This const serves to define a canonical ordering over inputs (and thus indices /// for each input). This is needed for sumcheck. -pub const ALL_R1CS_INPUTS: [JoltR1CSInputs; 37] = [ +pub const ALL_R1CS_INPUTS: [JoltR1CSInputs; 36] = [ JoltR1CSInputs::LeftInstructionInput, JoltR1CSInputs::RightInstructionInput, - JoltR1CSInputs::Product, JoltR1CSInputs::WriteLookupOutputToRD, JoltR1CSInputs::WritePCtoRD, JoltR1CSInputs::ShouldBranch, @@ -117,41 +115,40 @@ impl JoltR1CSInputs { match self { JoltR1CSInputs::LeftInstructionInput => 0, JoltR1CSInputs::RightInstructionInput => 1, - JoltR1CSInputs::Product => 2, - JoltR1CSInputs::WriteLookupOutputToRD => 3, - JoltR1CSInputs::WritePCtoRD => 4, - JoltR1CSInputs::ShouldBranch => 5, - JoltR1CSInputs::PC => 6, - JoltR1CSInputs::UnexpandedPC => 7, - JoltR1CSInputs::Imm => 8, - JoltR1CSInputs::RamAddress => 9, - JoltR1CSInputs::Rs1Value => 10, - JoltR1CSInputs::Rs2Value => 11, - JoltR1CSInputs::RdWriteValue => 12, - JoltR1CSInputs::RamReadValue => 13, - JoltR1CSInputs::RamWriteValue => 14, - JoltR1CSInputs::LeftLookupOperand => 15, - JoltR1CSInputs::RightLookupOperand => 16, - JoltR1CSInputs::NextUnexpandedPC => 17, - JoltR1CSInputs::NextPC => 18, - JoltR1CSInputs::NextIsVirtual => 19, - JoltR1CSInputs::NextIsFirstInSequence => 20, - JoltR1CSInputs::LookupOutput => 21, - JoltR1CSInputs::ShouldJump => 22, - JoltR1CSInputs::OpFlags(CircuitFlags::AddOperands) => 23, - JoltR1CSInputs::OpFlags(CircuitFlags::SubtractOperands) => 24, - JoltR1CSInputs::OpFlags(CircuitFlags::MultiplyOperands) => 25, - JoltR1CSInputs::OpFlags(CircuitFlags::Load) => 26, - JoltR1CSInputs::OpFlags(CircuitFlags::Store) => 27, - JoltR1CSInputs::OpFlags(CircuitFlags::Jump) => 28, - JoltR1CSInputs::OpFlags(CircuitFlags::WriteLookupOutputToRD) => 29, - JoltR1CSInputs::OpFlags(CircuitFlags::VirtualInstruction) => 30, - JoltR1CSInputs::OpFlags(CircuitFlags::Assert) => 31, - JoltR1CSInputs::OpFlags(CircuitFlags::DoNotUpdateUnexpandedPC) => 32, - JoltR1CSInputs::OpFlags(CircuitFlags::Advice) => 33, - JoltR1CSInputs::OpFlags(CircuitFlags::IsCompressed) => 34, - JoltR1CSInputs::OpFlags(CircuitFlags::IsFirstInSequence) => 35, - JoltR1CSInputs::OpFlags(CircuitFlags::IsLastInSequence) => 36, + JoltR1CSInputs::WriteLookupOutputToRD => 2, + JoltR1CSInputs::WritePCtoRD => 3, + JoltR1CSInputs::ShouldBranch => 4, + JoltR1CSInputs::PC => 5, + JoltR1CSInputs::UnexpandedPC => 6, + JoltR1CSInputs::Imm => 7, + JoltR1CSInputs::RamAddress => 8, + JoltR1CSInputs::Rs1Value => 9, + JoltR1CSInputs::Rs2Value => 10, + JoltR1CSInputs::RdWriteValue => 11, + JoltR1CSInputs::RamReadValue => 12, + JoltR1CSInputs::RamWriteValue => 13, + JoltR1CSInputs::LeftLookupOperand => 14, + JoltR1CSInputs::RightLookupOperand => 15, + JoltR1CSInputs::NextUnexpandedPC => 16, + JoltR1CSInputs::NextPC => 17, + JoltR1CSInputs::NextIsVirtual => 18, + JoltR1CSInputs::NextIsFirstInSequence => 19, + JoltR1CSInputs::LookupOutput => 20, + JoltR1CSInputs::ShouldJump => 21, + JoltR1CSInputs::OpFlags(CircuitFlags::AddOperands) => 22, + JoltR1CSInputs::OpFlags(CircuitFlags::SubtractOperands) => 23, + JoltR1CSInputs::OpFlags(CircuitFlags::MultiplyOperands) => 24, + JoltR1CSInputs::OpFlags(CircuitFlags::Load) => 25, + JoltR1CSInputs::OpFlags(CircuitFlags::Store) => 26, + JoltR1CSInputs::OpFlags(CircuitFlags::Jump) => 27, + JoltR1CSInputs::OpFlags(CircuitFlags::WriteLookupOutputToRD) => 28, + JoltR1CSInputs::OpFlags(CircuitFlags::VirtualInstruction) => 29, + JoltR1CSInputs::OpFlags(CircuitFlags::Assert) => 30, + JoltR1CSInputs::OpFlags(CircuitFlags::DoNotUpdateUnexpandedPC) => 31, + JoltR1CSInputs::OpFlags(CircuitFlags::Advice) => 32, + JoltR1CSInputs::OpFlags(CircuitFlags::IsCompressed) => 33, + JoltR1CSInputs::OpFlags(CircuitFlags::IsFirstInSequence) => 34, + JoltR1CSInputs::OpFlags(CircuitFlags::IsLastInSequence) => 35, } } } @@ -170,7 +167,6 @@ impl From<&JoltR1CSInputs> for VirtualPolynomial { JoltR1CSInputs::RamWriteValue => VirtualPolynomial::RamWriteValue, JoltR1CSInputs::LeftLookupOperand => VirtualPolynomial::LeftLookupOperand, JoltR1CSInputs::RightLookupOperand => VirtualPolynomial::RightLookupOperand, - JoltR1CSInputs::Product => VirtualPolynomial::Product, JoltR1CSInputs::NextUnexpandedPC => VirtualPolynomial::NextUnexpandedPC, JoltR1CSInputs::NextPC => VirtualPolynomial::NextPC, JoltR1CSInputs::LookupOutput => VirtualPolynomial::LookupOutput, @@ -205,10 +201,6 @@ pub struct R1CSCycleInputs { /// Right instruction input as signed-magnitude `S64`. /// Typically `Imm` or `Rs2Value` with exact integer semantics. pub right_input: S64, - /// Signed-magnitude `S128` product consistent with the `Product` witness. - /// Computed from `left_input` × `right_input` using the same truncation semantics as the witness. - pub product: S128, - /// Left lookup operand (u64) for the instruction lookup query. /// Matches `LeftLookupOperand` virtual polynomial semantics. pub left_lookup: u64, @@ -290,17 +282,13 @@ impl R1CSCycleInputs { None }; - // Instruction inputs and product let (left_input, right_i128) = LookupQuery::::to_instruction_inputs(cycle); - let left_s64: S64 = S64::from_u64(left_input); let right_mag = right_i128.unsigned_abs(); debug_assert!( right_mag <= u64::MAX as u128, "RightInstructionInput overflow at row {t}: |{right_i128}| > 2^64-1" ); let right_input = S64::from_u64_with_sign(right_mag as u64, right_i128 >= 0); - let right_s128: S128 = S128::from_i128(right_i128); - let product: S128 = left_s64.mul_trunc::<2, 2>(&right_s128); // Lookup operands and output let (left_lookup, right_lookup) = LookupQuery::::to_lookup_operands(cycle); @@ -314,9 +302,9 @@ impl R1CSCycleInputs { // RAM let ram_addr = cycle.ram_access().address() as u64; let (ram_read_value, ram_write_value) = match cycle.ram_access() { - tracer::instruction::RAMAccess::Read(r) => (r.value, r.value), - tracer::instruction::RAMAccess::Write(w) => (w.pre_value, w.post_value), - tracer::instruction::RAMAccess::NoOp => (0u64, 0u64), + RAMAccess::Read(r) => (r.value, r.value), + RAMAccess::Write(w) => (w.pre_value, w.post_value), + RAMAccess::NoOp => (0u64, 0u64), }; // PCs @@ -374,7 +362,6 @@ impl R1CSCycleInputs { Self { left_input, right_input, - product, left_lookup, right_lookup, lookup_output, @@ -416,7 +403,6 @@ impl R1CSCycleInputs { JoltR1CSInputs::RightInstructionInput => self.right_input.to_i128(), JoltR1CSInputs::LeftLookupOperand => self.left_lookup as i128, JoltR1CSInputs::RightLookupOperand => self.right_lookup as i128, - JoltR1CSInputs::Product => self.product.to_i128().expect("product too large for i128"), JoltR1CSInputs::WriteLookupOutputToRD => self.write_lookup_output_to_rd_addr as i128, JoltR1CSInputs::WritePCtoRD => self.write_pc_to_rd_addr as i128, JoltR1CSInputs::ShouldBranch => self.should_branch as i128, @@ -434,18 +420,14 @@ impl R1CSCycleInputs { /// Canonical, de-duplicated list of product-virtual factor polynomials used by /// the Product Virtualization stage. /// Order: -/// 0: LeftInstructionInput -/// 1: RightInstructionInput -/// 2: InstructionFlags(IsRdNotZero) -/// 3: OpFlags(WriteLookupOutputToRD) -/// 4: OpFlags(Jump) -/// 5: LookupOutput -/// 6: InstructionFlags(Branch) -/// 7: NextIsNoop -/// 8: OpFlags(VirtualInstruction) — not a product factor, but opened here for downstream stages -pub const PRODUCT_UNIQUE_FACTOR_VIRTUALS: [VirtualPolynomial; 9] = [ - VirtualPolynomial::LeftInstructionInput, - VirtualPolynomial::RightInstructionInput, +/// 0: InstructionFlags(IsRdNotZero) +/// 1: OpFlags(WriteLookupOutputToRD) +/// 2: OpFlags(Jump) +/// 3: LookupOutput +/// 4: InstructionFlags(Branch) +/// 5: NextIsNoop +/// 6: OpFlags(VirtualInstruction) — not a product factor, but opened here for downstream stages +pub const PRODUCT_UNIQUE_FACTOR_VIRTUALS: [VirtualPolynomial; 7] = [ VirtualPolynomial::InstructionFlags(InstructionFlags::IsRdNotZero), VirtualPolynomial::OpFlags(CircuitFlags::WriteLookupOutputToRD), VirtualPolynomial::OpFlags(CircuitFlags::Jump), @@ -455,21 +437,13 @@ pub const PRODUCT_UNIQUE_FACTOR_VIRTUALS: [VirtualPolynomial; 9] = [ VirtualPolynomial::OpFlags(CircuitFlags::VirtualInstruction), ]; -/// Minimal, unified view for the Product-virtualization round: the 6 product pairs -/// (left, right) materialized from the trace for a single cycle. -/// Total size is small; we keep primitive representations that match witness generation. +/// Minimal, unified view for the Product-virtualization round: the product pairs +/// materialized from the trace for a single cycle. #[derive(Clone, Debug)] pub struct ProductCycleInputs { - // 16-byte aligned - /// Instruction: LeftInstructionInput × RightInstructionInput (right input as i128) - pub instruction_right_input: i128, - - // 8-byte aligned - pub instruction_left_input: u64, /// ShouldBranch: LookupOutput × Branch_flag (left side) pub should_branch_lookup_output: u64, - // 1-byte fields /// WriteLookupOutputToRD right flag (boolean) pub write_lookup_output_to_rd_flag: bool, /// Jump flag used by both WritePCtoRD (right) and ShouldJump (left) @@ -497,37 +471,24 @@ impl ProductCycleInputs { let flags_view = instr.circuit_flags(); let instruction_flags = instr.instruction_flags(); - // Instruction inputs - let (left_input, right_input) = LookupQuery::::to_instruction_inputs(cycle); - - // Lookup output let lookup_output = LookupQuery::::to_lookup_output(cycle); - // Jump and Branch flags let jump_flag = flags_view[CircuitFlags::Jump]; let branch_flag = instruction_flags[InstructionFlags::Branch]; - // Next-is-noop and its complement (1 - NextIsNoop) let not_next_noop = { if t + 1 < len { !trace[t + 1].instruction().instruction_flags()[InstructionFlags::IsNoop] } else { - // Needs final not_next_noop to be false for the shift sumcheck - // (since EqPlusOne does not do overflow) false } }; let is_rd_not_zero = instruction_flags[InstructionFlags::IsRdNotZero]; - - // WriteLookupOutputToRD flag let write_lookup_output_to_rd_flag = flags_view[CircuitFlags::WriteLookupOutputToRD]; - let virtual_instruction_flag = flags_view[CircuitFlags::VirtualInstruction]; Self { - instruction_left_input: left_input, - instruction_right_input: right_input, write_lookup_output_to_rd_flag, should_branch_lookup_output: lookup_output, should_branch_flag: branch_flag, @@ -601,7 +562,6 @@ mod tests { } (JoltR1CSInputs::LeftLookupOperand, JoltR1CSInputs::LeftLookupOperand) => true, (JoltR1CSInputs::RightLookupOperand, JoltR1CSInputs::RightLookupOperand) => true, - (JoltR1CSInputs::Product, JoltR1CSInputs::Product) => true, (JoltR1CSInputs::WriteLookupOutputToRD, JoltR1CSInputs::WriteLookupOutputToRD) => { true } diff --git a/jolt-core/src/zkvm/ram/ra_virtual.rs b/jolt-core/src/zkvm/ram/ra_virtual.rs index 7df6975ae9..12e30b1c1e 100644 --- a/jolt-core/src/zkvm/ram/ra_virtual.rs +++ b/jolt-core/src/zkvm/ram/ra_virtual.rs @@ -5,10 +5,10 @@ //! //! ## Input //! -//! From RA reduction sumcheck (Stage 5), we receive a single claim: +//! From RA reduction sumcheck (Stage 4), we receive a single claim: //! //! ```text -//! ra(r_address_stage2, r_cycle_stage5) = ra_claim_stage5 +//! ra(r_address_stage2, r_cycle_stage4) = ra_claim_stage4 //! ``` //! //! ## Identity @@ -16,7 +16,7 @@ //! We prove the following sumcheck identity over `c ∈ {0,1}^{log_T}`: //! //! ```text -//! Σ_c eq(r_cycle_stage5, c) · Π_{i=0}^{d-1} ra_i(r_address_stage2_i, c) = ra_claim_stage5 +//! Σ_c eq(r_cycle_stage4, c) · Π_{i=0}^{d-1} ra_i(r_address_stage2_i, c) = ra_claim_stage4 //! ``` //! //! where: diff --git a/jolt-core/src/zkvm/spartan/instruction_input.rs b/jolt-core/src/zkvm/spartan/instruction_input.rs index ccababf7ee..08793fd670 100644 --- a/jolt-core/src/zkvm/spartan/instruction_input.rs +++ b/jolt-core/src/zkvm/spartan/instruction_input.rs @@ -42,7 +42,7 @@ const DEGREE_BOUND: usize = 3; #[derive(Allocative, Clone)] pub struct InstructionInputParams { - pub r_cycle_stage_2: OpeningPoint, + pub r_spartan: OpeningPoint, pub gamma: F, } @@ -51,15 +51,12 @@ impl InstructionInputParams { opening_accumulator: &dyn OpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { - let (r_cycle_stage_2, _) = opening_accumulator.get_virtual_polynomial_opening( + let (r_spartan, _) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::LeftInstructionInput, - SumcheckId::SpartanProductVirtualization, + SumcheckId::SpartanOuter, ); let gamma = transcript.challenge_scalar(); - Self { - r_cycle_stage_2, - gamma, - } + Self { r_spartan, gamma } } } @@ -69,40 +66,20 @@ impl SumcheckInstanceParams for InstructionInputParams { } fn num_rounds(&self) -> usize { - self.r_cycle_stage_2.len() + self.r_spartan.len() } fn input_claim(&self, accumulator: &dyn OpeningAccumulator) -> F { - let (r_left_claim_instruction, left_claim_instruction) = accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::InstructionClaimReduction, - ); - let (r_right_claim_instruction, right_claim_instruction) = accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::RightInstructionInput, - SumcheckId::InstructionClaimReduction, - ); - - let (r_left_claim_stage_2, left_claim_stage_2) = accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::SpartanProductVirtualization, - ); - let (r_right_claim_stage_2, right_claim_stage_2) = accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::RightInstructionInput, - SumcheckId::SpartanProductVirtualization, - ); - - // Soundness: InstructionClaimReduction and SpartanProductVirtualization must produce - // the same claims at the same opening points. - assert_eq!(r_left_claim_instruction, r_left_claim_stage_2); - assert_eq!(left_claim_instruction, left_claim_stage_2); - assert_eq!(r_right_claim_instruction, r_right_claim_stage_2); - assert_eq!(right_claim_instruction, right_claim_stage_2); + let (_, left_claim) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::LeftInstructionInput, + SumcheckId::SpartanOuter, + ); + let (_, right_claim) = accumulator.get_virtual_polynomial_opening( + VirtualPolynomial::RightInstructionInput, + SumcheckId::SpartanOuter, + ); - right_claim_stage_2 + self.gamma * left_claim_stage_2 + right_claim + self.gamma * left_claim } fn normalize_opening_point( @@ -117,11 +94,11 @@ impl SumcheckInstanceParams for InstructionInputParams { InputClaimConstraint::weighted_openings(&[ OpeningId::virt( VirtualPolynomial::RightInstructionInput, - SumcheckId::SpartanProductVirtualization, + SumcheckId::SpartanOuter, ), OpeningId::virt( VirtualPolynomial::LeftInstructionInput, - SumcheckId::SpartanProductVirtualization, + SumcheckId::SpartanOuter, ), ]) } @@ -211,7 +188,7 @@ impl SumcheckInstanceParams for InstructionInputParams { #[cfg(feature = "zk")] fn output_constraint_challenge_values(&self, sumcheck_challenges: &[F::Challenge]) -> Vec { let r = self.normalize_opening_point(sumcheck_challenges); - let e2 = EqPolynomial::mle_endian(&r, &self.r_cycle_stage_2); + let e2 = EqPolynomial::mle_endian(&r, &self.r_spartan); vec![e2, self.gamma * e2] } } @@ -227,7 +204,7 @@ pub struct InstructionInputSumcheckProver { rs2_value_poly: MultilinearPolynomial, imm_poly: MultilinearPolynomial, unexpanded_pc_poly: MultilinearPolynomial, - eq_r_cycle_stage_2: GruenSplitEqPolynomial, + eq_r_spartan: GruenSplitEqPolynomial, pub params: InstructionInputParams, } @@ -285,8 +262,8 @@ impl InstructionInputSumcheckProver { }, ); - let eq_r_cycle_stage_2 = - GruenSplitEqPolynomial::new(¶ms.r_cycle_stage_2.r, BindingOrder::LowToHigh); + let eq_r_spartan = + GruenSplitEqPolynomial::new(¶ms.r_spartan.r, BindingOrder::LowToHigh); Self { left_is_rs1_poly: left_is_rs1_poly.into(), @@ -297,7 +274,7 @@ impl InstructionInputSumcheckProver { rs2_value_poly: rs2_value_poly.into(), imm_poly: imm_poly.into(), unexpanded_pc_poly: unexpanded_pc_poly.into(), - eq_r_cycle_stage_2, + eq_r_spartan, params, } } @@ -313,7 +290,7 @@ impl SumcheckInstanceProver #[tracing::instrument(skip_all, name = "InstructionInputSumcheckProver::compute_message")] fn compute_message(&mut self, _round: usize, previous_claim: F) -> UniPoly { let [eval_at_0, eval_at_inf] = self - .eq_r_cycle_stage_2 + .eq_r_spartan .par_fold_out_in( || [F::UnreducedProductAccum::zero(); 2], |inner, j, _x_in, e_in| { @@ -385,7 +362,7 @@ impl SumcheckInstanceProver ) .map(|x| F::reduce_product_accum(x)); - self.eq_r_cycle_stage_2 + self.eq_r_spartan .gruen_poly_deg_3(eval_at_0, eval_at_inf, previous_claim) } @@ -400,7 +377,7 @@ impl SumcheckInstanceProver rs2_value_poly, imm_poly, unexpanded_pc_poly, - eq_r_cycle_stage_2, + eq_r_spartan, params: _, } = self; left_is_rs1_poly.bind_parallel(r_j, BindingOrder::LowToHigh); @@ -411,7 +388,7 @@ impl SumcheckInstanceProver rs2_value_poly.bind_parallel(r_j, BindingOrder::LowToHigh); imm_poly.bind_parallel(r_j, BindingOrder::LowToHigh); unexpanded_pc_poly.bind_parallel(r_j, BindingOrder::LowToHigh); - eq_r_cycle_stage_2.bind(r_j); + eq_r_spartan.bind(r_j); } fn cache_openings( @@ -490,7 +467,7 @@ impl SumcheckInstanceProver /// ``` /// /// Note: -/// - `r_cycle_stage_2` is the randomness from instruction product sumcheck (stage 2). +/// - `r_spartan` is the randomness from Spartan outer sumcheck (stage 1). pub struct InstructionInputSumcheckVerifier { params: InstructionInputParams, } @@ -536,7 +513,7 @@ impl SumcheckInstanceVerifier ) -> F { let r = self.params.normalize_opening_point(sumcheck_challenges); - let eq_eval_at_r_cycle_stage_2 = EqPolynomial::mle_endian(&r, &self.params.r_cycle_stage_2); + let eq_eval_at_r_spartan = EqPolynomial::mle_endian(&r, &self.params.r_spartan); let (_, rs1_value_eval) = accumulator.get_virtual_polynomial_opening( VirtualPolynomial::Rs1Value, @@ -576,7 +553,7 @@ impl SumcheckInstanceVerifier let right_instruction_input = right_is_rs2_eval * rs2_value_eval + right_is_imm_eval * imm_eval; - let result = eq_eval_at_r_cycle_stage_2 + let result = eq_eval_at_r_spartan * (right_instruction_input + self.params.gamma * left_instruction_input); #[cfg(test)] @@ -664,24 +641,24 @@ impl SumcheckFrontend for InstructionInputSumcheckVerifier { let left_instruction_input_eval = left_is_rs1 * rs1_value + left_is_pc * unexpanded_pc; let right_instruction_input_eval = right_is_rs2 * rs2_value + right_is_imm * imm; - let eq_r_stage2 = VerifierEvaluablePolynomial::Eq(CachedPointRef { + let eq_r_spartan = VerifierEvaluablePolynomial::Eq(CachedPointRef { opening: PolynomialId::Virtual(VirtualPolynomial::LeftInstructionInput), - sumcheck: SumcheckId::SpartanProductVirtualization, + sumcheck: SumcheckId::SpartanOuter, part: ChallengePart::Cycle, }); InputOutputClaims { claims: vec![ Claim { - input_sumcheck_id: SumcheckId::SpartanProductVirtualization, + input_sumcheck_id: SumcheckId::SpartanOuter, input_claim_expr: right_instruction_input, - batching_poly: eq_r_stage2, + batching_poly: eq_r_spartan, expected_output_claim_expr: right_instruction_input_eval, }, Claim { - input_sumcheck_id: SumcheckId::SpartanProductVirtualization, + input_sumcheck_id: SumcheckId::SpartanOuter, input_claim_expr: left_instruction_input, - batching_poly: eq_r_stage2, + batching_poly: eq_r_spartan, expected_output_claim_expr: left_instruction_input_eval, }, ], diff --git a/jolt-core/src/zkvm/spartan/outer.rs b/jolt-core/src/zkvm/spartan/outer.rs index 8c774279ea..f8e7ca90c5 100644 --- a/jolt-core/src/zkvm/spartan/outer.rs +++ b/jolt-core/src/zkvm/spartan/outer.rs @@ -1,7 +1,10 @@ +#[cfg(feature = "zk")] +use std::collections::BTreeSet; use std::marker::PhantomData; use std::sync::Arc; use allocative::Allocative; +use ark_ff::biginteger::{S128, S160, S192}; use ark_std::Zero; use rayon::prelude::*; use tracer::instruction::Cycle; @@ -11,7 +14,7 @@ use crate::field::{FMAdd, JoltField, MontgomeryReduce}; use crate::poly::dense_mlpoly::DensePolynomial; use crate::poly::eq_poly::EqPolynomial; use crate::poly::lagrange_poly::LagrangePolynomial; -use crate::poly::multilinear_polynomial::{BindingOrder, PolynomialBinding}; +use crate::poly::multilinear_polynomial::{BindingOrder, MultilinearPolynomial, PolynomialBinding}; use crate::poly::multiquadratic_poly::MultiquadraticPolynomial; #[cfg(feature = "zk")] use crate::poly::opening_proof::OpeningId; @@ -34,11 +37,12 @@ use crate::subprotocols::univariate_skip::build_uniskip_first_round_poly; use crate::transcripts::Transcript; use crate::utils::accumulation::{FullAccumS, MedAccumS, SmallAccumU, WideAccumS}; use crate::utils::expanding_table::ExpandingTable; -use crate::utils::math::Math; +use crate::utils::math::{s64_from_diff_u64s, Math}; #[cfg(feature = "allocative")] use crate::utils::profiling::print_data_structure_heap_usage; use crate::utils::thread::unsafe_allocate_zero_vec; use crate::zkvm::bytecode::BytecodePreprocessing; +use crate::zkvm::instruction::CircuitFlags; use crate::zkvm::r1cs::constraints::OUTER_FIRST_ROUND_POLY_DEGREE_BOUND; #[cfg(feature = "zk")] use crate::zkvm::r1cs::constraints::{R1CS_CONSTRAINTS_FIRST_GROUP, R1CS_CONSTRAINTS_SECOND_GROUP}; @@ -51,44 +55,52 @@ use crate::zkvm::r1cs::{ OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE, OUTER_UNIVARIATE_SKIP_EXTENDED_DOMAIN_SIZE, }, evaluation::R1CSEval, - inputs::{R1CSCycleInputs, ALL_R1CS_INPUTS}, + inputs::{JoltR1CSInputs, R1CSCycleInputs, ALL_R1CS_INPUTS}, }; use crate::zkvm::witness::VirtualPolynomial; /// Degree bound of the sumcheck round polynomials for [`OuterRemainingSumcheckVerifier`]. -const OUTER_REMAINING_DEGREE_BOUND: usize = 3; +/// Degree 4 because the product constraint `is_mul*(rl - li*ri)` is degree 3 inside eq. +const OUTER_REMAINING_DEGREE_BOUND: usize = 4; // this represents the index position in multi-quadratic poly array // This should actually be d where degree is the degree of the streaming data structure // For example : MultiQuadratic has d=2; for cubic this would be 3 etc. const INFINITY: usize = 2; // Spartan Outer sumcheck -// (with univariate-skip first round on Z, and no Cz term given all eq conditional constraints) // -// We define a univariate in Z first-round polynomial -// s1(Y) := L(τ_high, Y) · Σ_{x_out ∈ {0,1}^{m_out}} Σ_{x_in ∈ {0,1}^{m_in}} -// E_out(r_out, x_out) · E_in(r_in, x_in) · -// [ Az(x_out, x_in, Y) · Bz(x_out, x_in, Y) ], -// where L(τ_high, Y) is the Lagrange basis polynomial over the univariate-skip -// base domain evaluated at τ_high, and Az(·,·,Y), Bz(·,·,Y) are the -// per-row univariate polynomials in Y induced by the R1CS row (split into two -// internal groups in code, but algebraically composing to Az·Bz at Y). -// The prover sends s1(Y) via univariate-skip by evaluating t1(Y) := Σ Σ E_out·E_in·(Az·Bz) -// on an extended grid Y ∈ {−D..D} outside the base window, interpolating t1, -// multiplying by L(τ_high, Y) to obtain s1, and the verifier samples r0. +// Proves the combined R1CS + product constraint: +// +// Σ_x eq(τ, x) · [ Az(x)·Bz(x) + is_mul(x)·(rl(x) − li(x)·ri(x)) ] = 0 +// +// where Az·Bz encodes the R1CS constraints (no Cz term — all constraints are +// eq-conditional) and the product term enforces the multiply-instruction +// identity rl = li·ri on rows where is_mul = 1. +// +// The sumcheck proceeds in three phases: +// +// 1. Univariate-skip first round (on constraint index Z): +// Only Az·Bz participates. The product term is independent of Z. +// s1(Z) = L(τ_high, Z) · Σ_x eq(τ_low, x) · Σ_y eq(r_group, y) · Az(x,y,Z)·Bz(x,y,Z) +// The prover evaluates t1(Z) = s1(Z)/L(τ_high,Z) on the extended grid, +// sends s1 via univariate-skip, and the verifier samples r0. +// The scaling factor L(τ_high, r0) is baked into the split-eq polynomial +// via `new_with_scaling`, so it multiplies all subsequent terms uniformly. +// +// 2. Group-variable round (first cycle-variable round, binding the constraint +// group selector): +// Only Az·Bz participates. The product term is constant in the group +// variable and skipped (degree-3 round polynomial via gruen_poly_deg_3). // -// Subsequent outer rounds bind the cycle variables r_tail = (r1, r2, …) using -// a streaming first cycle-bit round followed by linear-time rounds: -// • Streaming round (after r0): compute -// t(0) = Σ_{x_out} E_out · Σ_{x_in} E_in · (Az(0)·Bz(0)) -// t(∞) = Σ_{x_out} E_out · Σ_{x_in} E_in · ((Az(1)−Az(0))·(Bz(1)−Bz(0))) -// send a cubic built from these endpoints, and bind cached coefficients by r1. -// • Remaining rounds: reuse bound coefficients to compute the same endpoints -// in linear time for each subsequent bit and bind by r_i. +// 3. Remaining cycle-variable rounds: +// Both Az·Bz and the product term participate. Each round produces a +// degree-4 round polynomial: cubic Az·Bz via gruen_poly_deg_3 plus the +// cubic product term via gruen_poly_deg_4, combined additively. +// The product term's claim starts at 0 (for a valid witness) and is +// tracked separately in prev_claim_product. // -// Final check (verifier): with r = [r0 || r_tail] and outer binding order from -// the top, evaluate Eq_τ(τ, r) and verify -// L(τ_high, r_high) · Eq_τ(τ, r) · (Az(r) · Bz(r)). +// Final check (verifier): +// L(τ_high, r0) · eq(τ_low, r_cycle) · [ Az(r)·Bz(r) + product(r) ] #[derive(Allocative, Clone)] pub struct OuterUniSkipParams { @@ -431,7 +443,6 @@ impl SumcheckInstanceParams for OuterRemainingSumcheckParams // Build structural template by iterating R1CS constraints to find // which input indices appear in A-sides and B-sides. - use std::collections::BTreeSet; let mut a_indices = BTreeSet::new(); let mut b_indices = BTreeSet::new(); let mut a_has_const = false; @@ -513,6 +524,54 @@ impl SumcheckInstanceParams for OuterRemainingSumcheckParams // Constant term: tau * az_c * bz_c if a_has_const && b_has_const { terms.push(ProductTerm::single(ValueSource::Challenge(challenge_idx))); + challenge_idx += 1; + } + + // Product constraint: tau_kernel * is_mul * (rl - li*ri) + let is_mul_idx = JoltR1CSInputs::OpFlags(CircuitFlags::MultiplyOperands).to_index(); + let rl_idx = JoltR1CSInputs::RightLookupOperand.to_index(); + let li_idx = JoltR1CSInputs::LeftInstructionInput.to_index(); + let ri_idx = JoltR1CSInputs::RightInstructionInput.to_index(); + + let is_mul_opening = OpeningId::virt( + VirtualPolynomial::from(&ALL_R1CS_INPUTS[is_mul_idx]), + SumcheckId::SpartanOuter, + ); + let rl_opening = OpeningId::virt( + VirtualPolynomial::from(&ALL_R1CS_INPUTS[rl_idx]), + SumcheckId::SpartanOuter, + ); + let li_opening = OpeningId::virt( + VirtualPolynomial::from(&ALL_R1CS_INPUTS[li_idx]), + SumcheckId::SpartanOuter, + ); + let ri_opening = OpeningId::virt( + VirtualPolynomial::from(&ALL_R1CS_INPUTS[ri_idx]), + SumcheckId::SpartanOuter, + ); + + // tau_kernel * is_mul * rl + terms.push(ProductTerm::scaled( + ValueSource::Challenge(challenge_idx), + vec![ + ValueSource::Opening(is_mul_opening), + ValueSource::Opening(rl_opening), + ], + )); + challenge_idx += 1; + + // -tau_kernel * is_mul * li * ri + #[allow(unused_assignments)] + { + terms.push(ProductTerm::scaled( + ValueSource::Challenge(challenge_idx), + vec![ + ValueSource::Opening(is_mul_opening), + ValueSource::Opening(li_opening), + ValueSource::Opening(ri_opening), + ], + )); + challenge_idx += 1; } Some(OutputClaimConstraint::sum_of_products(terms)) @@ -520,8 +579,6 @@ impl SumcheckInstanceParams for OuterRemainingSumcheckParams #[cfg(feature = "zk")] fn output_constraint_challenge_values(&self, sumcheck_challenges: &[F::Challenge]) -> Vec { - use std::collections::BTreeSet; - let r_stream = sumcheck_challenges[0]; // Lagrange weights at r0 @@ -665,6 +722,10 @@ impl SumcheckInstanceParams for OuterRemainingSumcheckParams challenges.push(tau_kernel * az_const * bz_const); } + // Product constraint: is_mul * (rl - li*ri) + challenges.push(tau_kernel); + challenges.push(-tau_kernel); + challenges } } @@ -708,10 +769,18 @@ impl SumcheckInstanceVerifier // Randomness used to bind the rows of R1CS matrices A,B. let rx_constr = &[sumcheck_challenges[0], self.params.r0]; // Compute sum_y A(rx_constr, y)*z(y) * sum_y B(rx_constr, y)*z(y). - let inner_sum_prod = self + let azbz = self .key .evaluate_inner_sum_product_at_point(rx_constr, r1cs_input_evals); + // Product constraint: is_mul * (right_lookup - left_input * right_input) + let is_mul_eval = + r1cs_input_evals[JoltR1CSInputs::OpFlags(CircuitFlags::MultiplyOperands).to_index()]; + let rl_eval = r1cs_input_evals[JoltR1CSInputs::RightLookupOperand.to_index()]; + let li_eval = r1cs_input_evals[JoltR1CSInputs::LeftInstructionInput.to_index()]; + let ri_eval = r1cs_input_evals[JoltR1CSInputs::RightInstructionInput.to_index()]; + let product_term = is_mul_eval * (rl_eval - li_eval * ri_eval); + let tau = &self.params.tau; let tau_high = &tau[tau.len() - 1]; let tau_high_bound_r0 = LagrangePolynomial::::lagrange_kernel::< @@ -722,7 +791,7 @@ impl SumcheckInstanceVerifier let r_tail_reversed: Vec = sumcheck_challenges.iter().rev().copied().collect(); let tau_bound_r_tail_reversed = EqPolynomial::mle(tau_low, &r_tail_reversed); - tau_high_bound_r0 * tau_bound_r_tail_reversed * inner_sum_prod + tau_high_bound_r0 * tau_bound_r_tail_reversed * (azbz + product_term) } fn cache_openings( @@ -756,6 +825,7 @@ pub struct OuterSharedState { #[allocative(skip)] lagrange_evals_r0: [F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], pub params: OuterRemainingSumcheckParams, + prev_claim_product: F, } impl OuterSharedState { @@ -801,6 +871,7 @@ impl OuterSharedState { r_grid, params: outer_params, lagrange_evals_r0: lagrange_evals_r, + prev_claim_product: F::zero(), } } @@ -1071,15 +1142,16 @@ impl StreamingSumcheckWindow for OuterStreamingWindow { #[tracing::instrument(skip_all, name = "OuterStreamingWindow::compute_message")] fn compute_message( - &self, + &mut self, shared: &Self::Shared, window_size: usize, previous_claim: F, ) -> UniPoly { + let prev_claim_azbz = previous_claim - shared.prev_claim_product; let (t_prime_0, t_prime_inf) = shared.compute_t_evals(window_size); shared .split_eq_poly - .gruen_poly_deg_3(t_prime_0, t_prime_inf, previous_claim) + .gruen_poly_deg_3(t_prime_0, t_prime_inf, prev_claim_azbz) } #[tracing::instrument(skip_all, name = "OuterStreamingWindow::ingest_challenge")] @@ -1098,9 +1170,225 @@ impl StreamingSumcheckWindow for OuterStreamingWindow { pub struct OuterLinearStage { az: DensePolynomial, bz: DensePolynomial, + product_is_mul: Option>, + product_rl: Option>, + product_li: Option>, + product_ri: Option>, + prev_round_poly_product: Option>, } impl OuterLinearStage { + /// Fused materialization + first-round product eval computation. + /// + /// In a single pass over the trace, simultaneously: + /// 1. Materializes compact `MultilinearPolynomial`s (bool/u128/u64/i128) for + /// subsequent rounds + /// 2. Computes the first product round evals `(t2, t_inf)` using small-scalar + /// arithmetic with deferred reduction. `t0 = 0` is known since + /// `prev_claim_product = 0` on the first round (valid witness). + #[tracing::instrument( + skip_all, + name = "OuterLinearStage::fused_materialise_and_first_product_evals" + )] + fn fused_materialise_and_first_product_evals( + shared: &OuterSharedState, + ) -> ( + F, + F, + MultilinearPolynomial, + MultilinearPolynomial, + MultilinearPolynomial, + MultilinearPolynomial, + ) { + let num_cycles = shared.trace.len(); + let mut is_mul_vec: Vec = vec![false; num_cycles]; + let mut rl_vec: Vec = unsafe_allocate_zero_vec(num_cycles); + let mut li_vec: Vec = unsafe_allocate_zero_vec(num_cycles); + let mut ri_vec: Vec = unsafe_allocate_zero_vec(num_cycles); + + let e_out = shared.split_eq_poly.E_out_current(); + let e_in = shared.split_eq_poly.E_in_current(); + let num_x_in = e_in.len(); + let chunk_size = 2 * num_x_in; + + let (t2_unr, tinf_unr) = is_mul_vec + .par_chunks_mut(chunk_size) + .zip(rl_vec.par_chunks_mut(chunk_size)) + .zip(li_vec.par_chunks_mut(chunk_size)) + .zip(ri_vec.par_chunks_mut(chunk_size)) + .enumerate() + .fold( + || { + ( + F::UnreducedProductAccum::zero(), + F::UnreducedProductAccum::zero(), + ) + }, + |(mut acc2, mut acc_inf), (x_out, (((m_chunk, r_chunk), l_chunk), i_chunk))| { + let mut inner2 = WideAccumS::::zero(); + let mut inner_inf = WideAccumS::::zero(); + + for x_in in 0..num_x_in { + let g = x_out * num_x_in + x_in; + let idx_lo = 2 * g; + let idx_hi = idx_lo + 1; + + let row_lo = R1CSCycleInputs::from_trace::( + &shared.bytecode_preprocessing, + &shared.trace, + idx_lo, + ); + let row_hi = R1CSCycleInputs::from_trace::( + &shared.bytecode_preprocessing, + &shared.trace, + idx_hi, + ); + + let m0 = row_lo.flags[CircuitFlags::MultiplyOperands]; + let m1 = row_hi.flags[CircuitFlags::MultiplyOperands]; + + let off_lo = 2 * x_in; + let off_hi = off_lo + 1; + m_chunk[off_lo] = m0; + m_chunk[off_hi] = m1; + r_chunk[off_lo] = row_lo.right_lookup; + r_chunk[off_hi] = row_hi.right_lookup; + l_chunk[off_lo] = row_lo.left_input; + l_chunk[off_hi] = row_hi.left_input; + i_chunk[off_lo] = row_lo.right_input.to_i128(); + i_chunk[off_hi] = row_hi.right_input.to_i128(); + + if !m0 && !m1 { + continue; + } + + let dm: i8 = (m1 as i8) - (m0 as i8); + let dl_i128 = (row_hi.left_input as i128) - (row_lo.left_input as i128); + let di_i128 = row_hi.right_input.to_i128() - row_lo.right_input.to_i128(); + + if dm != 0 { + let dl_s64 = s64_from_diff_u64s(row_hi.left_input, row_lo.left_input); + let di_s128 = S128::from(di_i128); + let dl_di: S192 = dl_s64.mul_trunc::<2, 3>(&di_s128); + let p_inf = if dm == 1 { dl_di.neg() } else { dl_di }; + inner_inf.fmadd(&e_in[x_in], &p_inf); + } + + let m2_val: i8 = (m0 as i8) + 2 * dm; + if m2_val != 0 { + let l2 = (row_lo.left_input as i128) + 2 * dl_i128; + let i2 = row_lo.right_input.to_i128() + 2 * di_i128; + let l2_s160 = S160::from(l2); + let i2_s160 = S160::from(i2); + let li_prod = &l2_s160 * &i2_s160; + let r2 = S160::from_sum_u128(row_hi.right_lookup, row_hi.right_lookup) + - S160::from(row_lo.right_lookup); + let val = r2 - li_prod; + let p2 = match m2_val { + 1 => val, + 2 => val + val, + _ => -val, + }; + inner2.fmadd(&e_in[x_in], &p2); + } + } + + let e_out_val = e_out[x_out]; + let red2: F = inner2.barrett_reduce(); + let red_inf: F = inner_inf.barrett_reduce(); + acc2 += e_out_val.mul_to_product_accum(red2); + acc_inf += e_out_val.mul_to_product_accum(red_inf); + + (acc2, acc_inf) + }, + ) + .reduce( + || { + ( + F::UnreducedProductAccum::zero(), + F::UnreducedProductAccum::zero(), + ) + }, + |a, b| (a.0 + b.0, a.1 + b.1), + ); + + let t2 = F::reduce_product_accum(t2_unr); + let t_inf = F::reduce_product_accum(tinf_unr); + + ( + t2, + t_inf, + is_mul_vec.into(), + rl_vec.into(), + li_vec.into(), + ri_vec.into(), + ) + } + + #[tracing::instrument(skip_all, name = "OuterLinearStage::compute_product_t_evals")] + fn compute_product_t_evals(&self, shared: &OuterSharedState) -> (F, F, F) { + let is_mul = self.product_is_mul.as_ref().unwrap(); + let rl = self.product_rl.as_ref().unwrap(); + let li = self.product_li.as_ref().unwrap(); + let ri = self.product_ri.as_ref().unwrap(); + + let [t0, t2, tinf] = shared + .split_eq_poly + .par_fold_out_in( + || [F::UnreducedProductAccum::zero(); 3], + |inner, g, _x_in, e_in| { + let m0 = is_mul.get_bound_coeff(2 * g); + let m1 = is_mul.get_bound_coeff(2 * g + 1); + if m0.is_zero() && m1.is_zero() { + return; + } + + let r0 = rl.get_bound_coeff(2 * g); + let r1 = rl.get_bound_coeff(2 * g + 1); + let l0 = li.get_bound_coeff(2 * g); + let l1 = li.get_bound_coeff(2 * g + 1); + let i0 = ri.get_bound_coeff(2 * g); + let i1 = ri.get_bound_coeff(2 * g + 1); + + let p0 = m0 * (r0 - l0 * i0); + + let dm = m1 - m0; + let dr = r1 - r0; + let dl = l1 - l0; + let di = i1 - i0; + + let m2 = m0 + dm + dm; + let r2 = r0 + dr + dr; + let l2 = l0 + dl + dl; + let i2 = i0 + di + di; + let p2 = m2 * (r2 - l2 * i2); + + let p_inf = -(dm * dl * di); + + inner[0] += e_in.mul_to_product_accum(p0); + inner[1] += e_in.mul_to_product_accum(p2); + inner[2] += e_in.mul_to_product_accum(p_inf); + }, + |_x_out, e_out, inner| { + let mut outer = [F::UnreducedProductAccum::zero(); 3]; + for k in 0..3 { + let red = F::reduce_product_accum(inner[k]); + outer[k] = e_out.mul_to_product_accum(red); + } + outer + }, + |mut a, b| { + for k in 0..3 { + a[k] += b[k]; + } + a + }, + ) + .map(F::reduce_product_accum); + + (t0, t2, tinf) + } + #[tracing::instrument( skip_all, name = "OuterLinearStage::fused_materialise_polynomials_general_with_multiquadratic" @@ -1669,7 +1957,15 @@ impl LinearSumcheckStage for OuterLinearStage { Self::fused_materialise_polynomials_round_zero(shared, window_size) }; - Self { az, bz } + Self { + az, + bz, + product_is_mul: None, + product_rl: None, + product_li: None, + product_ri: None, + prev_round_poly_product: None, + } } #[tracing::instrument(skip_all, name = "OuterLinearStage::next_window")] @@ -1679,19 +1975,58 @@ impl LinearSumcheckStage for OuterLinearStage { #[tracing::instrument(skip_all, name = "OuterLinearStage::compute_message")] fn compute_message( - &self, + &mut self, shared: &Self::Shared, window_size: usize, previous_claim: F, ) -> UniPoly { + let prev_claim_product = shared.prev_claim_product; + let prev_claim_azbz = previous_claim - prev_claim_product; + let (t_prime_0, t_prime_inf) = shared.compute_t_evals(window_size); - shared - .split_eq_poly - .gruen_poly_deg_3(t_prime_0, t_prime_inf, previous_claim) + let round_poly_azbz = + shared + .split_eq_poly + .gruen_poly_deg_3(t_prime_0, t_prime_inf, prev_claim_azbz); + + let is_group_variable_round = shared.split_eq_poly.num_challenges() == 0; + if is_group_variable_round { + self.prev_round_poly_product = None; + round_poly_azbz + } else { + let (t_prod_0, t_prod_2, t_prod_inf) = if self.product_is_mul.is_none() { + let (t2, t_inf, is_mul, rl, li, ri) = + Self::fused_materialise_and_first_product_evals(shared); + self.product_is_mul = Some(is_mul); + self.product_rl = Some(rl); + self.product_li = Some(li); + self.product_ri = Some(ri); + (F::zero(), t2, t_inf) + } else { + self.compute_product_t_evals(shared) + }; + + let round_poly_product = shared.split_eq_poly.gruen_poly_deg_4( + t_prod_0, + t_prod_2, + t_prod_inf, + prev_claim_product, + ); + self.prev_round_poly_product = Some(round_poly_product.clone()); + + let mut combined = round_poly_azbz; + combined += &round_poly_product; + combined + } } #[tracing::instrument(skip_all, name = "OuterLinearStage::ingest_challenge")] fn ingest_challenge(&mut self, shared: &mut Self::Shared, r_j: F::Challenge, _round: usize) { + if let Some(ref poly) = self.prev_round_poly_product { + shared.prev_claim_product = poly.evaluate(&r_j); + } + self.prev_round_poly_product = None; + shared.split_eq_poly.bind(r_j); if let Some(t_prime_poly) = shared.t_prime_poly.as_mut() { @@ -1699,8 +2034,50 @@ impl LinearSumcheckStage for OuterLinearStage { } rayon::join( - || self.az.bind_parallel(r_j, BindingOrder::LowToHigh), - || self.bz.bind_parallel(r_j, BindingOrder::LowToHigh), + || { + rayon::join( + || self.az.bind_parallel(r_j, BindingOrder::LowToHigh), + || self.bz.bind_parallel(r_j, BindingOrder::LowToHigh), + ) + }, + || { + if self.product_is_mul.is_some() { + rayon::join( + || { + rayon::join( + || { + self.product_is_mul + .as_mut() + .unwrap() + .bind_parallel(r_j, BindingOrder::LowToHigh) + }, + || { + self.product_rl + .as_mut() + .unwrap() + .bind_parallel(r_j, BindingOrder::LowToHigh) + }, + ) + }, + || { + rayon::join( + || { + self.product_li + .as_mut() + .unwrap() + .bind_parallel(r_j, BindingOrder::LowToHigh) + }, + || { + self.product_ri + .as_mut() + .unwrap() + .bind_parallel(r_j, BindingOrder::LowToHigh) + }, + ) + }, + ); + } + }, ); } diff --git a/jolt-core/src/zkvm/spartan/product.rs b/jolt-core/src/zkvm/spartan/product.rs index 36c91e4cfa..6e4ead2508 100644 --- a/jolt-core/src/zkvm/spartan/product.rs +++ b/jolt-core/src/zkvm/spartan/product.rs @@ -50,27 +50,20 @@ use tracer::instruction::Cycle; // We define a "combined" left and right polynomial // Left(x, y) = \sum_i L(y, i) * Left_i(x), // Right(x, y) = \sum_i R(y, i) * Right_i(x), -// where Left_i(x) = one of the five left polynomials, Right_i(x) = one of the five right polynomials -// Indexing is over i \in {-2, -1, 0, 1, 2}, though this gets mapped to the 0th, 1st, ..., 4th polynomial +// where Left_i(x) = one of the four left polynomials, Right_i(x) = one of the four right polynomials +// Indexing is over i \in {-1, 0, 1, 2} (mapped to the 0th, 1st, 2nd, 3rd polynomial) // -// We also need to define the combined claim: -// claim(y) = \sum_i L(y, i) * claim_i, -// where claim_i is the claim of the i-th product virtualization sumcheck +// The four product constraints are: +// WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, ShouldJump +// (The former Instruction product constraint is now handled directly in the outer sumcheck.) +// +// claim(y) = \sum_i L(y, i) * claim_i // -// The product virtualization sumcheck is then: // \sum_y L(tau_high, y) * \sum_x eq(tau_low, x) * Left(x, y) * Right(x, y) // = claim(tau_high) // // Final claim is: // L(tau_high, r0) * Eq(tau_low, r_tail^rev) * Left(r_tail, r0) * Right(r_tail, r0) -// -// After this, we also need to check the consistency of the Left and Right evaluations with the -// claimed evaluations of the factor polynomials. This is done in the ProductVirtualInner sumcheck. -// -// TODO (Quang): this is essentially Spartan with non-zero claims. We should unify this with Spartan outer/inner. -// Only complication is to generalize the splitting strategy -// (i.e. Spartan outer currently does uni skip for half of the constraints, -// whereas here we do it for all of them) /// Degree of the sumcheck round polynomials for [`ProductVirtualRemainderVerifier`]. const PRODUCT_VIRTUAL_REMAINDER_DEGREE: usize = 3; @@ -79,11 +72,11 @@ const PRODUCT_VIRTUAL_REMAINDER_DEGREE: usize = 3; pub struct ProductVirtualUniSkipParams { /// τ = [τ_low || τ_high] /// - τ_low: the cycle-point r_cycle carried from Spartan outer (length = num_cycle_vars) - /// - τ_high: the univariate-skip binding point sampled for the size-5 domain (length = 1) + /// - τ_high: the univariate-skip binding point sampled for the size-4 domain (length = 1) /// Ordering matches outer: variables are MSB→LSB with τ_high last pub tau: Vec, - /// Base evaluations (claims) for the five product terms at the base domain - /// Order: [Product, WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, ShouldJump] + /// Base evaluations (claims) for the four product terms at the base domain + /// Order: [WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, ShouldJump] pub base_evals: [F; NUM_PRODUCT_VIRTUAL], } @@ -92,9 +85,8 @@ impl ProductVirtualUniSkipParams { opening_accumulator: &dyn OpeningAccumulator, transcript: &mut T, ) -> Self { - // Reuse r_cycle from Stage 1 (outer) for τ_low, and sample τ_high let r_cycle = opening_accumulator - .get_virtual_polynomial_opening(VirtualPolynomial::Product, SumcheckId::SpartanOuter) + .get_virtual_polynomial_opening(PRODUCT_CONSTRAINTS[0].output, SumcheckId::SpartanOuter) .0 .r; let tau_high = transcript.challenge_scalar_optimized::(); @@ -213,20 +205,13 @@ impl ProductVirtualUniSkipProver { /// t1(z) = Σ_{x_out} E_out[x_out] · Σ_{x_in} E_in[x_in] · left_z(x) · right_z(x), /// where x is the concatenation of (x_out || x_in) in MSB→LSB order. /// - /// Lagrange fusion per target z on (current) extended window {−4,−3,3,4}: - /// - Compute c[0..4] = LagrangeHelper::shift_coeffs_i32(shift(z)) using the same shifted-kernel - /// as outer.rs (indices correspond to the 5 base points). - /// - Define fused values at this z by linearly combining the 5 product witnesses with c: + /// Lagrange fusion per target z on the extended window: + /// - Compute c[0..3] = LagrangeHelper::shift_coeffs_i32(shift(z)) using the same shifted-kernel + /// as outer.rs (indices correspond to the 4 base points). + /// - Define fused values at this z by linearly combining the 4 product witnesses with c: /// left_z(x) = Σ_i c[i] · Left_i(x) /// right_z(x) = Σ_i c[i] · Right_i^eff(x) - /// with Right_4^eff(x) = 1 − NextIsNoop(x) for the ShouldJump term only. - /// - /// Small-value lifting rules for integer accumulation before converting to the field: - /// - Instruction: LeftInstructionInput is u64 → lift to i128; RightInstructionInput is S64 → i128. - /// - WriteLookupOutputToRD: IsRdNotZero is bool/u8 → i32; flag is bool/u8 → i32. - /// - WritePCtoRD: IsRdNotZero is bool/u8 → i32; Jump flag is bool/u8 → i32. - /// - ShouldBranch: LookupOutput is u64 → i128; Branch flag is bool/u8 → i32. - /// - ShouldJump: Jump flag (left) is bool/u8 → i32; Right^eff = (1 − NextIsNoop) is bool/u8 → i32. + /// with Right_3^eff(x) = 1 − NextIsNoop(x) for the ShouldJump term only. fn compute_univariate_skip_extended_evals( trace: &[Cycle], tau: &[F::Challenge], @@ -446,11 +431,11 @@ impl SumcheckInstanceParams for ProductVirtualRemainderParams Option { + // With 4 product constraints, the unique left factors are: + // [IsRdNotZero, LookupOutput, Jump] + // and the unique right factors are: + // [WriteLookupOutputToRD, Jump, Branch, NextIsNoop] let left_openings = [ - OpeningId::virt( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::SpartanProductVirtualization, - ), OpeningId::virt( VirtualPolynomial::InstructionFlags(InstructionFlags::IsRdNotZero), SumcheckId::SpartanProductVirtualization, @@ -466,10 +451,6 @@ impl SumcheckInstanceParams for ProductVirtualRemainderParams SumcheckInstanceParams for ProductVirtualRemainderParams SumcheckInstanceParams for ProductVirtualRemainderParams ProductVirtualRemainderParams { PRODUCT_VIRTUAL_UNIVARIATE_SKIP_DOMAIN_SIZE, >(&self.r0); - // Left coefficients: [w[0], w[1]+w[2], w[3], w[4]] - let alpha = [w[0], w[1] + w[2], w[3], w[4]]; + // Left coefficients grouped by unique factor: + // IsRdNotZero (constraints 0,1): w[0]+w[1] + // LookupOutput (constraint 2): w[2] + // Jump (constraint 3): w[3] + let alpha = [w[0] + w[1], w[2], w[3]]; - // Right coefficients: [w[0], w[1], w[2], w[3], -w[4]] - let beta = [w[0], w[1], w[2], w[3], -w[4]]; + // Right coefficients (one per unique right factor): + // WriteLookupOutputToRD: w[0] + // Jump: w[1] + // Branch: w[2] + // NextIsNoop (negated): -w[3] + let beta = [w[0], w[1], w[2], -w[3]]; - let mut challenges = Vec::with_capacity(24); + let mut challenges = Vec::with_capacity(15); // Product coefficients: λ*α_i*β_j for alpha_i in &alpha { @@ -557,9 +547,9 @@ impl ProductVirtualRemainderParams { } } - // Constant contribution coefficients: λ*w[4]*α_i + // Constant contribution from (1 - NextIsNoop): λ*w[3]*α_i for alpha_i in &alpha { - challenges.push(lambda * w[4] * *alpha_i); + challenges.push(lambda * w[3] * *alpha_i); } challenges @@ -576,17 +566,17 @@ impl ProductVirtualRemainderParams { /// bound by this instance (low-to-high from the prover's perspective; the verifier uses the /// reversed vector `r_tail^rev` when evaluating Eq_τ over τ_low). /// -/// Define Lagrange weights over the size-5 domain at r₀: -/// w_i := L_i(r₀) for i ∈ {0..4} corresponding to -/// [Instruction, WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, ShouldJump]. +/// Define Lagrange weights over the size-4 domain at r₀: +/// w_i := L_i(r₀) for i ∈ {0..3} corresponding to +/// [WriteLookupOutputToRD, WritePCtoRD, ShouldBranch, ShouldJump]. /// /// Define fused left/right evaluations at the cycle point r_tail: /// left_eval := Σ_i w_i · eval(Left_i, r_tail) /// right_eval := Σ_i w_i · eval(Right_i, r_tail), except for ShouldJump where -/// Right_4^eff := 1 − NextIsNoop, i.e., use (1 − eval(NextIsNoop, r_tail)). +/// Right_3^eff := 1 − NextIsNoop, i.e., use (1 − eval(NextIsNoop, r_tail)). /// /// Let -/// E_high := L(τ_high, r₀) (Lagrange kernel over the size-5 domain) +/// E_high := L(τ_high, r₀) (Lagrange kernel over the size-4 domain) /// E_low := Eq_τ_low(τ_low, r_tail^rev) (multilinear Eq kernel on the cycle variables) /// /// Then the expected final claim is @@ -847,25 +837,11 @@ impl SumcheckInstanceVerifier accumulator: &VerifierOpeningAccumulator, sumcheck_challenges: &[F::Challenge], ) -> F { - // Lagrange weights at r0 let w = LagrangePolynomial::::evals::< F::Challenge, PRODUCT_VIRTUAL_UNIVARIATE_SKIP_DOMAIN_SIZE, >(&self.params.r0); - // Fetch factor claims - let l_inst = accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::LeftInstructionInput, - SumcheckId::SpartanProductVirtualization, - ) - .1; - let r_inst = accumulator - .get_virtual_polynomial_opening( - VirtualPolynomial::RightInstructionInput, - SumcheckId::SpartanProductVirtualization, - ) - .1; let is_rd_not_zero = accumulator .get_virtual_polynomial_opening( VirtualPolynomial::InstructionFlags(InstructionFlags::IsRdNotZero), @@ -903,18 +879,11 @@ impl SumcheckInstanceVerifier ) .1; - let fused_left = w[0] * l_inst - + w[1] * is_rd_not_zero - + w[2] * is_rd_not_zero - + w[3] * lookup_out - + w[4] * j_flag; - let fused_right = w[0] * r_inst - + w[1] * wl_flag - + w[2] * j_flag - + w[3] * branch_flag - + w[4] * (F::one() - next_is_noop); - - // Multiply by L(τ_high, r0) and Eq(τ_low, r_tail^rev) + let fused_left = + w[0] * is_rd_not_zero + w[1] * is_rd_not_zero + w[2] * lookup_out + w[3] * j_flag; + let fused_right = + w[0] * wl_flag + w[1] * j_flag + w[2] * branch_flag + w[3] * (F::one() - next_is_noop); + let tau_high = &self.params.tau[self.params.tau.len() - 1]; let tau_high_bound_r0 = LagrangePolynomial::::lagrange_kernel::< F::Challenge, diff --git a/jolt-core/src/zkvm/verifier.rs b/jolt-core/src/zkvm/verifier.rs index 9e63a1a254..6e032f3d23 100644 --- a/jolt-core/src/zkvm/verifier.rs +++ b/jolt-core/src/zkvm/verifier.rs @@ -4,6 +4,11 @@ use std::io::{Read, Write}; use std::path::Path; use std::sync::Arc; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use common::jolt_device::MemoryLayout; +use tracer::instruction::Instruction; +use tracer::JoltDevice; + use crate::curve::JoltCurve; use crate::poly::commitment::commitment_scheme::{CommitmentScheme, ZkEvalCommitment}; #[cfg(feature = "zk")] @@ -11,6 +16,8 @@ use crate::poly::commitment::dory::bind_opening_inputs_zk; use crate::poly::commitment::dory::{bind_opening_inputs, DoryContext, DoryGlobals}; #[cfg(feature = "zk")] use crate::poly::lagrange_poly::LagrangeHelper; +#[cfg(not(feature = "zk"))] +use crate::poly::opening_proof::{OpeningPoint, BIG_ENDIAN}; #[cfg(feature = "zk")] use crate::subprotocols::blindfold::{ pedersen_generator_count_for_r1cs, BakedPublicInputs, BlindFoldVerifier, @@ -35,7 +42,7 @@ use crate::zkvm::r1cs::constraints::{ OUTER_FIRST_ROUND_POLY_NUM_COEFFS, OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE, PRODUCT_VIRTUAL_FIRST_ROUND_POLY_NUM_COEFFS, PRODUCT_VIRTUAL_UNIVARIATE_SKIP_DOMAIN_SIZE, }; -use crate::zkvm::ram::RAMPreprocessing; +use crate::zkvm::ram::{gen_ram_initial_memory_state, RAMPreprocessing}; use crate::zkvm::witness::all_committed_polynomials; use crate::zkvm::Serializable; use crate::zkvm::{ @@ -184,11 +191,6 @@ fn scale_batching_coefficients( }) .collect() } -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use common::jolt_device::MemoryLayout; -use tracer::instruction::Instruction; -use tracer::JoltDevice; - pub struct JoltVerifier< 'a, F: JoltField, @@ -202,10 +204,10 @@ pub struct JoltVerifier< pub preprocessing: &'a JoltVerifierPreprocessing, pub transcript: ProofTranscript, pub opening_accumulator: VerifierOpeningAccumulator, - /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). + /// The advice claim reduction sumcheck effectively spans two stages (5 and 6). /// Cache the verifier state here between stages. advice_reduction_verifier_trusted: Option>, - /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). + /// The advice claim reduction sumcheck effectively spans two stages (5 and 6). /// Cache the verifier state here between stages. advice_reduction_verifier_untrusted: Option>, pub spartan_key: UniformSpartanKey, @@ -268,7 +270,6 @@ where #[cfg(not(feature = "zk"))] { - use crate::poly::opening_proof::{OpeningPoint, BIG_ENDIAN}; for (id, (_, claim)) in &proof.opening_claims.0 { let dummy_point = OpeningPoint::::new(vec![]); opening_accumulator @@ -377,12 +378,9 @@ where let stage6_result = self .verify_stage6() .inspect_err(|e| tracing::error!("Stage 6: {e}"))?; - let stage7_result = self + let stage7_data = self .verify_stage7() .inspect_err(|e| tracing::error!("Stage 7: {e}"))?; - let stage8_data = self - .verify_stage8() - .inspect_err(|e| tracing::error!("Stage 8: {e}"))?; if zk_mode { #[cfg(feature = "zk")] @@ -394,7 +392,6 @@ where stage4_result.challenges.clone(), stage5_result.challenges.clone(), stage6_result.challenges.clone(), - stage7_result.challenges.clone(), ]; let uniskip_challenges = [uniskip_challenge1, uniskip_challenge2]; @@ -405,7 +402,6 @@ where stage4_result.batched_output_constraint, stage5_result.batched_output_constraint, stage6_result.batched_output_constraint, - stage7_result.batched_output_constraint, ]; let stage_input_constraints = [ @@ -415,7 +411,6 @@ where stage4_result.batched_input_constraint.clone(), stage5_result.batched_input_constraint.clone(), stage6_result.batched_input_constraint.clone(), - stage7_result.batched_input_constraint.clone(), ]; let stage_input_constraint_values = [ @@ -429,17 +424,15 @@ where stage4_result.input_constraint_challenge_values.clone(), stage5_result.input_constraint_challenge_values.clone(), stage6_result.input_constraint_challenge_values.clone(), - stage7_result.input_constraint_challenge_values.clone(), ]; - let output_constraint_challenge_values: [Vec; 7] = [ + let output_constraint_challenge_values: [Vec; 6] = [ stage1_result.output_constraint_challenge_values.clone(), stage2_result.output_constraint_challenge_values.clone(), stage3_result.output_constraint_challenge_values.clone(), stage4_result.output_constraint_challenge_values.clone(), stage5_result.output_constraint_challenge_values.clone(), stage6_result.output_constraint_challenge_values.clone(), - stage7_result.output_constraint_challenge_values.clone(), ]; self.verify_blindfold( @@ -453,7 +446,7 @@ where &stage2_result.batched_input_constraint, &stage1_result.input_constraint_challenge_values, &stage2_result.input_constraint_challenge_values, - &stage8_data, + &stage7_data, )?; } #[cfg(not(feature = "zk"))] @@ -564,18 +557,6 @@ where &self.proof.rw_config, ); - let spartan_product_virtual_remainder = ProductVirtualRemainderVerifier::new( - self.proof.trace_length, - uni_skip_params.clone(), - &self.opening_accumulator, - ); - - let instruction_claim_reduction = InstructionLookupsClaimReductionSumcheckVerifier::new( - self.proof.trace_length, - &self.opening_accumulator, - &mut self.transcript, - ); - let ram_raf_evaluation = RamRafEvaluationSumcheckVerifier::new( &self.program_io.memory_layout, &self.one_hot_params, @@ -592,12 +573,40 @@ where &self.proof.rw_config, ); + #[cfg(feature = "zk")] + let uniskip_input_constraint = uni_skip_params.input_claim_constraint(); + #[cfg(feature = "zk")] + let uniskip_input_constraint_challenge_values = + uni_skip_params.input_constraint_challenge_values(&self.opening_accumulator); + + let spartan_product_virtual_remainder = ProductVirtualRemainderVerifier::new( + self.proof.trace_length, + uni_skip_params, + &self.opening_accumulator, + ); + + let instruction_claim_reduction = InstructionLookupsClaimReductionSumcheckVerifier::new( + self.proof.trace_length, + &self.opening_accumulator, + &mut self.transcript, + ); + + let spartan_instruction_input = + InstructionInputSumcheckVerifier::new(&self.opening_accumulator, &mut self.transcript); + let spartan_registers_claim_reduction = RegistersClaimReductionSumcheckVerifier::new( + self.proof.trace_length, + &self.opening_accumulator, + &mut self.transcript, + ); + let instances: Vec<&dyn SumcheckInstanceVerifier> = vec![ &ram_read_write_checking, - &spartan_product_virtual_remainder, - &instruction_claim_reduction, &ram_raf_evaluation, &ram_output_check, + &spartan_product_virtual_remainder, + &instruction_claim_reduction, + &spartan_instruction_input, + &spartan_registers_claim_reduction, ]; let (batching_coefficients, r_stage2) = BatchedSumcheck::verify( @@ -632,10 +641,6 @@ where ); } - let uniskip_input_constraint = uni_skip_params.input_claim_constraint(); - let uniskip_input_constraint_challenge_values = - uni_skip_params.input_constraint_challenge_values(&self.opening_accumulator); - let stage_result = StageVerifyResult::with_uniskip( r_stage2, batched_output_constraint, @@ -664,66 +669,7 @@ where &self.opening_accumulator, &mut self.transcript, ); - let spartan_instruction_input = - InstructionInputSumcheckVerifier::new(&self.opening_accumulator, &mut self.transcript); - let spartan_registers_claim_reduction = RegistersClaimReductionSumcheckVerifier::new( - self.proof.trace_length, - &self.opening_accumulator, - &mut self.transcript, - ); - - let instances: Vec<&dyn SumcheckInstanceVerifier> = vec![ - &spartan_shift, - &spartan_instruction_input, - &spartan_registers_claim_reduction, - ]; - - let (batching_coefficients, r_stage3) = BatchedSumcheck::verify( - &self.proof.stage3_sumcheck_proof, - instances.clone(), - &mut self.opening_accumulator, - &mut self.transcript, - )?; - - #[cfg(feature = "zk")] - { - let batched_output_constraint = batch_output_constraints(&instances); - let batched_input_constraint = batch_input_constraints(&instances); - let max_num_rounds = instances.iter().map(|i| i.num_rounds()).max().unwrap(); - let mut output_constraint_challenge_values: Vec = batching_coefficients.clone(); - let mut input_constraint_challenge_values: Vec = - scale_batching_coefficients(&batching_coefficients, &instances); - for instance in &instances { - let num_rounds = instance.num_rounds(); - let offset = instance.round_offset(max_num_rounds); - let r_slice = &r_stage3[offset..offset + num_rounds]; - output_constraint_challenge_values.extend( - instance - .get_params() - .output_constraint_challenge_values(r_slice), - ); - input_constraint_challenge_values.extend( - instance - .get_params() - .input_constraint_challenge_values(&self.opening_accumulator), - ); - } - Ok(StageVerifyResult::new( - r_stage3, - batched_output_constraint, - output_constraint_challenge_values, - batched_input_constraint, - input_constraint_challenge_values, - )) - } - #[cfg(not(feature = "zk"))] - Ok(StageVerifyResult { - challenges: r_stage3, - }) - } - #[cfg_attr(not(feature = "zk"), allow(unused_variables))] - fn verify_stage4(&mut self) -> Result, ProofVerifyError> { let registers_read_write_checking = RegistersReadWriteCheckingVerifier::new( self.proof.trace_length, &self.opening_accumulator, @@ -737,10 +683,9 @@ where self.trusted_advice_commitment.is_some(), &mut self.opening_accumulator, ); - // Domain-separate the batching challenge. self.transcript.append_bytes(b"ram_val_check_gamma", &[]); let ram_val_check_gamma: F = self.transcript.challenge_scalar::(); - let initial_ram_state = crate::zkvm::ram::gen_ram_initial_memory_state::( + let initial_ram_state = gen_ram_initial_memory_state::( self.proof.ram_K, &self.preprocessing.shared.ram, &self.program_io, @@ -756,11 +701,14 @@ where &self.opening_accumulator, ); - let instances: Vec<&dyn SumcheckInstanceVerifier> = - vec![®isters_read_write_checking, &ram_val_check]; + let instances: Vec<&dyn SumcheckInstanceVerifier> = vec![ + &spartan_shift as &dyn SumcheckInstanceVerifier, + ®isters_read_write_checking, + &ram_val_check, + ]; - let (batching_coefficients, r_stage4) = BatchedSumcheck::verify( - &self.proof.stage4_sumcheck_proof, + let (batching_coefficients, r_stage3) = BatchedSumcheck::verify( + &self.proof.stage3_sumcheck_proof, instances.clone(), &mut self.opening_accumulator, &mut self.transcript, @@ -777,7 +725,7 @@ where for instance in &instances { let num_rounds = instance.num_rounds(); let offset = instance.round_offset(max_num_rounds); - let r_slice = &r_stage4[offset..offset + num_rounds]; + let r_slice = &r_stage3[offset..offset + num_rounds]; output_constraint_challenge_values.extend( instance .get_params() @@ -790,7 +738,7 @@ where ); } Ok(StageVerifyResult::new( - r_stage4, + r_stage3, batched_output_constraint, output_constraint_challenge_values, batched_input_constraint, @@ -799,12 +747,12 @@ where } #[cfg(not(feature = "zk"))] Ok(StageVerifyResult { - challenges: r_stage4, + challenges: r_stage3, }) } #[cfg_attr(not(feature = "zk"), allow(unused_variables))] - fn verify_stage5(&mut self) -> Result, ProofVerifyError> { + fn verify_stage4(&mut self) -> Result, ProofVerifyError> { let n_cycle_vars = self.proof.trace_length.log_2(); let lookups_read_raf = InstructionReadRafSumcheckVerifier::new( @@ -828,8 +776,8 @@ where ®isters_val_evaluation, ]; - let (batching_coefficients, r_stage5) = BatchedSumcheck::verify( - &self.proof.stage5_sumcheck_proof, + let (batching_coefficients, r_stage4) = BatchedSumcheck::verify( + &self.proof.stage4_sumcheck_proof, instances.clone(), &mut self.opening_accumulator, &mut self.transcript, @@ -846,7 +794,7 @@ where for instance in &instances { let num_rounds = instance.num_rounds(); let offset = instance.round_offset(max_num_rounds); - let r_slice = &r_stage5[offset..offset + num_rounds]; + let r_slice = &r_stage4[offset..offset + num_rounds]; output_constraint_challenge_values.extend( instance .get_params() @@ -859,7 +807,7 @@ where ); } Ok(StageVerifyResult::new( - r_stage5, + r_stage4, batched_output_constraint, output_constraint_challenge_values, batched_input_constraint, @@ -868,12 +816,12 @@ where } #[cfg(not(feature = "zk"))] Ok(StageVerifyResult { - challenges: r_stage5, + challenges: r_stage4, }) } #[cfg_attr(not(feature = "zk"), allow(unused_variables))] - fn verify_stage6(&mut self) -> Result, ProofVerifyError> { + fn verify_stage5(&mut self) -> Result, ProofVerifyError> { let n_cycle_vars = self.proof.trace_length.log_2(); let bytecode_read_raf = BytecodeReadRafSumcheckVerifier::gen( &self.preprocessing.shared.bytecode, @@ -910,7 +858,7 @@ where &mut self.transcript, ); - // Advice claim reduction (Phase 1 in Stage 6): trusted and untrusted are separate instances. + // Advice claim reduction (Phase 1 in Stage 5): trusted and untrusted are separate instances. if self.trusted_advice_commitment.is_some() { self.advice_reduction_verifier_trusted = Some(AdviceClaimReductionVerifier::new( AdviceKind::Trusted, @@ -943,8 +891,8 @@ where instances.push(advice); } - let (batching_coefficients, r_stage6) = BatchedSumcheck::verify( - &self.proof.stage6_sumcheck_proof, + let (batching_coefficients, r_stage5) = BatchedSumcheck::verify( + &self.proof.stage5_sumcheck_proof, instances.clone(), &mut self.opening_accumulator, &mut self.transcript, @@ -961,7 +909,7 @@ where for instance in &instances { let num_rounds = instance.num_rounds(); let offset = instance.round_offset(max_num_rounds); - let r_slice = &r_stage6[offset..offset + num_rounds]; + let r_slice = &r_stage5[offset..offset + num_rounds]; output_constraint_challenge_values.extend( instance .get_params() @@ -974,7 +922,7 @@ where ); } Ok(StageVerifyResult::new( - r_stage6, + r_stage5, batched_output_constraint, output_constraint_challenge_values, batched_input_constraint, @@ -983,7 +931,7 @@ where } #[cfg(not(feature = "zk"))] Ok(StageVerifyResult { - challenges: r_stage6, + challenges: r_stage5, }) } @@ -991,21 +939,18 @@ where #[allow(clippy::too_many_arguments)] fn verify_blindfold( &mut self, - sumcheck_challenges: &[Vec; 7], + sumcheck_challenges: &[Vec; 6], uniskip_challenges: [F::Challenge; 2], - stage_output_constraints: &[Option; 7], - output_constraint_challenge_values: &[Vec; 7], - stage_input_constraints: &[InputClaimConstraint; 7], - input_constraint_challenge_values: &[Vec; 7], - // For stages 0-1: batched input constraint for regular rounds (different from uni-skip) + stage_output_constraints: &[Option; 6], + output_constraint_challenge_values: &[Vec; 6], + stage_input_constraints: &[InputClaimConstraint; 6], + input_constraint_challenge_values: &[Vec; 6], stage1_batched_input: &InputClaimConstraint, stage2_batched_input: &InputClaimConstraint, stage1_batched_input_values: &[F], stage2_batched_input_values: &[F], - stage8_data: &Stage8VerifyData, + stage7_data: &Stage8VerifyData, ) -> Result<(), ProofVerifyError> { - // Build stage configurations including uni-skip rounds. - // Uni-skip rounds are the first round of stages 1 and 2 (indices 0 and 1). let stage_proofs = [ &self.proof.stage1_sumcheck_proof, &self.proof.stage2_sumcheck_proof, @@ -1013,10 +958,8 @@ where &self.proof.stage4_sumcheck_proof, &self.proof.stage5_sumcheck_proof, &self.proof.stage6_sumcheck_proof, - &self.proof.stage7_sumcheck_proof, ]; - // Precompute power sums for uni-skip domains let outer_power_sums = LagrangeHelper::power_sums::< OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE, OUTER_FIRST_ROUND_POLY_NUM_COEFFS, @@ -1027,13 +970,11 @@ where >(); let mut stage_configs = Vec::new(); - // Track which stage_config index corresponds to uni-skip and regular first rounds - let mut uniskip_indices: Vec = Vec::new(); // Only 2 elements for stages 0-1 - let mut regular_first_round_indices: Vec = Vec::new(); // 7 elements for all stages + let mut uniskip_indices: Vec = Vec::new(); + let mut regular_first_round_indices: Vec = Vec::new(); let mut last_round_indices: Vec = Vec::new(); for (stage_idx, proof) in stage_proofs.iter().enumerate() { - // For stages 0 and 1 (Jolt stages 1 and 2), add uni-skip config first if stage_idx < 2 { let uniskip_proof = if stage_idx == 0 { &self.proof.stage1_uni_skip_first_round_proof @@ -1048,7 +989,6 @@ where product_power_sums.to_vec() }; - // Record uni-skip index for its input constraint uniskip_indices.push(stage_configs.len()); let config = if stage_idx == 0 { @@ -1059,24 +999,17 @@ where stage_configs.push(config); } - // Record first regular round index for its input constraint regular_first_round_indices.push(stage_configs.len()); - // Add regular sumcheck rounds let num_rounds = proof.num_rounds(); for round_idx in 0..num_rounds { let poly_degree = match proof { - crate::subprotocols::sumcheck::SumcheckInstanceProof::Clear(std_proof) => { - std_proof.compressed_polys[round_idx] - .coeffs_except_linear_term - .len() - } - crate::subprotocols::sumcheck::SumcheckInstanceProof::Zk(zk_proof) => { - zk_proof.poly_degrees[round_idx] - } + SumcheckInstanceProof::Clear(std_proof) => std_proof.compressed_polys + [round_idx] + .coeffs_except_linear_term + .len(), + SumcheckInstanceProof::Zk(zk_proof) => zk_proof.poly_degrees[round_idx], }; - // First regular round ALWAYS starts a new chain - // (batched claims differ from uni-skip output due to batching coefficients) let starts_new_chain = round_idx == 0; let config = if starts_new_chain { StageConfig::new_chain(1, poly_degree) @@ -1086,11 +1019,9 @@ where stage_configs.push(config); } - // Record the last round index for output constraint last_round_indices.push(stage_configs.len() - 1); } - // Add final_output configurations using the batched constraints from verifier instances for (stage_idx, constraint) in stage_output_constraints.iter().enumerate() { if let Some(batched) = constraint { let last_round_idx = last_round_indices[stage_idx]; @@ -1099,11 +1030,9 @@ where } } - // Add initial_input configurations for uni-skip stages (stages 0-1) - // These use the uni-skip's own input constraints let uniskip_constraints = [ - stage_input_constraints[0].clone(), // Stage 0 uni-skip - stage_input_constraints[1].clone(), // Stage 1 uni-skip + stage_input_constraints[0].clone(), + stage_input_constraints[1].clone(), ]; for (i, constraint) in uniskip_constraints.iter().enumerate() { let idx = uniskip_indices[i]; @@ -1111,16 +1040,13 @@ where Some(ClaimBindingConfig::with_constraint(constraint.clone())); } - // Add initial_input configurations for regular first rounds (all 7 stages) - // These use the batched input constraints from the stage results let regular_constraints = [ - stage1_batched_input.clone(), // Stage 0 regular - stage2_batched_input.clone(), // Stage 1 regular - stage_input_constraints[2].clone(), // Stage 2 - stage_input_constraints[3].clone(), // Stage 3 - stage_input_constraints[4].clone(), // Stage 4 - stage_input_constraints[5].clone(), // Stage 5 - stage_input_constraints[6].clone(), // Stage 6 + stage1_batched_input.clone(), + stage2_batched_input.clone(), + stage_input_constraints[2].clone(), + stage_input_constraints[3].clone(), + stage_input_constraints[4].clone(), + stage_input_constraints[5].clone(), ]; for (i, constraint) in regular_constraints.iter().enumerate() { let idx = regular_first_round_indices[i]; @@ -1128,7 +1054,7 @@ where Some(ClaimBindingConfig::with_constraint(constraint.clone())); } - let extra_constraint_terms: Vec<(ValueSource, ValueSource)> = stage8_data + let extra_constraint_terms: Vec<(ValueSource, ValueSource)> = stage7_data .opening_ids .iter() .enumerate() @@ -1137,7 +1063,6 @@ where let extra_constraint = OutputClaimConstraint::linear(extra_constraint_terms); let extra_constraints = vec![extra_constraint]; - // Build baked public inputs from expected values let mut baked_challenges: Vec = Vec::new(); for (stage_idx, stage_challenges) in sumcheck_challenges.iter().enumerate() { if stage_idx < 2 { @@ -1148,7 +1073,7 @@ where } } - let all_input_challenge_values: [&[F]; 9] = [ + let all_input_challenge_values: [&[F]; 8] = [ &input_constraint_challenge_values[0], stage1_batched_input_values, &input_constraint_challenge_values[1], @@ -1157,7 +1082,6 @@ where &input_constraint_challenge_values[3], &input_constraint_challenge_values[4], &input_constraint_challenge_values[5], - &input_constraint_challenge_values[6], ]; let mut baked_input_challenges: Vec = Vec::new(); for expected_values in all_input_challenge_values.iter() { @@ -1175,7 +1099,7 @@ where batching_coefficients: Vec::new(), output_constraint_challenges: baked_output_challenges, input_constraint_challenges: baked_input_challenges, - extra_constraint_challenges: stage8_data.constraint_coeffs.clone(), + extra_constraint_challenges: stage7_data.constraint_coeffs.clone(), }; let builder = @@ -1184,7 +1108,6 @@ where let mut round_commitments: Vec = Vec::new(); for (stage_idx, proof) in stage_proofs.iter().enumerate() { - // For stages 0-1, include uni-skip commitment first if stage_idx < 2 { let uniskip_proof = if stage_idx == 0 { &self.proof.stage1_uni_skip_first_round_proof @@ -1195,7 +1118,6 @@ where round_commitments.push(zk_uniskip.commitment); } } - // Add regular sumcheck round commitments if let SumcheckInstanceProof::Zk(zk_proof) = proof { round_commitments.extend(zk_proof.round_commitments.iter().cloned()); } @@ -1237,7 +1159,7 @@ where } #[cfg_attr(not(feature = "zk"), allow(unused_variables))] - fn verify_stage7(&mut self) -> Result, ProofVerifyError> { + fn verify_stage6(&mut self) -> Result, ProofVerifyError> { // Create verifier for HammingWeightClaimReduction // (r_cycle and r_addr_bool are extracted from Booleanity opening internally) let hw_verifier = HammingWeightClaimReductionVerifier::new( @@ -1269,8 +1191,8 @@ where } } - let (batching_coefficients, r_stage7) = BatchedSumcheck::verify( - &self.proof.stage7_sumcheck_proof, + let (batching_coefficients, r_stage6) = BatchedSumcheck::verify( + &self.proof.stage6_sumcheck_proof, instances.clone(), &mut self.opening_accumulator, &mut self.transcript, @@ -1287,7 +1209,7 @@ where for instance in &instances { let num_rounds = instance.num_rounds(); let offset = instance.round_offset(max_num_rounds); - let r_slice = &r_stage7[offset..offset + num_rounds]; + let r_slice = &r_stage6[offset..offset + num_rounds]; output_constraint_challenge_values.extend( instance .get_params() @@ -1300,7 +1222,7 @@ where ); } Ok(StageVerifyResult::new( - r_stage7, + r_stage6, batched_output_constraint, output_constraint_challenge_values, batched_input_constraint, @@ -1309,12 +1231,11 @@ where } #[cfg(not(feature = "zk"))] Ok(StageVerifyResult { - challenges: r_stage7, + challenges: r_stage6, }) } - /// Stage 8: Dory batch opening verification. - fn verify_stage8(&mut self) -> Result, ProofVerifyError> { + fn verify_stage7(&mut self) -> Result, ProofVerifyError> { // Initialize DoryGlobals with the layout from the proof // This ensures the verifier uses the same layout as the prover let _guard = DoryGlobals::initialize_context( @@ -1325,19 +1246,19 @@ where ); // Get the unified opening point from HammingWeightClaimReduction - // This contains (r_address_stage7 || r_cycle_stage6) in big-endian + // This contains (r_address_stage6 || r_cycle_stage5) in big-endian let (opening_point, _) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::InstructionRa(0), SumcheckId::HammingWeightClaimReduction, ); let log_k_chunk = self.one_hot_params.log_k_chunk; - let r_address_stage7 = &opening_point.r[..log_k_chunk]; + let r_address_stage6 = &opening_point.r[..log_k_chunk]; // 1. Collect all (polynomial, claim) pairs let mut polynomial_claims = Vec::new(); let mut scaling_factors = Vec::new(); - // Dense polynomials: RamInc and RdInc (from IncClaimReduction in Stage 6) + // Dense polynomials: RamInc and RdInc (from IncClaimReduction in Stage 5) let (_, ram_inc_claim) = self.opening_accumulator.get_committed_polynomial_opening( CommittedPolynomial::RamInc, SumcheckId::IncClaimReduction, @@ -1347,9 +1268,8 @@ where SumcheckId::IncClaimReduction, ); - // Dense polynomials are zero-padded in the Dory matrix, so their evaluation - // includes a factor eq(r_addr, 0) = ∏(1 − r_addr_i). - let lagrange_factor: F = EqPolynomial::zero_selector(r_address_stage7); + let lagrange_factor: F = EqPolynomial::zero_selector(r_address_stage6); + polynomial_claims.push((CommittedPolynomial::RamInc, ram_inc_claim * lagrange_factor)); scaling_factors.push(lagrange_factor); polynomial_claims.push((CommittedPolynomial::RdInc, rd_inc_claim * lagrange_factor)); @@ -1381,7 +1301,7 @@ where scaling_factors.push(F::one()); } - // Advice polynomials: TrustedAdvice and UntrustedAdvice (from AdviceClaimReduction in Stage 6) + // Advice polynomials: TrustedAdvice and UntrustedAdvice (from AdviceClaimReduction in Stage 5) // These are committed with smaller dimensions, so we apply Lagrange factors to embed // them in the top-left block of the main Dory matrix. let mut include_trusted_advice = false; diff --git a/jolt-core/src/zkvm/witness.rs b/jolt-core/src/zkvm/witness.rs index efcef73652..f0538afcf4 100644 --- a/jolt-core/src/zkvm/witness.rs +++ b/jolt-core/src/zkvm/witness.rs @@ -241,7 +241,6 @@ pub enum VirtualPolynomial { RightLookupOperand, LeftInstructionInput, RightInstructionInput, - Product, ShouldJump, ShouldBranch, WritePCtoRD, diff --git a/z3-verifier/src/cpu_constraints.rs b/z3-verifier/src/cpu_constraints.rs index 7b2f8a43c4..e51aeea6e8 100644 --- a/z3-verifier/src/cpu_constraints.rs +++ b/z3-verifier/src/cpu_constraints.rs @@ -152,7 +152,6 @@ impl JoltState { [ &self.left_input, &self.right_input, - &self.product, &self.write_lookup_output_to_rd, &self.write_pc_to_rd, &self.should_branch, @@ -213,7 +212,6 @@ impl JoltState { match poly { VirtualPolynomial::LeftInstructionInput => &self.left_input, VirtualPolynomial::RightInstructionInput => &self.right_input, - VirtualPolynomial::Product => &self.product, VirtualPolynomial::InstructionFlags(InstructionFlags::IsRdNotZero) => { &self.instruction_flags[InstructionFlags::IsRdNotZero as usize] }