From fbe1b0a37be815f087a4cd82673390bcaf0fd207 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Sat, 28 Mar 2026 15:25:03 -0400 Subject: [PATCH 1/2] Bound proof deserialization without format changes --- .../src/field/challenge/mont_ark_u128.rs | 32 +- jolt-core/src/poly/commitment/hyrax.rs | 54 +++- jolt-core/src/poly/unipoly.rs | 138 ++++++++- .../src/subprotocols/blindfold/protocol.rs | 123 +++++++- .../subprotocols/blindfold/relaxed_r1cs.rs | 149 +++++++++- jolt-core/src/subprotocols/sumcheck.rs | 67 ++++- jolt-core/src/subprotocols/univariate_skip.rs | 73 ++++- jolt-core/src/utils/errors.rs | 2 + jolt-core/src/utils/mod.rs | 1 + jolt-core/src/utils/serialization.rs | 73 +++++ jolt-core/src/zkvm/proof_serialization.rs | 273 +++++++++++++++++- 11 files changed, 946 insertions(+), 39 deletions(-) create mode 100644 jolt-core/src/utils/serialization.rs diff --git a/jolt-core/src/field/challenge/mont_ark_u128.rs b/jolt-core/src/field/challenge/mont_ark_u128.rs index f558cbbea7..2c7bdd6ba8 100644 --- a/jolt-core/src/field/challenge/mont_ark_u128.rs +++ b/jolt-core/src/field/challenge/mont_ark_u128.rs @@ -36,6 +36,9 @@ pub struct MontU128Challenge { // Custom serialization: serialize as [u64; 4] for compatibility with field element format impl Valid for MontU128Challenge { fn check(&self) -> Result<(), SerializationError> { + if (self.high >> 61) != 0 { + return Err(SerializationError::InvalidData); + } Ok(()) } } @@ -63,12 +66,18 @@ impl CanonicalDeserialize for MontU128Challenge { validate: Validate, ) -> Result { let arr = <[u64; 4]>::deserialize_with_mode(reader, compress, validate)?; - // arr[0] and arr[1] should be 0, arr[2] is low, arr[3] is high - Ok(Self { + if arr[0] != 0 || arr[1] != 0 { + return Err(SerializationError::InvalidData); + } + let value = Self { low: arr[2], high: arr[3], _marker: PhantomData, - }) + }; + if validate == Validate::Yes { + value.check()?; + } + Ok(value) } } @@ -230,6 +239,7 @@ impl OptimizedMul for MontU128Challenge { #[cfg(test)] mod tests { use super::*; + use ark_serialize::CanonicalSerialize; #[test] fn masks_high_three_bits_in_default_challenge_width() { @@ -237,4 +247,20 @@ mod tests { assert_eq!(challenge.low, u64::MAX); assert_eq!(challenge.high, (u64::MAX >> 3)); } + + #[test] + fn rejects_nonzero_lower_limbs_on_deserialize() { + let mut bytes = Vec::new(); + [1u64, 0, 0, 0].serialize_compressed(&mut bytes).unwrap(); + assert!(MontU128Challenge::::deserialize_compressed(&bytes[..]).is_err()); + } + + #[test] + fn rejects_high_bits_outside_challenge_width() { + let mut bytes = Vec::new(); + [0u64, 0, 0, 1u64 << 61] + .serialize_compressed(&mut bytes) + .unwrap(); + assert!(MontU128Challenge::::deserialize_compressed(&bytes[..]).is_err()); + } } diff --git a/jolt-core/src/poly/commitment/hyrax.rs b/jolt-core/src/poly/commitment/hyrax.rs index 9d20a43ea4..845fb98695 100644 --- a/jolt-core/src/poly/commitment/hyrax.rs +++ b/jolt-core/src/poly/commitment/hyrax.rs @@ -1,7 +1,11 @@ -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Valid}; use crate::field::JoltField; use crate::poly::eq_poly::EqPolynomial; +use crate::utils::serialization::{ + deserialize_bounded_vec, serialize_vec_with_len, serialized_vec_with_len_size, + MAX_BLINDFOLD_VECTOR_LEN, +}; /// combined[k] = Σ_i eq(ry_row, i) · flat[i*cols + k] pub fn combined_row(flat: &[F], cols: usize, ry_row: &[F]) -> Vec { @@ -44,12 +48,58 @@ pub fn combined_blinding(row_blindings: &[F], ry_row: &[F]) -> F { .sum() } -#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Clone, Debug)] pub struct HyraxOpeningProof { pub combined_row: Vec, pub combined_blinding: F, } +impl CanonicalSerialize for HyraxOpeningProof { + fn serialize_with_mode( + &self, + mut writer: W, + compress: ark_serialize::Compress, + ) -> Result<(), ark_serialize::SerializationError> { + serialize_vec_with_len(&self.combined_row, &mut writer, compress)?; + self.combined_blinding + .serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: ark_serialize::Compress) -> usize { + serialized_vec_with_len_size(&self.combined_row, compress) + + self.combined_blinding.serialized_size(compress) + } +} + +impl ark_serialize::Valid for HyraxOpeningProof { + fn check(&self) -> Result<(), ark_serialize::SerializationError> { + self.combined_row.check()?; + self.combined_blinding.check() + } +} + +impl CanonicalDeserialize for HyraxOpeningProof { + fn deserialize_with_mode( + mut reader: R, + compress: ark_serialize::Compress, + validate: ark_serialize::Validate, + ) -> Result { + let proof = Self { + combined_row: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_BLINDFOLD_VECTOR_LEN, + )?, + combined_blinding: F::deserialize_with_mode(&mut reader, compress, validate)?, + }; + if validate == ark_serialize::Validate::Yes { + proof.check()?; + } + Ok(proof) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/jolt-core/src/poly/unipoly.rs b/jolt-core/src/poly/unipoly.rs index b2a2e0a9b3..be45885bbe 100644 --- a/jolt-core/src/poly/unipoly.rs +++ b/jolt-core/src/poly/unipoly.rs @@ -1,4 +1,8 @@ use crate::field::{ChallengeFieldOps, FieldChallengeOps, JoltField}; +use crate::utils::serialization::{ + deserialize_bounded_vec, serialize_vec_with_len, serialized_vec_with_len_size, + MAX_UNIPOLY_COEFFS, +}; use std::cmp::Ordering; use std::iter::zip; use std::ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Sub}; @@ -14,21 +18,106 @@ use crate::utils::small_scalar::SmallScalar; // ax^2 + bx + c stored as vec![c,b,a] // ax^3 + bx^2 + cx + d stored as vec![d,c,b,a] -#[derive(CanonicalSerialize, CanonicalDeserialize, Debug, Clone, PartialEq, Allocative)] +#[derive(Debug, Clone, PartialEq, Allocative)] pub struct UniPoly { pub coeffs: Vec, } // ax^2 + bx + c stored as vec![c,a] // ax^3 + bx^2 + cx + d stored as vec![d,b,a] -#[derive(CanonicalSerialize, CanonicalDeserialize, Debug, Clone)] +#[derive(Debug, Clone)] pub struct CompressedUniPoly { pub coeffs_except_linear_term: Vec, } +impl CanonicalSerialize for UniPoly { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + serialize_vec_with_len(&self.coeffs, writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + serialized_vec_with_len_size(&self.coeffs, compress) + } +} + +impl Valid for UniPoly { + fn check(&self) -> Result<(), SerializationError> { + if self.coeffs.is_empty() { + return Err(SerializationError::InvalidData); + } + self.coeffs.check() + } +} + +impl CanonicalDeserialize for UniPoly { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let coeffs = deserialize_bounded_vec(reader, compress, validate, MAX_UNIPOLY_COEFFS)?; + let poly = Self { coeffs }; + if validate == Validate::Yes { + poly.check()?; + } + Ok(poly) + } +} + +impl CanonicalSerialize for CompressedUniPoly { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + serialize_vec_with_len(&self.coeffs_except_linear_term, writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + serialized_vec_with_len_size(&self.coeffs_except_linear_term, compress) + } +} + +impl Valid for CompressedUniPoly { + fn check(&self) -> Result<(), SerializationError> { + if self.coeffs_except_linear_term.is_empty() { + return Err(SerializationError::InvalidData); + } + self.coeffs_except_linear_term.check() + } +} + +impl CanonicalDeserialize for CompressedUniPoly { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let coeffs_except_linear_term = + deserialize_bounded_vec(reader, compress, validate, MAX_UNIPOLY_COEFFS)?; + let poly = Self { + coeffs_except_linear_term, + }; + if validate == Validate::Yes { + poly.check()?; + } + Ok(poly) + } +} + impl UniPoly { pub fn from_coeff(coeffs: Vec) -> Self { - UniPoly { coeffs } + if coeffs.is_empty() { + UniPoly { + coeffs: vec![F::zero()], + } + } else { + UniPoly { coeffs } + } } /// Interpolate a polynomial from its evaluations at the points 0, 1, 2, ..., n-1. @@ -188,10 +277,11 @@ impl UniPoly { } pub fn zero() -> Self { - Self::from_coeff(Vec::new()) + Self::from_coeff(vec![F::zero()]) } pub fn degree(&self) -> usize { + debug_assert!(!self.coeffs.is_empty()); self.coeffs.len() - 1 } @@ -297,10 +387,14 @@ impl UniPoly { } pub fn compress(&self) -> CompressedUniPoly { - let mut coeffs_except_linear_term = Vec::with_capacity(self.coeffs.len() - 1); + debug_assert!(!self.coeffs.is_empty()); + let mut coeffs_except_linear_term = + Vec::with_capacity(self.coeffs.len().saturating_sub(1).max(1)); coeffs_except_linear_term.push(self.coeffs[0]); - coeffs_except_linear_term.extend_from_slice(&self.coeffs[2..]); - debug_assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len()); + if self.coeffs.len() > 2 { + coeffs_except_linear_term.extend_from_slice(&self.coeffs[2..]); + debug_assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len()); + } CompressedUniPoly { coeffs_except_linear_term, } @@ -484,6 +578,10 @@ impl CompressedUniPoly { // we require eval(0) + eval(1) = hint, so we can solve for the linear term as: // linear_term = hint - 2 * constant_term - deg2 term - deg3 term pub fn decompress(&self, hint: &F) -> UniPoly { + debug_assert!(!self.coeffs_except_linear_term.is_empty()); + if self.coeffs_except_linear_term.len() == 1 { + return UniPoly::from_coeff(vec![self.coeffs_except_linear_term[0]]); + } let mut linear_term = *hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0]; for i in 1..self.coeffs_except_linear_term.len() { @@ -499,6 +597,10 @@ impl CompressedUniPoly { // In the verifier we do not have to check that f(0) + f(1) = hint as we can just // recover the linear term assuming the prover did it right, then eval the poly pub fn eval_from_hint(&self, hint: &F, x: &F::Challenge) -> F { + debug_assert!(!self.coeffs_except_linear_term.is_empty()); + if self.coeffs_except_linear_term.len() == 1 { + return self.coeffs_except_linear_term[0]; + } let mut linear_term = *hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0]; for i in 1..self.coeffs_except_linear_term.len() { @@ -515,7 +617,12 @@ impl CompressedUniPoly { } pub fn degree(&self) -> usize { - self.coeffs_except_linear_term.len() + debug_assert!(!self.coeffs_except_linear_term.is_empty()); + if self.coeffs_except_linear_term.len() == 1 { + 0 + } else { + self.coeffs_except_linear_term.len() + } } } @@ -523,6 +630,7 @@ impl CompressedUniPoly { mod tests { use super::*; use ark_bn254::Fr; + use ark_serialize::CanonicalSerialize; use rand_chacha::ChaCha20Rng; use rand_core::SeedableRng; @@ -661,4 +769,18 @@ mod tests { ); assert_eq!(poly.coeffs, true_poly.coeffs); } + + #[test] + fn rejects_empty_unipoly_deserialization() { + let mut bytes = Vec::new(); + 0usize.serialize_compressed(&mut bytes).unwrap(); + assert!(UniPoly::::deserialize_compressed(&bytes[..]).is_err()); + } + + #[test] + fn rejects_empty_compressed_unipoly_deserialization() { + let mut bytes = Vec::new(); + 0usize.serialize_compressed(&mut bytes).unwrap(); + assert!(CompressedUniPoly::::deserialize_compressed(&bytes[..]).is_err()); + } } diff --git a/jolt-core/src/subprotocols/blindfold/protocol.rs b/jolt-core/src/subprotocols/blindfold/protocol.rs index 3e12a80f71..20b9c53197 100644 --- a/jolt-core/src/subprotocols/blindfold/protocol.rs +++ b/jolt-core/src/subprotocols/blindfold/protocol.rs @@ -6,7 +6,7 @@ //! 3. Using Spartan sumcheck to prove R1CS satisfaction without revealing the witness //! 4. Hyrax-style openings to verify W(ry) and E(rx) evaluations -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Valid}; use crate::curve::{JoltCurve, JoltGroupElement}; use crate::field::JoltField; @@ -16,6 +16,10 @@ use crate::poly::eq_poly::EqPolynomial; use crate::poly::unipoly::CompressedUniPoly; use crate::transcripts::Transcript; use crate::utils::math::Math; +use crate::utils::serialization::{ + deserialize_bounded_vec, serialize_vec_with_len, serialized_vec_with_len_size, + MAX_BLINDFOLD_VECTOR_LEN, MAX_OPENING_CLAIMS, MAX_SUMCHECK_ROUNDS, +}; use super::folding::{commit_cross_term_rows, compute_cross_term, sample_random_satisfying_pair}; use super::r1cs::VerifierR1CS; @@ -27,7 +31,7 @@ use super::spartan::{INNER_SUMCHECK_DEGREE_BOUND, SPARTAN_DEGREE_BOUND}; /// The real instance is NOT included — verifier reconstructs from round_commitments, /// eval_commitments, public_inputs. The random instance IS included — verifier reads /// it from the proof and absorbs into transcript, never learning the random witness. -#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Clone, Debug)] pub struct BlindFoldProof> { pub random_instance: RelaxedR1CSInstance, @@ -49,6 +53,121 @@ pub struct BlindFoldProof> { pub folded_eval_blindings: Vec, } +impl> CanonicalSerialize for BlindFoldProof { + fn serialize_with_mode( + &self, + mut writer: W, + compress: ark_serialize::Compress, + ) -> Result<(), ark_serialize::SerializationError> { + self.random_instance + .serialize_with_mode(&mut writer, compress)?; + serialize_vec_with_len(&self.noncoeff_row_commitments, &mut writer, compress)?; + serialize_vec_with_len(&self.cross_term_row_commitments, &mut writer, compress)?; + serialize_vec_with_len(&self.spartan_proof, &mut writer, compress)?; + self.az_r.serialize_with_mode(&mut writer, compress)?; + self.bz_r.serialize_with_mode(&mut writer, compress)?; + self.cz_r.serialize_with_mode(&mut writer, compress)?; + serialize_vec_with_len(&self.inner_sumcheck_proof, &mut writer, compress)?; + self.w_opening.serialize_with_mode(&mut writer, compress)?; + self.e_opening.serialize_with_mode(&mut writer, compress)?; + serialize_vec_with_len(&self.folded_eval_outputs, &mut writer, compress)?; + serialize_vec_with_len(&self.folded_eval_blindings, writer, compress) + } + + fn serialized_size(&self, compress: ark_serialize::Compress) -> usize { + self.random_instance.serialized_size(compress) + + serialized_vec_with_len_size(&self.noncoeff_row_commitments, compress) + + serialized_vec_with_len_size(&self.cross_term_row_commitments, compress) + + serialized_vec_with_len_size(&self.spartan_proof, compress) + + self.az_r.serialized_size(compress) + + self.bz_r.serialized_size(compress) + + self.cz_r.serialized_size(compress) + + serialized_vec_with_len_size(&self.inner_sumcheck_proof, compress) + + self.w_opening.serialized_size(compress) + + self.e_opening.serialized_size(compress) + + serialized_vec_with_len_size(&self.folded_eval_outputs, compress) + + serialized_vec_with_len_size(&self.folded_eval_blindings, compress) + } +} + +impl> ark_serialize::Valid for BlindFoldProof { + fn check(&self) -> Result<(), ark_serialize::SerializationError> { + self.random_instance.check()?; + self.noncoeff_row_commitments.check()?; + self.cross_term_row_commitments.check()?; + self.spartan_proof.check()?; + self.az_r.check()?; + self.bz_r.check()?; + self.cz_r.check()?; + self.inner_sumcheck_proof.check()?; + self.w_opening.check()?; + self.e_opening.check()?; + self.folded_eval_outputs.check()?; + self.folded_eval_blindings.check() + } +} + +impl> CanonicalDeserialize for BlindFoldProof { + fn deserialize_with_mode( + mut reader: R, + compress: ark_serialize::Compress, + validate: ark_serialize::Validate, + ) -> Result { + let proof = Self { + random_instance: RelaxedR1CSInstance::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + noncoeff_row_commitments: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_BLINDFOLD_VECTOR_LEN, + )?, + cross_term_row_commitments: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_BLINDFOLD_VECTOR_LEN, + )?, + spartan_proof: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_SUMCHECK_ROUNDS, + )?, + az_r: F::deserialize_with_mode(&mut reader, compress, validate)?, + bz_r: F::deserialize_with_mode(&mut reader, compress, validate)?, + cz_r: F::deserialize_with_mode(&mut reader, compress, validate)?, + inner_sumcheck_proof: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_SUMCHECK_ROUNDS, + )?, + w_opening: HyraxOpeningProof::deserialize_with_mode(&mut reader, compress, validate)?, + e_opening: HyraxOpeningProof::deserialize_with_mode(&mut reader, compress, validate)?, + folded_eval_outputs: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_OPENING_CLAIMS, + )?, + folded_eval_blindings: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_OPENING_CLAIMS, + )?, + }; + if validate == ark_serialize::Validate::Yes { + proof.check()?; + } + Ok(proof) + } +} + pub struct BlindFoldProver<'a, F: JoltField, C: JoltCurve> { gens: &'a PedersenGenerators, r1cs: &'a VerifierR1CS, diff --git a/jolt-core/src/subprotocols/blindfold/relaxed_r1cs.rs b/jolt-core/src/subprotocols/blindfold/relaxed_r1cs.rs index c6026179c7..80cf8b0eb9 100644 --- a/jolt-core/src/subprotocols/blindfold/relaxed_r1cs.rs +++ b/jolt-core/src/subprotocols/blindfold/relaxed_r1cs.rs @@ -17,7 +17,11 @@ use crate::curve::{JoltCurve, JoltGroupElement}; use crate::field::JoltField; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use crate::utils::serialization::{ + deserialize_bounded_vec, serialize_vec_with_len, serialized_vec_with_len_size, + MAX_BLINDFOLD_VECTOR_LEN, +}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Valid}; use rayon::prelude::*; use super::protocol::BlindFoldVerifyError; @@ -30,7 +34,7 @@ use super::r1cs::VerifierR1CS; /// - `round_commitments`: coefficient row commitments (reuse existing sumcheck round commitments) /// - `noncoeff_row_commitments`: non-coefficient row commitments (prover sends in proof) /// - `e_row_commitments`: E row commitments (derived from cross-term and random instance) -#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Clone, Debug)] pub struct RelaxedR1CSInstance> { pub u: F, /// Per-round commitments from ZK sumcheck (= coefficient row commitments) @@ -50,7 +54,7 @@ pub struct RelaxedR1CSInstance> { /// /// W is in grid layout (R' × C). Row blindings cover all W rows. /// E is flat but has per-row blindings for Hyrax opening. -#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Clone, Debug)] pub struct RelaxedR1CSWitness { /// Error vector (zeros for non-relaxed) pub E: Vec, @@ -62,6 +66,145 @@ pub struct RelaxedR1CSWitness { pub e_row_blindings: Vec, } +impl> CanonicalSerialize for RelaxedR1CSInstance { + fn serialize_with_mode( + &self, + mut writer: W, + compress: ark_serialize::Compress, + ) -> Result<(), ark_serialize::SerializationError> { + self.u.serialize_with_mode(&mut writer, compress)?; + serialize_vec_with_len(&self.round_commitments, &mut writer, compress)?; + serialize_vec_with_len(&self.output_claims_row_commitments, &mut writer, compress)?; + serialize_vec_with_len(&self.noncoeff_row_commitments, &mut writer, compress)?; + serialize_vec_with_len(&self.e_row_commitments, &mut writer, compress)?; + serialize_vec_with_len(&self.eval_commitments, writer, compress) + } + + fn serialized_size(&self, compress: ark_serialize::Compress) -> usize { + self.u.serialized_size(compress) + + serialized_vec_with_len_size(&self.round_commitments, compress) + + serialized_vec_with_len_size(&self.output_claims_row_commitments, compress) + + serialized_vec_with_len_size(&self.noncoeff_row_commitments, compress) + + serialized_vec_with_len_size(&self.e_row_commitments, compress) + + serialized_vec_with_len_size(&self.eval_commitments, compress) + } +} + +impl> ark_serialize::Valid for RelaxedR1CSInstance { + fn check(&self) -> Result<(), ark_serialize::SerializationError> { + self.u.check()?; + self.round_commitments.check()?; + self.output_claims_row_commitments.check()?; + self.noncoeff_row_commitments.check()?; + self.e_row_commitments.check()?; + self.eval_commitments.check() + } +} + +impl> CanonicalDeserialize for RelaxedR1CSInstance { + fn deserialize_with_mode( + mut reader: R, + compress: ark_serialize::Compress, + validate: ark_serialize::Validate, + ) -> Result { + let instance = Self { + u: F::deserialize_with_mode(&mut reader, compress, validate)?, + round_commitments: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_BLINDFOLD_VECTOR_LEN, + )?, + output_claims_row_commitments: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_BLINDFOLD_VECTOR_LEN, + )?, + noncoeff_row_commitments: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_BLINDFOLD_VECTOR_LEN, + )?, + e_row_commitments: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_BLINDFOLD_VECTOR_LEN, + )?, + eval_commitments: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_BLINDFOLD_VECTOR_LEN, + )?, + }; + if validate == ark_serialize::Validate::Yes { + instance.check()?; + } + Ok(instance) + } +} + +impl CanonicalSerialize for RelaxedR1CSWitness { + fn serialize_with_mode( + &self, + mut writer: W, + compress: ark_serialize::Compress, + ) -> Result<(), ark_serialize::SerializationError> { + serialize_vec_with_len(&self.E, &mut writer, compress)?; + serialize_vec_with_len(&self.W, &mut writer, compress)?; + serialize_vec_with_len(&self.w_row_blindings, &mut writer, compress)?; + serialize_vec_with_len(&self.e_row_blindings, writer, compress) + } + + fn serialized_size(&self, compress: ark_serialize::Compress) -> usize { + serialized_vec_with_len_size(&self.E, compress) + + serialized_vec_with_len_size(&self.W, compress) + + serialized_vec_with_len_size(&self.w_row_blindings, compress) + + serialized_vec_with_len_size(&self.e_row_blindings, compress) + } +} + +impl ark_serialize::Valid for RelaxedR1CSWitness { + fn check(&self) -> Result<(), ark_serialize::SerializationError> { + self.E.check()?; + self.W.check()?; + self.w_row_blindings.check()?; + self.e_row_blindings.check() + } +} + +impl CanonicalDeserialize for RelaxedR1CSWitness { + fn deserialize_with_mode( + mut reader: R, + compress: ark_serialize::Compress, + validate: ark_serialize::Validate, + ) -> Result { + let witness = Self { + E: deserialize_bounded_vec(&mut reader, compress, validate, MAX_BLINDFOLD_VECTOR_LEN)?, + W: deserialize_bounded_vec(&mut reader, compress, validate, MAX_BLINDFOLD_VECTOR_LEN)?, + w_row_blindings: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_BLINDFOLD_VECTOR_LEN, + )?, + e_row_blindings: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_BLINDFOLD_VECTOR_LEN, + )?, + }; + if validate == ark_serialize::Validate::Yes { + witness.check()?; + } + Ok(witness) + } +} + impl> RelaxedR1CSInstance { #[allow(clippy::too_many_arguments)] pub fn new_non_relaxed( diff --git a/jolt-core/src/subprotocols/sumcheck.rs b/jolt-core/src/subprotocols/sumcheck.rs index 803486ec19..5b0285ec85 100644 --- a/jolt-core/src/subprotocols/sumcheck.rs +++ b/jolt-core/src/subprotocols/sumcheck.rs @@ -15,6 +15,10 @@ use crate::transcripts::Transcript; use crate::utils::errors::ProofVerifyError; #[cfg(not(target_arch = "wasm32"))] use crate::utils::profiling::print_current_memory_usage; +use crate::utils::serialization::{ + deserialize_bounded_vec, serialize_vec_with_len, serialized_vec_with_len_size, + MAX_OPENING_CLAIMS, MAX_SUMCHECK_ROUNDS, +}; use ark_serialize::*; #[cfg(feature = "zk")] @@ -547,12 +551,58 @@ impl BatchedSumcheck { /// Clear (non-ZK) sumcheck proof - coefficients visible to verifier. /// Used in non-ZK mode where the verifier evaluates polynomials directly. -#[derive(CanonicalSerialize, CanonicalDeserialize, Debug, Clone)] +#[derive(Debug, Clone)] pub struct ClearSumcheckProof { pub compressed_polys: Vec>, _marker: PhantomData, } +impl CanonicalSerialize + for ClearSumcheckProof +{ + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + serialize_vec_with_len(&self.compressed_polys, writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + serialized_vec_with_len_size(&self.compressed_polys, compress) + } +} + +impl Valid for ClearSumcheckProof { + fn check(&self) -> Result<(), SerializationError> { + self.compressed_polys.check() + } +} + +impl CanonicalDeserialize + for ClearSumcheckProof +{ + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let proof = Self { + compressed_polys: deserialize_bounded_vec( + reader, + compress, + validate, + MAX_SUMCHECK_ROUNDS, + )?, + _marker: PhantomData, + }; + if validate == Validate::Yes { + proof.check()?; + } + Ok(proof) + } +} + impl ClearSumcheckProof { pub fn new(compressed_polys: Vec>) -> Self { Self { @@ -579,6 +629,14 @@ impl ClearSumcheckProof degree_bound { return Err(ProofVerifyError::InvalidInputLength( degree_bound, @@ -658,10 +716,11 @@ impl, ProofTranscript: Transcript> CanonicalDe validate: ark_serialize::Validate, ) -> Result { let round_commitments = - Vec::::deserialize_with_mode(&mut reader, compress, validate)?; - let poly_degrees = Vec::::deserialize_with_mode(&mut reader, compress, validate)?; + deserialize_bounded_vec(&mut reader, compress, validate, MAX_SUMCHECK_ROUNDS)?; + let poly_degrees = + deserialize_bounded_vec(&mut reader, compress, validate, MAX_SUMCHECK_ROUNDS)?; let output_claims_commitments = - Vec::::deserialize_with_mode(reader, compress, validate)?; + deserialize_bounded_vec(reader, compress, validate, MAX_OPENING_CLAIMS)?; Ok(Self { round_commitments, poly_degrees, diff --git a/jolt-core/src/subprotocols/univariate_skip.rs b/jolt-core/src/subprotocols/univariate_skip.rs index 64ba1224dd..7518a55c75 100644 --- a/jolt-core/src/subprotocols/univariate_skip.rs +++ b/jolt-core/src/subprotocols/univariate_skip.rs @@ -1,6 +1,8 @@ use std::marker::PhantomData; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate, +}; #[cfg(feature = "zk")] use rand_core::CryptoRngCore; @@ -17,6 +19,10 @@ use crate::subprotocols::sumcheck_prover::SumcheckInstanceProver; use crate::subprotocols::sumcheck_verifier::SumcheckInstanceVerifier; use crate::transcripts::Transcript; use crate::utils::errors::ProofVerifyError; +use crate::utils::serialization::{ + deserialize_bounded_vec, serialize_vec_with_len, serialized_vec_with_len_size, + MAX_OPENING_CLAIMS, +}; /// Returns the interleaved symmetric univariate-skip target indices outside the base window. /// @@ -214,12 +220,49 @@ pub fn prove_uniskip_round_zk< /// The sumcheck proof for a univariate skip round /// Consists of the (single) univariate polynomial sent in that round, no omission of any coefficient -#[derive(CanonicalSerialize, CanonicalDeserialize, Debug, Clone)] +#[derive(Debug, Clone)] pub struct UniSkipFirstRoundProof { pub uni_poly: UniPoly, _marker: PhantomData, } +impl CanonicalSerialize for UniSkipFirstRoundProof { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.uni_poly.serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.uni_poly.serialized_size(compress) + } +} + +impl Valid for UniSkipFirstRoundProof { + fn check(&self) -> Result<(), SerializationError> { + self.uni_poly.check() + } +} + +impl CanonicalDeserialize for UniSkipFirstRoundProof { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let proof = Self { + uni_poly: UniPoly::deserialize_with_mode(reader, compress, validate)?, + _marker: PhantomData, + }; + if validate == Validate::Yes { + proof.check()?; + } + Ok(proof) + } +} + impl UniSkipFirstRoundProof { pub fn new(uni_poly: UniPoly) -> Self { Self { @@ -237,6 +280,11 @@ impl UniSkipFirstRoundProof { transcript: &mut T, ) -> Result { let degree_bound = sumcheck_instance.degree(); + if proof.uni_poly.coeffs.is_empty() { + return Err(ProofVerifyError::MalformedProof( + "empty uniskip polynomial".to_string(), + )); + } // Degree check for the high-degree first polynomial if proof.uni_poly.degree() > degree_bound { return Err(ProofVerifyError::InvalidInputLength( @@ -325,19 +373,18 @@ impl, T: Transcript> CanonicalSerialize fn serialize_with_mode( &self, mut writer: W, - compress: ark_serialize::Compress, - ) -> Result<(), ark_serialize::SerializationError> { + compress: Compress, + ) -> Result<(), SerializationError> { self.commitment.serialize_with_mode(&mut writer, compress)?; self.poly_degree .serialize_with_mode(&mut writer, compress)?; - self.output_claims_commitments - .serialize_with_mode(writer, compress) + serialize_vec_with_len(&self.output_claims_commitments, writer, compress) } - fn serialized_size(&self, compress: ark_serialize::Compress) -> usize { + fn serialized_size(&self, compress: Compress) -> usize { self.commitment.serialized_size(compress) + self.poly_degree.serialized_size(compress) - + self.output_claims_commitments.serialized_size(compress) + + serialized_vec_with_len_size(&self.output_claims_commitments, compress) } } @@ -346,13 +393,13 @@ impl, T: Transcript> CanonicalDeserialize { fn deserialize_with_mode( mut reader: R, - compress: ark_serialize::Compress, - validate: ark_serialize::Validate, - ) -> Result { + compress: Compress, + validate: Validate, + ) -> Result { let commitment = C::G1::deserialize_with_mode(&mut reader, compress, validate)?; let poly_degree = usize::deserialize_with_mode(&mut reader, compress, validate)?; let output_claims_commitments = - Vec::::deserialize_with_mode(reader, compress, validate)?; + deserialize_bounded_vec(reader, compress, validate, MAX_OPENING_CLAIMS)?; Ok(Self::new( commitment, poly_degree, @@ -364,7 +411,7 @@ impl, T: Transcript> CanonicalDeserialize impl, T: Transcript> ark_serialize::Valid for ZkUniSkipFirstRoundProof { - fn check(&self) -> Result<(), ark_serialize::SerializationError> { + fn check(&self) -> Result<(), SerializationError> { self.commitment.check()?; self.output_claims_commitments.check() } diff --git a/jolt-core/src/utils/errors.rs b/jolt-core/src/utils/errors.rs index 2ca3d05e99..af5bf39a62 100644 --- a/jolt-core/src/utils/errors.rs +++ b/jolt-core/src/utils/errors.rs @@ -42,4 +42,6 @@ pub enum ProofVerifyError { ZkFeatureRequired, #[error("BlindFold verification failed: {0}")] BlindFoldError(String), + #[error("Malformed proof: {0}")] + MalformedProof(String), } diff --git a/jolt-core/src/utils/mod.rs b/jolt-core/src/utils/mod.rs index 3669d0b6a8..e1260f4ad6 100644 --- a/jolt-core/src/utils/mod.rs +++ b/jolt-core/src/utils/mod.rs @@ -12,6 +12,7 @@ pub mod math; #[cfg(feature = "monitor")] pub mod monitor; pub mod profiling; +pub mod serialization; pub mod small_scalar; pub mod thread; diff --git a/jolt-core/src/utils/serialization.rs b/jolt-core/src/utils/serialization.rs new file mode 100644 index 0000000000..be8c60ea62 --- /dev/null +++ b/jolt-core/src/utils/serialization.rs @@ -0,0 +1,73 @@ +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Validate, +}; +use std::io::{Read, Write}; + +pub const MAX_JOLT_COMMITMENTS: usize = 1 << 12; +pub const MAX_OPENING_CLAIMS: usize = 1 << 16; +pub const MAX_SUMCHECK_ROUNDS: usize = 1 << 16; +pub const MAX_UNIPOLY_COEFFS: usize = 1 << 16; +pub const MAX_BLINDFOLD_VECTOR_LEN: usize = 1 << 20; + +pub fn serialize_vec_with_len( + values: &[T], + mut writer: W, + compress: Compress, +) -> Result<(), SerializationError> { + values.len().serialize_with_mode(&mut writer, compress)?; + for value in values { + value.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) +} + +pub fn serialized_vec_with_len_size( + values: &[T], + compress: Compress, +) -> usize { + values.len().serialized_size(compress) + + values + .iter() + .map(|value| value.serialized_size(compress)) + .sum::() +} + +pub fn deserialize_bounded_len( + mut reader: R, + compress: Compress, + validate: Validate, + max_len: usize, +) -> Result { + let len = usize::deserialize_with_mode(&mut reader, compress, validate)?; + if len > max_len { + return Err(SerializationError::InvalidData); + } + Ok(len) +} + +pub fn deserialize_bounded_vec( + mut reader: R, + compress: Compress, + validate: Validate, + max_len: usize, +) -> Result, SerializationError> { + let len = deserialize_bounded_len(&mut reader, compress, validate, max_len)?; + let mut values = Vec::with_capacity(len); + for _ in 0..len { + values.push(T::deserialize_with_mode(&mut reader, compress, validate)?); + } + Ok(values) +} + +pub fn deserialize_bounded_u32_len( + mut reader: R, + compress: Compress, + validate: Validate, + max_len: usize, +) -> Result { + let len = u32::deserialize_with_mode(&mut reader, compress, validate)? as usize; + if len > max_len { + return Err(SerializationError::InvalidData); + } + Ok(len) +} diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index 343881168b..fbff8b8051 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -12,6 +12,8 @@ use strum::EnumCount; use crate::poly::opening_proof::{OpeningPoint, Openings}; #[cfg(feature = "zk")] use crate::subprotocols::blindfold::BlindFoldProof; +#[cfg(not(feature = "zk"))] +use crate::utils::serialization::MAX_OPENING_CLAIMS; use crate::{ curve::JoltCurve, field::JoltField, @@ -25,6 +27,10 @@ use crate::{ sumcheck::SumcheckInstanceProof, univariate_skip::UniSkipFirstRoundProofVariant, }, transcripts::Transcript, + utils::serialization::{ + deserialize_bounded_vec, serialize_vec_with_len, serialized_vec_with_len_size, + MAX_JOLT_COMMITMENTS, + }, zkvm::{ config::{OneHotConfig, ReadWriteConfig}, instruction::{CircuitFlags, InstructionFlags}, @@ -32,7 +38,7 @@ use crate::{ }, }; -#[derive(CanonicalSerialize, CanonicalDeserialize)] +#[derive(Clone)] pub struct JoltProof< F: JoltField, C: JoltCurve, @@ -62,7 +68,218 @@ pub struct JoltProof< pub dory_layout: DoryLayout, } +impl, PCS: CommitmentScheme, FS: Transcript> + CanonicalSerialize for JoltProof +{ + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + serialize_vec_with_len(&self.commitments, &mut writer, compress)?; + self.stage1_uni_skip_first_round_proof + .serialize_with_mode(&mut writer, compress)?; + self.stage1_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + self.stage2_uni_skip_first_round_proof + .serialize_with_mode(&mut writer, compress)?; + self.stage2_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + self.stage3_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + self.stage4_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + self.stage5_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + self.stage6_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + self.stage7_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + #[cfg(feature = "zk")] + self.blindfold_proof + .serialize_with_mode(&mut writer, compress)?; + self.joint_opening_proof + .serialize_with_mode(&mut writer, compress)?; + self.untrusted_advice_commitment + .serialize_with_mode(&mut writer, compress)?; + #[cfg(not(feature = "zk"))] + self.opening_claims + .serialize_with_mode(&mut writer, compress)?; + self.trace_length + .serialize_with_mode(&mut writer, compress)?; + self.ram_K.serialize_with_mode(&mut writer, compress)?; + self.rw_config.serialize_with_mode(&mut writer, compress)?; + self.one_hot_config + .serialize_with_mode(&mut writer, compress)?; + self.dory_layout.serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + serialized_vec_with_len_size(&self.commitments, compress) + + self + .stage1_uni_skip_first_round_proof + .serialized_size(compress) + + self.stage1_sumcheck_proof.serialized_size(compress) + + self + .stage2_uni_skip_first_round_proof + .serialized_size(compress) + + self.stage2_sumcheck_proof.serialized_size(compress) + + self.stage3_sumcheck_proof.serialized_size(compress) + + self.stage4_sumcheck_proof.serialized_size(compress) + + self.stage5_sumcheck_proof.serialized_size(compress) + + self.stage6_sumcheck_proof.serialized_size(compress) + + self.stage7_sumcheck_proof.serialized_size(compress) + + { + #[cfg(feature = "zk")] + { + self.blindfold_proof.serialized_size(compress) + } + #[cfg(not(feature = "zk"))] + { + 0 + } + } + + self.joint_opening_proof.serialized_size(compress) + + self.untrusted_advice_commitment.serialized_size(compress) + + { + #[cfg(not(feature = "zk"))] + { + self.opening_claims.serialized_size(compress) + } + #[cfg(feature = "zk")] + { + 0 + } + } + + self.trace_length.serialized_size(compress) + + self.ram_K.serialized_size(compress) + + self.rw_config.serialized_size(compress) + + self.one_hot_config.serialized_size(compress) + + self.dory_layout.serialized_size(compress) + } +} + +impl, PCS: CommitmentScheme, FS: Transcript> Valid + for JoltProof +{ + fn check(&self) -> Result<(), SerializationError> { + self.commitments.check()?; + self.stage1_uni_skip_first_round_proof.check()?; + self.stage1_sumcheck_proof.check()?; + self.stage2_uni_skip_first_round_proof.check()?; + self.stage2_sumcheck_proof.check()?; + self.stage3_sumcheck_proof.check()?; + self.stage4_sumcheck_proof.check()?; + self.stage5_sumcheck_proof.check()?; + self.stage6_sumcheck_proof.check()?; + self.stage7_sumcheck_proof.check()?; + #[cfg(feature = "zk")] + self.blindfold_proof.check()?; + self.joint_opening_proof.check()?; + self.untrusted_advice_commitment.check()?; + #[cfg(not(feature = "zk"))] + self.opening_claims.check()?; + self.rw_config.check()?; + self.one_hot_config.check()?; + self.dory_layout.check() + } +} + +impl, PCS: CommitmentScheme, FS: Transcript> + CanonicalDeserialize for JoltProof +{ + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let proof = Self { + commitments: deserialize_bounded_vec( + &mut reader, + compress, + validate, + MAX_JOLT_COMMITMENTS, + )?, + stage1_uni_skip_first_round_proof: + UniSkipFirstRoundProofVariant::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + stage1_sumcheck_proof: SumcheckInstanceProof::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + stage2_uni_skip_first_round_proof: + UniSkipFirstRoundProofVariant::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + stage2_sumcheck_proof: SumcheckInstanceProof::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + stage3_sumcheck_proof: SumcheckInstanceProof::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + stage4_sumcheck_proof: SumcheckInstanceProof::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + stage5_sumcheck_proof: SumcheckInstanceProof::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + stage6_sumcheck_proof: SumcheckInstanceProof::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + stage7_sumcheck_proof: SumcheckInstanceProof::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + #[cfg(feature = "zk")] + blindfold_proof: BlindFoldProof::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + joint_opening_proof: PCS::Proof::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + untrusted_advice_commitment: Option::::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + #[cfg(not(feature = "zk"))] + opening_claims: Claims::deserialize_with_mode(&mut reader, compress, validate)?, + trace_length: usize::deserialize_with_mode(&mut reader, compress, validate)?, + ram_K: usize::deserialize_with_mode(&mut reader, compress, validate)?, + rw_config: ReadWriteConfig::deserialize_with_mode(&mut reader, compress, validate)?, + one_hot_config: OneHotConfig::deserialize_with_mode(&mut reader, compress, validate)?, + dory_layout: DoryLayout::deserialize_with_mode(&mut reader, compress, validate)?, + }; + if validate == Validate::Yes { + proof.check()?; + } + Ok(proof) + } +} + #[cfg(not(feature = "zk"))] +#[derive(Clone)] pub struct Claims(pub Openings); #[cfg(not(feature = "zk"))] @@ -93,7 +310,12 @@ impl CanonicalSerialize for Claims { #[cfg(not(feature = "zk"))] impl Valid for Claims { fn check(&self) -> Result<(), SerializationError> { - Ok(()) + self.0 + .iter() + .try_for_each(|(id, (_point, claim))| -> Result<(), SerializationError> { + id.check()?; + claim.check() + }) } } @@ -105,13 +327,25 @@ impl CanonicalDeserialize for Claims { validate: Validate, ) -> Result { let size = usize::deserialize_with_mode(&mut reader, compress, validate)?; + if size > MAX_OPENING_CLAIMS { + return Err(SerializationError::InvalidData); + } let mut claims = BTreeMap::new(); for _ in 0..size { let key = OpeningId::deserialize_with_mode(&mut reader, compress, validate)?; let claim = F::deserialize_with_mode(&mut reader, compress, validate)?; - claims.insert(key, (OpeningPoint::default(), claim)); + if claims + .insert(key, (OpeningPoint::default(), claim)) + .is_some() + { + return Err(SerializationError::InvalidData); + } } - Ok(Claims(claims)) + let claims = Claims(claims); + if validate == Validate::Yes { + claims.check()?; + } + Ok(claims) } } @@ -515,3 +749,34 @@ pub fn serialize_and_print_size( tracing::info!("{item_name} size: {file_size_kb:.1} kB"); Ok(()) } + +#[cfg(all(test, not(feature = "zk")))] +mod tests { + use super::*; + use ark_bn254::Fr; + + #[test] + fn claims_reject_duplicate_keys() { + let key = OpeningId::committed( + CommittedPolynomial::InstructionRa(0), + SumcheckId::HammingWeightClaimReduction, + ); + let mut bytes = Vec::new(); + 2usize.serialize_compressed(&mut bytes).unwrap(); + key.serialize_compressed(&mut bytes).unwrap(); + Fr::from(1u64).serialize_compressed(&mut bytes).unwrap(); + key.serialize_compressed(&mut bytes).unwrap(); + Fr::from(1u64).serialize_compressed(&mut bytes).unwrap(); + + assert!(Claims::::deserialize_compressed(&bytes[..]).is_err()); + } + + #[test] + fn claims_reject_oversized_length_prefix() { + let mut bytes = Vec::new(); + (MAX_OPENING_CLAIMS + 1) + .serialize_compressed(&mut bytes) + .unwrap(); + assert!(Claims::::deserialize_compressed(&bytes[..]).is_err()); + } +} From fe8c1f9ad8aea6fc89b81e874780b5d43fb344b3 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Sat, 28 Mar 2026 17:28:37 -0400 Subject: [PATCH 2/2] Fix compressed linear unipoly recovery --- jolt-core/src/poly/unipoly.rs | 40 +++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/jolt-core/src/poly/unipoly.rs b/jolt-core/src/poly/unipoly.rs index be45885bbe..e206a5b1ea 100644 --- a/jolt-core/src/poly/unipoly.rs +++ b/jolt-core/src/poly/unipoly.rs @@ -575,18 +575,23 @@ impl MulAssign<&F> for UniPoly { } impl CompressedUniPoly { + fn recover_linear_term(&self, hint: &F) -> F { + let constant_term = self.coeffs_except_linear_term[0]; + let mut linear_term = *hint - constant_term - constant_term; + for coeff in &self.coeffs_except_linear_term[1..] { + linear_term -= *coeff; + } + linear_term + } + // we require eval(0) + eval(1) = hint, so we can solve for the linear term as: // linear_term = hint - 2 * constant_term - deg2 term - deg3 term pub fn decompress(&self, hint: &F) -> UniPoly { debug_assert!(!self.coeffs_except_linear_term.is_empty()); - if self.coeffs_except_linear_term.len() == 1 { + let linear_term = self.recover_linear_term(hint); + if self.coeffs_except_linear_term.len() == 1 && linear_term.is_zero() { return UniPoly::from_coeff(vec![self.coeffs_except_linear_term[0]]); } - let mut linear_term = - *hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0]; - for i in 1..self.coeffs_except_linear_term.len() { - linear_term -= self.coeffs_except_linear_term[i]; - } let mut coeffs = vec![self.coeffs_except_linear_term[0], linear_term]; coeffs.extend(&self.coeffs_except_linear_term[1..]); @@ -598,14 +603,7 @@ impl CompressedUniPoly { // recover the linear term assuming the prover did it right, then eval the poly pub fn eval_from_hint(&self, hint: &F, x: &F::Challenge) -> F { debug_assert!(!self.coeffs_except_linear_term.is_empty()); - if self.coeffs_except_linear_term.len() == 1 { - return self.coeffs_except_linear_term[0]; - } - let mut linear_term = - *hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0]; - for i in 1..self.coeffs_except_linear_term.len() { - linear_term -= self.coeffs_except_linear_term[i]; - } + let linear_term = self.recover_linear_term(hint); let mut running_point: F = (*x).into(); let mut running_sum = self.coeffs_except_linear_term[0] + *x * linear_term; @@ -685,6 +683,20 @@ mod tests { fn test_from_evals_cubic() { test_from_evals_cubic_helper::() } + + #[test] + fn test_compressed_linear_round_trip() { + let poly = UniPoly::::from_coeff(vec![Fr::from_u64(5), Fr::from_u64(7)]); + let hint = poly.eval_at_zero() + poly.eval_at_one(); + let compressed_poly = poly.compress(); + let decompressed_poly = compressed_poly.decompress(&hint); + + assert_eq!(decompressed_poly.coeffs, poly.coeffs); + + let x = ::Challenge::from(9u128); + assert_eq!(compressed_poly.eval_from_hint(&hint, &x), poly.evaluate(&x)); + } + fn test_from_evals_cubic_helper() { // polynomial is x^3 + 2x^2 + 3x + 1 let e0 = F::one();