From 82ce40665a50de21e933459ac8c7a145b95b79ba Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Mon, 16 Feb 2026 12:14:25 -0500 Subject: [PATCH 1/9] feat(zkvm): add framed transport and robust proof serialization Introduce a self-describing framed encoding for proofs and the minimal transport helpers. This removes dependence on brittle enum counts and enables strict parsing with length caps. Co-authored-by: Cursor --- jolt-core/src/zkvm/mod.rs | 1 + jolt-core/src/zkvm/proof_serialization.rs | 759 ++++++++++++++++++++-- jolt-core/src/zkvm/transport.rs | 115 ++++ jolt-sdk/src/host_utils.rs | 1 + 4 files changed, 818 insertions(+), 58 deletions(-) create mode 100644 jolt-core/src/zkvm/transport.rs diff --git a/jolt-core/src/zkvm/mod.rs b/jolt-core/src/zkvm/mod.rs index 9aacba500d..0b75d524bc 100644 --- a/jolt-core/src/zkvm/mod.rs +++ b/jolt-core/src/zkvm/mod.rs @@ -33,6 +33,7 @@ pub mod r1cs; pub mod ram; pub mod registers; pub mod spartan; +pub mod transport; pub mod verifier; pub mod witness; diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index c6de7a1c60..0d24fc832d 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -7,7 +7,6 @@ use ark_serialize::{ CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate, }; use num::FromPrimitive; -use strum::EnumCount; use crate::{ field::JoltField, @@ -27,7 +26,34 @@ use crate::{ poly::opening_proof::PolynomialId, subprotocols::univariate_skip::UniSkipFirstRoundProof, }; -#[derive(CanonicalSerialize, CanonicalDeserialize)] +use crate::zkvm::transport; + +/// Stream signature for `JoltProof` bytes (clean rewrite). +/// +/// This is a short fixed header to fail fast on wrong-format inputs. +const PROOF_SIGNATURE: &[u8; 8] = b"JOLTPRF\0"; + +// Frame tags for proof sections. Decoding is strict: unknown tags are rejected. +const TAG_PARAMS: u8 = 1; +const TAG_COMMITMENTS: u8 = 2; +const TAG_OPENING_CLAIMS: u8 = 3; +const TAG_STAGE1: u8 = 10; +const TAG_STAGE2: u8 = 11; +const TAG_STAGE3: u8 = 12; +const TAG_STAGE4: u8 = 13; +const TAG_STAGE5: u8 = 14; +const TAG_STAGE6: u8 = 15; +const TAG_STAGE7: u8 = 16; +const TAG_JOINT_OPENING: u8 = 20; + +// Per-section payload caps (DoS resistance). These can be tuned if legitimate proofs grow. +const MAX_PARAMS_LEN: u64 = 16 * 1024; +const MAX_COMMITMENTS_LEN: u64 = 256 * 1024 * 1024; +const MAX_OPENING_CLAIMS_LEN: u64 = 256 * 1024 * 1024; +const MAX_STAGE_LEN: u64 = 512 * 1024 * 1024; +const MAX_JOINT_OPENING_LEN: u64 = 512 * 1024 * 1024; +const MAX_ANY_SECTION_LEN: u64 = 512 * 1024 * 1024; + pub struct JoltProof, FS: Transcript> { pub opening_claims: Claims, pub commitments: Vec, @@ -50,6 +76,500 @@ pub struct JoltProof, FS: Transcr pub dory_layout: DoryLayout, } +#[inline] +fn io_err(e: std::io::Error) -> SerializationError { + SerializationError::IoError(e) +} + +#[inline] +fn write_u8(w: &mut W, b: u8) -> Result<(), SerializationError> { + w.write_all(&[b]).map_err(io_err) +} + +#[inline] +fn write_varint_u64(w: &mut W, x: u64) -> Result<(), SerializationError> { + transport::write_varint_u64(w, x).map_err(io_err) +} + +#[inline] +fn read_varint_u64(r: &mut R) -> Result { + transport::read_varint_u64(r).map_err(io_err) +} + +#[inline] +fn write_frame_header(w: &mut W, tag: u8, len: u64) -> Result<(), SerializationError> { + transport::write_frame_header(w, tag, len).map_err(io_err) +} + +#[inline] +fn read_frame_header( + r: &mut R, + max_len: u64, +) -> Result, SerializationError> { + transport::read_frame_header(r, max_len).map_err(io_err) +} + +#[inline] +fn section_cap_for_tag(tag: u8) -> u64 { + match tag { + TAG_PARAMS => MAX_PARAMS_LEN, + TAG_COMMITMENTS => MAX_COMMITMENTS_LEN, + TAG_OPENING_CLAIMS => MAX_OPENING_CLAIMS_LEN, + TAG_STAGE1 | TAG_STAGE2 | TAG_STAGE3 | TAG_STAGE4 | TAG_STAGE5 | TAG_STAGE6 + | TAG_STAGE7 => MAX_STAGE_LEN, + TAG_JOINT_OPENING => MAX_JOINT_OPENING_LEN, + _ => 0, + } +} + +impl, FS: Transcript> CanonicalSerialize + for JoltProof +{ + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + transport::signature_write(&mut writer, PROOF_SIGNATURE).map_err(io_err)?; + + // ---------------- Params ---------------- + let params_len = (transport::varint_u64_len(self.trace_length as u64) + + transport::varint_u64_len(self.ram_K as u64) + + transport::varint_u64_len(self.bytecode_K as u64)) as u64 + + self.rw_config.serialized_size(compress) as u64 + + self.one_hot_config.serialized_size(compress) as u64 + + self.dory_layout.serialized_size(compress) as u64; + write_frame_header(&mut writer, TAG_PARAMS, params_len)?; + write_varint_u64(&mut writer, self.trace_length as u64)?; + write_varint_u64(&mut writer, self.ram_K as u64)?; + write_varint_u64(&mut writer, self.bytecode_K as u64)?; + self.rw_config.serialize_with_mode(&mut writer, compress)?; + self.one_hot_config + .serialize_with_mode(&mut writer, compress)?; + self.dory_layout + .serialize_with_mode(&mut writer, compress)?; + + // ---------------- Commitments ---------------- + let commitments_count_len = transport::varint_u64_len(self.commitments.len() as u64) as u64; + let commitments_items_len: u64 = self + .commitments + .iter() + .map(|c| c.serialized_size(compress) as u64) + .sum(); + let untrusted_commitment_len = self + .untrusted_advice_commitment + .as_ref() + .map(|c| c.serialized_size(compress) as u64) + .unwrap_or(0); + let commitments_len = + commitments_count_len + commitments_items_len + 1 + untrusted_commitment_len; + write_frame_header(&mut writer, TAG_COMMITMENTS, commitments_len)?; + write_varint_u64(&mut writer, self.commitments.len() as u64)?; + for c in &self.commitments { + c.serialize_with_mode(&mut writer, compress)?; + } + match &self.untrusted_advice_commitment { + None => write_u8(&mut writer, 0)?, + Some(c) => { + write_u8(&mut writer, 1)?; + c.serialize_with_mode(&mut writer, compress)?; + } + } + + // ---------------- Opening claims ---------------- + let claims_count_len = transport::varint_u64_len(self.opening_claims.0.len() as u64) as u64; + let claims_items_len: u64 = self + .opening_claims + .0 + .iter() + .map(|(k, (_p, claim))| { + (k.serialized_size(compress) + claim.serialized_size(compress)) as u64 + }) + .sum(); + let claims_len = claims_count_len + claims_items_len; + write_frame_header(&mut writer, TAG_OPENING_CLAIMS, claims_len)?; + write_varint_u64(&mut writer, self.opening_claims.0.len() as u64)?; + for (k, (_p, claim)) in self.opening_claims.0.iter() { + k.serialize_with_mode(&mut writer, compress)?; + claim.serialize_with_mode(&mut writer, compress)?; + } + + // ---------------- Stages ---------------- + let stage1_len = self + .stage1_uni_skip_first_round_proof + .serialized_size(compress) as u64 + + self.stage1_sumcheck_proof.serialized_size(compress) as u64; + write_frame_header(&mut writer, TAG_STAGE1, stage1_len)?; + self.stage1_uni_skip_first_round_proof + .serialize_with_mode(&mut writer, compress)?; + self.stage1_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + + let stage2_len = self + .stage2_uni_skip_first_round_proof + .serialized_size(compress) as u64 + + self.stage2_sumcheck_proof.serialized_size(compress) as u64; + write_frame_header(&mut writer, TAG_STAGE2, stage2_len)?; + self.stage2_uni_skip_first_round_proof + .serialize_with_mode(&mut writer, compress)?; + self.stage2_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + + let stage3_len = self.stage3_sumcheck_proof.serialized_size(compress) as u64; + write_frame_header(&mut writer, TAG_STAGE3, stage3_len)?; + self.stage3_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + + let stage4_len = self.stage4_sumcheck_proof.serialized_size(compress) as u64; + write_frame_header(&mut writer, TAG_STAGE4, stage4_len)?; + self.stage4_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + + let stage5_len = self.stage5_sumcheck_proof.serialized_size(compress) as u64; + write_frame_header(&mut writer, TAG_STAGE5, stage5_len)?; + self.stage5_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + + let stage6_len = self.stage6_sumcheck_proof.serialized_size(compress) as u64; + write_frame_header(&mut writer, TAG_STAGE6, stage6_len)?; + self.stage6_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + + let stage7_len = self.stage7_sumcheck_proof.serialized_size(compress) as u64; + write_frame_header(&mut writer, TAG_STAGE7, stage7_len)?; + self.stage7_sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + + // ---------------- Joint opening proof ---------------- + let joint_len = self.joint_opening_proof.serialized_size(compress) as u64; + write_frame_header(&mut writer, TAG_JOINT_OPENING, joint_len)?; + self.joint_opening_proof + .serialize_with_mode(&mut writer, compress)?; + + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + let mut size = PROOF_SIGNATURE.len(); + + let params_len = (transport::varint_u64_len(self.trace_length as u64) + + transport::varint_u64_len(self.ram_K as u64) + + transport::varint_u64_len(self.bytecode_K as u64)) as u64 + + self.rw_config.serialized_size(compress) as u64 + + self.one_hot_config.serialized_size(compress) as u64 + + self.dory_layout.serialized_size(compress) as u64; + size += 1 + transport::varint_u64_len(params_len) + params_len as usize; + + let commitments_count_len = transport::varint_u64_len(self.commitments.len() as u64) as u64; + let commitments_items_len: u64 = self + .commitments + .iter() + .map(|c| c.serialized_size(compress) as u64) + .sum(); + let untrusted_commitment_len = self + .untrusted_advice_commitment + .as_ref() + .map(|c| c.serialized_size(compress) as u64) + .unwrap_or(0); + let commitments_len = + commitments_count_len + commitments_items_len + 1 + untrusted_commitment_len; + size += 1 + transport::varint_u64_len(commitments_len) + commitments_len as usize; + + let claims_count_len = transport::varint_u64_len(self.opening_claims.0.len() as u64) as u64; + let claims_items_len: u64 = self + .opening_claims + .0 + .iter() + .map(|(k, (_p, claim))| { + (k.serialized_size(compress) + claim.serialized_size(compress)) as u64 + }) + .sum(); + let claims_len = claims_count_len + claims_items_len; + size += 1 + transport::varint_u64_len(claims_len) + claims_len as usize; + + let stage1_len = self + .stage1_uni_skip_first_round_proof + .serialized_size(compress) as u64 + + self.stage1_sumcheck_proof.serialized_size(compress) as u64; + size += 1 + transport::varint_u64_len(stage1_len) + stage1_len as usize; + + let stage2_len = self + .stage2_uni_skip_first_round_proof + .serialized_size(compress) as u64 + + self.stage2_sumcheck_proof.serialized_size(compress) as u64; + size += 1 + transport::varint_u64_len(stage2_len) + stage2_len as usize; + + let stage3_len = self.stage3_sumcheck_proof.serialized_size(compress) as u64; + size += 1 + transport::varint_u64_len(stage3_len) + stage3_len as usize; + + let stage4_len = self.stage4_sumcheck_proof.serialized_size(compress) as u64; + size += 1 + transport::varint_u64_len(stage4_len) + stage4_len as usize; + + let stage5_len = self.stage5_sumcheck_proof.serialized_size(compress) as u64; + size += 1 + transport::varint_u64_len(stage5_len) + stage5_len as usize; + + let stage6_len = self.stage6_sumcheck_proof.serialized_size(compress) as u64; + size += 1 + transport::varint_u64_len(stage6_len) + stage6_len as usize; + + let stage7_len = self.stage7_sumcheck_proof.serialized_size(compress) as u64; + size += 1 + transport::varint_u64_len(stage7_len) + stage7_len as usize; + + let joint_len = self.joint_opening_proof.serialized_size(compress) as u64; + size += 1 + transport::varint_u64_len(joint_len) + joint_len as usize; + + size + } +} + +impl, FS: Transcript> Valid + for JoltProof +{ + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl, FS: Transcript> CanonicalDeserialize + for JoltProof +{ + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + transport::signature_check(&mut reader, PROOF_SIGNATURE).map_err(io_err)?; + + let mut trace_length: Option = None; + let mut ram_K: Option = None; + let mut bytecode_K: Option = None; + let mut rw_config: Option = None; + let mut one_hot_config: Option = None; + let mut dory_layout: Option = None; + + let mut commitments: Option> = None; + let mut untrusted_advice_commitment: Option> = None; + let mut opening_claims: Option> = None; + + let mut stage1_uni: Option> = None; + let mut stage1_sumcheck: Option> = None; + let mut stage2_uni: Option> = None; + let mut stage2_sumcheck: Option> = None; + let mut stage3_sumcheck: Option> = None; + let mut stage4_sumcheck: Option> = None; + let mut stage5_sumcheck: Option> = None; + let mut stage6_sumcheck: Option> = None; + let mut stage7_sumcheck: Option> = None; + let mut joint_opening_proof: Option = None; + + while let Some((tag, len)) = read_frame_header(&mut reader, MAX_ANY_SECTION_LEN)? { + let cap = section_cap_for_tag(tag); + if cap == 0 || len > cap { + return Err(SerializationError::InvalidData); + } + let mut limited = (&mut reader).take(len); + + match tag { + TAG_PARAMS => { + if trace_length.is_some() { + return Err(SerializationError::InvalidData); + } + let t = read_varint_u64(&mut limited)?; + let r = read_varint_u64(&mut limited)?; + let b = read_varint_u64(&mut limited)?; + trace_length = + Some(usize::try_from(t).map_err(|_| SerializationError::InvalidData)?); + ram_K = Some(usize::try_from(r).map_err(|_| SerializationError::InvalidData)?); + bytecode_K = + Some(usize::try_from(b).map_err(|_| SerializationError::InvalidData)?); + rw_config = Some(ReadWriteConfig::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + one_hot_config = Some(OneHotConfig::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + dory_layout = Some(DoryLayout::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + } + TAG_COMMITMENTS => { + if commitments.is_some() { + return Err(SerializationError::InvalidData); + } + let n = read_varint_u64(&mut limited)?; + let n_usize = + usize::try_from(n).map_err(|_| SerializationError::InvalidData)?; + if n_usize > 1_000_000 { + return Err(SerializationError::InvalidData); + } + let mut v = Vec::with_capacity(n_usize.min(1024)); + for _ in 0..n_usize { + v.push(PCS::Commitment::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + } + let presence = u8::deserialize_with_mode(&mut limited, compress, validate)?; + let opt = match presence { + 0 => None, + 1 => Some(PCS::Commitment::deserialize_with_mode( + &mut limited, + compress, + validate, + )?), + _ => return Err(SerializationError::InvalidData), + }; + commitments = Some(v); + untrusted_advice_commitment = Some(opt); + } + TAG_OPENING_CLAIMS => { + if opening_claims.is_some() { + return Err(SerializationError::InvalidData); + } + let n = read_varint_u64(&mut limited)?; + let n_usize = + usize::try_from(n).map_err(|_| SerializationError::InvalidData)?; + if n_usize > 10_000_000 { + return Err(SerializationError::InvalidData); + } + let mut claims = BTreeMap::new(); + for _ in 0..n_usize { + let key = + OpeningId::deserialize_with_mode(&mut limited, compress, validate)?; + let claim = F::deserialize_with_mode(&mut limited, compress, validate)?; + claims.insert(key, (OpeningPoint::default(), claim)); + } + opening_claims = Some(Claims(claims)); + } + TAG_STAGE1 => { + if stage1_uni.is_some() { + return Err(SerializationError::InvalidData); + } + stage1_uni = Some(UniSkipFirstRoundProof::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + stage1_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + } + TAG_STAGE2 => { + if stage2_uni.is_some() { + return Err(SerializationError::InvalidData); + } + stage2_uni = Some(UniSkipFirstRoundProof::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + stage2_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + } + TAG_STAGE3 => { + if stage3_sumcheck.is_some() { + return Err(SerializationError::InvalidData); + } + stage3_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + } + TAG_STAGE4 => { + if stage4_sumcheck.is_some() { + return Err(SerializationError::InvalidData); + } + stage4_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + } + TAG_STAGE5 => { + if stage5_sumcheck.is_some() { + return Err(SerializationError::InvalidData); + } + stage5_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + } + TAG_STAGE6 => { + if stage6_sumcheck.is_some() { + return Err(SerializationError::InvalidData); + } + stage6_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + } + TAG_STAGE7 => { + if stage7_sumcheck.is_some() { + return Err(SerializationError::InvalidData); + } + stage7_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + } + TAG_JOINT_OPENING => { + if joint_opening_proof.is_some() { + return Err(SerializationError::InvalidData); + } + joint_opening_proof = Some(PCS::Proof::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + } + _ => return Err(SerializationError::InvalidData), + } + + if limited.limit() != 0 { + return Err(SerializationError::InvalidData); + } + } + + Ok(Self { + opening_claims: opening_claims.ok_or(SerializationError::InvalidData)?, + commitments: commitments.ok_or(SerializationError::InvalidData)?, + stage1_uni_skip_first_round_proof: stage1_uni.ok_or(SerializationError::InvalidData)?, + stage1_sumcheck_proof: stage1_sumcheck.ok_or(SerializationError::InvalidData)?, + stage2_uni_skip_first_round_proof: stage2_uni.ok_or(SerializationError::InvalidData)?, + stage2_sumcheck_proof: stage2_sumcheck.ok_or(SerializationError::InvalidData)?, + stage3_sumcheck_proof: stage3_sumcheck.ok_or(SerializationError::InvalidData)?, + stage4_sumcheck_proof: stage4_sumcheck.ok_or(SerializationError::InvalidData)?, + stage5_sumcheck_proof: stage5_sumcheck.ok_or(SerializationError::InvalidData)?, + stage6_sumcheck_proof: stage6_sumcheck.ok_or(SerializationError::InvalidData)?, + stage7_sumcheck_proof: stage7_sumcheck.ok_or(SerializationError::InvalidData)?, + joint_opening_proof: joint_opening_proof.ok_or(SerializationError::InvalidData)?, + untrusted_advice_commitment: untrusted_advice_commitment + .ok_or(SerializationError::InvalidData)?, + trace_length: trace_length.ok_or(SerializationError::InvalidData)?, + ram_K: ram_K.ok_or(SerializationError::InvalidData)?, + bytecode_K: bytecode_K.ok_or(SerializationError::InvalidData)?, + rw_config: rw_config.ok_or(SerializationError::InvalidData)?, + one_hot_config: one_hot_config.ok_or(SerializationError::InvalidData)?, + dory_layout: dory_layout.ok_or(SerializationError::InvalidData)?, + }) + } +} + impl CanonicalSerialize for DoryLayout { fn serialize_with_mode( &self, @@ -134,17 +654,24 @@ impl CanonicalDeserialize for Claims { } } -// Compact encoding for OpeningId: -// Each variant uses a fused byte = BASE + sumcheck_id (1 byte total for advice, 2 bytes for committed/virtual) -// - [0, NUM_SUMCHECKS) = UntrustedAdvice(sumcheck_id) -// - [NUM_SUMCHECKS, 2*NUM_SUMCHECKS) = TrustedAdvice(sumcheck_id) -// - [2*NUM_SUMCHECKS, 3*NUM_SUMCHECKS) + poly_index = Committed(poly, sumcheck_id) -// - [3*NUM_SUMCHECKS, 4*NUM_SUMCHECKS) + poly_index = Virtual(poly, sumcheck_id) -const OPENING_ID_UNTRUSTED_ADVICE_BASE: u8 = 0; -const OPENING_ID_TRUSTED_ADVICE_BASE: u8 = - OPENING_ID_UNTRUSTED_ADVICE_BASE + SumcheckId::COUNT as u8; -const OPENING_ID_COMMITTED_BASE: u8 = OPENING_ID_TRUSTED_ADVICE_BASE + SumcheckId::COUNT as u8; -const OPENING_ID_VIRTUAL_BASE: u8 = OPENING_ID_COMMITTED_BASE + SumcheckId::COUNT as u8; +// OpeningId wire encoding (packed, self-describing): +// +// Header byte: +// bits[7..6] = kind (0..=3) +// bits[5..0] = sumcheck_id if < 63, otherwise 63 as an escape +// +// If bits[5..0] == 63: +// next bytes = varint u64 sumcheck_id (must fit in u8 for current SumcheckId) +// +// If kind indicates a polynomial: +// append polynomial id bytes (CommittedPolynomial / VirtualPolynomial) +// +// This does NOT depend on `SumcheckId::COUNT` range boundaries. +const OPENING_ID_KIND_UNTRUSTED_ADVICE: u8 = 0; +const OPENING_ID_KIND_TRUSTED_ADVICE: u8 = 1; +const OPENING_ID_KIND_COMMITTED: u8 = 2; +const OPENING_ID_KIND_VIRTUAL: u8 = 3; +const OPENING_ID_SUMCHECK_ESCAPE: u8 = 63; impl CanonicalSerialize for OpeningId { fn serialize_with_mode( @@ -152,40 +679,61 @@ impl CanonicalSerialize for OpeningId { mut writer: W, compress: Compress, ) -> Result<(), SerializationError> { - match self { + let (kind, sumcheck_u64, poly) = match *self { OpeningId::UntrustedAdvice(sumcheck_id) => { - let fused = OPENING_ID_UNTRUSTED_ADVICE_BASE + (*sumcheck_id as u8); - fused.serialize_with_mode(&mut writer, compress) + (OPENING_ID_KIND_UNTRUSTED_ADVICE, sumcheck_id as u64, None) } OpeningId::TrustedAdvice(sumcheck_id) => { - let fused = OPENING_ID_TRUSTED_ADVICE_BASE + (*sumcheck_id as u8); - fused.serialize_with_mode(&mut writer, compress) - } - OpeningId::Polynomial(PolynomialId::Committed(committed_polynomial), sumcheck_id) => { - let fused = OPENING_ID_COMMITTED_BASE + (*sumcheck_id as u8); - fused.serialize_with_mode(&mut writer, compress)?; - committed_polynomial.serialize_with_mode(&mut writer, compress) + (OPENING_ID_KIND_TRUSTED_ADVICE, sumcheck_id as u64, None) } - OpeningId::Polynomial(PolynomialId::Virtual(virtual_polynomial), sumcheck_id) => { - let fused = OPENING_ID_VIRTUAL_BASE + (*sumcheck_id as u8); - fused.serialize_with_mode(&mut writer, compress)?; - virtual_polynomial.serialize_with_mode(&mut writer, compress) + OpeningId::Polynomial(PolynomialId::Committed(committed_polynomial), sumcheck_id) => ( + OPENING_ID_KIND_COMMITTED, + sumcheck_id as u64, + Some(PolynomialId::Committed(committed_polynomial)), + ), + OpeningId::Polynomial(PolynomialId::Virtual(virtual_polynomial), sumcheck_id) => ( + OPENING_ID_KIND_VIRTUAL, + sumcheck_id as u64, + Some(PolynomialId::Virtual(virtual_polynomial)), + ), + }; + + let header = if sumcheck_u64 < OPENING_ID_SUMCHECK_ESCAPE as u64 { + (kind << 6) | (sumcheck_u64 as u8) + } else { + (kind << 6) | OPENING_ID_SUMCHECK_ESCAPE + }; + header.serialize_with_mode(&mut writer, compress)?; + if (header & 0x3F) == OPENING_ID_SUMCHECK_ESCAPE { + write_varint_u64(&mut writer, sumcheck_u64)?; + } + if let Some(poly) = poly { + match poly { + PolynomialId::Committed(p) => p.serialize_with_mode(&mut writer, compress)?, + PolynomialId::Virtual(p) => p.serialize_with_mode(&mut writer, compress)?, } } + Ok(()) } fn serialized_size(&self, compress: Compress) -> usize { - match self { - OpeningId::UntrustedAdvice(_) | OpeningId::TrustedAdvice(_) => 1, - OpeningId::Polynomial(PolynomialId::Committed(committed_polynomial), _) => { - // 1 byte fused (variant + sumcheck_id) + poly index - 1 + committed_polynomial.serialized_size(compress) - } - OpeningId::Polynomial(PolynomialId::Virtual(virtual_polynomial), _) => { - // 1 byte fused (variant + sumcheck_id) + poly index - 1 + virtual_polynomial.serialized_size(compress) - } - } + let sumcheck_u64 = match self { + OpeningId::UntrustedAdvice(sumcheck_id) + | OpeningId::TrustedAdvice(sumcheck_id) + | OpeningId::Polynomial(_, sumcheck_id) => *sumcheck_id as u64, + }; + let header_len = 1usize; + let sumcheck_ext_len = if sumcheck_u64 < OPENING_ID_SUMCHECK_ESCAPE as u64 { + 0usize + } else { + transport::varint_u64_len(sumcheck_u64) + }; + let poly_len = match self { + OpeningId::UntrustedAdvice(_) | OpeningId::TrustedAdvice(_) => 0, + OpeningId::Polynomial(PolynomialId::Committed(p), _) => p.serialized_size(compress), + OpeningId::Polynomial(PolynomialId::Virtual(p), _) => p.serialized_size(compress), + }; + header_len + sumcheck_ext_len + poly_len } } @@ -201,38 +749,41 @@ impl CanonicalDeserialize for OpeningId { compress: Compress, validate: Validate, ) -> Result { - let fused = u8::deserialize_with_mode(&mut reader, compress, validate)?; - match fused { - _ if fused < OPENING_ID_TRUSTED_ADVICE_BASE => { - let sumcheck_id = fused - OPENING_ID_UNTRUSTED_ADVICE_BASE; - Ok(OpeningId::UntrustedAdvice( - SumcheckId::from_u8(sumcheck_id).ok_or(SerializationError::InvalidData)?, - )) - } - _ if fused < OPENING_ID_COMMITTED_BASE => { - let sumcheck_id = fused - OPENING_ID_TRUSTED_ADVICE_BASE; - Ok(OpeningId::TrustedAdvice( - SumcheckId::from_u8(sumcheck_id).ok_or(SerializationError::InvalidData)?, - )) - } - _ if fused < OPENING_ID_VIRTUAL_BASE => { - let sumcheck_id = fused - OPENING_ID_COMMITTED_BASE; + let header = u8::deserialize_with_mode(&mut reader, compress, validate)?; + let kind = header >> 6; + let small = header & 0x3F; + + let sumcheck_u64 = if small == OPENING_ID_SUMCHECK_ESCAPE { + read_varint_u64(&mut reader)? + } else { + small as u64 + }; + + let sumcheck_u8 = + u8::try_from(sumcheck_u64).map_err(|_| SerializationError::InvalidData)?; + let sumcheck_id = + SumcheckId::from_u8(sumcheck_u8).ok_or(SerializationError::InvalidData)?; + + match kind { + OPENING_ID_KIND_UNTRUSTED_ADVICE => Ok(OpeningId::UntrustedAdvice(sumcheck_id)), + OPENING_ID_KIND_TRUSTED_ADVICE => Ok(OpeningId::TrustedAdvice(sumcheck_id)), + OPENING_ID_KIND_COMMITTED => { let polynomial = CommittedPolynomial::deserialize_with_mode(&mut reader, compress, validate)?; Ok(OpeningId::Polynomial( PolynomialId::Committed(polynomial), - SumcheckId::from_u8(sumcheck_id).ok_or(SerializationError::InvalidData)?, + sumcheck_id, )) } - _ => { - let sumcheck_id = fused - OPENING_ID_VIRTUAL_BASE; + OPENING_ID_KIND_VIRTUAL => { let polynomial = VirtualPolynomial::deserialize_with_mode(&mut reader, compress, validate)?; Ok(OpeningId::Polynomial( PolynomialId::Virtual(polynomial), - SumcheckId::from_u8(sumcheck_id).ok_or(SerializationError::InvalidData)?, + sumcheck_id, )) } + _ => Err(SerializationError::InvalidData), } } } @@ -508,3 +1059,95 @@ pub fn serialize_and_print_size( tracing::info!("{item_name} size: {file_size_kb:.1} kB"); Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::poly::opening_proof::{OpeningId, PolynomialId, SumcheckId}; + use crate::zkvm::witness::{CommittedPolynomial, VirtualPolynomial}; + use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; + + #[test] + fn opening_id_header_is_packed_common_case() { + let id = OpeningId::UntrustedAdvice(SumcheckId::SpartanOuter); + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + assert_eq!(bytes.len(), 1); + let expected = (OPENING_ID_KIND_UNTRUSTED_ADVICE << 6) | (SumcheckId::SpartanOuter as u8); + assert_eq!(bytes[0], expected); + + let id = OpeningId::Polynomial( + PolynomialId::Committed(CommittedPolynomial::RdInc), + SumcheckId::SpartanOuter, + ); + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + assert!(bytes.len() >= 2); // header + poly id + assert_eq!(bytes[0] >> 6, OPENING_ID_KIND_COMMITTED); + assert_eq!(bytes[0] & 0x3F, SumcheckId::SpartanOuter as u8); + + let id = OpeningId::Polynomial( + PolynomialId::Virtual(VirtualPolynomial::PC), + SumcheckId::SpartanOuter, + ); + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + assert!(bytes.len() >= 2); // header + poly id + assert_eq!(bytes[0] >> 6, OPENING_ID_KIND_VIRTUAL); + assert_eq!(bytes[0] & 0x3F, SumcheckId::SpartanOuter as u8); + } + + #[test] + fn opening_id_roundtrips() { + let cases = [ + OpeningId::UntrustedAdvice(SumcheckId::SpartanOuter), + OpeningId::TrustedAdvice(SumcheckId::SpartanOuter), + OpeningId::Polynomial( + PolynomialId::Committed(CommittedPolynomial::RdInc), + SumcheckId::SpartanOuter, + ), + OpeningId::Polynomial( + PolynomialId::Virtual(VirtualPolynomial::PC), + SumcheckId::SpartanOuter, + ), + ]; + for id in cases { + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + let decoded = OpeningId::deserialize_compressed(bytes.as_slice()).unwrap(); + assert_eq!(decoded, id); + } + } + + #[test] + fn proof_signature_is_required_and_unknown_tags_reject() { + // Missing sections should reject cleanly (after signature). + let mut just_sig = Vec::new(); + just_sig.extend_from_slice(PROOF_SIGNATURE); + let res = crate::zkvm::RV64IMACProof::deserialize_with_mode( + std::io::Cursor::new(&just_sig), + Compress::Yes, + Validate::Yes, + ) + .map(|_| ()); + match res { + Err(SerializationError::InvalidData) | Err(SerializationError::IoError(_)) => {} + _ => panic!("expected decode error"), + } + + // Unknown tag should reject. + let mut bytes = Vec::new(); + bytes.extend_from_slice(PROOF_SIGNATURE); + transport::write_frame_header(&mut bytes, 99, 0).unwrap(); + let res = crate::zkvm::RV64IMACProof::deserialize_with_mode( + std::io::Cursor::new(&bytes), + Compress::Yes, + Validate::Yes, + ) + .map(|_| ()); + match res { + Err(SerializationError::InvalidData) | Err(SerializationError::IoError(_)) => {} + _ => panic!("expected decode error"), + } + } +} diff --git a/jolt-core/src/zkvm/transport.rs b/jolt-core/src/zkvm/transport.rs new file mode 100644 index 0000000000..2fbe9d56c1 --- /dev/null +++ b/jolt-core/src/zkvm/transport.rs @@ -0,0 +1,115 @@ +//! Lightweight framed transport encoding helpers. +//! +//! This module is intentionally small and dependency-free so we can use it in verifier-facing +//! deserialization paths with explicit caps (DoS resistance) and strict parsing invariants. +//! +//! Design: +//! - Streams begin with a short fixed signature (header bytes). +//! - Then a sequence of frames: (tag: u8, len: varint u64, payload: len bytes). +//! - Decoders should be strict by default: reject unknown tags, reject duplicates for singleton +//! sections, and require full consumption of each payload. + +use std::io::{self, Read, Write}; + +/// Maximum number of bytes in a u64 varint (LEB128-style). +const VARINT_U64_MAX_BYTES: usize = 10; + +#[inline] +pub fn signature_check(r: &mut R, expected: &[u8]) -> io::Result<()> { + let mut got = vec![0u8; expected.len()]; + r.read_exact(&mut got)?; + if got != expected { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid signature", + )); + } + Ok(()) +} + +#[inline] +pub fn signature_write(w: &mut W, signature: &[u8]) -> io::Result<()> { + w.write_all(signature) +} + +#[inline] +pub fn write_varint_u64(w: &mut W, mut x: u64) -> io::Result<()> { + while x >= 0x80 { + w.write_all(&[((x as u8) & 0x7F) | 0x80])?; + x >>= 7; + } + w.write_all(&[x as u8]) +} + +#[inline] +pub fn varint_u64_len(mut x: u64) -> usize { + let mut n = 1usize; + while x >= 0x80 { + n += 1; + x >>= 7; + } + n +} + +#[inline] +pub fn read_varint_u64(r: &mut R) -> io::Result { + let mut x = 0u64; + let mut shift = 0u32; + for _ in 0..VARINT_U64_MAX_BYTES { + let mut b = [0u8; 1]; + r.read_exact(&mut b)?; + let byte = b[0]; + x |= ((byte & 0x7F) as u64) << shift; + if (byte & 0x80) == 0 { + return Ok(x); + } + shift += 7; + } + Err(io::Error::new( + io::ErrorKind::InvalidData, + "varint overflow", + )) +} + +#[inline] +pub fn read_u8_opt(r: &mut R) -> io::Result> { + let mut b = [0u8; 1]; + match r.read(&mut b) { + Ok(0) => Ok(None), + Ok(1) => Ok(Some(b[0])), + Ok(_) => unreachable!(), + Err(e) => Err(e), + } +} + +#[inline] +pub fn write_frame_header(w: &mut W, tag: u8, len: u64) -> io::Result<()> { + w.write_all(&[tag])?; + write_varint_u64(w, len) +} + +#[inline] +pub fn read_frame_header(r: &mut R, max_len: u64) -> io::Result> { + let Some(tag) = read_u8_opt(r)? else { + return Ok(None); + }; + let len = read_varint_u64(r)?; + if len > max_len { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "frame too large", + )); + } + Ok(Some((tag, len))) +} + +#[inline] +pub fn skip_exact(r: &mut R, mut n: u64) -> io::Result<()> { + let mut buf = [0u8; 4096]; + while n > 0 { + let k = (n as usize).min(buf.len()); + r.read_exact(&mut buf[..k])?; + n -= k as u64; + } + Ok(()) +} diff --git a/jolt-sdk/src/host_utils.rs b/jolt-sdk/src/host_utils.rs index 74beec573d..53f7d85727 100644 --- a/jolt-sdk/src/host_utils.rs +++ b/jolt-sdk/src/host_utils.rs @@ -10,6 +10,7 @@ pub use jolt_core::ark_bn254::Fr as F; pub use jolt_core::field::JoltField; pub use jolt_core::guest; pub use jolt_core::poly::commitment::dory::DoryCommitmentScheme as PCS; +pub use jolt_core::zkvm::transport; pub use jolt_core::zkvm::{ proof_serialization::JoltProof, verifier::JoltSharedPreprocessing, verifier::JoltVerifierPreprocessing, RV64IMACProof, RV64IMACVerifier, Serializable, From acef26efcfa638c92691a2b73a89aa3578fb7793 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Mon, 16 Feb 2026 12:41:31 -0500 Subject: [PATCH 2/9] chore(zkvm): clean up proof serialization imports Hoist File import to the top-level import block. Co-authored-by: Cursor --- jolt-core/src/zkvm/proof_serialization.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index 0d24fc832d..2eb43ee311 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -1,5 +1,6 @@ use std::{ collections::BTreeMap, + fs::File, io::{Read, Write}, }; @@ -1050,7 +1051,6 @@ pub fn serialize_and_print_size( file_name: &str, item: &impl CanonicalSerialize, ) -> Result<(), SerializationError> { - use std::fs::File; let mut file = File::create(file_name)?; item.serialize_compressed(&mut file)?; let file_size_bytes = file.metadata()?.len(); From 05243f9c24f9f6a6a2aa5c1f9d59a9ee750c1397 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Tue, 24 Feb 2026 08:52:42 -0800 Subject: [PATCH 3/9] refactor(zkvm): tighten proof serialization and transport Add write_framed_section!/framed_section_size!/read_singleton! macros to eliminate repetitive stage serialization in JoltProof. Collapse 6 per-tag size caps into a single MAX_SECTION_LEN, inline thin transport wrapper functions, compact VirtualPolynomial serialized_size to a wildcard match, and move bundle framing constants to transport.rs. Co-authored-by: Cursor --- jolt-core/src/zkvm/proof_serialization.rs | 387 +++++++--------------- jolt-core/src/zkvm/transport.rs | 7 + 2 files changed, 126 insertions(+), 268 deletions(-) diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index 2eb43ee311..c0c643edaf 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -47,13 +47,7 @@ const TAG_STAGE6: u8 = 15; const TAG_STAGE7: u8 = 16; const TAG_JOINT_OPENING: u8 = 20; -// Per-section payload caps (DoS resistance). These can be tuned if legitimate proofs grow. -const MAX_PARAMS_LEN: u64 = 16 * 1024; -const MAX_COMMITMENTS_LEN: u64 = 256 * 1024 * 1024; -const MAX_OPENING_CLAIMS_LEN: u64 = 256 * 1024 * 1024; -const MAX_STAGE_LEN: u64 = 512 * 1024 * 1024; -const MAX_JOINT_OPENING_LEN: u64 = 512 * 1024 * 1024; -const MAX_ANY_SECTION_LEN: u64 = 512 * 1024 * 1024; +const MAX_SECTION_LEN: u64 = 512 * 1024 * 1024; pub struct JoltProof, FS: Transcript> { pub opening_claims: Claims, @@ -82,45 +76,28 @@ fn io_err(e: std::io::Error) -> SerializationError { SerializationError::IoError(e) } -#[inline] -fn write_u8(w: &mut W, b: u8) -> Result<(), SerializationError> { - w.write_all(&[b]).map_err(io_err) -} - -#[inline] -fn write_varint_u64(w: &mut W, x: u64) -> Result<(), SerializationError> { - transport::write_varint_u64(w, x).map_err(io_err) -} - -#[inline] -fn read_varint_u64(r: &mut R) -> Result { - transport::read_varint_u64(r).map_err(io_err) +macro_rules! write_framed_section { + ($w:expr, $c:expr, $tag:expr, $($item:expr),+ $(,)?) => {{ + let len: u64 = 0 $(+ $item.serialized_size($c) as u64)+; + transport::write_frame_header($w, $tag, len).map_err(io_err)?; + $($item.serialize_with_mode($w, $c)?;)+ + }}; } -#[inline] -fn write_frame_header(w: &mut W, tag: u8, len: u64) -> Result<(), SerializationError> { - transport::write_frame_header(w, tag, len).map_err(io_err) -} - -#[inline] -fn read_frame_header( - r: &mut R, - max_len: u64, -) -> Result, SerializationError> { - transport::read_frame_header(r, max_len).map_err(io_err) +macro_rules! framed_section_size { + ($c:expr, $($item:expr),+ $(,)?) => {{ + let payload: u64 = 0 $(+ $item.serialized_size($c) as u64)+; + 1 + transport::varint_u64_len(payload) + payload as usize + }}; } -#[inline] -fn section_cap_for_tag(tag: u8) -> u64 { - match tag { - TAG_PARAMS => MAX_PARAMS_LEN, - TAG_COMMITMENTS => MAX_COMMITMENTS_LEN, - TAG_OPENING_CLAIMS => MAX_OPENING_CLAIMS_LEN, - TAG_STAGE1 | TAG_STAGE2 | TAG_STAGE3 | TAG_STAGE4 | TAG_STAGE5 | TAG_STAGE6 - | TAG_STAGE7 => MAX_STAGE_LEN, - TAG_JOINT_OPENING => MAX_JOINT_OPENING_LEN, - _ => 0, - } +macro_rules! read_singleton { + ($r:expr, $c:expr, $v:expr, $field:expr) => {{ + if $field.is_some() { + return Err(SerializationError::InvalidData); + } + $field = Some(CanonicalDeserialize::deserialize_with_mode($r, $c, $v)?); + }}; } impl, FS: Transcript> CanonicalSerialize @@ -133,24 +110,22 @@ impl, FS: Transcript> CanonicalSe ) -> Result<(), SerializationError> { transport::signature_write(&mut writer, PROOF_SIGNATURE).map_err(io_err)?; - // ---------------- Params ---------------- let params_len = (transport::varint_u64_len(self.trace_length as u64) + transport::varint_u64_len(self.ram_K as u64) + transport::varint_u64_len(self.bytecode_K as u64)) as u64 + self.rw_config.serialized_size(compress) as u64 + self.one_hot_config.serialized_size(compress) as u64 + self.dory_layout.serialized_size(compress) as u64; - write_frame_header(&mut writer, TAG_PARAMS, params_len)?; - write_varint_u64(&mut writer, self.trace_length as u64)?; - write_varint_u64(&mut writer, self.ram_K as u64)?; - write_varint_u64(&mut writer, self.bytecode_K as u64)?; + transport::write_frame_header(&mut writer, TAG_PARAMS, params_len).map_err(io_err)?; + transport::write_varint_u64(&mut writer, self.trace_length as u64).map_err(io_err)?; + transport::write_varint_u64(&mut writer, self.ram_K as u64).map_err(io_err)?; + transport::write_varint_u64(&mut writer, self.bytecode_K as u64).map_err(io_err)?; self.rw_config.serialize_with_mode(&mut writer, compress)?; self.one_hot_config .serialize_with_mode(&mut writer, compress)?; self.dory_layout .serialize_with_mode(&mut writer, compress)?; - // ---------------- Commitments ---------------- let commitments_count_len = transport::varint_u64_len(self.commitments.len() as u64) as u64; let commitments_items_len: u64 = self .commitments @@ -164,20 +139,20 @@ impl, FS: Transcript> CanonicalSe .unwrap_or(0); let commitments_len = commitments_count_len + commitments_items_len + 1 + untrusted_commitment_len; - write_frame_header(&mut writer, TAG_COMMITMENTS, commitments_len)?; - write_varint_u64(&mut writer, self.commitments.len() as u64)?; + transport::write_frame_header(&mut writer, TAG_COMMITMENTS, commitments_len) + .map_err(io_err)?; + transport::write_varint_u64(&mut writer, self.commitments.len() as u64).map_err(io_err)?; for c in &self.commitments { c.serialize_with_mode(&mut writer, compress)?; } match &self.untrusted_advice_commitment { - None => write_u8(&mut writer, 0)?, + None => writer.write_all(&[0]).map_err(io_err)?, Some(c) => { - write_u8(&mut writer, 1)?; + writer.write_all(&[1]).map_err(io_err)?; c.serialize_with_mode(&mut writer, compress)?; } } - // ---------------- Opening claims ---------------- let claims_count_len = transport::varint_u64_len(self.opening_claims.0.len() as u64) as u64; let claims_items_len: u64 = self .opening_claims @@ -188,64 +163,65 @@ impl, FS: Transcript> CanonicalSe }) .sum(); let claims_len = claims_count_len + claims_items_len; - write_frame_header(&mut writer, TAG_OPENING_CLAIMS, claims_len)?; - write_varint_u64(&mut writer, self.opening_claims.0.len() as u64)?; + transport::write_frame_header(&mut writer, TAG_OPENING_CLAIMS, claims_len) + .map_err(io_err)?; + transport::write_varint_u64(&mut writer, self.opening_claims.0.len() as u64) + .map_err(io_err)?; for (k, (_p, claim)) in self.opening_claims.0.iter() { k.serialize_with_mode(&mut writer, compress)?; claim.serialize_with_mode(&mut writer, compress)?; } - // ---------------- Stages ---------------- - let stage1_len = self - .stage1_uni_skip_first_round_proof - .serialized_size(compress) as u64 - + self.stage1_sumcheck_proof.serialized_size(compress) as u64; - write_frame_header(&mut writer, TAG_STAGE1, stage1_len)?; - self.stage1_uni_skip_first_round_proof - .serialize_with_mode(&mut writer, compress)?; - self.stage1_sumcheck_proof - .serialize_with_mode(&mut writer, compress)?; - - let stage2_len = self - .stage2_uni_skip_first_round_proof - .serialized_size(compress) as u64 - + self.stage2_sumcheck_proof.serialized_size(compress) as u64; - write_frame_header(&mut writer, TAG_STAGE2, stage2_len)?; - self.stage2_uni_skip_first_round_proof - .serialize_with_mode(&mut writer, compress)?; - self.stage2_sumcheck_proof - .serialize_with_mode(&mut writer, compress)?; - - let stage3_len = self.stage3_sumcheck_proof.serialized_size(compress) as u64; - write_frame_header(&mut writer, TAG_STAGE3, stage3_len)?; - self.stage3_sumcheck_proof - .serialize_with_mode(&mut writer, compress)?; - - let stage4_len = self.stage4_sumcheck_proof.serialized_size(compress) as u64; - write_frame_header(&mut writer, TAG_STAGE4, stage4_len)?; - self.stage4_sumcheck_proof - .serialize_with_mode(&mut writer, compress)?; - - let stage5_len = self.stage5_sumcheck_proof.serialized_size(compress) as u64; - write_frame_header(&mut writer, TAG_STAGE5, stage5_len)?; - self.stage5_sumcheck_proof - .serialize_with_mode(&mut writer, compress)?; - - let stage6_len = self.stage6_sumcheck_proof.serialized_size(compress) as u64; - write_frame_header(&mut writer, TAG_STAGE6, stage6_len)?; - self.stage6_sumcheck_proof - .serialize_with_mode(&mut writer, compress)?; - - let stage7_len = self.stage7_sumcheck_proof.serialized_size(compress) as u64; - write_frame_header(&mut writer, TAG_STAGE7, stage7_len)?; - self.stage7_sumcheck_proof - .serialize_with_mode(&mut writer, compress)?; - - // ---------------- Joint opening proof ---------------- - let joint_len = self.joint_opening_proof.serialized_size(compress) as u64; - write_frame_header(&mut writer, TAG_JOINT_OPENING, joint_len)?; - self.joint_opening_proof - .serialize_with_mode(&mut writer, compress)?; + write_framed_section!( + &mut writer, + compress, + TAG_STAGE1, + &self.stage1_uni_skip_first_round_proof, + &self.stage1_sumcheck_proof + ); + write_framed_section!( + &mut writer, + compress, + TAG_STAGE2, + &self.stage2_uni_skip_first_round_proof, + &self.stage2_sumcheck_proof + ); + write_framed_section!( + &mut writer, + compress, + TAG_STAGE3, + &self.stage3_sumcheck_proof + ); + write_framed_section!( + &mut writer, + compress, + TAG_STAGE4, + &self.stage4_sumcheck_proof + ); + write_framed_section!( + &mut writer, + compress, + TAG_STAGE5, + &self.stage5_sumcheck_proof + ); + write_framed_section!( + &mut writer, + compress, + TAG_STAGE6, + &self.stage6_sumcheck_proof + ); + write_framed_section!( + &mut writer, + compress, + TAG_STAGE7, + &self.stage7_sumcheck_proof + ); + write_framed_section!( + &mut writer, + compress, + TAG_JOINT_OPENING, + &self.joint_opening_proof + ); Ok(()) } @@ -288,35 +264,22 @@ impl, FS: Transcript> CanonicalSe let claims_len = claims_count_len + claims_items_len; size += 1 + transport::varint_u64_len(claims_len) + claims_len as usize; - let stage1_len = self - .stage1_uni_skip_first_round_proof - .serialized_size(compress) as u64 - + self.stage1_sumcheck_proof.serialized_size(compress) as u64; - size += 1 + transport::varint_u64_len(stage1_len) + stage1_len as usize; - - let stage2_len = self - .stage2_uni_skip_first_round_proof - .serialized_size(compress) as u64 - + self.stage2_sumcheck_proof.serialized_size(compress) as u64; - size += 1 + transport::varint_u64_len(stage2_len) + stage2_len as usize; - - let stage3_len = self.stage3_sumcheck_proof.serialized_size(compress) as u64; - size += 1 + transport::varint_u64_len(stage3_len) + stage3_len as usize; - - let stage4_len = self.stage4_sumcheck_proof.serialized_size(compress) as u64; - size += 1 + transport::varint_u64_len(stage4_len) + stage4_len as usize; - - let stage5_len = self.stage5_sumcheck_proof.serialized_size(compress) as u64; - size += 1 + transport::varint_u64_len(stage5_len) + stage5_len as usize; - - let stage6_len = self.stage6_sumcheck_proof.serialized_size(compress) as u64; - size += 1 + transport::varint_u64_len(stage6_len) + stage6_len as usize; - - let stage7_len = self.stage7_sumcheck_proof.serialized_size(compress) as u64; - size += 1 + transport::varint_u64_len(stage7_len) + stage7_len as usize; - - let joint_len = self.joint_opening_proof.serialized_size(compress) as u64; - size += 1 + transport::varint_u64_len(joint_len) + joint_len as usize; + size += framed_section_size!( + compress, + &self.stage1_uni_skip_first_round_proof, + &self.stage1_sumcheck_proof + ); + size += framed_section_size!( + compress, + &self.stage2_uni_skip_first_round_proof, + &self.stage2_sumcheck_proof + ); + size += framed_section_size!(compress, &self.stage3_sumcheck_proof); + size += framed_section_size!(compress, &self.stage4_sumcheck_proof); + size += framed_section_size!(compress, &self.stage5_sumcheck_proof); + size += framed_section_size!(compress, &self.stage6_sumcheck_proof); + size += framed_section_size!(compress, &self.stage7_sumcheck_proof); + size += framed_section_size!(compress, &self.joint_opening_proof); size } @@ -362,11 +325,9 @@ impl, FS: Transcript> CanonicalDe let mut stage7_sumcheck: Option> = None; let mut joint_opening_proof: Option = None; - while let Some((tag, len)) = read_frame_header(&mut reader, MAX_ANY_SECTION_LEN)? { - let cap = section_cap_for_tag(tag); - if cap == 0 || len > cap { - return Err(SerializationError::InvalidData); - } + while let Some((tag, len)) = + transport::read_frame_header(&mut reader, MAX_SECTION_LEN).map_err(io_err)? + { let mut limited = (&mut reader).take(len); match tag { @@ -374,9 +335,9 @@ impl, FS: Transcript> CanonicalDe if trace_length.is_some() { return Err(SerializationError::InvalidData); } - let t = read_varint_u64(&mut limited)?; - let r = read_varint_u64(&mut limited)?; - let b = read_varint_u64(&mut limited)?; + let t = transport::read_varint_u64(&mut limited).map_err(io_err)?; + let r = transport::read_varint_u64(&mut limited).map_err(io_err)?; + let b = transport::read_varint_u64(&mut limited).map_err(io_err)?; trace_length = Some(usize::try_from(t).map_err(|_| SerializationError::InvalidData)?); ram_K = Some(usize::try_from(r).map_err(|_| SerializationError::InvalidData)?); @@ -402,7 +363,7 @@ impl, FS: Transcript> CanonicalDe if commitments.is_some() { return Err(SerializationError::InvalidData); } - let n = read_varint_u64(&mut limited)?; + let n = transport::read_varint_u64(&mut limited).map_err(io_err)?; let n_usize = usize::try_from(n).map_err(|_| SerializationError::InvalidData)?; if n_usize > 1_000_000 { @@ -433,7 +394,7 @@ impl, FS: Transcript> CanonicalDe if opening_claims.is_some() { return Err(SerializationError::InvalidData); } - let n = read_varint_u64(&mut limited)?; + let n = transport::read_varint_u64(&mut limited).map_err(io_err)?; let n_usize = usize::try_from(n).map_err(|_| SerializationError::InvalidData)?; if n_usize > 10_000_000 { @@ -449,94 +410,20 @@ impl, FS: Transcript> CanonicalDe opening_claims = Some(Claims(claims)); } TAG_STAGE1 => { - if stage1_uni.is_some() { - return Err(SerializationError::InvalidData); - } - stage1_uni = Some(UniSkipFirstRoundProof::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); - stage1_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); + read_singleton!(&mut limited, compress, validate, stage1_uni); + read_singleton!(&mut limited, compress, validate, stage1_sumcheck); } TAG_STAGE2 => { - if stage2_uni.is_some() { - return Err(SerializationError::InvalidData); - } - stage2_uni = Some(UniSkipFirstRoundProof::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); - stage2_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); - } - TAG_STAGE3 => { - if stage3_sumcheck.is_some() { - return Err(SerializationError::InvalidData); - } - stage3_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); - } - TAG_STAGE4 => { - if stage4_sumcheck.is_some() { - return Err(SerializationError::InvalidData); - } - stage4_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); - } - TAG_STAGE5 => { - if stage5_sumcheck.is_some() { - return Err(SerializationError::InvalidData); - } - stage5_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); - } - TAG_STAGE6 => { - if stage6_sumcheck.is_some() { - return Err(SerializationError::InvalidData); - } - stage6_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); - } - TAG_STAGE7 => { - if stage7_sumcheck.is_some() { - return Err(SerializationError::InvalidData); - } - stage7_sumcheck = Some(SumcheckInstanceProof::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); + read_singleton!(&mut limited, compress, validate, stage2_uni); + read_singleton!(&mut limited, compress, validate, stage2_sumcheck); } + TAG_STAGE3 => read_singleton!(&mut limited, compress, validate, stage3_sumcheck), + TAG_STAGE4 => read_singleton!(&mut limited, compress, validate, stage4_sumcheck), + TAG_STAGE5 => read_singleton!(&mut limited, compress, validate, stage5_sumcheck), + TAG_STAGE6 => read_singleton!(&mut limited, compress, validate, stage6_sumcheck), + TAG_STAGE7 => read_singleton!(&mut limited, compress, validate, stage7_sumcheck), TAG_JOINT_OPENING => { - if joint_opening_proof.is_some() { - return Err(SerializationError::InvalidData); - } - joint_opening_proof = Some(PCS::Proof::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); + read_singleton!(&mut limited, compress, validate, joint_opening_proof) } _ => return Err(SerializationError::InvalidData), } @@ -706,7 +593,7 @@ impl CanonicalSerialize for OpeningId { }; header.serialize_with_mode(&mut writer, compress)?; if (header & 0x3F) == OPENING_ID_SUMCHECK_ESCAPE { - write_varint_u64(&mut writer, sumcheck_u64)?; + transport::write_varint_u64(&mut writer, sumcheck_u64).map_err(io_err)?; } if let Some(poly) = poly { match poly { @@ -755,7 +642,7 @@ impl CanonicalDeserialize for OpeningId { let small = header & 0x3F; let sumcheck_u64 = if small == OPENING_ID_SUMCHECK_ESCAPE { - read_varint_u64(&mut reader)? + transport::read_varint_u64(&mut reader).map_err(io_err)? } else { small as u64 }; @@ -924,47 +811,11 @@ impl CanonicalSerialize for VirtualPolynomial { fn serialized_size(&self, _compress: Compress) -> usize { match self { - Self::PC - | Self::UnexpandedPC - | Self::NextPC - | Self::NextUnexpandedPC - | Self::NextIsNoop - | Self::NextIsVirtual - | Self::NextIsFirstInSequence - | Self::LeftLookupOperand - | Self::RightLookupOperand - | Self::LeftInstructionInput - | Self::RightInstructionInput - | Self::Product - | Self::ShouldJump - | Self::ShouldBranch - | Self::WritePCtoRD - | Self::WriteLookupOutputToRD - | Self::Rd - | Self::Imm - | Self::Rs1Value - | Self::Rs2Value - | Self::RdWriteValue - | Self::Rs1Ra - | Self::Rs2Ra - | Self::RdWa - | Self::LookupOutput - | Self::InstructionRaf - | Self::InstructionRafFlag - | Self::RegistersVal - | Self::RamAddress - | Self::RamRa - | Self::RamReadValue - | Self::RamWriteValue - | Self::RamVal - | Self::RamValInit - | Self::RamValFinal - | Self::RamHammingWeight - | Self::UnivariateSkip => 1, Self::InstructionRa(_) | Self::OpFlags(_) | Self::InstructionFlags(_) | Self::LookupTableFlag(_) => 2, + _ => 1, } } } diff --git a/jolt-core/src/zkvm/transport.rs b/jolt-core/src/zkvm/transport.rs index 2fbe9d56c1..4a623ff08c 100644 --- a/jolt-core/src/zkvm/transport.rs +++ b/jolt-core/src/zkvm/transport.rs @@ -14,6 +14,13 @@ use std::io::{self, Read, Write}; /// Maximum number of bytes in a u64 varint (LEB128-style). const VARINT_U64_MAX_BYTES: usize = 10; +// Recursion bundle framing constants (shared between host and guest). +pub const BUNDLE_SIGNATURE: &[u8; 8] = b"JOLTBDL\0"; +pub const BUNDLE_TAG_PREPROCESSING: u8 = 1; +pub const BUNDLE_TAG_RECORD: u8 = 2; +pub const RECORD_TAG_DEVICE: u8 = 1; +pub const RECORD_TAG_PROOF: u8 = 2; + #[inline] pub fn signature_check(r: &mut R, expected: &[u8]) -> io::Result<()> { let mut got = vec![0u8; expected.len()]; From f3112b3fe6e621908d7b46865e2aa5a23da29d1f Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Tue, 24 Feb 2026 09:18:08 -0800 Subject: [PATCH 4/9] =?UTF-8?q?fix(zkvm):=20harden=20proof=20format=20?= =?UTF-8?q?=E2=80=94=20version=20byte,=20per-tag=20caps,=20error=20context?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Embed format version (v1) in the proof signature so future format changes produce a clear mismatch instead of opaque InvalidData. - Restore per-tag section caps (params: 16 KiB, commitments/claims: 256 MiB, stages: 512 MiB) instead of a single 512 MiB cap. - Add descriptive error messages to all deserialization failure paths: duplicate sections, unknown tags, trailing bytes, missing fields. - Remove unused bundle framing constants from transport.rs (will be re-added when the recursion example PR lands). Co-authored-by: Cursor --- jolt-core/src/zkvm/proof_serialization.rs | 123 ++++++++++++++++------ jolt-core/src/zkvm/transport.rs | 6 -- 2 files changed, 91 insertions(+), 38 deletions(-) diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index c0c643edaf..25fb980439 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -29,10 +29,10 @@ use crate::{ use crate::zkvm::transport; -/// Stream signature for `JoltProof` bytes (clean rewrite). +/// Stream signature for `JoltProof` bytes. /// -/// This is a short fixed header to fail fast on wrong-format inputs. -const PROOF_SIGNATURE: &[u8; 8] = b"JOLTPRF\0"; +/// Last byte is the format version (bump when the wire format changes). +const PROOF_SIGNATURE: &[u8; 8] = b"JOLTPRF\x01"; // Frame tags for proof sections. Decoding is strict: unknown tags are rejected. const TAG_PARAMS: u8 = 1; @@ -47,7 +47,11 @@ const TAG_STAGE6: u8 = 15; const TAG_STAGE7: u8 = 16; const TAG_JOINT_OPENING: u8 = 20; -const MAX_SECTION_LEN: u64 = 512 * 1024 * 1024; +const MAX_PARAMS_LEN: u64 = 16 * 1024; +const MAX_COMMITMENTS_LEN: u64 = 1024 * 1024; +const MAX_OPENING_CLAIMS_LEN: u64 = 1024 * 1024; +const MAX_STAGE_LEN: u64 = 1024 * 1024; +const MAX_JOINT_OPENING_LEN: u64 = 1024 * 1024; pub struct JoltProof, FS: Transcript> { pub opening_claims: Claims, @@ -94,7 +98,10 @@ macro_rules! framed_section_size { macro_rules! read_singleton { ($r:expr, $c:expr, $v:expr, $field:expr) => {{ if $field.is_some() { - return Err(SerializationError::InvalidData); + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + concat!("duplicate field: ", stringify!($field)), + ))); } $field = Some(CanonicalDeserialize::deserialize_with_mode($r, $c, $v)?); }}; @@ -326,14 +333,37 @@ impl, FS: Transcript> CanonicalDe let mut joint_opening_proof: Option = None; while let Some((tag, len)) = - transport::read_frame_header(&mut reader, MAX_SECTION_LEN).map_err(io_err)? + transport::read_frame_header(&mut reader, MAX_STAGE_LEN).map_err(io_err)? { + let cap = match tag { + TAG_PARAMS => MAX_PARAMS_LEN, + TAG_COMMITMENTS => MAX_COMMITMENTS_LEN, + TAG_OPENING_CLAIMS => MAX_OPENING_CLAIMS_LEN, + TAG_STAGE1 | TAG_STAGE2 | TAG_STAGE3 | TAG_STAGE4 | TAG_STAGE5 | TAG_STAGE6 + | TAG_STAGE7 => MAX_STAGE_LEN, + TAG_JOINT_OPENING => MAX_JOINT_OPENING_LEN, + _ => { + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("unknown proof section tag {tag}"), + ))); + } + }; + if len > cap { + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("section tag {tag} payload {len} exceeds cap {cap}"), + ))); + } let mut limited = (&mut reader).take(len); match tag { TAG_PARAMS => { if trace_length.is_some() { - return Err(SerializationError::InvalidData); + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "duplicate params section", + ))); } let t = transport::read_varint_u64(&mut limited).map_err(io_err)?; let r = transport::read_varint_u64(&mut limited).map_err(io_err)?; @@ -361,7 +391,10 @@ impl, FS: Transcript> CanonicalDe } TAG_COMMITMENTS => { if commitments.is_some() { - return Err(SerializationError::InvalidData); + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "duplicate commitments section", + ))); } let n = transport::read_varint_u64(&mut limited).map_err(io_err)?; let n_usize = @@ -392,7 +425,10 @@ impl, FS: Transcript> CanonicalDe } TAG_OPENING_CLAIMS => { if opening_claims.is_some() { - return Err(SerializationError::InvalidData); + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "duplicate opening claims section", + ))); } let n = transport::read_varint_u64(&mut limited).map_err(io_err)?; let n_usize = @@ -425,35 +461,54 @@ impl, FS: Transcript> CanonicalDe TAG_JOINT_OPENING => { read_singleton!(&mut limited, compress, validate, joint_opening_proof) } - _ => return Err(SerializationError::InvalidData), + _ => unreachable!("unknown tags rejected above"), } if limited.limit() != 0 { - return Err(SerializationError::InvalidData); + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!( + "frame tag {tag}: {} trailing bytes not consumed", + limited.limit() + ), + ))); } } + macro_rules! require { + ($opt:expr, $name:expr) => { + $opt.ok_or_else(|| { + io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + concat!("missing required section: ", $name), + )) + })? + }; + } + Ok(Self { - opening_claims: opening_claims.ok_or(SerializationError::InvalidData)?, - commitments: commitments.ok_or(SerializationError::InvalidData)?, - stage1_uni_skip_first_round_proof: stage1_uni.ok_or(SerializationError::InvalidData)?, - stage1_sumcheck_proof: stage1_sumcheck.ok_or(SerializationError::InvalidData)?, - stage2_uni_skip_first_round_proof: stage2_uni.ok_or(SerializationError::InvalidData)?, - stage2_sumcheck_proof: stage2_sumcheck.ok_or(SerializationError::InvalidData)?, - stage3_sumcheck_proof: stage3_sumcheck.ok_or(SerializationError::InvalidData)?, - stage4_sumcheck_proof: stage4_sumcheck.ok_or(SerializationError::InvalidData)?, - stage5_sumcheck_proof: stage5_sumcheck.ok_or(SerializationError::InvalidData)?, - stage6_sumcheck_proof: stage6_sumcheck.ok_or(SerializationError::InvalidData)?, - stage7_sumcheck_proof: stage7_sumcheck.ok_or(SerializationError::InvalidData)?, - joint_opening_proof: joint_opening_proof.ok_or(SerializationError::InvalidData)?, - untrusted_advice_commitment: untrusted_advice_commitment - .ok_or(SerializationError::InvalidData)?, - trace_length: trace_length.ok_or(SerializationError::InvalidData)?, - ram_K: ram_K.ok_or(SerializationError::InvalidData)?, - bytecode_K: bytecode_K.ok_or(SerializationError::InvalidData)?, - rw_config: rw_config.ok_or(SerializationError::InvalidData)?, - one_hot_config: one_hot_config.ok_or(SerializationError::InvalidData)?, - dory_layout: dory_layout.ok_or(SerializationError::InvalidData)?, + opening_claims: require!(opening_claims, "opening_claims"), + commitments: require!(commitments, "commitments"), + stage1_uni_skip_first_round_proof: require!(stage1_uni, "stage1_uni"), + stage1_sumcheck_proof: require!(stage1_sumcheck, "stage1_sumcheck"), + stage2_uni_skip_first_round_proof: require!(stage2_uni, "stage2_uni"), + stage2_sumcheck_proof: require!(stage2_sumcheck, "stage2_sumcheck"), + stage3_sumcheck_proof: require!(stage3_sumcheck, "stage3_sumcheck"), + stage4_sumcheck_proof: require!(stage4_sumcheck, "stage4_sumcheck"), + stage5_sumcheck_proof: require!(stage5_sumcheck, "stage5_sumcheck"), + stage6_sumcheck_proof: require!(stage6_sumcheck, "stage6_sumcheck"), + stage7_sumcheck_proof: require!(stage7_sumcheck, "stage7_sumcheck"), + joint_opening_proof: require!(joint_opening_proof, "joint_opening"), + untrusted_advice_commitment: require!( + untrusted_advice_commitment, + "untrusted_advice_commitment" + ), + trace_length: require!(trace_length, "params"), + ram_K: require!(ram_K, "params"), + bytecode_K: require!(bytecode_K, "params"), + rw_config: require!(rw_config, "params"), + one_hot_config: require!(one_hot_config, "params"), + dory_layout: require!(dory_layout, "params"), }) } } @@ -970,9 +1025,13 @@ mod tests { } } + #[test] + fn proof_signature_version_byte() { + assert_eq!(PROOF_SIGNATURE[7], 1, "format version should be 1"); + } + #[test] fn proof_signature_is_required_and_unknown_tags_reject() { - // Missing sections should reject cleanly (after signature). let mut just_sig = Vec::new(); just_sig.extend_from_slice(PROOF_SIGNATURE); let res = crate::zkvm::RV64IMACProof::deserialize_with_mode( diff --git a/jolt-core/src/zkvm/transport.rs b/jolt-core/src/zkvm/transport.rs index 4a623ff08c..b148c20ccb 100644 --- a/jolt-core/src/zkvm/transport.rs +++ b/jolt-core/src/zkvm/transport.rs @@ -14,12 +14,6 @@ use std::io::{self, Read, Write}; /// Maximum number of bytes in a u64 varint (LEB128-style). const VARINT_U64_MAX_BYTES: usize = 10; -// Recursion bundle framing constants (shared between host and guest). -pub const BUNDLE_SIGNATURE: &[u8; 8] = b"JOLTBDL\0"; -pub const BUNDLE_TAG_PREPROCESSING: u8 = 1; -pub const BUNDLE_TAG_RECORD: u8 = 2; -pub const RECORD_TAG_DEVICE: u8 = 1; -pub const RECORD_TAG_PROOF: u8 = 2; #[inline] pub fn signature_check(r: &mut R, expected: &[u8]) -> io::Result<()> { From 0f4ace80ad0e8ffc22801b6e30d9c6ad3c22d436 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Wed, 4 Mar 2026 10:16:05 -0800 Subject: [PATCH 5/9] =?UTF-8?q?refactor(zkvm):=20simplify=20proof=20format?= =?UTF-8?q?=20=E2=80=94=20sequential=20sections,=20tighter=20caps,=20harde?= =?UTF-8?q?ned=20deserialization?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace TLV tag-dispatch with sequential length-prefixed format: no tags, no Option temporaries, no require! macro. Extract shared payload-length helpers to eliminate serialize/serialized_size duplication. Restore exhaustive VirtualPolynomial::serialized_size() match. Harden deserialization: reject trace_length=0, reject duplicate opening-claim keys, tighten all section caps to 128 KiB (params 1 KiB), entry count cap 10K. Separate magic from version byte for specific "unsupported version" errors. Simplify transport.rs: remove dead code (skip_exact, read_u8_opt, frame headers), add read_section helper, add unit tests. Remove leaked SDK re-export of transport module. Made-with: Cursor --- jolt-core/src/zkvm/proof_serialization.rs | 679 ++++++++++------------ jolt-core/src/zkvm/transport.rs | 156 +++-- jolt-sdk/src/host_utils.rs | 1 - 3 files changed, 407 insertions(+), 429 deletions(-) diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index 23c455553a..de132e2a2c 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -1,5 +1,6 @@ #[cfg(not(feature = "zk"))] use std::collections::BTreeMap; +use std::fs::File; use std::io::{Read, Write}; use ark_serialize::{ @@ -33,31 +34,11 @@ use crate::{ use crate::zkvm::transport; -const PROOF_SIGNATURE: &[u8; 8] = b"JOLTPRF\x01"; +const PROOF_MAGIC: &[u8; 7] = b"JOLTPRF"; +const PROOF_VERSION: u8 = 1; -const TAG_PARAMS: u8 = 1; -const TAG_COMMITMENTS: u8 = 2; -#[cfg(not(feature = "zk"))] -const TAG_OPENING_CLAIMS: u8 = 3; -const TAG_STAGE1: u8 = 10; -const TAG_STAGE2: u8 = 11; -const TAG_STAGE3: u8 = 12; -const TAG_STAGE4: u8 = 13; -const TAG_STAGE5: u8 = 14; -const TAG_STAGE6: u8 = 15; -const TAG_STAGE7: u8 = 16; -const TAG_JOINT_OPENING: u8 = 20; -#[cfg(feature = "zk")] -const TAG_BLINDFOLD: u8 = 21; - -const MAX_PARAMS_LEN: u64 = 16 * 1024; -const MAX_COMMITMENTS_LEN: u64 = 1024 * 1024; -#[cfg(not(feature = "zk"))] -const MAX_OPENING_CLAIMS_LEN: u64 = 1024 * 1024; -const MAX_STAGE_LEN: u64 = 1024 * 1024; -const MAX_JOINT_OPENING_LEN: u64 = 1024 * 1024; -#[cfg(feature = "zk")] -const MAX_BLINDFOLD_LEN: u64 = 16 * 1024 * 1024; +const MAX_PARAMS_LEN: u64 = 1024; +const MAX_SECTION_LEN: u64 = 128 * 1024; pub struct JoltProof, FS: Transcript> { pub commitments: Vec, @@ -88,31 +69,70 @@ fn io_err(e: std::io::Error) -> SerializationError { SerializationError::IoError(e) } -macro_rules! write_framed_section { - ($w:expr, $c:expr, $tag:expr, $($item:expr),+ $(,)?) => {{ +macro_rules! write_section { + ($w:expr, $c:expr, $($item:expr),+ $(,)?) => {{ let len: u64 = 0 $(+ $item.serialized_size($c) as u64)+; - transport::write_frame_header($w, $tag, len).map_err(io_err)?; + transport::write_varint_u64($w, len).map_err(io_err)?; $($item.serialize_with_mode($w, $c)?;)+ }}; } -macro_rules! framed_section_size { +macro_rules! section_size { ($c:expr, $($item:expr),+ $(,)?) => {{ let payload: u64 = 0 $(+ $item.serialized_size($c) as u64)+; - 1 + transport::varint_u64_len(payload) + payload as usize + transport::varint_u64_len(payload) + payload as usize }}; } -macro_rules! read_singleton { - ($r:expr, $c:expr, $v:expr, $field:expr) => {{ - if $field.is_some() { - return Err(io_err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - concat!("duplicate field: ", stringify!($field)), - ))); - } - $field = Some(CanonicalDeserialize::deserialize_with_mode($r, $c, $v)?); - }}; +fn check_trailing_bytes(limited: &std::io::Take<&mut impl Read>) -> Result<(), SerializationError> { + if limited.limit() != 0 { + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("{} trailing bytes not consumed", limited.limit()), + ))); + } + Ok(()) +} + +impl, FS: Transcript> + JoltProof +{ + fn params_payload_len(&self, compress: Compress) -> u64 { + (transport::varint_u64_len(self.trace_length as u64) + + transport::varint_u64_len(self.ram_K as u64)) as u64 + + self.rw_config.serialized_size(compress) as u64 + + self.one_hot_config.serialized_size(compress) as u64 + + self.dory_layout.serialized_size(compress) as u64 + } + + fn commitments_payload_len(&self, compress: Compress) -> u64 { + let count_len = transport::varint_u64_len(self.commitments.len() as u64) as u64; + let items_len: u64 = self + .commitments + .iter() + .map(|c| c.serialized_size(compress) as u64) + .sum(); + let untrusted_len = self + .untrusted_advice_commitment + .as_ref() + .map(|c| c.serialized_size(compress) as u64) + .unwrap_or(0); + count_len + items_len + 1 + untrusted_len + } + + #[cfg(not(feature = "zk"))] + fn claims_payload_len(&self, compress: Compress) -> u64 { + let count_len = transport::varint_u64_len(self.opening_claims.0.len() as u64) as u64; + let items_len: u64 = self + .opening_claims + .0 + .iter() + .map(|(k, (_p, claim))| { + (k.serialized_size(compress) + claim.serialized_size(compress)) as u64 + }) + .sum(); + count_len + items_len + } } impl, FS: Transcript> @@ -123,14 +143,10 @@ impl, FS: Transcrip mut writer: W, compress: Compress, ) -> Result<(), SerializationError> { - transport::signature_write(&mut writer, PROOF_SIGNATURE).map_err(io_err)?; + transport::write_magic_version(&mut writer, PROOF_MAGIC, PROOF_VERSION).map_err(io_err)?; - let params_len = (transport::varint_u64_len(self.trace_length as u64) - + transport::varint_u64_len(self.ram_K as u64)) as u64 - + self.rw_config.serialized_size(compress) as u64 - + self.one_hot_config.serialized_size(compress) as u64 - + self.dory_layout.serialized_size(compress) as u64; - transport::write_frame_header(&mut writer, TAG_PARAMS, params_len).map_err(io_err)?; + let params_len = self.params_payload_len(compress); + transport::write_varint_u64(&mut writer, params_len).map_err(io_err)?; transport::write_varint_u64(&mut writer, self.trace_length as u64).map_err(io_err)?; transport::write_varint_u64(&mut writer, self.ram_K as u64).map_err(io_err)?; self.rw_config.serialize_with_mode(&mut writer, compress)?; @@ -139,21 +155,8 @@ impl, FS: Transcrip self.dory_layout .serialize_with_mode(&mut writer, compress)?; - let commitments_count_len = transport::varint_u64_len(self.commitments.len() as u64) as u64; - let commitments_items_len: u64 = self - .commitments - .iter() - .map(|c| c.serialized_size(compress) as u64) - .sum(); - let untrusted_commitment_len = self - .untrusted_advice_commitment - .as_ref() - .map(|c| c.serialized_size(compress) as u64) - .unwrap_or(0); - let commitments_len = - commitments_count_len + commitments_items_len + 1 + untrusted_commitment_len; - transport::write_frame_header(&mut writer, TAG_COMMITMENTS, commitments_len) - .map_err(io_err)?; + let commitments_len = self.commitments_payload_len(compress); + transport::write_varint_u64(&mut writer, commitments_len).map_err(io_err)?; transport::write_varint_u64(&mut writer, self.commitments.len() as u64).map_err(io_err)?; for c in &self.commitments { c.serialize_with_mode(&mut writer, compress)?; @@ -168,19 +171,8 @@ impl, FS: Transcrip #[cfg(not(feature = "zk"))] { - let claims_count_len = - transport::varint_u64_len(self.opening_claims.0.len() as u64) as u64; - let claims_items_len: u64 = self - .opening_claims - .0 - .iter() - .map(|(k, (_p, claim))| { - (k.serialized_size(compress) + claim.serialized_size(compress)) as u64 - }) - .sum(); - let claims_len = claims_count_len + claims_items_len; - transport::write_frame_header(&mut writer, TAG_OPENING_CLAIMS, claims_len) - .map_err(io_err)?; + let claims_len = self.claims_payload_len(compress); + transport::write_varint_u64(&mut writer, claims_len).map_err(io_err)?; transport::write_varint_u64(&mut writer, self.opening_claims.0.len() as u64) .map_err(io_err)?; for (k, (_p, claim)) in self.opening_claims.0.iter() { @@ -189,124 +181,66 @@ impl, FS: Transcrip } } - write_framed_section!( + write_section!( &mut writer, compress, - TAG_STAGE1, &self.stage1_uni_skip_first_round_proof, &self.stage1_sumcheck_proof ); - write_framed_section!( + write_section!( &mut writer, compress, - TAG_STAGE2, &self.stage2_uni_skip_first_round_proof, &self.stage2_sumcheck_proof ); - write_framed_section!( - &mut writer, - compress, - TAG_STAGE3, - &self.stage3_sumcheck_proof - ); - write_framed_section!( - &mut writer, - compress, - TAG_STAGE4, - &self.stage4_sumcheck_proof - ); - write_framed_section!( - &mut writer, - compress, - TAG_STAGE5, - &self.stage5_sumcheck_proof - ); - write_framed_section!( - &mut writer, - compress, - TAG_STAGE6, - &self.stage6_sumcheck_proof - ); - write_framed_section!( - &mut writer, - compress, - TAG_STAGE7, - &self.stage7_sumcheck_proof - ); - write_framed_section!( - &mut writer, - compress, - TAG_JOINT_OPENING, - &self.joint_opening_proof - ); + write_section!(&mut writer, compress, &self.stage3_sumcheck_proof); + write_section!(&mut writer, compress, &self.stage4_sumcheck_proof); + write_section!(&mut writer, compress, &self.stage5_sumcheck_proof); + write_section!(&mut writer, compress, &self.stage6_sumcheck_proof); + write_section!(&mut writer, compress, &self.stage7_sumcheck_proof); + write_section!(&mut writer, compress, &self.joint_opening_proof); #[cfg(feature = "zk")] - write_framed_section!(&mut writer, compress, TAG_BLINDFOLD, &self.blindfold_proof); + write_section!(&mut writer, compress, &self.blindfold_proof); Ok(()) } fn serialized_size(&self, compress: Compress) -> usize { - let mut size = PROOF_SIGNATURE.len(); + let mut size = PROOF_MAGIC.len() + 1; // magic + version byte - let params_len = (transport::varint_u64_len(self.trace_length as u64) - + transport::varint_u64_len(self.ram_K as u64)) as u64 - + self.rw_config.serialized_size(compress) as u64 - + self.one_hot_config.serialized_size(compress) as u64 - + self.dory_layout.serialized_size(compress) as u64; - size += 1 + transport::varint_u64_len(params_len) + params_len as usize; + let params_len = self.params_payload_len(compress); + size += transport::varint_u64_len(params_len) + params_len as usize; - let commitments_count_len = transport::varint_u64_len(self.commitments.len() as u64) as u64; - let commitments_items_len: u64 = self - .commitments - .iter() - .map(|c| c.serialized_size(compress) as u64) - .sum(); - let untrusted_commitment_len = self - .untrusted_advice_commitment - .as_ref() - .map(|c| c.serialized_size(compress) as u64) - .unwrap_or(0); - let commitments_len = - commitments_count_len + commitments_items_len + 1 + untrusted_commitment_len; - size += 1 + transport::varint_u64_len(commitments_len) + commitments_len as usize; + let commitments_len = self.commitments_payload_len(compress); + size += transport::varint_u64_len(commitments_len) + commitments_len as usize; #[cfg(not(feature = "zk"))] { - let claims_count_len = - transport::varint_u64_len(self.opening_claims.0.len() as u64) as u64; - let claims_items_len: u64 = self - .opening_claims - .0 - .iter() - .map(|(k, (_p, claim))| { - (k.serialized_size(compress) + claim.serialized_size(compress)) as u64 - }) - .sum(); - let claims_len = claims_count_len + claims_items_len; - size += 1 + transport::varint_u64_len(claims_len) + claims_len as usize; + let claims_len = self.claims_payload_len(compress); + size += transport::varint_u64_len(claims_len) + claims_len as usize; } - size += framed_section_size!( + size += section_size!( compress, &self.stage1_uni_skip_first_round_proof, &self.stage1_sumcheck_proof ); - size += framed_section_size!( + size += section_size!( compress, &self.stage2_uni_skip_first_round_proof, &self.stage2_sumcheck_proof ); - size += framed_section_size!(compress, &self.stage3_sumcheck_proof); - size += framed_section_size!(compress, &self.stage4_sumcheck_proof); - size += framed_section_size!(compress, &self.stage5_sumcheck_proof); - size += framed_section_size!(compress, &self.stage6_sumcheck_proof); - size += framed_section_size!(compress, &self.stage7_sumcheck_proof); - size += framed_section_size!(compress, &self.joint_opening_proof); + size += section_size!(compress, &self.stage3_sumcheck_proof); + size += section_size!(compress, &self.stage4_sumcheck_proof); + size += section_size!(compress, &self.stage5_sumcheck_proof); + size += section_size!(compress, &self.stage6_sumcheck_proof); + size += section_size!(compress, &self.stage7_sumcheck_proof); + size += section_size!(compress, &self.joint_opening_proof); #[cfg(feature = "zk")] { - size += framed_section_size!(compress, &self.blindfold_proof); + size += section_size!(compress, &self.blindfold_proof); } size @@ -329,223 +263,147 @@ impl, FS: Transcrip compress: Compress, validate: Validate, ) -> Result { - transport::signature_check(&mut reader, PROOF_SIGNATURE).map_err(io_err)?; + let version = transport::read_magic_version(&mut reader, PROOF_MAGIC).map_err(io_err)?; + if version != PROOF_VERSION { + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("unsupported proof version {version}, expected {PROOF_VERSION}"), + ))); + } - let mut trace_length: Option = None; - let mut ram_K: Option = None; - let mut rw_config: Option = None; - let mut one_hot_config: Option = None; - let mut dory_layout: Option = None; + // Params + let mut limited = transport::read_section(&mut reader, MAX_PARAMS_LEN).map_err(io_err)?; + let t = transport::read_varint_u64(&mut limited).map_err(io_err)?; + let r = transport::read_varint_u64(&mut limited).map_err(io_err)?; + let trace_length = usize::try_from(t).map_err(|_| SerializationError::InvalidData)?; + if trace_length == 0 { + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "trace_length must be nonzero", + ))); + } + let ram_K = usize::try_from(r).map_err(|_| SerializationError::InvalidData)?; + let rw_config = ReadWriteConfig::deserialize_with_mode(&mut limited, compress, validate)?; + let one_hot_config = OneHotConfig::deserialize_with_mode(&mut limited, compress, validate)?; + let dory_layout = DoryLayout::deserialize_with_mode(&mut limited, compress, validate)?; + check_trailing_bytes(&limited)?; + + // Commitments + let mut limited = transport::read_section(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let n = transport::read_varint_u64(&mut limited).map_err(io_err)?; + let n_usize = usize::try_from(n).map_err(|_| SerializationError::InvalidData)?; + if n_usize > 10_000 { + return Err(SerializationError::InvalidData); + } + let mut commitments = Vec::with_capacity(n_usize); + for _ in 0..n_usize { + commitments.push(PCS::Commitment::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + } + let presence = u8::deserialize_with_mode(&mut limited, compress, validate)?; + let untrusted_advice_commitment = match presence { + 0 => None, + 1 => Some(PCS::Commitment::deserialize_with_mode( + &mut limited, + compress, + validate, + )?), + _ => return Err(SerializationError::InvalidData), + }; + check_trailing_bytes(&limited)?; - let mut commitments: Option> = None; - let mut untrusted_advice_commitment: Option> = None; + // Opening claims (non-ZK only) #[cfg(not(feature = "zk"))] - let mut opening_claims: Option> = None; - - let mut stage1_uni: Option> = None; - let mut stage1_sumcheck: Option> = None; - let mut stage2_uni: Option> = None; - let mut stage2_sumcheck: Option> = None; - let mut stage3_sumcheck: Option> = None; - let mut stage4_sumcheck: Option> = None; - let mut stage5_sumcheck: Option> = None; - let mut stage6_sumcheck: Option> = None; - let mut stage7_sumcheck: Option> = None; - let mut joint_opening_proof: Option = None; - #[cfg(feature = "zk")] - let mut blindfold_proof: Option> = None; - - while let Some((tag, len)) = - transport::read_frame_header(&mut reader, MAX_BLINDFOLD_LEN_VAL).map_err(io_err)? - { - let cap = match tag { - TAG_PARAMS => MAX_PARAMS_LEN, - TAG_COMMITMENTS => MAX_COMMITMENTS_LEN, - #[cfg(not(feature = "zk"))] - TAG_OPENING_CLAIMS => MAX_OPENING_CLAIMS_LEN, - TAG_STAGE1 | TAG_STAGE2 | TAG_STAGE3 | TAG_STAGE4 | TAG_STAGE5 | TAG_STAGE6 - | TAG_STAGE7 => MAX_STAGE_LEN, - TAG_JOINT_OPENING => MAX_JOINT_OPENING_LEN, - #[cfg(feature = "zk")] - TAG_BLINDFOLD => MAX_BLINDFOLD_LEN, - _ => { - return Err(io_err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("unknown proof section tag {tag}"), - ))); - } - }; - if len > cap { - return Err(io_err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("section tag {tag} payload {len} exceeds cap {cap}"), - ))); + let opening_claims = { + let mut limited = + transport::read_section(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let n = transport::read_varint_u64(&mut limited).map_err(io_err)?; + let n_usize = usize::try_from(n).map_err(|_| SerializationError::InvalidData)?; + if n_usize > 10_000 { + return Err(SerializationError::InvalidData); } - let mut limited = (&mut reader).take(len); - - match tag { - TAG_PARAMS => { - if trace_length.is_some() { - return Err(io_err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "duplicate params section", - ))); - } - let t = transport::read_varint_u64(&mut limited).map_err(io_err)?; - let r = transport::read_varint_u64(&mut limited).map_err(io_err)?; - trace_length = - Some(usize::try_from(t).map_err(|_| SerializationError::InvalidData)?); - ram_K = Some(usize::try_from(r).map_err(|_| SerializationError::InvalidData)?); - rw_config = Some(ReadWriteConfig::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); - one_hot_config = Some(OneHotConfig::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); - dory_layout = Some(DoryLayout::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); - } - TAG_COMMITMENTS => { - if commitments.is_some() { - return Err(io_err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "duplicate commitments section", - ))); - } - let n = transport::read_varint_u64(&mut limited).map_err(io_err)?; - let n_usize = - usize::try_from(n).map_err(|_| SerializationError::InvalidData)?; - if n_usize > 1_000_000 { - return Err(SerializationError::InvalidData); - } - let mut v = Vec::with_capacity(n_usize.min(1024)); - for _ in 0..n_usize { - v.push(PCS::Commitment::deserialize_with_mode( - &mut limited, - compress, - validate, - )?); - } - let presence = u8::deserialize_with_mode(&mut limited, compress, validate)?; - let opt = match presence { - 0 => None, - 1 => Some(PCS::Commitment::deserialize_with_mode( - &mut limited, - compress, - validate, - )?), - _ => return Err(SerializationError::InvalidData), - }; - commitments = Some(v); - untrusted_advice_commitment = Some(opt); - } - #[cfg(not(feature = "zk"))] - TAG_OPENING_CLAIMS => { - if opening_claims.is_some() { - return Err(io_err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "duplicate opening claims section", - ))); - } - let n = transport::read_varint_u64(&mut limited).map_err(io_err)?; - let n_usize = - usize::try_from(n).map_err(|_| SerializationError::InvalidData)?; - if n_usize > 10_000_000 { - return Err(SerializationError::InvalidData); - } - let mut claims = BTreeMap::new(); - for _ in 0..n_usize { - let key = - OpeningId::deserialize_with_mode(&mut limited, compress, validate)?; - let claim = F::deserialize_with_mode(&mut limited, compress, validate)?; - claims.insert(key, (OpeningPoint::default(), claim)); - } - opening_claims = Some(Claims(claims)); - } - TAG_STAGE1 => { - read_singleton!(&mut limited, compress, validate, stage1_uni); - read_singleton!(&mut limited, compress, validate, stage1_sumcheck); - } - TAG_STAGE2 => { - read_singleton!(&mut limited, compress, validate, stage2_uni); - read_singleton!(&mut limited, compress, validate, stage2_sumcheck); - } - TAG_STAGE3 => read_singleton!(&mut limited, compress, validate, stage3_sumcheck), - TAG_STAGE4 => read_singleton!(&mut limited, compress, validate, stage4_sumcheck), - TAG_STAGE5 => read_singleton!(&mut limited, compress, validate, stage5_sumcheck), - TAG_STAGE6 => read_singleton!(&mut limited, compress, validate, stage6_sumcheck), - TAG_STAGE7 => read_singleton!(&mut limited, compress, validate, stage7_sumcheck), - TAG_JOINT_OPENING => { - read_singleton!(&mut limited, compress, validate, joint_opening_proof) + let mut claims = BTreeMap::new(); + for _ in 0..n_usize { + let key = OpeningId::deserialize_with_mode(&mut limited, compress, validate)?; + let claim = F::deserialize_with_mode(&mut limited, compress, validate)?; + if claims + .insert(key, (OpeningPoint::default(), claim)) + .is_some() + { + return Err(SerializationError::InvalidData); } - #[cfg(feature = "zk")] - TAG_BLINDFOLD => { - read_singleton!(&mut limited, compress, validate, blindfold_proof) - } - _ => unreachable!("unknown tags rejected above"), } + check_trailing_bytes(&limited)?; + Claims(claims) + }; - if limited.limit() != 0 { - return Err(io_err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!( - "frame tag {tag}: {} trailing bytes not consumed", - limited.limit() - ), - ))); - } + // Stage 1 + let mut limited = transport::read_section(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let stage1_uni_skip_first_round_proof = + CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; + let stage1_sumcheck_proof = + CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; + check_trailing_bytes(&limited)?; + + // Stage 2 + let mut limited = transport::read_section(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let stage2_uni_skip_first_round_proof = + CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; + let stage2_sumcheck_proof = + CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; + check_trailing_bytes(&limited)?; + + // Stages 3-7 + macro_rules! read_single_section { + ($reader:expr) => {{ + let mut limited = + transport::read_section($reader, MAX_SECTION_LEN).map_err(io_err)?; + let val = + CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; + check_trailing_bytes(&limited)?; + val + }}; } - macro_rules! require { - ($opt:expr, $name:expr) => { - $opt.ok_or_else(|| { - io_err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - concat!("missing required section: ", $name), - )) - })? - }; - } + let stage3_sumcheck_proof = read_single_section!(&mut reader); + let stage4_sumcheck_proof = read_single_section!(&mut reader); + let stage5_sumcheck_proof = read_single_section!(&mut reader); + let stage6_sumcheck_proof = read_single_section!(&mut reader); + let stage7_sumcheck_proof = read_single_section!(&mut reader); + let joint_opening_proof = read_single_section!(&mut reader); + + #[cfg(feature = "zk")] + let blindfold_proof = read_single_section!(&mut reader); Ok(Self { - commitments: require!(commitments, "commitments"), - stage1_uni_skip_first_round_proof: require!(stage1_uni, "stage1_uni"), - stage1_sumcheck_proof: require!(stage1_sumcheck, "stage1_sumcheck"), - stage2_uni_skip_first_round_proof: require!(stage2_uni, "stage2_uni"), - stage2_sumcheck_proof: require!(stage2_sumcheck, "stage2_sumcheck"), - stage3_sumcheck_proof: require!(stage3_sumcheck, "stage3_sumcheck"), - stage4_sumcheck_proof: require!(stage4_sumcheck, "stage4_sumcheck"), - stage5_sumcheck_proof: require!(stage5_sumcheck, "stage5_sumcheck"), - stage6_sumcheck_proof: require!(stage6_sumcheck, "stage6_sumcheck"), - stage7_sumcheck_proof: require!(stage7_sumcheck, "stage7_sumcheck"), + commitments, + stage1_uni_skip_first_round_proof, + stage1_sumcheck_proof, + stage2_uni_skip_first_round_proof, + stage2_sumcheck_proof, + stage3_sumcheck_proof, + stage4_sumcheck_proof, + stage5_sumcheck_proof, + stage6_sumcheck_proof, + stage7_sumcheck_proof, #[cfg(feature = "zk")] - blindfold_proof: require!(blindfold_proof, "blindfold"), - joint_opening_proof: require!(joint_opening_proof, "joint_opening"), - untrusted_advice_commitment: require!( - untrusted_advice_commitment, - "untrusted_advice_commitment" - ), + blindfold_proof, + joint_opening_proof, + untrusted_advice_commitment, #[cfg(not(feature = "zk"))] - opening_claims: require!(opening_claims, "opening_claims"), - trace_length: require!(trace_length, "params"), - ram_K: require!(ram_K, "params"), - rw_config: require!(rw_config, "params"), - one_hot_config: require!(one_hot_config, "params"), - dory_layout: require!(dory_layout, "params"), + opening_claims, + trace_length, + ram_K, + rw_config, + one_hot_config, + dory_layout, }) } } -/// Inline constant so both cfg branches can reference it in `read_frame_header`. -const MAX_BLINDFOLD_LEN_VAL: u64 = 16 * 1024 * 1024; - #[cfg(not(feature = "zk"))] pub struct Claims(pub Openings); @@ -593,7 +451,12 @@ impl CanonicalDeserialize for Claims { for _ in 0..size { let key = OpeningId::deserialize_with_mode(&mut reader, compress, validate)?; let claim = F::deserialize_with_mode(&mut reader, compress, validate)?; - claims.insert(key, (OpeningPoint::default(), claim)); + if claims + .insert(key, (OpeningPoint::default(), claim)) + .is_some() + { + return Err(SerializationError::InvalidData); + } } Ok(Claims(claims)) @@ -895,11 +758,47 @@ impl CanonicalSerialize for VirtualPolynomial { fn serialized_size(&self, _compress: Compress) -> usize { match self { + Self::PC + | Self::UnexpandedPC + | Self::NextPC + | Self::NextUnexpandedPC + | Self::NextIsNoop + | Self::NextIsVirtual + | Self::NextIsFirstInSequence + | Self::LeftLookupOperand + | Self::RightLookupOperand + | Self::LeftInstructionInput + | Self::RightInstructionInput + | Self::Product + | Self::ShouldJump + | Self::ShouldBranch + | Self::WritePCtoRD + | Self::WriteLookupOutputToRD + | Self::Rd + | Self::Imm + | Self::Rs1Value + | Self::Rs2Value + | Self::RdWriteValue + | Self::Rs1Ra + | Self::Rs2Ra + | Self::RdWa + | Self::LookupOutput + | Self::InstructionRaf + | Self::InstructionRafFlag + | Self::RegistersVal + | Self::RamAddress + | Self::RamRa + | Self::RamReadValue + | Self::RamWriteValue + | Self::RamVal + | Self::RamValInit + | Self::RamValFinal + | Self::RamHammingWeight + | Self::UnivariateSkip => 1, Self::InstructionRa(_) | Self::OpFlags(_) | Self::InstructionFlags(_) | Self::LookupTableFlag(_) => 2, - _ => 1, } } } @@ -986,7 +885,6 @@ pub fn serialize_and_print_size( file_name: &str, item: &impl CanonicalSerialize, ) -> Result<(), SerializationError> { - use std::fs::File; let mut file = File::create(file_name)?; item.serialize_compressed(&mut file)?; let file_size_bytes = file.metadata()?.len(); @@ -1001,6 +899,7 @@ mod tests { use super::*; use crate::poly::opening_proof::{OpeningId, SumcheckId}; use crate::zkvm::witness::{CommittedPolynomial, VirtualPolynomial}; + use crate::zkvm::RV64IMACProof; #[test] fn opening_id_header_is_packed_common_case() { @@ -1043,37 +942,65 @@ mod tests { } #[test] - fn proof_signature_version_byte() { - assert_eq!(PROOF_SIGNATURE[7], 1, "format version should be 1"); + fn proof_version_byte() { + assert_eq!(PROOF_VERSION, 1); + } + + #[test] + fn proof_magic_required() { + let mut just_magic = Vec::new(); + transport::write_magic_version(&mut just_magic, PROOF_MAGIC, PROOF_VERSION).unwrap(); + let res = RV64IMACProof::deserialize_with_mode( + std::io::Cursor::new(&just_magic), + Compress::Yes, + Validate::Yes, + ) + .map(|_| ()); + assert!(res.is_err()); } #[test] - fn proof_signature_is_required_and_unknown_tags_reject() { - let mut just_sig = Vec::new(); - just_sig.extend_from_slice(PROOF_SIGNATURE); - let res = crate::zkvm::RV64IMACProof::deserialize_with_mode( - std::io::Cursor::new(&just_sig), + fn wrong_version_rejected() { + let mut buf = Vec::new(); + transport::write_magic_version(&mut buf, PROOF_MAGIC, 99).unwrap(); + // Append enough zeros to avoid EOF on the version read + buf.extend_from_slice(&[0u8; 100]); + let res = RV64IMACProof::deserialize_with_mode( + std::io::Cursor::new(&buf), Compress::Yes, Validate::Yes, ) .map(|_| ()); match res { - Err(SerializationError::InvalidData) | Err(SerializationError::IoError(_)) => {} - _ => panic!("expected decode error"), + Err(SerializationError::IoError(e)) => { + assert!( + e.to_string().contains("unsupported proof version"), + "unexpected error: {e}" + ); + } + other => panic!("expected IoError with version message, got {other:?}"), } + } - let mut bytes = Vec::new(); - bytes.extend_from_slice(PROOF_SIGNATURE); - transport::write_frame_header(&mut bytes, 99, 0).unwrap(); - let res = crate::zkvm::RV64IMACProof::deserialize_with_mode( - std::io::Cursor::new(&bytes), + #[test] + fn wrong_magic_rejected() { + let mut buf = Vec::new(); + transport::write_magic_version(&mut buf, b"BADMAGC", 1).unwrap(); + buf.extend_from_slice(&[0u8; 100]); + let res = RV64IMACProof::deserialize_with_mode( + std::io::Cursor::new(&buf), Compress::Yes, Validate::Yes, ) .map(|_| ()); match res { - Err(SerializationError::InvalidData) | Err(SerializationError::IoError(_)) => {} - _ => panic!("expected decode error"), + Err(SerializationError::IoError(e)) => { + assert!( + e.to_string().contains("invalid proof magic"), + "unexpected error: {e}" + ); + } + other => panic!("expected IoError with magic message, got {other:?}"), } } } diff --git a/jolt-core/src/zkvm/transport.rs b/jolt-core/src/zkvm/transport.rs index 2fbe9d56c1..c8ec7da5d5 100644 --- a/jolt-core/src/zkvm/transport.rs +++ b/jolt-core/src/zkvm/transport.rs @@ -1,35 +1,31 @@ -//! Lightweight framed transport encoding helpers. +//! Lightweight length-prefixed transport encoding helpers. //! -//! This module is intentionally small and dependency-free so we can use it in verifier-facing -//! deserialization paths with explicit caps (DoS resistance) and strict parsing invariants. -//! -//! Design: -//! - Streams begin with a short fixed signature (header bytes). -//! - Then a sequence of frames: (tag: u8, len: varint u64, payload: len bytes). -//! - Decoders should be strict by default: reject unknown tags, reject duplicates for singleton -//! sections, and require full consumption of each payload. +//! Used in proof serialization for magic/version header and varint-prefixed sections +//! with explicit caps (DoS resistance) and trailing-byte validation. use std::io::{self, Read, Write}; -/// Maximum number of bytes in a u64 varint (LEB128-style). const VARINT_U64_MAX_BYTES: usize = 10; #[inline] -pub fn signature_check(r: &mut R, expected: &[u8]) -> io::Result<()> { - let mut got = vec![0u8; expected.len()]; - r.read_exact(&mut got)?; - if got != expected { +pub fn write_magic_version(w: &mut W, magic: &[u8], version: u8) -> io::Result<()> { + w.write_all(magic)?; + w.write_all(&[version]) +} + +#[inline] +pub fn read_magic_version(r: &mut R, magic: &[u8]) -> io::Result { + let mut buf = vec![0u8; magic.len()]; + r.read_exact(&mut buf)?; + if buf != magic { return Err(io::Error::new( io::ErrorKind::InvalidData, - "invalid signature", + "invalid proof magic", )); } - Ok(()) -} - -#[inline] -pub fn signature_write(w: &mut W, signature: &[u8]) -> io::Result<()> { - w.write_all(signature) + let mut v = [0u8; 1]; + r.read_exact(&mut v)?; + Ok(v[0]) } #[inline] @@ -71,45 +67,101 @@ pub fn read_varint_u64(r: &mut R) -> io::Result { )) } +/// Reads a varint-prefixed section, enforcing a maximum payload length. +/// Returns a `Take` reader limited to exactly the declared payload length. +/// Callers should check `limited.limit() == 0` after reading to detect trailing bytes. #[inline] -pub fn read_u8_opt(r: &mut R) -> io::Result> { - let mut b = [0u8; 1]; - match r.read(&mut b) { - Ok(0) => Ok(None), - Ok(1) => Ok(Some(b[0])), - Ok(_) => unreachable!(), - Err(e) => Err(e), - } -} - -#[inline] -pub fn write_frame_header(w: &mut W, tag: u8, len: u64) -> io::Result<()> { - w.write_all(&[tag])?; - write_varint_u64(w, len) -} - -#[inline] -pub fn read_frame_header(r: &mut R, max_len: u64) -> io::Result> { - let Some(tag) = read_u8_opt(r)? else { - return Ok(None); - }; +pub fn read_section(r: &mut R, max_len: u64) -> io::Result> { let len = read_varint_u64(r)?; if len > max_len { return Err(io::Error::new( io::ErrorKind::InvalidData, - "frame too large", + "section too large", )); } - Ok(Some((tag, len))) + Ok(r.take(len)) } -#[inline] -pub fn skip_exact(r: &mut R, mut n: u64) -> io::Result<()> { - let mut buf = [0u8; 4096]; - while n > 0 { - let k = (n as usize).min(buf.len()); - r.read_exact(&mut buf[..k])?; - n -= k as u64; +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn varint_roundtrip_edge_cases() { + let cases: &[u64] = &[0, 1, 127, 128, 16383, 16384, u32::MAX as u64, u64::MAX]; + for &val in cases { + let mut buf = Vec::new(); + write_varint_u64(&mut buf, val).unwrap(); + let decoded = read_varint_u64(&mut buf.as_slice()).unwrap(); + assert_eq!(decoded, val, "roundtrip failed for {val}"); + } + } + + #[test] + fn varint_u64_len_matches_encoding() { + let cases: &[u64] = &[0, 1, 127, 128, 16383, 16384, u32::MAX as u64, u64::MAX]; + for &val in cases { + let mut buf = Vec::new(); + write_varint_u64(&mut buf, val).unwrap(); + assert_eq!( + buf.len(), + varint_u64_len(val), + "varint_u64_len mismatch for {val}" + ); + } + } + + #[test] + fn varint_overflow_rejected() { + // 11 continuation bytes — exceeds VARINT_U64_MAX_BYTES + let bad = vec![0x80u8; 11]; + let res = read_varint_u64(&mut bad.as_slice()); + assert!(res.is_err()); + } + + #[test] + fn magic_version_roundtrip() { + let magic = b"JOLTPRF"; + let version = 1u8; + let mut buf = Vec::new(); + write_magic_version(&mut buf, magic, version).unwrap(); + assert_eq!(buf.len(), 8); + + let decoded_version = read_magic_version(&mut buf.as_slice(), magic).unwrap(); + assert_eq!(decoded_version, version); + } + + #[test] + fn wrong_magic_rejected() { + let mut buf = Vec::new(); + write_magic_version(&mut buf, b"JOLTPRF", 1).unwrap(); + let res = read_magic_version(&mut buf.as_slice(), b"BADMAGC"); + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("invalid proof magic")); + } + + #[test] + fn wrong_version_readable() { + let mut buf = Vec::new(); + write_magic_version(&mut buf, b"JOLTPRF", 2).unwrap(); + let version = read_magic_version(&mut buf.as_slice(), b"JOLTPRF").unwrap(); + assert_eq!(version, 2); + } + + #[test] + fn read_section_enforces_cap() { + let mut buf = Vec::new(); + write_varint_u64(&mut buf, 1000).unwrap(); + buf.extend_from_slice(&[0u8; 1000]); + + let mut too_small = buf.as_slice(); + let res = read_section(&mut too_small, 999); + assert!(res.is_err()); + + let mut cursor = buf.as_slice(); + let limited = read_section(&mut cursor, 1000).unwrap(); + assert_eq!(limited.limit(), 1000); } - Ok(()) } diff --git a/jolt-sdk/src/host_utils.rs b/jolt-sdk/src/host_utils.rs index 616511de26..95dc7c29ee 100644 --- a/jolt-sdk/src/host_utils.rs +++ b/jolt-sdk/src/host_utils.rs @@ -11,7 +11,6 @@ pub use jolt_core::curve::{Bn254Curve, JoltCurve}; pub use jolt_core::field::JoltField; pub use jolt_core::guest; pub use jolt_core::poly::commitment::dory::DoryCommitmentScheme as PCS; -pub use jolt_core::zkvm::transport; pub use jolt_core::zkvm::{ proof_serialization::JoltProof, verifier::JoltSharedPreprocessing, verifier::JoltVerifierPreprocessing, RV64IMACProof, RV64IMACVerifier, Serializable, From 631b1d32b3b5519650818f3f65c8930b1dd8b8be Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Sat, 7 Mar 2026 04:16:50 +0800 Subject: [PATCH 6/9] =?UTF-8?q?fix(zkvm):=20address=20audit=20findings=20?= =?UTF-8?q?=E2=80=94=20EOF=20error=20propagation,=20Claims=20unification,?= =?UTF-8?q?=20ZK=20discriminator?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Propagate I/O errors in EOF check instead of silently swallowing them - Unify standalone Claims ser/de with JoltProof's varint-based format; delegate via write_section!/read_single_section! instead of duplicating - Add debug_assert! in write_section! to catch ser/de cap mismatches - Add flags byte to wire header (bit 0 = is_zk) with clear cross-mode error messages; shorten magic from b"JOLTPRF" to b"JOLT" - Move `use std::fs::File` back inside function scope - Expand OpeningId roundtrip tests to cover all parameterized variants (CommittedPolynomial, VirtualPolynomial) x multiple SumcheckId values Made-with: Cursor --- jolt-core/src/zkvm/proof_serialization.rs | 218 ++++++++++++++-------- jolt-core/src/zkvm/transport.rs | 36 +++- 2 files changed, 166 insertions(+), 88 deletions(-) diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index 3f9dd3092f..3c8e18e921 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -1,6 +1,5 @@ #[cfg(not(feature = "zk"))] use std::collections::BTreeMap; -use std::fs::File; use std::io::{Read, Write}; use ark_serialize::{ @@ -34,8 +33,10 @@ use crate::{ use crate::zkvm::transport; -const PROOF_MAGIC: &[u8; 7] = b"JOLTPRF"; +const PROOF_MAGIC: &[u8; 4] = b"JOLT"; const PROOF_VERSION: u8 = 1; +const PROOF_FLAGS_RESERVED_MASK: u8 = 0xFE; +const PROOF_FLAG_ZK: u8 = 0x01; const MAX_PARAMS_LEN: u64 = 1024; const MAX_SECTION_LEN: u64 = 128 * 1024; @@ -72,6 +73,7 @@ fn io_err(e: std::io::Error) -> SerializationError { macro_rules! write_section { ($w:expr, $c:expr, $($item:expr),+ $(,)?) => {{ let len: u64 = 0 $(+ $item.serialized_size($c) as u64)+; + debug_assert!(len <= MAX_SECTION_LEN, "section size {len} exceeds MAX_SECTION_LEN"); transport::write_varint_u64($w, len).map_err(io_err)?; $($item.serialize_with_mode($w, $c)?;)+ }}; @@ -84,7 +86,9 @@ macro_rules! section_size { }}; } -fn check_trailing_bytes(limited: &std::io::Take<&mut impl Read>) -> Result<(), SerializationError> { +fn check_trailing_bytes( + limited: &std::io::Take<&mut R>, +) -> Result<(), SerializationError> { if limited.limit() != 0 { return Err(io_err(std::io::Error::new( std::io::ErrorKind::InvalidData, @@ -119,20 +123,6 @@ impl, FS: Transcrip .unwrap_or(0); count_len + items_len + 1 + untrusted_len } - - #[cfg(not(feature = "zk"))] - fn claims_payload_len(&self, compress: Compress) -> u64 { - let count_len = transport::varint_u64_len(self.opening_claims.0.len() as u64) as u64; - let items_len: u64 = self - .opening_claims - .0 - .iter() - .map(|(k, (_p, claim))| { - (k.serialized_size(compress) + claim.serialized_size(compress)) as u64 - }) - .sum(); - count_len + items_len - } } impl, FS: Transcript> @@ -144,6 +134,12 @@ impl, FS: Transcrip compress: Compress, ) -> Result<(), SerializationError> { transport::write_magic_version(&mut writer, PROOF_MAGIC, PROOF_VERSION).map_err(io_err)?; + let flags: u8 = if cfg!(feature = "zk") { + PROOF_FLAG_ZK + } else { + 0 + }; + writer.write_all(&[flags]).map_err(io_err)?; let params_len = self.params_payload_len(compress); transport::write_varint_u64(&mut writer, params_len).map_err(io_err)?; @@ -170,16 +166,7 @@ impl, FS: Transcrip } #[cfg(not(feature = "zk"))] - { - let claims_len = self.claims_payload_len(compress); - transport::write_varint_u64(&mut writer, claims_len).map_err(io_err)?; - transport::write_varint_u64(&mut writer, self.opening_claims.0.len() as u64) - .map_err(io_err)?; - for (k, (_p, claim)) in self.opening_claims.0.iter() { - k.serialize_with_mode(&mut writer, compress)?; - claim.serialize_with_mode(&mut writer, compress)?; - } - } + write_section!(&mut writer, compress, &self.opening_claims); write_section!( &mut writer, @@ -207,7 +194,7 @@ impl, FS: Transcrip } fn serialized_size(&self, compress: Compress) -> usize { - let mut size = PROOF_MAGIC.len() + 1; // magic + version byte + let mut size = PROOF_MAGIC.len() + 1 + 1; // magic + version + flags let params_len = self.params_payload_len(compress); size += transport::varint_u64_len(params_len) + params_len as usize; @@ -217,8 +204,7 @@ impl, FS: Transcrip #[cfg(not(feature = "zk"))] { - let claims_len = self.claims_payload_len(compress); - size += transport::varint_u64_len(claims_len) + claims_len as usize; + size += section_size!(compress, &self.opening_claims); } size += section_size!( @@ -271,6 +257,26 @@ impl, FS: Transcrip ))); } + let mut flags_buf = [0u8; 1]; + reader.read_exact(&mut flags_buf).map_err(io_err)?; + let flags = flags_buf[0]; + if flags & PROOF_FLAGS_RESERVED_MASK != 0 { + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("unknown proof flags bits set: {flags:#04x}"), + ))); + } + let proof_is_zk = flags & PROOF_FLAG_ZK != 0; + let compiled_for_zk = cfg!(feature = "zk"); + if proof_is_zk != compiled_for_zk { + let mode = if proof_is_zk { "ZK" } else { "standard" }; + let expected = if compiled_for_zk { "ZK" } else { "standard" }; + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("proof was serialized in {mode} mode but deserializer expects {expected}"), + ))); + } + // Params let mut limited = transport::read_section(&mut reader, MAX_PARAMS_LEN).map_err(io_err)?; let t = transport::read_varint_u64(&mut limited).map_err(io_err)?; @@ -315,30 +321,19 @@ impl, FS: Transcrip }; check_trailing_bytes(&limited)?; - // Opening claims (non-ZK only) + macro_rules! read_single_section { + ($reader:expr) => {{ + let mut limited = + transport::read_section($reader, MAX_SECTION_LEN).map_err(io_err)?; + let val = + CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; + check_trailing_bytes(&limited)?; + val + }}; + } + #[cfg(not(feature = "zk"))] - let opening_claims = { - let mut limited = - transport::read_section(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; - let n = transport::read_varint_u64(&mut limited).map_err(io_err)?; - let n_usize = usize::try_from(n).map_err(|_| SerializationError::InvalidData)?; - if n_usize > 10_000 { - return Err(SerializationError::InvalidData); - } - let mut claims = BTreeMap::new(); - for _ in 0..n_usize { - let key = OpeningId::deserialize_with_mode(&mut limited, compress, validate)?; - let claim = F::deserialize_with_mode(&mut limited, compress, validate)?; - if claims - .insert(key, (OpeningPoint::default(), claim)) - .is_some() - { - return Err(SerializationError::InvalidData); - } - } - check_trailing_bytes(&limited)?; - Claims(claims) - }; + let opening_claims: Claims = read_single_section!(&mut reader); // Stage 1 let mut limited = transport::read_section(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; @@ -356,18 +351,6 @@ impl, FS: Transcrip CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; check_trailing_bytes(&limited)?; - // Stages 3-7 - macro_rules! read_single_section { - ($reader:expr) => {{ - let mut limited = - transport::read_section($reader, MAX_SECTION_LEN).map_err(io_err)?; - let val = - CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; - check_trailing_bytes(&limited)?; - val - }}; - } - let stage3_sumcheck_proof = read_single_section!(&mut reader); let stage4_sumcheck_proof = read_single_section!(&mut reader); let stage5_sumcheck_proof = read_single_section!(&mut reader); @@ -378,6 +361,18 @@ impl, FS: Transcrip #[cfg(feature = "zk")] let blindfold_proof = read_single_section!(&mut reader); + let mut eof_check = [0u8; 1]; + match reader.read(&mut eof_check) { + Ok(0) => {} + Ok(_) => { + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unexpected trailing bytes after proof", + ))); + } + Err(e) => return Err(io_err(e)), + } + Ok(Self { commitments, stage1_uni_skip_first_round_proof, @@ -407,6 +402,9 @@ impl, FS: Transcrip #[cfg(not(feature = "zk"))] pub struct Claims(pub Openings); +#[cfg(not(feature = "zk"))] +const MAX_CLAIMS_COUNT: u64 = 10_000; + #[cfg(not(feature = "zk"))] impl CanonicalSerialize for Claims { fn serialize_with_mode( @@ -414,7 +412,7 @@ impl CanonicalSerialize for Claims { mut writer: W, compress: Compress, ) -> Result<(), SerializationError> { - self.0.len().serialize_with_mode(&mut writer, compress)?; + transport::write_varint_u64(&mut writer, self.0.len() as u64).map_err(io_err)?; for (key, (_opening_point, claim)) in self.0.iter() { key.serialize_with_mode(&mut writer, compress)?; claim.serialize_with_mode(&mut writer, compress)?; @@ -423,7 +421,7 @@ impl CanonicalSerialize for Claims { } fn serialized_size(&self, compress: Compress) -> usize { - let mut size = self.0.len().serialized_size(compress); + let mut size = transport::varint_u64_len(self.0.len() as u64); for (key, (_opening_point, claim)) in self.0.iter() { size += key.serialized_size(compress); size += claim.serialized_size(compress); @@ -446,9 +444,13 @@ impl CanonicalDeserialize for Claims { compress: Compress, validate: Validate, ) -> Result { - let size = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let n = transport::read_varint_u64(&mut reader).map_err(io_err)?; + let n_usize = usize::try_from(n).map_err(|_| SerializationError::InvalidData)?; + if n > MAX_CLAIMS_COUNT { + return Err(SerializationError::InvalidData); + } let mut claims = BTreeMap::new(); - for _ in 0..size { + for _ in 0..n_usize { let key = OpeningId::deserialize_with_mode(&mut reader, compress, validate)?; let claim = F::deserialize_with_mode(&mut reader, compress, validate)?; if claims @@ -458,7 +460,6 @@ impl CanonicalDeserialize for Claims { return Err(SerializationError::InvalidData); } } - Ok(Claims(claims)) } } @@ -879,6 +880,7 @@ pub fn serialize_and_print_size( file_name: &str, item: &impl CanonicalSerialize, ) -> Result<(), SerializationError> { + use std::fs::File; let mut file = File::create(file_name)?; item.serialize_compressed(&mut file)?; let file_size_bytes = file.metadata()?.len(); @@ -921,17 +923,75 @@ mod tests { #[test] fn opening_id_roundtrips() { - let cases = [ - OpeningId::UntrustedAdvice(SumcheckId::SpartanOuter), - OpeningId::TrustedAdvice(SumcheckId::SpartanOuter), - OpeningId::committed(CommittedPolynomial::RdInc, SumcheckId::SpartanOuter), - OpeningId::virt(VirtualPolynomial::PC, SumcheckId::SpartanOuter), + use crate::zkvm::instruction::{CircuitFlags, InstructionFlags}; + + let sumcheck_ids = [ + SumcheckId::SpartanOuter, + SumcheckId::RamReadWriteChecking, + SumcheckId::HammingWeightClaimReduction, + ]; + + let committed_polys = [ + CommittedPolynomial::RdInc, + CommittedPolynomial::RamInc, + CommittedPolynomial::InstructionRa(0), + CommittedPolynomial::InstructionRa(7), + CommittedPolynomial::BytecodeRa(0), + CommittedPolynomial::RamRa(0), + CommittedPolynomial::TrustedAdvice, + CommittedPolynomial::UntrustedAdvice, + ]; + + let virtual_polys = [ + VirtualPolynomial::PC, + VirtualPolynomial::NextPC, + VirtualPolynomial::UnivariateSkip, + VirtualPolynomial::InstructionRa(0), + VirtualPolynomial::InstructionRa(5), + VirtualPolynomial::OpFlags(CircuitFlags::AddOperands), + VirtualPolynomial::OpFlags(CircuitFlags::IsLastInSequence), + VirtualPolynomial::InstructionFlags(InstructionFlags::LeftOperandIsPC), + VirtualPolynomial::InstructionFlags(InstructionFlags::IsNoop), + VirtualPolynomial::LookupTableFlag(0), + VirtualPolynomial::LookupTableFlag(3), ]; - for id in cases { + + for &sc in &sumcheck_ids { + let id = OpeningId::UntrustedAdvice(sc); let mut bytes = Vec::new(); id.serialize_compressed(&mut bytes).unwrap(); - let decoded = OpeningId::deserialize_compressed(bytes.as_slice()).unwrap(); - assert_eq!(decoded, id); + assert_eq!( + OpeningId::deserialize_compressed(bytes.as_slice()).unwrap(), + id + ); + + let id = OpeningId::TrustedAdvice(sc); + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + assert_eq!( + OpeningId::deserialize_compressed(bytes.as_slice()).unwrap(), + id + ); + + for &cp in &committed_polys { + let id = OpeningId::committed(cp, sc); + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + assert_eq!( + OpeningId::deserialize_compressed(bytes.as_slice()).unwrap(), + id + ); + } + + for &vp in &virtual_polys { + let id = OpeningId::virt(vp, sc); + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + assert_eq!( + OpeningId::deserialize_compressed(bytes.as_slice()).unwrap(), + id + ); + } } } @@ -979,7 +1039,7 @@ mod tests { #[test] fn wrong_magic_rejected() { let mut buf = Vec::new(); - transport::write_magic_version(&mut buf, b"BADMAGC", 1).unwrap(); + transport::write_magic_version(&mut buf, b"BAAD", 1).unwrap(); buf.extend_from_slice(&[0u8; 100]); let res = RV64IMACProof::deserialize_with_mode( std::io::Cursor::new(&buf), diff --git a/jolt-core/src/zkvm/transport.rs b/jolt-core/src/zkvm/transport.rs index c8ec7da5d5..064beecd96 100644 --- a/jolt-core/src/zkvm/transport.rs +++ b/jolt-core/src/zkvm/transport.rs @@ -1,7 +1,9 @@ //! Lightweight length-prefixed transport encoding helpers. //! -//! Used in proof serialization for magic/version header and varint-prefixed sections -//! with explicit caps (DoS resistance) and trailing-byte validation. +//! Wire format: `[magic: 4B][version: 1B][flags: 1B][section₀][section₁]…` +//! where each section is `[varint payload_len][payload bytes]`. +//! Sections are sequential and untagged; the deserializer reads them in a +//! fixed order defined by the proof schema. use std::io::{self, Read, Write}; @@ -55,7 +57,14 @@ pub fn read_varint_u64(r: &mut R) -> io::Result { let mut b = [0u8; 1]; r.read_exact(&mut b)?; let byte = b[0]; - x |= ((byte & 0x7F) as u64) << shift; + let payload = (byte & 0x7F) as u64; + if shift == 63 && payload > 1 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "varint overflow", + )); + } + x |= payload << shift; if (byte & 0x80) == 0 { return Ok(x); } @@ -119,13 +128,22 @@ mod tests { assert!(res.is_err()); } + #[test] + fn varint_10th_byte_overflow_rejected() { + // 9 continuation bytes + a 10th byte with payload 2 (only 0 or 1 is valid at shift=63) + let mut bad = vec![0x80u8; 9]; + bad.push(0x02); + let res = read_varint_u64(&mut bad.as_slice()); + assert!(res.is_err()); + } + #[test] fn magic_version_roundtrip() { - let magic = b"JOLTPRF"; + let magic = b"JOLT"; let version = 1u8; let mut buf = Vec::new(); write_magic_version(&mut buf, magic, version).unwrap(); - assert_eq!(buf.len(), 8); + assert_eq!(buf.len(), 5); let decoded_version = read_magic_version(&mut buf.as_slice(), magic).unwrap(); assert_eq!(decoded_version, version); @@ -134,8 +152,8 @@ mod tests { #[test] fn wrong_magic_rejected() { let mut buf = Vec::new(); - write_magic_version(&mut buf, b"JOLTPRF", 1).unwrap(); - let res = read_magic_version(&mut buf.as_slice(), b"BADMAGC"); + write_magic_version(&mut buf, b"JOLT", 1).unwrap(); + let res = read_magic_version(&mut buf.as_slice(), b"BAAD"); assert!(res.is_err()); let err = res.unwrap_err(); assert_eq!(err.kind(), io::ErrorKind::InvalidData); @@ -145,8 +163,8 @@ mod tests { #[test] fn wrong_version_readable() { let mut buf = Vec::new(); - write_magic_version(&mut buf, b"JOLTPRF", 2).unwrap(); - let version = read_magic_version(&mut buf.as_slice(), b"JOLTPRF").unwrap(); + write_magic_version(&mut buf, b"JOLT", 2).unwrap(); + let version = read_magic_version(&mut buf.as_slice(), b"JOLT").unwrap(); assert_eq!(version, 2); } From e8dc11158668516fdee5d8ef5f0283732373cf72 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Tue, 31 Mar 2026 08:22:36 -0400 Subject: [PATCH 7/9] fix(zkvm): address framed proof serialization review --- jolt-core/src/zkvm/mod.rs | 2 +- jolt-core/src/zkvm/proof_serialization.rs | 173 ++++++++++++++++++++-- jolt-core/src/zkvm/prover.rs | 12 +- jolt-core/src/zkvm/transport.rs | 14 +- 4 files changed, 183 insertions(+), 18 deletions(-) diff --git a/jolt-core/src/zkvm/mod.rs b/jolt-core/src/zkvm/mod.rs index cb9e095fc7..6c09e8d32f 100644 --- a/jolt-core/src/zkvm/mod.rs +++ b/jolt-core/src/zkvm/mod.rs @@ -37,7 +37,7 @@ pub mod r1cs; pub mod ram; pub mod registers; pub mod spartan; -pub mod transport; +pub(crate) mod transport; pub mod verifier; pub mod witness; diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index 3c8e18e921..f07b55f1eb 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -1,6 +1,9 @@ #[cfg(not(feature = "zk"))] use std::collections::BTreeMap; -use std::io::{Read, Write}; +use std::{ + any::TypeId, + io::{Cursor, Read, Write}, +}; use ark_serialize::{ CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate, @@ -15,7 +18,10 @@ use crate::{ curve::JoltCurve, field::JoltField, poly::{ - commitment::{commitment_scheme::CommitmentScheme, dory::DoryLayout}, + commitment::{ + commitment_scheme::CommitmentScheme, + dory::{ArkG1, ArkG2, ArkGT, DoryCommitmentScheme, DoryLayout}, + }, opening_proof::{OpeningId, PolynomialId, SumcheckId}, }, }; @@ -40,6 +46,8 @@ const PROOF_FLAG_ZK: u8 = 0x01; const MAX_PARAMS_LEN: u64 = 1024; const MAX_SECTION_LEN: u64 = 128 * 1024; +#[cfg(not(feature = "zk"))] +const MIN_OPENING_CLAIM_BYTES: u64 = 33; pub struct JoltProof, FS: Transcript> { pub commitments: Vec, @@ -70,10 +78,31 @@ fn io_err(e: std::io::Error) -> SerializationError { SerializationError::IoError(e) } +#[inline] +fn invalid_data(message: impl Into) -> SerializationError { + io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + message.into(), + )) +} + +fn ensure_section_len( + section_name: &str, + len: u64, + max_len: u64, +) -> Result<(), SerializationError> { + if len > max_len { + return Err(invalid_data(format!( + "{section_name} section size {len} exceeds cap {max_len}" + ))); + } + Ok(()) +} + macro_rules! write_section { ($w:expr, $c:expr, $($item:expr),+ $(,)?) => {{ let len: u64 = 0 $(+ $item.serialized_size($c) as u64)+; - debug_assert!(len <= MAX_SECTION_LEN, "section size {len} exceeds MAX_SECTION_LEN"); + ensure_section_len("proof", len, MAX_SECTION_LEN)?; transport::write_varint_u64($w, len).map_err(io_err)?; $($item.serialize_with_mode($w, $c)?;)+ }}; @@ -90,14 +119,90 @@ fn check_trailing_bytes( limited: &std::io::Take<&mut R>, ) -> Result<(), SerializationError> { if limited.limit() != 0 { - return Err(io_err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("{} trailing bytes not consumed", limited.limit()), + return Err(invalid_data(format!( + "{} trailing bytes not consumed", + limited.limit() + ))); + } + Ok(()) +} + +fn check_cursor_consumed( + section_name: &str, + cursor: &Cursor<&[u8]>, +) -> Result<(), SerializationError> { + let trailing = cursor + .get_ref() + .len() + .saturating_sub(cursor.position() as usize); + if trailing != 0 { + return Err(invalid_data(format!( + "{trailing} trailing bytes not consumed in {section_name} section" ))); } Ok(()) } +fn ark_group_size(compress: Compress) -> usize { + T::default().serialized_size(compress) +} + +fn check_dory_round_count( + section_bytes: &[u8], + compress: Compress, + validate: Validate, +) -> Result<(), SerializationError> { + let mut cursor = Cursor::new(section_bytes); + ArkGT::deserialize_with_mode(&mut cursor, compress, validate)?; + ArkGT::deserialize_with_mode(&mut cursor, compress, validate)?; + ArkG1::deserialize_with_mode(&mut cursor, compress, validate)?; + + let num_rounds = u32::deserialize_with_mode(&mut cursor, compress, validate)? as usize; + + let first_round_bytes = 4 * ark_group_size::(compress) + + ark_group_size::(compress) + + ark_group_size::(compress); + let second_round_bytes = 2 * ark_group_size::(compress) + + 2 * ark_group_size::(compress) + + 2 * ark_group_size::(compress); + let min_tail_bytes = ark_group_size::(compress) + + ark_group_size::(compress) + + 2 * std::mem::size_of::(); + let remaining_bytes = section_bytes + .len() + .saturating_sub(cursor.position() as usize); + + if remaining_bytes < min_tail_bytes { + return Err(SerializationError::InvalidData); + } + + let per_round_bytes = first_round_bytes + second_round_bytes; + let max_rounds = remaining_bytes.saturating_sub(min_tail_bytes) / per_round_bytes; + if num_rounds > max_rounds { + return Err(invalid_data(format!( + "Dory opening proof declares {num_rounds} rounds but section only has room for {max_rounds}" + ))); + } + + Ok(()) +} + +fn deserialize_joint_opening_proof_section( + reader: &mut impl Read, + compress: Compress, + validate: Validate, +) -> Result { + let section_bytes = transport::read_section_bytes(reader, MAX_SECTION_LEN).map_err(io_err)?; + if TypeId::of::() == TypeId::of::() { + check_dory_round_count(§ion_bytes, compress, validate)?; + } + + let mut cursor = Cursor::new(section_bytes.as_slice()); + let proof = PCS::Proof::deserialize_with_mode(&mut cursor, compress, validate)?; + check_cursor_consumed("joint opening proof", &cursor)?; + Ok(proof) +} + impl, FS: Transcript> JoltProof { @@ -142,6 +247,7 @@ impl, FS: Transcrip writer.write_all(&[flags]).map_err(io_err)?; let params_len = self.params_payload_len(compress); + ensure_section_len("params", params_len, MAX_PARAMS_LEN)?; transport::write_varint_u64(&mut writer, params_len).map_err(io_err)?; transport::write_varint_u64(&mut writer, self.trace_length as u64).map_err(io_err)?; transport::write_varint_u64(&mut writer, self.ram_K as u64).map_err(io_err)?; @@ -152,6 +258,7 @@ impl, FS: Transcrip .serialize_with_mode(&mut writer, compress)?; let commitments_len = self.commitments_payload_len(compress); + ensure_section_len("commitments", commitments_len, MAX_SECTION_LEN)?; transport::write_varint_u64(&mut writer, commitments_len).map_err(io_err)?; transport::write_varint_u64(&mut writer, self.commitments.len() as u64).map_err(io_err)?; for c in &self.commitments { @@ -295,11 +402,20 @@ impl, FS: Transcrip check_trailing_bytes(&limited)?; // Commitments - let mut limited = transport::read_section(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let commitments_bytes = + transport::read_section_bytes(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let mut limited = Cursor::new(commitments_bytes.as_slice()); let n = transport::read_varint_u64(&mut limited).map_err(io_err)?; let n_usize = usize::try_from(n).map_err(|_| SerializationError::InvalidData)?; - if n_usize > 10_000 { - return Err(SerializationError::InvalidData); + let remaining_bytes = commitments_bytes + .len() + .saturating_sub(limited.position() as usize); + let min_commitment_bytes = PCS::Commitment::default().serialized_size(compress).max(1); + let max_commitments = remaining_bytes.saturating_sub(1) / min_commitment_bytes; + if n_usize > max_commitments { + return Err(invalid_data(format!( + "commitments section count {n_usize} exceeds byte cap {max_commitments}" + ))); } let mut commitments = Vec::with_capacity(n_usize); for _ in 0..n_usize { @@ -319,7 +435,7 @@ impl, FS: Transcrip )?), _ => return Err(SerializationError::InvalidData), }; - check_trailing_bytes(&limited)?; + check_cursor_consumed("commitments", &limited)?; macro_rules! read_single_section { ($reader:expr) => {{ @@ -356,7 +472,8 @@ impl, FS: Transcrip let stage5_sumcheck_proof = read_single_section!(&mut reader); let stage6_sumcheck_proof = read_single_section!(&mut reader); let stage7_sumcheck_proof = read_single_section!(&mut reader); - let joint_opening_proof = read_single_section!(&mut reader); + let joint_opening_proof = + deserialize_joint_opening_proof_section::(&mut reader, compress, validate)?; #[cfg(feature = "zk")] let blindfold_proof = read_single_section!(&mut reader); @@ -403,7 +520,7 @@ impl, FS: Transcrip pub struct Claims(pub Openings); #[cfg(not(feature = "zk"))] -const MAX_CLAIMS_COUNT: u64 = 10_000; +const MAX_CLAIMS_COUNT: u64 = MAX_SECTION_LEN / MIN_OPENING_CLAIM_BYTES; #[cfg(not(feature = "zk"))] impl CanonicalSerialize for Claims { @@ -596,7 +713,11 @@ impl CanonicalDeserialize for OpeningId { let small = header & 0x3F; let sumcheck_u64 = if small == OPENING_ID_SUMCHECK_ESCAPE { - transport::read_varint_u64(&mut reader).map_err(io_err)? + let sumcheck_u64 = transport::read_varint_u64(&mut reader).map_err(io_err)?; + if sumcheck_u64 < OPENING_ID_SUMCHECK_ESCAPE as u64 { + return Err(SerializationError::InvalidData); + } + sumcheck_u64 } else { small as u64 }; @@ -995,6 +1116,16 @@ mod tests { } } + #[test] + fn opening_id_rejects_noncanonical_escape_encoding() { + let bytes = [ + (OPENING_ID_KIND_UNTRUSTED_ADVICE << 6) | OPENING_ID_SUMCHECK_ESCAPE, + 0, + ]; + let res = OpeningId::deserialize_compressed(bytes.as_slice()); + assert!(matches!(res, Err(SerializationError::InvalidData))); + } + #[test] fn proof_version_byte() { assert_eq!(PROOF_VERSION, 1); @@ -1057,4 +1188,20 @@ mod tests { other => panic!("expected IoError with magic message, got {other:?}"), } } + + #[test] + fn dory_round_count_sanity_check_rejects_oversized_count() { + let mut bytes = Vec::new(); + ArkGT::default().serialize_compressed(&mut bytes).unwrap(); + ArkGT::default().serialize_compressed(&mut bytes).unwrap(); + ArkG1::default().serialize_compressed(&mut bytes).unwrap(); + 1_000u32.serialize_compressed(&mut bytes).unwrap(); + ArkG1::default().serialize_compressed(&mut bytes).unwrap(); + ArkG2::default().serialize_compressed(&mut bytes).unwrap(); + 0u32.serialize_compressed(&mut bytes).unwrap(); + 0u32.serialize_compressed(&mut bytes).unwrap(); + + let res = check_dory_round_count(&bytes, Compress::Yes, Validate::Yes); + assert!(res.is_err()); + } } diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index 7e025a65a0..35f6146463 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -2229,7 +2229,7 @@ mod tests { prover::JoltProverPreprocessing, ram::populate_memory_states, verifier::{JoltVerifier, JoltVerifierPreprocessing}, - RV64IMACProver, RV64IMACVerifier, + RV64IMACProof, RV64IMACProver, RV64IMACVerifier, Serializable, }; #[cfg(feature = "zk")] use crate::{curve::JoltCurve, field::JoltField}; @@ -2313,6 +2313,8 @@ mod tests { ); let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); + let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); let verifier = RV64IMACVerifier::new( @@ -2416,6 +2418,8 @@ mod tests { ); let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); + let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); let verifier = RV64IMACVerifier::new( @@ -2539,6 +2543,8 @@ mod tests { ); let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); + let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); RV64IMACVerifier::new( @@ -2661,6 +2667,8 @@ mod tests { ); let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); + let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); RV64IMACVerifier::new( @@ -2900,6 +2908,8 @@ mod tests { ); let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); + let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); let verifier = RV64IMACVerifier::new( diff --git a/jolt-core/src/zkvm/transport.rs b/jolt-core/src/zkvm/transport.rs index 064beecd96..110693e424 100644 --- a/jolt-core/src/zkvm/transport.rs +++ b/jolt-core/src/zkvm/transport.rs @@ -16,10 +16,10 @@ pub fn write_magic_version(w: &mut W, magic: &[u8], version: u8) -> io } #[inline] -pub fn read_magic_version(r: &mut R, magic: &[u8]) -> io::Result { - let mut buf = vec![0u8; magic.len()]; +pub fn read_magic_version(r: &mut R, magic: &[u8; 4]) -> io::Result { + let mut buf = [0u8; 4]; r.read_exact(&mut buf)?; - if buf != magic { + if buf != *magic { return Err(io::Error::new( io::ErrorKind::InvalidData, "invalid proof magic", @@ -91,6 +91,14 @@ pub fn read_section(r: &mut R, max_len: u64) -> io::Result(r: &mut R, max_len: u64) -> io::Result> { + let mut limited = read_section(r, max_len)?; + let mut bytes = vec![0u8; limited.limit() as usize]; + limited.read_exact(&mut bytes)?; + Ok(bytes) +} + #[cfg(test)] mod tests { use super::*; From 7ce8285b09f1b90f6b16cc00e31e5535e03ab417 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Tue, 31 Mar 2026 10:03:24 -0400 Subject: [PATCH 8/9] fix(zkvm): harden framed proof deserialization --- jolt-core/src/zkvm/proof_serialization.rs | 440 ++++++++++++++++++++-- jolt-core/src/zkvm/prover.rs | 4 + jolt-core/src/zkvm/transport.rs | 19 +- 3 files changed, 431 insertions(+), 32 deletions(-) diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index dd3a78a17e..48d5d0f867 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -10,12 +10,14 @@ use ark_serialize::{ }; use num::FromPrimitive; +#[cfg(feature = "zk")] +use crate::poly::commitment::hyrax::HyraxOpeningProof; #[cfg(not(feature = "zk"))] use crate::poly::opening_proof::{OpeningPoint, Openings}; #[cfg(feature = "zk")] -use crate::subprotocols::blindfold::BlindFoldProof; +use crate::subprotocols::blindfold::{BlindFoldProof, RelaxedR1CSInstance}; use crate::{ - curve::JoltCurve, + curve::{JoltCurve, JoltGroupElement}, field::JoltField, poly::{ commitment::{ @@ -23,11 +25,15 @@ use crate::{ dory::{ArkG1, ArkG2, ArkGT, DoryCommitmentScheme, DoryLayout}, }, opening_proof::{OpeningId, PolynomialId, SumcheckId}, + unipoly::{CompressedUniPoly, UniPoly}, }, }; use crate::{ subprotocols::{ - sumcheck::SumcheckInstanceProof, univariate_skip::UniSkipFirstRoundProofVariant, + sumcheck::{ClearSumcheckProof, SumcheckInstanceProof, ZkSumcheckProof}, + univariate_skip::{ + UniSkipFirstRoundProof, UniSkipFirstRoundProofVariant, ZkUniSkipFirstRoundProof, + }, }, transcripts::Transcript, zkvm::{ @@ -148,6 +154,304 @@ fn check_cursor_consumed( Ok(()) } +fn cursor_remaining_bytes(cursor: &Cursor<&[u8]>) -> usize { + cursor + .get_ref() + .len() + .saturating_sub(cursor.position() as usize) +} + +fn read_ark_seq_len( + cursor: &mut Cursor<&[u8]>, + min_elem_bytes: usize, + compress: Compress, + validate: Validate, +) -> Result { + let len_u64 = u64::deserialize_with_mode(&mut *cursor, compress, validate)?; + let len = usize::try_from(len_u64).map_err(|_| SerializationError::InvalidData)?; + let max_len = cursor_remaining_bytes(cursor) / min_elem_bytes.max(1); + if len > max_len { + return Err(invalid_data(format!( + "sequence length {len} exceeds remaining byte budget {max_len}" + ))); + } + Ok(len) +} + +fn read_bounded_ark_vec( + cursor: &mut Cursor<&[u8]>, + min_elem_bytes: usize, + compress: Compress, + validate: Validate, + mut read_item: ReadItem, +) -> Result, SerializationError> +where + ReadItem: FnMut(&mut Cursor<&[u8]>, Compress, Validate) -> Result, +{ + let len = read_ark_seq_len(cursor, min_elem_bytes, compress, validate)?; + let mut items = Vec::with_capacity(len); + for _ in 0..len { + items.push(read_item(cursor, compress, validate)?); + } + Ok(items) +} + +fn deserialize_field_vec( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + read_bounded_ark_vec( + cursor, + F::zero().serialized_size(compress), + compress, + validate, + |cursor, compress, validate| F::deserialize_with_mode(cursor, compress, validate), + ) +} + +fn deserialize_g1_vec>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + read_bounded_ark_vec( + cursor, + C::G1::zero().serialized_size(compress), + compress, + validate, + |cursor, compress, validate| C::G1::deserialize_with_mode(cursor, compress, validate), + ) +} + +fn deserialize_usize_vec( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + read_bounded_ark_vec( + cursor, + 0usize.serialized_size(compress), + compress, + validate, + |cursor, compress, validate| usize::deserialize_with_mode(cursor, compress, validate), + ) +} + +fn deserialize_uni_poly_from_cursor( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(UniPoly::from_coeff(deserialize_field_vec( + cursor, compress, validate, + )?)) +} + +fn deserialize_compressed_uni_poly_from_cursor( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(CompressedUniPoly { + coeffs_except_linear_term: deserialize_field_vec(cursor, compress, validate)?, + }) +} + +fn deserialize_clear_sumcheck_proof_from_cursor( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + let compressed_polys = read_bounded_ark_vec( + cursor, + 0u64.serialized_size(compress), + compress, + validate, + |cursor, compress, validate| { + deserialize_compressed_uni_poly_from_cursor::(cursor, compress, validate) + }, + )?; + Ok(ClearSumcheckProof::new(compressed_polys)) +} + +fn deserialize_zk_sumcheck_proof_from_cursor, FS: Transcript>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(ZkSumcheckProof::new( + deserialize_g1_vec::(cursor, compress, validate)?, + deserialize_usize_vec(cursor, compress, validate)?, + deserialize_g1_vec::(cursor, compress, validate)?, + )) +} + +fn deserialize_sumcheck_instance_proof_from_cursor< + F: JoltField, + C: JoltCurve, + FS: Transcript, +>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + match u8::deserialize_with_mode(&mut *cursor, compress, validate)? { + 0 => Ok(SumcheckInstanceProof::Clear( + deserialize_clear_sumcheck_proof_from_cursor::(cursor, compress, validate)?, + )), + 1 => Ok(SumcheckInstanceProof::Zk( + deserialize_zk_sumcheck_proof_from_cursor::(cursor, compress, validate)?, + )), + _ => Err(SerializationError::InvalidData), + } +} + +fn deserialize_uniskip_first_round_proof_from_cursor( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(UniSkipFirstRoundProof::new( + deserialize_uni_poly_from_cursor(cursor, compress, validate)?, + )) +} + +fn deserialize_zk_uniskip_first_round_proof_from_cursor< + F: JoltField, + C: JoltCurve, + FS: Transcript, +>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(ZkUniSkipFirstRoundProof::new( + C::G1::deserialize_with_mode(&mut *cursor, compress, validate)?, + usize::deserialize_with_mode(&mut *cursor, compress, validate)?, + deserialize_g1_vec::(cursor, compress, validate)?, + )) +} + +fn deserialize_uniskip_first_round_proof_variant_from_cursor< + F: JoltField, + C: JoltCurve, + FS: Transcript, +>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + match u8::deserialize_with_mode(&mut *cursor, compress, validate)? { + 0 => Ok(UniSkipFirstRoundProofVariant::Standard( + deserialize_uniskip_first_round_proof_from_cursor::(cursor, compress, validate)?, + )), + 1 => Ok(UniSkipFirstRoundProofVariant::Zk( + deserialize_zk_uniskip_first_round_proof_from_cursor::( + cursor, compress, validate, + )?, + )), + _ => Err(SerializationError::InvalidData), + } +} + +#[cfg(feature = "zk")] +fn deserialize_hyrax_opening_proof_from_cursor( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(HyraxOpeningProof { + combined_row: deserialize_field_vec(cursor, compress, validate)?, + combined_blinding: F::deserialize_with_mode(cursor, compress, validate)?, + }) +} + +#[cfg(feature = "zk")] +fn deserialize_relaxed_r1cs_instance_from_cursor>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(RelaxedR1CSInstance { + u: F::deserialize_with_mode(&mut *cursor, compress, validate)?, + round_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + output_claims_row_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + noncoeff_row_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + e_row_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + eval_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + }) +} + +#[cfg(feature = "zk")] +fn deserialize_blindfold_proof_from_cursor>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(BlindFoldProof { + random_instance: deserialize_relaxed_r1cs_instance_from_cursor::( + cursor, compress, validate, + )?, + noncoeff_row_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + cross_term_row_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + spartan_proof: read_bounded_ark_vec( + cursor, + 0u64.serialized_size(compress), + compress, + validate, + |cursor, compress, validate| { + deserialize_compressed_uni_poly_from_cursor::(cursor, compress, validate) + }, + )?, + az_r: F::deserialize_with_mode(&mut *cursor, compress, validate)?, + bz_r: F::deserialize_with_mode(&mut *cursor, compress, validate)?, + cz_r: F::deserialize_with_mode(&mut *cursor, compress, validate)?, + inner_sumcheck_proof: read_bounded_ark_vec( + cursor, + 0u64.serialized_size(compress), + compress, + validate, + |cursor, compress, validate| { + deserialize_compressed_uni_poly_from_cursor::(cursor, compress, validate) + }, + )?, + w_opening: deserialize_hyrax_opening_proof_from_cursor(cursor, compress, validate)?, + e_opening: deserialize_hyrax_opening_proof_from_cursor(cursor, compress, validate)?, + folded_eval_outputs: deserialize_field_vec(cursor, compress, validate)?, + folded_eval_blindings: deserialize_field_vec(cursor, compress, validate)?, + }) +} + +fn ensure_sumcheck_mode, FS: Transcript>( + section_name: &str, + proof: &SumcheckInstanceProof, + proof_is_zk: bool, +) -> Result<(), SerializationError> { + if proof.is_zk() != proof_is_zk { + let mode = if proof_is_zk { "ZK" } else { "standard" }; + return Err(invalid_data(format!( + "{section_name} sumcheck proof mode does not match outer proof {mode} flag" + ))); + } + Ok(()) +} + +fn ensure_uniskip_mode, FS: Transcript>( + section_name: &str, + proof: &UniSkipFirstRoundProofVariant, + proof_is_zk: bool, +) -> Result<(), SerializationError> { + let uniskip_is_zk = matches!(proof, UniSkipFirstRoundProofVariant::Zk(_)); + if uniskip_is_zk != proof_is_zk { + let mode = if proof_is_zk { "ZK" } else { "standard" }; + return Err(invalid_data(format!( + "{section_name} uni-skip proof mode does not match outer proof {mode} flag" + ))); + } + Ok(()) +} + fn ark_group_size(compress: Compress) -> usize { T::default().serialized_size(compress) } @@ -442,46 +746,87 @@ impl, PCS: CommitmentScheme, FS: Tr }; check_cursor_consumed("commitments", &limited)?; - macro_rules! read_single_section { - ($reader:expr) => {{ - let mut limited = - transport::read_section($reader, MAX_SECTION_LEN).map_err(io_err)?; - let val = - CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; - check_trailing_bytes(&limited)?; - val - }}; - } - #[cfg(not(feature = "zk"))] - let opening_claims: Claims = read_single_section!(&mut reader); + let opening_claims: Claims = { + let mut limited = + transport::read_section(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let claims = + CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; + check_trailing_bytes(&limited)?; + claims + }; // Stage 1 - let mut limited = transport::read_section(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let stage1_bytes = + transport::read_section_bytes(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let mut stage1_cursor = Cursor::new(stage1_bytes.as_slice()); let stage1_uni_skip_first_round_proof = - CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; - let stage1_sumcheck_proof = - CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; - check_trailing_bytes(&limited)?; + deserialize_uniskip_first_round_proof_variant_from_cursor( + &mut stage1_cursor, + compress, + validate, + )?; + let stage1_sumcheck_proof = deserialize_sumcheck_instance_proof_from_cursor( + &mut stage1_cursor, + compress, + validate, + )?; + check_cursor_consumed("stage1", &stage1_cursor)?; + ensure_uniskip_mode("stage1", &stage1_uni_skip_first_round_proof, proof_is_zk)?; + ensure_sumcheck_mode("stage1", &stage1_sumcheck_proof, proof_is_zk)?; // Stage 2 - let mut limited = transport::read_section(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let stage2_bytes = + transport::read_section_bytes(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let mut stage2_cursor = Cursor::new(stage2_bytes.as_slice()); let stage2_uni_skip_first_round_proof = - CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; - let stage2_sumcheck_proof = - CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; - check_trailing_bytes(&limited)?; + deserialize_uniskip_first_round_proof_variant_from_cursor( + &mut stage2_cursor, + compress, + validate, + )?; + let stage2_sumcheck_proof = deserialize_sumcheck_instance_proof_from_cursor( + &mut stage2_cursor, + compress, + validate, + )?; + check_cursor_consumed("stage2", &stage2_cursor)?; + ensure_uniskip_mode("stage2", &stage2_uni_skip_first_round_proof, proof_is_zk)?; + ensure_sumcheck_mode("stage2", &stage2_sumcheck_proof, proof_is_zk)?; + + macro_rules! read_sumcheck_section { + ($reader:expr, $section_name:literal) => {{ + let section_bytes = + transport::read_section_bytes($reader, MAX_SECTION_LEN).map_err(io_err)?; + let mut cursor = Cursor::new(section_bytes.as_slice()); + let proof = deserialize_sumcheck_instance_proof_from_cursor( + &mut cursor, + compress, + validate, + )?; + check_cursor_consumed($section_name, &cursor)?; + ensure_sumcheck_mode($section_name, &proof, proof_is_zk)?; + proof + }}; + } - let stage3_sumcheck_proof = read_single_section!(&mut reader); - let stage4_sumcheck_proof = read_single_section!(&mut reader); - let stage5_sumcheck_proof = read_single_section!(&mut reader); - let stage6_sumcheck_proof = read_single_section!(&mut reader); - let stage7_sumcheck_proof = read_single_section!(&mut reader); + let stage3_sumcheck_proof = read_sumcheck_section!(&mut reader, "stage3"); + let stage4_sumcheck_proof = read_sumcheck_section!(&mut reader, "stage4"); + let stage5_sumcheck_proof = read_sumcheck_section!(&mut reader, "stage5"); + let stage6_sumcheck_proof = read_sumcheck_section!(&mut reader, "stage6"); + let stage7_sumcheck_proof = read_sumcheck_section!(&mut reader, "stage7"); let joint_opening_proof = deserialize_joint_opening_proof_section::(&mut reader, compress, validate)?; #[cfg(feature = "zk")] - let blindfold_proof = read_single_section!(&mut reader); + let blindfold_proof = { + let section_bytes = + transport::read_section_bytes(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let mut cursor = Cursor::new(section_bytes.as_slice()); + let proof = deserialize_blindfold_proof_from_cursor(&mut cursor, compress, validate)?; + check_cursor_consumed("blindfold", &cursor)?; + proof + }; let mut eof_check = [0u8; 1]; match reader.read(&mut eof_check) { @@ -1022,6 +1367,8 @@ mod tests { use crate::poly::opening_proof::{OpeningId, SumcheckId}; use crate::zkvm::witness::{CommittedPolynomial, VirtualPolynomial}; use crate::zkvm::RV64IMACProof; + use crate::{curve::Bn254Curve, transcripts::Blake2bTranscript}; + use ark_bn254::Fr; #[test] fn opening_id_header_is_packed_common_case() { @@ -1131,6 +1478,37 @@ mod tests { assert!(matches!(res, Err(SerializationError::InvalidData))); } + #[test] + fn sumcheck_section_rejects_oversized_nested_vector() { + let mut bytes = Vec::new(); + 0u8.serialize_compressed(&mut bytes).unwrap(); + 1u64.serialize_compressed(&mut bytes).unwrap(); + + let mut cursor = Cursor::new(bytes.as_slice()); + let res = deserialize_sumcheck_instance_proof_from_cursor::< + Fr, + Bn254Curve, + Blake2bTranscript, + >(&mut cursor, Compress::Yes, Validate::Yes); + assert!(res.is_err()); + } + + #[cfg(feature = "zk")] + #[test] + fn blindfold_section_rejects_oversized_nested_vector() { + let mut bytes = Vec::new(); + Fr::from(0u64).serialize_compressed(&mut bytes).unwrap(); + 1u64.serialize_compressed(&mut bytes).unwrap(); + + let mut cursor = Cursor::new(bytes.as_slice()); + let res = deserialize_blindfold_proof_from_cursor::( + &mut cursor, + Compress::Yes, + Validate::Yes, + ); + assert!(res.is_err()); + } + #[test] fn proof_version_byte() { assert_eq!(PROOF_VERSION, 1); diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index 3a0f8419ce..6baf81b0bb 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -3534,6 +3534,8 @@ mod tests { ); let io_device = prover.program_io.clone(); let (proof, debug_info) = prover.prove(); + let proof_bytes = proof.serialize_to_bytes().unwrap(); + let proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); @@ -3587,6 +3589,8 @@ mod tests { ); let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); + let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); RV64IMACVerifier::new( diff --git a/jolt-core/src/zkvm/transport.rs b/jolt-core/src/zkvm/transport.rs index 110693e424..8192527b87 100644 --- a/jolt-core/src/zkvm/transport.rs +++ b/jolt-core/src/zkvm/transport.rs @@ -53,7 +53,7 @@ pub fn varint_u64_len(mut x: u64) -> usize { pub fn read_varint_u64(r: &mut R) -> io::Result { let mut x = 0u64; let mut shift = 0u32; - for _ in 0..VARINT_U64_MAX_BYTES { + for i in 0..VARINT_U64_MAX_BYTES { let mut b = [0u8; 1]; r.read_exact(&mut b)?; let byte = b[0]; @@ -66,6 +66,13 @@ pub fn read_varint_u64(r: &mut R) -> io::Result { } x |= payload << shift; if (byte & 0x80) == 0 { + let bytes_read = i + 1; + if bytes_read != varint_u64_len(x) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "non-canonical varint", + )); + } return Ok(x); } shift += 7; @@ -145,6 +152,16 @@ mod tests { assert!(res.is_err()); } + #[test] + fn noncanonical_varint_rejected() { + let bad = [0x80u8, 0x00]; + let res = read_varint_u64(&mut bad.as_slice()); + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("non-canonical varint")); + } + #[test] fn magic_version_roundtrip() { let magic = b"JOLT"; From c97fc3883bc42186a155e8e9a5997d98ff3883d3 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Wed, 1 Apr 2026 09:13:55 -0400 Subject: [PATCH 9/9] test(zkvm): avoid ZK roundtrip stack overflow --- jolt-core/src/zkvm/prover.rs | 87 +++++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 31 deletions(-) diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index 6baf81b0bb..a4d9ad9c70 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -2225,6 +2225,8 @@ mod tests { extern crate jolt_inlines_sha2; use std::sync::Arc; + #[cfg(feature = "zk")] + use std::thread; use ark_bn254::Fr; use serial_test::serial; @@ -2303,6 +2305,26 @@ mod tests { (commitment, hint) } + fn with_roundtrip_stack( + f: impl FnOnce() -> T + Send + 'static, + ) -> T { + #[cfg(feature = "zk")] + { + return thread::Builder::new() + .name("proof-roundtrip".to_string()) + .stack_size(32 * 1024 * 1024) + .spawn(f) + .unwrap() + .join() + .unwrap(); + } + + #[cfg(not(feature = "zk"))] + { + f() + } + } + #[test] #[serial] fn fib_e2e_dory() { @@ -2336,18 +2358,19 @@ mod tests { let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); - let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); - let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); - let verifier = RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device, - None, - debug_info, - ) - .expect("Failed to create verifier"); - verifier.verify().expect("Failed to verify proof"); + with_roundtrip_stack(move || { + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); + let verifier = RV64IMACVerifier::new( + &verifier_preprocessing, + jolt_proof, + io_device, + None, + debug_info, + ) + .expect("Failed to create verifier"); + verifier.verify().expect("Failed to verify proof"); + }); } #[test] @@ -3535,15 +3558,15 @@ mod tests { let io_device = prover.program_io.clone(); let (proof, debug_info) = prover.prove(); let proof_bytes = proof.serialize_to_bytes().unwrap(); - let proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); - let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); - - // DoryGlobals is now initialized inside the verifier's verify_stage8 - RV64IMACVerifier::new(&verifier_preprocessing, proof, io_device, None, debug_info) - .expect("verifier creation failed") - .verify() - .expect("verification failed"); + with_roundtrip_stack(move || { + let proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); + // DoryGlobals is now initialized inside the verifier's verify_stage8 + RV64IMACVerifier::new(&verifier_preprocessing, proof, io_device, None, debug_info) + .expect("verifier creation failed") + .verify() + .expect("verification failed"); + }); } #[test] @@ -3590,19 +3613,21 @@ mod tests { let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); - let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); - let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); - RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device.clone(), - Some(trusted_commitment), - debug_info, - ) - .expect("Failed to create verifier") - .verify() - .expect("Verification failed"); + let verifier_io_device = io_device.clone(); + with_roundtrip_stack(move || { + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); + RV64IMACVerifier::new( + &verifier_preprocessing, + jolt_proof, + verifier_io_device, + Some(trusted_commitment), + debug_info, + ) + .expect("Failed to create verifier") + .verify() + .expect("Verification failed"); + }); // Expected merkle root for leaves [5;32], [6;32], [7;32], [8;32] let expected_output = &[