diff --git a/jolt-core/src/zkvm/mod.rs b/jolt-core/src/zkvm/mod.rs index 92d23a1ee..f9cfadc5a 100644 --- a/jolt-core/src/zkvm/mod.rs +++ b/jolt-core/src/zkvm/mod.rs @@ -37,6 +37,7 @@ pub mod r1cs; pub mod ram; pub mod registers; pub mod spartan; +pub(crate) mod transport; pub mod verifier; pub mod witness; diff --git a/jolt-core/src/zkvm/proof_serialization.rs b/jolt-core/src/zkvm/proof_serialization.rs index 343881168..48d5d0f86 100644 --- a/jolt-core/src/zkvm/proof_serialization.rs +++ b/jolt-core/src/zkvm/proof_serialization.rs @@ -1,28 +1,39 @@ #[cfg(not(feature = "zk"))] use std::collections::BTreeMap; -use std::io::{Read, Write}; +use std::{ + any::TypeId, + io::{Cursor, Read, Write}, +}; use ark_serialize::{ CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate, }; use num::FromPrimitive; -use strum::EnumCount; +#[cfg(feature = "zk")] +use crate::poly::commitment::hyrax::HyraxOpeningProof; #[cfg(not(feature = "zk"))] use crate::poly::opening_proof::{OpeningPoint, Openings}; #[cfg(feature = "zk")] -use crate::subprotocols::blindfold::BlindFoldProof; +use crate::subprotocols::blindfold::{BlindFoldProof, RelaxedR1CSInstance}; use crate::{ - curve::JoltCurve, + curve::{JoltCurve, JoltGroupElement}, field::JoltField, poly::{ - commitment::{commitment_scheme::CommitmentScheme, dory::DoryLayout}, + commitment::{ + commitment_scheme::CommitmentScheme, + dory::{ArkG1, ArkG2, ArkGT, DoryCommitmentScheme, DoryLayout}, + }, opening_proof::{OpeningId, PolynomialId, SumcheckId}, + unipoly::{CompressedUniPoly, UniPoly}, }, }; use crate::{ subprotocols::{ - sumcheck::SumcheckInstanceProof, univariate_skip::UniSkipFirstRoundProofVariant, + sumcheck::{ClearSumcheckProof, SumcheckInstanceProof, ZkSumcheckProof}, + univariate_skip::{ + UniSkipFirstRoundProof, UniSkipFirstRoundProofVariant, ZkUniSkipFirstRoundProof, + }, }, transcripts::Transcript, zkvm::{ @@ -32,7 +43,18 @@ use crate::{ }, }; -#[derive(CanonicalSerialize, CanonicalDeserialize)] +use crate::zkvm::transport; + +const PROOF_MAGIC: &[u8; 4] = b"JOLT"; +const PROOF_VERSION: u8 = 1; +const PROOF_FLAGS_RESERVED_MASK: u8 = 0xFE; +const PROOF_FLAG_ZK: u8 = 0x01; + +const MAX_PARAMS_LEN: u64 = 1024; +const MAX_SECTION_LEN: u64 = 128 * 1024; +#[cfg(not(feature = "zk"))] +const MIN_OPENING_CLAIM_BYTES: u64 = 33; + pub struct JoltProof< F: JoltField, C: JoltCurve, @@ -62,9 +84,794 @@ pub struct JoltProof< pub dory_layout: DoryLayout, } +#[inline] +fn io_err(e: std::io::Error) -> SerializationError { + SerializationError::IoError(e) +} + +#[inline] +fn invalid_data(message: impl Into) -> SerializationError { + io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + message.into(), + )) +} + +fn ensure_section_len( + section_name: &str, + len: u64, + max_len: u64, +) -> Result<(), SerializationError> { + if len > max_len { + return Err(invalid_data(format!( + "{section_name} section size {len} exceeds cap {max_len}" + ))); + } + Ok(()) +} + +macro_rules! write_section { + ($w:expr, $c:expr, $($item:expr),+ $(,)?) => {{ + let len: u64 = 0 $(+ $item.serialized_size($c) as u64)+; + ensure_section_len("proof", len, MAX_SECTION_LEN)?; + transport::write_varint_u64($w, len).map_err(io_err)?; + $($item.serialize_with_mode($w, $c)?;)+ + }}; +} + +macro_rules! section_size { + ($c:expr, $($item:expr),+ $(,)?) => {{ + let payload: u64 = 0 $(+ $item.serialized_size($c) as u64)+; + transport::varint_u64_len(payload) + payload as usize + }}; +} + +fn check_trailing_bytes( + limited: &std::io::Take<&mut R>, +) -> Result<(), SerializationError> { + if limited.limit() != 0 { + return Err(invalid_data(format!( + "{} trailing bytes not consumed", + limited.limit() + ))); + } + Ok(()) +} + +fn check_cursor_consumed( + section_name: &str, + cursor: &Cursor<&[u8]>, +) -> Result<(), SerializationError> { + let trailing = cursor + .get_ref() + .len() + .saturating_sub(cursor.position() as usize); + if trailing != 0 { + return Err(invalid_data(format!( + "{trailing} trailing bytes not consumed in {section_name} section" + ))); + } + Ok(()) +} + +fn cursor_remaining_bytes(cursor: &Cursor<&[u8]>) -> usize { + cursor + .get_ref() + .len() + .saturating_sub(cursor.position() as usize) +} + +fn read_ark_seq_len( + cursor: &mut Cursor<&[u8]>, + min_elem_bytes: usize, + compress: Compress, + validate: Validate, +) -> Result { + let len_u64 = u64::deserialize_with_mode(&mut *cursor, compress, validate)?; + let len = usize::try_from(len_u64).map_err(|_| SerializationError::InvalidData)?; + let max_len = cursor_remaining_bytes(cursor) / min_elem_bytes.max(1); + if len > max_len { + return Err(invalid_data(format!( + "sequence length {len} exceeds remaining byte budget {max_len}" + ))); + } + Ok(len) +} + +fn read_bounded_ark_vec( + cursor: &mut Cursor<&[u8]>, + min_elem_bytes: usize, + compress: Compress, + validate: Validate, + mut read_item: ReadItem, +) -> Result, SerializationError> +where + ReadItem: FnMut(&mut Cursor<&[u8]>, Compress, Validate) -> Result, +{ + let len = read_ark_seq_len(cursor, min_elem_bytes, compress, validate)?; + let mut items = Vec::with_capacity(len); + for _ in 0..len { + items.push(read_item(cursor, compress, validate)?); + } + Ok(items) +} + +fn deserialize_field_vec( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + read_bounded_ark_vec( + cursor, + F::zero().serialized_size(compress), + compress, + validate, + |cursor, compress, validate| F::deserialize_with_mode(cursor, compress, validate), + ) +} + +fn deserialize_g1_vec>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + read_bounded_ark_vec( + cursor, + C::G1::zero().serialized_size(compress), + compress, + validate, + |cursor, compress, validate| C::G1::deserialize_with_mode(cursor, compress, validate), + ) +} + +fn deserialize_usize_vec( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + read_bounded_ark_vec( + cursor, + 0usize.serialized_size(compress), + compress, + validate, + |cursor, compress, validate| usize::deserialize_with_mode(cursor, compress, validate), + ) +} + +fn deserialize_uni_poly_from_cursor( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(UniPoly::from_coeff(deserialize_field_vec( + cursor, compress, validate, + )?)) +} + +fn deserialize_compressed_uni_poly_from_cursor( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(CompressedUniPoly { + coeffs_except_linear_term: deserialize_field_vec(cursor, compress, validate)?, + }) +} + +fn deserialize_clear_sumcheck_proof_from_cursor( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + let compressed_polys = read_bounded_ark_vec( + cursor, + 0u64.serialized_size(compress), + compress, + validate, + |cursor, compress, validate| { + deserialize_compressed_uni_poly_from_cursor::(cursor, compress, validate) + }, + )?; + Ok(ClearSumcheckProof::new(compressed_polys)) +} + +fn deserialize_zk_sumcheck_proof_from_cursor, FS: Transcript>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(ZkSumcheckProof::new( + deserialize_g1_vec::(cursor, compress, validate)?, + deserialize_usize_vec(cursor, compress, validate)?, + deserialize_g1_vec::(cursor, compress, validate)?, + )) +} + +fn deserialize_sumcheck_instance_proof_from_cursor< + F: JoltField, + C: JoltCurve, + FS: Transcript, +>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + match u8::deserialize_with_mode(&mut *cursor, compress, validate)? { + 0 => Ok(SumcheckInstanceProof::Clear( + deserialize_clear_sumcheck_proof_from_cursor::(cursor, compress, validate)?, + )), + 1 => Ok(SumcheckInstanceProof::Zk( + deserialize_zk_sumcheck_proof_from_cursor::(cursor, compress, validate)?, + )), + _ => Err(SerializationError::InvalidData), + } +} + +fn deserialize_uniskip_first_round_proof_from_cursor( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(UniSkipFirstRoundProof::new( + deserialize_uni_poly_from_cursor(cursor, compress, validate)?, + )) +} + +fn deserialize_zk_uniskip_first_round_proof_from_cursor< + F: JoltField, + C: JoltCurve, + FS: Transcript, +>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(ZkUniSkipFirstRoundProof::new( + C::G1::deserialize_with_mode(&mut *cursor, compress, validate)?, + usize::deserialize_with_mode(&mut *cursor, compress, validate)?, + deserialize_g1_vec::(cursor, compress, validate)?, + )) +} + +fn deserialize_uniskip_first_round_proof_variant_from_cursor< + F: JoltField, + C: JoltCurve, + FS: Transcript, +>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + match u8::deserialize_with_mode(&mut *cursor, compress, validate)? { + 0 => Ok(UniSkipFirstRoundProofVariant::Standard( + deserialize_uniskip_first_round_proof_from_cursor::(cursor, compress, validate)?, + )), + 1 => Ok(UniSkipFirstRoundProofVariant::Zk( + deserialize_zk_uniskip_first_round_proof_from_cursor::( + cursor, compress, validate, + )?, + )), + _ => Err(SerializationError::InvalidData), + } +} + +#[cfg(feature = "zk")] +fn deserialize_hyrax_opening_proof_from_cursor( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(HyraxOpeningProof { + combined_row: deserialize_field_vec(cursor, compress, validate)?, + combined_blinding: F::deserialize_with_mode(cursor, compress, validate)?, + }) +} + +#[cfg(feature = "zk")] +fn deserialize_relaxed_r1cs_instance_from_cursor>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(RelaxedR1CSInstance { + u: F::deserialize_with_mode(&mut *cursor, compress, validate)?, + round_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + output_claims_row_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + noncoeff_row_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + e_row_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + eval_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + }) +} + +#[cfg(feature = "zk")] +fn deserialize_blindfold_proof_from_cursor>( + cursor: &mut Cursor<&[u8]>, + compress: Compress, + validate: Validate, +) -> Result, SerializationError> { + Ok(BlindFoldProof { + random_instance: deserialize_relaxed_r1cs_instance_from_cursor::( + cursor, compress, validate, + )?, + noncoeff_row_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + cross_term_row_commitments: deserialize_g1_vec::(cursor, compress, validate)?, + spartan_proof: read_bounded_ark_vec( + cursor, + 0u64.serialized_size(compress), + compress, + validate, + |cursor, compress, validate| { + deserialize_compressed_uni_poly_from_cursor::(cursor, compress, validate) + }, + )?, + az_r: F::deserialize_with_mode(&mut *cursor, compress, validate)?, + bz_r: F::deserialize_with_mode(&mut *cursor, compress, validate)?, + cz_r: F::deserialize_with_mode(&mut *cursor, compress, validate)?, + inner_sumcheck_proof: read_bounded_ark_vec( + cursor, + 0u64.serialized_size(compress), + compress, + validate, + |cursor, compress, validate| { + deserialize_compressed_uni_poly_from_cursor::(cursor, compress, validate) + }, + )?, + w_opening: deserialize_hyrax_opening_proof_from_cursor(cursor, compress, validate)?, + e_opening: deserialize_hyrax_opening_proof_from_cursor(cursor, compress, validate)?, + folded_eval_outputs: deserialize_field_vec(cursor, compress, validate)?, + folded_eval_blindings: deserialize_field_vec(cursor, compress, validate)?, + }) +} + +fn ensure_sumcheck_mode, FS: Transcript>( + section_name: &str, + proof: &SumcheckInstanceProof, + proof_is_zk: bool, +) -> Result<(), SerializationError> { + if proof.is_zk() != proof_is_zk { + let mode = if proof_is_zk { "ZK" } else { "standard" }; + return Err(invalid_data(format!( + "{section_name} sumcheck proof mode does not match outer proof {mode} flag" + ))); + } + Ok(()) +} + +fn ensure_uniskip_mode, FS: Transcript>( + section_name: &str, + proof: &UniSkipFirstRoundProofVariant, + proof_is_zk: bool, +) -> Result<(), SerializationError> { + let uniskip_is_zk = matches!(proof, UniSkipFirstRoundProofVariant::Zk(_)); + if uniskip_is_zk != proof_is_zk { + let mode = if proof_is_zk { "ZK" } else { "standard" }; + return Err(invalid_data(format!( + "{section_name} uni-skip proof mode does not match outer proof {mode} flag" + ))); + } + Ok(()) +} + +fn ark_group_size(compress: Compress) -> usize { + T::default().serialized_size(compress) +} + +fn check_dory_round_count( + section_bytes: &[u8], + compress: Compress, + validate: Validate, +) -> Result<(), SerializationError> { + let mut cursor = Cursor::new(section_bytes); + ArkGT::deserialize_with_mode(&mut cursor, compress, validate)?; + ArkGT::deserialize_with_mode(&mut cursor, compress, validate)?; + ArkG1::deserialize_with_mode(&mut cursor, compress, validate)?; + + let num_rounds = u32::deserialize_with_mode(&mut cursor, compress, validate)? as usize; + + let first_round_bytes = 4 * ark_group_size::(compress) + + ark_group_size::(compress) + + ark_group_size::(compress); + let second_round_bytes = 2 * ark_group_size::(compress) + + 2 * ark_group_size::(compress) + + 2 * ark_group_size::(compress); + let min_tail_bytes = ark_group_size::(compress) + + ark_group_size::(compress) + + 2 * std::mem::size_of::(); + let remaining_bytes = section_bytes + .len() + .saturating_sub(cursor.position() as usize); + + if remaining_bytes < min_tail_bytes { + return Err(SerializationError::InvalidData); + } + + let per_round_bytes = first_round_bytes + second_round_bytes; + let max_rounds = remaining_bytes.saturating_sub(min_tail_bytes) / per_round_bytes; + if num_rounds > max_rounds { + return Err(invalid_data(format!( + "Dory opening proof declares {num_rounds} rounds but section only has room for {max_rounds}" + ))); + } + + Ok(()) +} + +fn deserialize_joint_opening_proof_section( + reader: &mut impl Read, + compress: Compress, + validate: Validate, +) -> Result { + let section_bytes = transport::read_section_bytes(reader, MAX_SECTION_LEN).map_err(io_err)?; + if TypeId::of::() == TypeId::of::() { + check_dory_round_count(§ion_bytes, compress, validate)?; + } + + let mut cursor = Cursor::new(section_bytes.as_slice()); + let proof = PCS::Proof::deserialize_with_mode(&mut cursor, compress, validate)?; + check_cursor_consumed("joint opening proof", &cursor)?; + Ok(proof) +} + +impl, PCS: CommitmentScheme, FS: Transcript> + JoltProof +{ + fn params_payload_len(&self, compress: Compress) -> u64 { + (transport::varint_u64_len(self.trace_length as u64) + + transport::varint_u64_len(self.ram_K as u64)) as u64 + + self.rw_config.serialized_size(compress) as u64 + + self.one_hot_config.serialized_size(compress) as u64 + + self.dory_layout.serialized_size(compress) as u64 + } + + fn commitments_payload_len(&self, compress: Compress) -> u64 { + let count_len = transport::varint_u64_len(self.commitments.len() as u64) as u64; + let items_len: u64 = self + .commitments + .iter() + .map(|c| c.serialized_size(compress) as u64) + .sum(); + let untrusted_len = self + .untrusted_advice_commitment + .as_ref() + .map(|c| c.serialized_size(compress) as u64) + .unwrap_or(0); + count_len + items_len + 1 + untrusted_len + } +} + +impl, PCS: CommitmentScheme, FS: Transcript> + CanonicalSerialize for JoltProof +{ + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + transport::write_magic_version(&mut writer, PROOF_MAGIC, PROOF_VERSION).map_err(io_err)?; + let flags: u8 = if cfg!(feature = "zk") { + PROOF_FLAG_ZK + } else { + 0 + }; + writer.write_all(&[flags]).map_err(io_err)?; + + let params_len = self.params_payload_len(compress); + ensure_section_len("params", params_len, MAX_PARAMS_LEN)?; + transport::write_varint_u64(&mut writer, params_len).map_err(io_err)?; + transport::write_varint_u64(&mut writer, self.trace_length as u64).map_err(io_err)?; + transport::write_varint_u64(&mut writer, self.ram_K as u64).map_err(io_err)?; + self.rw_config.serialize_with_mode(&mut writer, compress)?; + self.one_hot_config + .serialize_with_mode(&mut writer, compress)?; + self.dory_layout + .serialize_with_mode(&mut writer, compress)?; + + let commitments_len = self.commitments_payload_len(compress); + ensure_section_len("commitments", commitments_len, MAX_SECTION_LEN)?; + transport::write_varint_u64(&mut writer, commitments_len).map_err(io_err)?; + transport::write_varint_u64(&mut writer, self.commitments.len() as u64).map_err(io_err)?; + for c in &self.commitments { + c.serialize_with_mode(&mut writer, compress)?; + } + match &self.untrusted_advice_commitment { + None => writer.write_all(&[0]).map_err(io_err)?, + Some(c) => { + writer.write_all(&[1]).map_err(io_err)?; + c.serialize_with_mode(&mut writer, compress)?; + } + } + + #[cfg(not(feature = "zk"))] + write_section!(&mut writer, compress, &self.opening_claims); + + write_section!( + &mut writer, + compress, + &self.stage1_uni_skip_first_round_proof, + &self.stage1_sumcheck_proof + ); + write_section!( + &mut writer, + compress, + &self.stage2_uni_skip_first_round_proof, + &self.stage2_sumcheck_proof + ); + write_section!(&mut writer, compress, &self.stage3_sumcheck_proof); + write_section!(&mut writer, compress, &self.stage4_sumcheck_proof); + write_section!(&mut writer, compress, &self.stage5_sumcheck_proof); + write_section!(&mut writer, compress, &self.stage6_sumcheck_proof); + write_section!(&mut writer, compress, &self.stage7_sumcheck_proof); + write_section!(&mut writer, compress, &self.joint_opening_proof); + + #[cfg(feature = "zk")] + write_section!(&mut writer, compress, &self.blindfold_proof); + + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + let mut size = PROOF_MAGIC.len() + 1 + 1; // magic + version + flags + + let params_len = self.params_payload_len(compress); + size += transport::varint_u64_len(params_len) + params_len as usize; + + let commitments_len = self.commitments_payload_len(compress); + size += transport::varint_u64_len(commitments_len) + commitments_len as usize; + + #[cfg(not(feature = "zk"))] + { + size += section_size!(compress, &self.opening_claims); + } + + size += section_size!( + compress, + &self.stage1_uni_skip_first_round_proof, + &self.stage1_sumcheck_proof + ); + size += section_size!( + compress, + &self.stage2_uni_skip_first_round_proof, + &self.stage2_sumcheck_proof + ); + size += section_size!(compress, &self.stage3_sumcheck_proof); + size += section_size!(compress, &self.stage4_sumcheck_proof); + size += section_size!(compress, &self.stage5_sumcheck_proof); + size += section_size!(compress, &self.stage6_sumcheck_proof); + size += section_size!(compress, &self.stage7_sumcheck_proof); + size += section_size!(compress, &self.joint_opening_proof); + + #[cfg(feature = "zk")] + { + size += section_size!(compress, &self.blindfold_proof); + } + + size + } +} + +impl, PCS: CommitmentScheme, FS: Transcript> Valid + for JoltProof +{ + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl, PCS: CommitmentScheme, FS: Transcript> + CanonicalDeserialize for JoltProof +{ + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let version = transport::read_magic_version(&mut reader, PROOF_MAGIC).map_err(io_err)?; + if version != PROOF_VERSION { + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("unsupported proof version {version}, expected {PROOF_VERSION}"), + ))); + } + + let mut flags_buf = [0u8; 1]; + reader.read_exact(&mut flags_buf).map_err(io_err)?; + let flags = flags_buf[0]; + if flags & PROOF_FLAGS_RESERVED_MASK != 0 { + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("unknown proof flags bits set: {flags:#04x}"), + ))); + } + let proof_is_zk = flags & PROOF_FLAG_ZK != 0; + let compiled_for_zk = cfg!(feature = "zk"); + if proof_is_zk != compiled_for_zk { + let mode = if proof_is_zk { "ZK" } else { "standard" }; + let expected = if compiled_for_zk { "ZK" } else { "standard" }; + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("proof was serialized in {mode} mode but deserializer expects {expected}"), + ))); + } + + // Params + let mut limited = transport::read_section(&mut reader, MAX_PARAMS_LEN).map_err(io_err)?; + let t = transport::read_varint_u64(&mut limited).map_err(io_err)?; + let r = transport::read_varint_u64(&mut limited).map_err(io_err)?; + let trace_length = usize::try_from(t).map_err(|_| SerializationError::InvalidData)?; + if trace_length == 0 { + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "trace_length must be nonzero", + ))); + } + let ram_K = usize::try_from(r).map_err(|_| SerializationError::InvalidData)?; + let rw_config = ReadWriteConfig::deserialize_with_mode(&mut limited, compress, validate)?; + let one_hot_config = OneHotConfig::deserialize_with_mode(&mut limited, compress, validate)?; + let dory_layout = DoryLayout::deserialize_with_mode(&mut limited, compress, validate)?; + check_trailing_bytes(&limited)?; + + // Commitments + let commitments_bytes = + transport::read_section_bytes(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let mut limited = Cursor::new(commitments_bytes.as_slice()); + let n = transport::read_varint_u64(&mut limited).map_err(io_err)?; + let n_usize = usize::try_from(n).map_err(|_| SerializationError::InvalidData)?; + let remaining_bytes = commitments_bytes + .len() + .saturating_sub(limited.position() as usize); + let min_commitment_bytes = PCS::Commitment::default().serialized_size(compress).max(1); + let max_commitments = remaining_bytes.saturating_sub(1) / min_commitment_bytes; + if n_usize > max_commitments { + return Err(invalid_data(format!( + "commitments section count {n_usize} exceeds byte cap {max_commitments}" + ))); + } + let mut commitments = Vec::with_capacity(n_usize); + for _ in 0..n_usize { + commitments.push(PCS::Commitment::deserialize_with_mode( + &mut limited, + compress, + validate, + )?); + } + let presence = u8::deserialize_with_mode(&mut limited, compress, validate)?; + let untrusted_advice_commitment = match presence { + 0 => None, + 1 => Some(PCS::Commitment::deserialize_with_mode( + &mut limited, + compress, + validate, + )?), + _ => return Err(SerializationError::InvalidData), + }; + check_cursor_consumed("commitments", &limited)?; + + #[cfg(not(feature = "zk"))] + let opening_claims: Claims = { + let mut limited = + transport::read_section(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let claims = + CanonicalDeserialize::deserialize_with_mode(&mut limited, compress, validate)?; + check_trailing_bytes(&limited)?; + claims + }; + + // Stage 1 + let stage1_bytes = + transport::read_section_bytes(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let mut stage1_cursor = Cursor::new(stage1_bytes.as_slice()); + let stage1_uni_skip_first_round_proof = + deserialize_uniskip_first_round_proof_variant_from_cursor( + &mut stage1_cursor, + compress, + validate, + )?; + let stage1_sumcheck_proof = deserialize_sumcheck_instance_proof_from_cursor( + &mut stage1_cursor, + compress, + validate, + )?; + check_cursor_consumed("stage1", &stage1_cursor)?; + ensure_uniskip_mode("stage1", &stage1_uni_skip_first_round_proof, proof_is_zk)?; + ensure_sumcheck_mode("stage1", &stage1_sumcheck_proof, proof_is_zk)?; + + // Stage 2 + let stage2_bytes = + transport::read_section_bytes(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let mut stage2_cursor = Cursor::new(stage2_bytes.as_slice()); + let stage2_uni_skip_first_round_proof = + deserialize_uniskip_first_round_proof_variant_from_cursor( + &mut stage2_cursor, + compress, + validate, + )?; + let stage2_sumcheck_proof = deserialize_sumcheck_instance_proof_from_cursor( + &mut stage2_cursor, + compress, + validate, + )?; + check_cursor_consumed("stage2", &stage2_cursor)?; + ensure_uniskip_mode("stage2", &stage2_uni_skip_first_round_proof, proof_is_zk)?; + ensure_sumcheck_mode("stage2", &stage2_sumcheck_proof, proof_is_zk)?; + + macro_rules! read_sumcheck_section { + ($reader:expr, $section_name:literal) => {{ + let section_bytes = + transport::read_section_bytes($reader, MAX_SECTION_LEN).map_err(io_err)?; + let mut cursor = Cursor::new(section_bytes.as_slice()); + let proof = deserialize_sumcheck_instance_proof_from_cursor( + &mut cursor, + compress, + validate, + )?; + check_cursor_consumed($section_name, &cursor)?; + ensure_sumcheck_mode($section_name, &proof, proof_is_zk)?; + proof + }}; + } + + let stage3_sumcheck_proof = read_sumcheck_section!(&mut reader, "stage3"); + let stage4_sumcheck_proof = read_sumcheck_section!(&mut reader, "stage4"); + let stage5_sumcheck_proof = read_sumcheck_section!(&mut reader, "stage5"); + let stage6_sumcheck_proof = read_sumcheck_section!(&mut reader, "stage6"); + let stage7_sumcheck_proof = read_sumcheck_section!(&mut reader, "stage7"); + let joint_opening_proof = + deserialize_joint_opening_proof_section::(&mut reader, compress, validate)?; + + #[cfg(feature = "zk")] + let blindfold_proof = { + let section_bytes = + transport::read_section_bytes(&mut reader, MAX_SECTION_LEN).map_err(io_err)?; + let mut cursor = Cursor::new(section_bytes.as_slice()); + let proof = deserialize_blindfold_proof_from_cursor(&mut cursor, compress, validate)?; + check_cursor_consumed("blindfold", &cursor)?; + proof + }; + + let mut eof_check = [0u8; 1]; + match reader.read(&mut eof_check) { + Ok(0) => {} + Ok(_) => { + return Err(io_err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unexpected trailing bytes after proof", + ))); + } + Err(e) => return Err(io_err(e)), + } + + Ok(Self { + commitments, + stage1_uni_skip_first_round_proof, + stage1_sumcheck_proof, + stage2_uni_skip_first_round_proof, + stage2_sumcheck_proof, + stage3_sumcheck_proof, + stage4_sumcheck_proof, + stage5_sumcheck_proof, + stage6_sumcheck_proof, + stage7_sumcheck_proof, + #[cfg(feature = "zk")] + blindfold_proof, + joint_opening_proof, + untrusted_advice_commitment, + #[cfg(not(feature = "zk"))] + opening_claims, + trace_length, + ram_K, + rw_config, + one_hot_config, + dory_layout, + }) + } +} + #[cfg(not(feature = "zk"))] pub struct Claims(pub Openings); +#[cfg(not(feature = "zk"))] +const MAX_CLAIMS_COUNT: u64 = MAX_SECTION_LEN / MIN_OPENING_CLAIM_BYTES; + #[cfg(not(feature = "zk"))] impl CanonicalSerialize for Claims { fn serialize_with_mode( @@ -72,7 +879,7 @@ impl CanonicalSerialize for Claims { mut writer: W, compress: Compress, ) -> Result<(), SerializationError> { - self.0.len().serialize_with_mode(&mut writer, compress)?; + transport::write_varint_u64(&mut writer, self.0.len() as u64).map_err(io_err)?; for (key, (_opening_point, claim)) in self.0.iter() { key.serialize_with_mode(&mut writer, compress)?; claim.serialize_with_mode(&mut writer, compress)?; @@ -81,7 +888,7 @@ impl CanonicalSerialize for Claims { } fn serialized_size(&self, compress: Compress) -> usize { - let mut size = self.0.len().serialized_size(compress); + let mut size = transport::varint_u64_len(self.0.len() as u64); for (key, (_opening_point, claim)) in self.0.iter() { size += key.serialized_size(compress); size += claim.serialized_size(compress); @@ -104,12 +911,21 @@ impl CanonicalDeserialize for Claims { compress: Compress, validate: Validate, ) -> Result { - let size = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let n = transport::read_varint_u64(&mut reader).map_err(io_err)?; + let n_usize = usize::try_from(n).map_err(|_| SerializationError::InvalidData)?; + if n > MAX_CLAIMS_COUNT { + return Err(SerializationError::InvalidData); + } let mut claims = BTreeMap::new(); - for _ in 0..size { + for _ in 0..n_usize { let key = OpeningId::deserialize_with_mode(&mut reader, compress, validate)?; let claim = F::deserialize_with_mode(&mut reader, compress, validate)?; - claims.insert(key, (OpeningPoint::default(), claim)); + if claims + .insert(key, (OpeningPoint::default(), claim)) + .is_some() + { + return Err(SerializationError::InvalidData); + } } Ok(Claims(claims)) } @@ -149,17 +965,22 @@ impl CanonicalDeserialize for DoryLayout { } } -// Compact encoding for OpeningId: -// Each variant uses a fused byte = BASE + sumcheck_id (1 byte total for advice, 2 bytes for committed/virtual) -// - [0, NUM_SUMCHECKS) = UntrustedAdvice(sumcheck_id) -// - [NUM_SUMCHECKS, 2*NUM_SUMCHECKS) = TrustedAdvice(sumcheck_id) -// - [2*NUM_SUMCHECKS, 3*NUM_SUMCHECKS) + poly_index = Committed(poly, sumcheck_id) -// - [3*NUM_SUMCHECKS, 4*NUM_SUMCHECKS) + poly_index = Virtual(poly, sumcheck_id) -const OPENING_ID_UNTRUSTED_ADVICE_BASE: u8 = 0; -const OPENING_ID_TRUSTED_ADVICE_BASE: u8 = - OPENING_ID_UNTRUSTED_ADVICE_BASE + SumcheckId::COUNT as u8; -const OPENING_ID_COMMITTED_BASE: u8 = OPENING_ID_TRUSTED_ADVICE_BASE + SumcheckId::COUNT as u8; -const OPENING_ID_VIRTUAL_BASE: u8 = OPENING_ID_COMMITTED_BASE + SumcheckId::COUNT as u8; +// OpeningId wire encoding (packed, self-describing): +// +// Header byte: +// bits[7..6] = kind (0..=3) +// bits[5..0] = sumcheck_id if < 63, otherwise 63 as an escape +// +// If bits[5..0] == 63: +// next bytes = varint u64 sumcheck_id (must fit in u8 for current SumcheckId) +// +// If kind indicates a polynomial: +// append polynomial id bytes (CommittedPolynomial / VirtualPolynomial) +const OPENING_ID_KIND_UNTRUSTED_ADVICE: u8 = 0; +const OPENING_ID_KIND_TRUSTED_ADVICE: u8 = 1; +const OPENING_ID_KIND_COMMITTED: u8 = 2; +const OPENING_ID_KIND_VIRTUAL: u8 = 3; +const OPENING_ID_SUMCHECK_ESCAPE: u8 = 63; impl CanonicalSerialize for OpeningId { fn serialize_with_mode( @@ -167,38 +988,61 @@ impl CanonicalSerialize for OpeningId { mut writer: W, compress: Compress, ) -> Result<(), SerializationError> { - match self { + let (kind, sumcheck_u64, poly) = match *self { OpeningId::UntrustedAdvice(sumcheck_id) => { - let fused = OPENING_ID_UNTRUSTED_ADVICE_BASE + (*sumcheck_id as u8); - fused.serialize_with_mode(&mut writer, compress) + (OPENING_ID_KIND_UNTRUSTED_ADVICE, sumcheck_id as u64, None) } OpeningId::TrustedAdvice(sumcheck_id) => { - let fused = OPENING_ID_TRUSTED_ADVICE_BASE + (*sumcheck_id as u8); - fused.serialize_with_mode(&mut writer, compress) + (OPENING_ID_KIND_TRUSTED_ADVICE, sumcheck_id as u64, None) } - OpeningId::Polynomial(PolynomialId::Committed(committed_polynomial), sumcheck_id) => { - let fused = OPENING_ID_COMMITTED_BASE + (*sumcheck_id as u8); - fused.serialize_with_mode(&mut writer, compress)?; - committed_polynomial.serialize_with_mode(&mut writer, compress) - } - OpeningId::Polynomial(PolynomialId::Virtual(virtual_polynomial), sumcheck_id) => { - let fused = OPENING_ID_VIRTUAL_BASE + (*sumcheck_id as u8); - fused.serialize_with_mode(&mut writer, compress)?; - virtual_polynomial.serialize_with_mode(&mut writer, compress) + OpeningId::Polynomial(PolynomialId::Committed(committed_polynomial), sumcheck_id) => ( + OPENING_ID_KIND_COMMITTED, + sumcheck_id as u64, + Some(PolynomialId::Committed(committed_polynomial)), + ), + OpeningId::Polynomial(PolynomialId::Virtual(virtual_polynomial), sumcheck_id) => ( + OPENING_ID_KIND_VIRTUAL, + sumcheck_id as u64, + Some(PolynomialId::Virtual(virtual_polynomial)), + ), + }; + + let header = if sumcheck_u64 < OPENING_ID_SUMCHECK_ESCAPE as u64 { + (kind << 6) | (sumcheck_u64 as u8) + } else { + (kind << 6) | OPENING_ID_SUMCHECK_ESCAPE + }; + header.serialize_with_mode(&mut writer, compress)?; + if (header & 0x3F) == OPENING_ID_SUMCHECK_ESCAPE { + transport::write_varint_u64(&mut writer, sumcheck_u64).map_err(io_err)?; + } + if let Some(poly) = poly { + match poly { + PolynomialId::Committed(p) => p.serialize_with_mode(&mut writer, compress)?, + PolynomialId::Virtual(p) => p.serialize_with_mode(&mut writer, compress)?, } } + Ok(()) } fn serialized_size(&self, compress: Compress) -> usize { - match self { - OpeningId::UntrustedAdvice(_) | OpeningId::TrustedAdvice(_) => 1, - OpeningId::Polynomial(PolynomialId::Committed(committed_polynomial), _) => { - 1 + committed_polynomial.serialized_size(compress) - } - OpeningId::Polynomial(PolynomialId::Virtual(virtual_polynomial), _) => { - 1 + virtual_polynomial.serialized_size(compress) - } - } + let sumcheck_u64 = match self { + OpeningId::UntrustedAdvice(sumcheck_id) + | OpeningId::TrustedAdvice(sumcheck_id) + | OpeningId::Polynomial(_, sumcheck_id) => *sumcheck_id as u64, + }; + let header_len = 1usize; + let sumcheck_ext_len = if sumcheck_u64 < OPENING_ID_SUMCHECK_ESCAPE as u64 { + 0usize + } else { + transport::varint_u64_len(sumcheck_u64) + }; + let poly_len = match self { + OpeningId::UntrustedAdvice(_) | OpeningId::TrustedAdvice(_) => 0, + OpeningId::Polynomial(PolynomialId::Committed(p), _) => p.serialized_size(compress), + OpeningId::Polynomial(PolynomialId::Virtual(p), _) => p.serialized_size(compress), + }; + header_len + sumcheck_ext_len + poly_len } } @@ -214,38 +1058,39 @@ impl CanonicalDeserialize for OpeningId { compress: Compress, validate: Validate, ) -> Result { - let fused = u8::deserialize_with_mode(&mut reader, compress, validate)?; - match fused { - _ if fused < OPENING_ID_TRUSTED_ADVICE_BASE => { - let sumcheck_id = fused - OPENING_ID_UNTRUSTED_ADVICE_BASE; - Ok(OpeningId::UntrustedAdvice( - SumcheckId::from_u8(sumcheck_id).ok_or(SerializationError::InvalidData)?, - )) - } - _ if fused < OPENING_ID_COMMITTED_BASE => { - let sumcheck_id = fused - OPENING_ID_TRUSTED_ADVICE_BASE; - Ok(OpeningId::TrustedAdvice( - SumcheckId::from_u8(sumcheck_id).ok_or(SerializationError::InvalidData)?, - )) + let header = u8::deserialize_with_mode(&mut reader, compress, validate)?; + let kind = header >> 6; + let small = header & 0x3F; + + let sumcheck_u64 = if small == OPENING_ID_SUMCHECK_ESCAPE { + let sumcheck_u64 = transport::read_varint_u64(&mut reader).map_err(io_err)?; + if sumcheck_u64 < OPENING_ID_SUMCHECK_ESCAPE as u64 { + return Err(SerializationError::InvalidData); } - _ if fused < OPENING_ID_VIRTUAL_BASE => { - let sumcheck_id = fused - OPENING_ID_COMMITTED_BASE; + sumcheck_u64 + } else { + small as u64 + }; + + let sumcheck_u8 = + u8::try_from(sumcheck_u64).map_err(|_| SerializationError::InvalidData)?; + let sumcheck_id = + SumcheckId::from_u8(sumcheck_u8).ok_or(SerializationError::InvalidData)?; + + match kind { + OPENING_ID_KIND_UNTRUSTED_ADVICE => Ok(OpeningId::UntrustedAdvice(sumcheck_id)), + OPENING_ID_KIND_TRUSTED_ADVICE => Ok(OpeningId::TrustedAdvice(sumcheck_id)), + OPENING_ID_KIND_COMMITTED => { let polynomial = CommittedPolynomial::deserialize_with_mode(&mut reader, compress, validate)?; - Ok(OpeningId::committed( - polynomial, - SumcheckId::from_u8(sumcheck_id).ok_or(SerializationError::InvalidData)?, - )) + Ok(OpeningId::committed(polynomial, sumcheck_id)) } - _ => { - let sumcheck_id = fused - OPENING_ID_VIRTUAL_BASE; + OPENING_ID_KIND_VIRTUAL => { let polynomial = VirtualPolynomial::deserialize_with_mode(&mut reader, compress, validate)?; - Ok(OpeningId::virt( - polynomial, - SumcheckId::from_u8(sumcheck_id).ok_or(SerializationError::InvalidData)?, - )) + Ok(OpeningId::virt(polynomial, sumcheck_id)) } + _ => Err(SerializationError::InvalidData), } } } @@ -515,3 +1360,231 @@ pub fn serialize_and_print_size( tracing::info!("{item_name} size: {file_size_kb:.1} kB"); Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::poly::opening_proof::{OpeningId, SumcheckId}; + use crate::zkvm::witness::{CommittedPolynomial, VirtualPolynomial}; + use crate::zkvm::RV64IMACProof; + use crate::{curve::Bn254Curve, transcripts::Blake2bTranscript}; + use ark_bn254::Fr; + + #[test] + fn opening_id_header_is_packed_common_case() { + let id = OpeningId::UntrustedAdvice(SumcheckId::SpartanOuter); + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + assert_eq!(bytes.len(), 1); + let expected = (OPENING_ID_KIND_UNTRUSTED_ADVICE << 6) | (SumcheckId::SpartanOuter as u8); + assert_eq!(bytes[0], expected); + + let id = OpeningId::committed(CommittedPolynomial::RdInc, SumcheckId::SpartanOuter); + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + assert!(bytes.len() >= 2); + assert_eq!(bytes[0] >> 6, OPENING_ID_KIND_COMMITTED); + assert_eq!(bytes[0] & 0x3F, SumcheckId::SpartanOuter as u8); + + let id = OpeningId::virt(VirtualPolynomial::PC, SumcheckId::SpartanOuter); + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + assert!(bytes.len() >= 2); + assert_eq!(bytes[0] >> 6, OPENING_ID_KIND_VIRTUAL); + assert_eq!(bytes[0] & 0x3F, SumcheckId::SpartanOuter as u8); + } + + #[test] + fn opening_id_roundtrips() { + use crate::zkvm::instruction::{CircuitFlags, InstructionFlags}; + + let sumcheck_ids = [ + SumcheckId::SpartanOuter, + SumcheckId::RamReadWriteChecking, + SumcheckId::HammingWeightClaimReduction, + ]; + + let committed_polys = [ + CommittedPolynomial::RdInc, + CommittedPolynomial::RamInc, + CommittedPolynomial::InstructionRa(0), + CommittedPolynomial::InstructionRa(7), + CommittedPolynomial::BytecodeRa(0), + CommittedPolynomial::RamRa(0), + CommittedPolynomial::TrustedAdvice, + CommittedPolynomial::UntrustedAdvice, + ]; + + let virtual_polys = [ + VirtualPolynomial::PC, + VirtualPolynomial::NextPC, + VirtualPolynomial::UnivariateSkip, + VirtualPolynomial::InstructionRa(0), + VirtualPolynomial::InstructionRa(5), + VirtualPolynomial::OpFlags(CircuitFlags::AddOperands), + VirtualPolynomial::OpFlags(CircuitFlags::IsLastInSequence), + VirtualPolynomial::InstructionFlags(InstructionFlags::LeftOperandIsPC), + VirtualPolynomial::InstructionFlags(InstructionFlags::IsNoop), + VirtualPolynomial::LookupTableFlag(0), + VirtualPolynomial::LookupTableFlag(3), + ]; + + for &sc in &sumcheck_ids { + let id = OpeningId::UntrustedAdvice(sc); + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + assert_eq!( + OpeningId::deserialize_compressed(bytes.as_slice()).unwrap(), + id + ); + + let id = OpeningId::TrustedAdvice(sc); + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + assert_eq!( + OpeningId::deserialize_compressed(bytes.as_slice()).unwrap(), + id + ); + + for &cp in &committed_polys { + let id = OpeningId::committed(cp, sc); + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + assert_eq!( + OpeningId::deserialize_compressed(bytes.as_slice()).unwrap(), + id + ); + } + + for &vp in &virtual_polys { + let id = OpeningId::virt(vp, sc); + let mut bytes = Vec::new(); + id.serialize_compressed(&mut bytes).unwrap(); + assert_eq!( + OpeningId::deserialize_compressed(bytes.as_slice()).unwrap(), + id + ); + } + } + } + + #[test] + fn opening_id_rejects_noncanonical_escape_encoding() { + let bytes = [ + (OPENING_ID_KIND_UNTRUSTED_ADVICE << 6) | OPENING_ID_SUMCHECK_ESCAPE, + 0, + ]; + let res = OpeningId::deserialize_compressed(bytes.as_slice()); + assert!(matches!(res, Err(SerializationError::InvalidData))); + } + + #[test] + fn sumcheck_section_rejects_oversized_nested_vector() { + let mut bytes = Vec::new(); + 0u8.serialize_compressed(&mut bytes).unwrap(); + 1u64.serialize_compressed(&mut bytes).unwrap(); + + let mut cursor = Cursor::new(bytes.as_slice()); + let res = deserialize_sumcheck_instance_proof_from_cursor::< + Fr, + Bn254Curve, + Blake2bTranscript, + >(&mut cursor, Compress::Yes, Validate::Yes); + assert!(res.is_err()); + } + + #[cfg(feature = "zk")] + #[test] + fn blindfold_section_rejects_oversized_nested_vector() { + let mut bytes = Vec::new(); + Fr::from(0u64).serialize_compressed(&mut bytes).unwrap(); + 1u64.serialize_compressed(&mut bytes).unwrap(); + + let mut cursor = Cursor::new(bytes.as_slice()); + let res = deserialize_blindfold_proof_from_cursor::( + &mut cursor, + Compress::Yes, + Validate::Yes, + ); + assert!(res.is_err()); + } + + #[test] + fn proof_version_byte() { + assert_eq!(PROOF_VERSION, 1); + } + + #[test] + fn proof_magic_required() { + let mut just_magic = Vec::new(); + transport::write_magic_version(&mut just_magic, PROOF_MAGIC, PROOF_VERSION).unwrap(); + let res = RV64IMACProof::deserialize_with_mode( + std::io::Cursor::new(&just_magic), + Compress::Yes, + Validate::Yes, + ) + .map(|_| ()); + assert!(res.is_err()); + } + + #[test] + fn wrong_version_rejected() { + let mut buf = Vec::new(); + transport::write_magic_version(&mut buf, PROOF_MAGIC, 99).unwrap(); + // Append enough zeros to avoid EOF on the version read + buf.extend_from_slice(&[0u8; 100]); + let res = RV64IMACProof::deserialize_with_mode( + std::io::Cursor::new(&buf), + Compress::Yes, + Validate::Yes, + ) + .map(|_| ()); + match res { + Err(SerializationError::IoError(e)) => { + assert!( + e.to_string().contains("unsupported proof version"), + "unexpected error: {e}" + ); + } + other => panic!("expected IoError with version message, got {other:?}"), + } + } + + #[test] + fn wrong_magic_rejected() { + let mut buf = Vec::new(); + transport::write_magic_version(&mut buf, b"BAAD", 1).unwrap(); + buf.extend_from_slice(&[0u8; 100]); + let res = RV64IMACProof::deserialize_with_mode( + std::io::Cursor::new(&buf), + Compress::Yes, + Validate::Yes, + ) + .map(|_| ()); + match res { + Err(SerializationError::IoError(e)) => { + assert!( + e.to_string().contains("invalid proof magic"), + "unexpected error: {e}" + ); + } + other => panic!("expected IoError with magic message, got {other:?}"), + } + } + + #[test] + fn dory_round_count_sanity_check_rejects_oversized_count() { + let mut bytes = Vec::new(); + ArkGT::default().serialize_compressed(&mut bytes).unwrap(); + ArkGT::default().serialize_compressed(&mut bytes).unwrap(); + ArkG1::default().serialize_compressed(&mut bytes).unwrap(); + 1_000u32.serialize_compressed(&mut bytes).unwrap(); + ArkG1::default().serialize_compressed(&mut bytes).unwrap(); + ArkG2::default().serialize_compressed(&mut bytes).unwrap(); + 0u32.serialize_compressed(&mut bytes).unwrap(); + 0u32.serialize_compressed(&mut bytes).unwrap(); + + let res = check_dory_round_count(&bytes, Compress::Yes, Validate::Yes); + assert!(res.is_err()); + } +} diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index 8d58def45..a4d9ad9c7 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -2225,6 +2225,8 @@ mod tests { extern crate jolt_inlines_sha2; use std::sync::Arc; + #[cfg(feature = "zk")] + use std::thread; use ark_bn254::Fr; use serial_test::serial; @@ -2249,7 +2251,7 @@ mod tests { prover::JoltProverPreprocessing, ram::populate_memory_states, verifier::{JoltVerifier, JoltVerifierPreprocessing}, - RV64IMACProver, RV64IMACVerifier, + RV64IMACProof, RV64IMACProver, RV64IMACVerifier, Serializable, }; #[cfg(feature = "zk")] use crate::{curve::JoltCurve, field::JoltField}; @@ -2303,6 +2305,26 @@ mod tests { (commitment, hint) } + fn with_roundtrip_stack( + f: impl FnOnce() -> T + Send + 'static, + ) -> T { + #[cfg(feature = "zk")] + { + return thread::Builder::new() + .name("proof-roundtrip".to_string()) + .stack_size(32 * 1024 * 1024) + .spawn(f) + .unwrap() + .join() + .unwrap(); + } + + #[cfg(not(feature = "zk"))] + { + f() + } + } + #[test] #[serial] fn fib_e2e_dory() { @@ -2335,17 +2357,20 @@ mod tests { ); let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); - + let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); - let verifier = RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device, - None, - debug_info, - ) - .expect("Failed to create verifier"); - verifier.verify().expect("Failed to verify proof"); + with_roundtrip_stack(move || { + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); + let verifier = RV64IMACVerifier::new( + &verifier_preprocessing, + jolt_proof, + io_device, + None, + debug_info, + ) + .expect("Failed to create verifier"); + verifier.verify().expect("Failed to verify proof"); + }); } #[test] @@ -2437,6 +2462,8 @@ mod tests { ); let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); + let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); let verifier = RV64IMACVerifier::new( @@ -2561,6 +2588,8 @@ mod tests { ); let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); + let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); RV64IMACVerifier::new( @@ -2688,6 +2717,8 @@ mod tests { ); let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); + let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); RV64IMACVerifier::new( @@ -2935,6 +2966,8 @@ mod tests { ); let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); + let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); let verifier = RV64IMACVerifier::new( @@ -3524,14 +3557,16 @@ mod tests { ); let io_device = prover.program_io.clone(); let (proof, debug_info) = prover.prove(); - + let proof_bytes = proof.serialize_to_bytes().unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); - - // DoryGlobals is now initialized inside the verifier's verify_stage8 - RV64IMACVerifier::new(&verifier_preprocessing, proof, io_device, None, debug_info) - .expect("verifier creation failed") - .verify() - .expect("verification failed"); + with_roundtrip_stack(move || { + let proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); + // DoryGlobals is now initialized inside the verifier's verify_stage8 + RV64IMACVerifier::new(&verifier_preprocessing, proof, io_device, None, debug_info) + .expect("verifier creation failed") + .verify() + .expect("verification failed"); + }); } #[test] @@ -3577,18 +3612,22 @@ mod tests { ); let io_device = prover.program_io.clone(); let (jolt_proof, debug_info) = prover.prove(); - + let proof_bytes = jolt_proof.serialize_to_bytes().unwrap(); let verifier_preprocessing = JoltVerifierPreprocessing::from(&prover_preprocessing); - RV64IMACVerifier::new( - &verifier_preprocessing, - jolt_proof, - io_device.clone(), - Some(trusted_commitment), - debug_info, - ) - .expect("Failed to create verifier") - .verify() - .expect("Verification failed"); + let verifier_io_device = io_device.clone(); + with_roundtrip_stack(move || { + let jolt_proof = RV64IMACProof::deserialize_from_bytes(&proof_bytes).unwrap(); + RV64IMACVerifier::new( + &verifier_preprocessing, + jolt_proof, + verifier_io_device, + Some(trusted_commitment), + debug_info, + ) + .expect("Failed to create verifier") + .verify() + .expect("Verification failed"); + }); // Expected merkle root for leaves [5;32], [6;32], [7;32], [8;32] let expected_output = &[ diff --git a/jolt-core/src/zkvm/transport.rs b/jolt-core/src/zkvm/transport.rs new file mode 100644 index 000000000..8192527b8 --- /dev/null +++ b/jolt-core/src/zkvm/transport.rs @@ -0,0 +1,210 @@ +//! Lightweight length-prefixed transport encoding helpers. +//! +//! Wire format: `[magic: 4B][version: 1B][flags: 1B][section₀][section₁]…` +//! where each section is `[varint payload_len][payload bytes]`. +//! Sections are sequential and untagged; the deserializer reads them in a +//! fixed order defined by the proof schema. + +use std::io::{self, Read, Write}; + +const VARINT_U64_MAX_BYTES: usize = 10; + +#[inline] +pub fn write_magic_version(w: &mut W, magic: &[u8], version: u8) -> io::Result<()> { + w.write_all(magic)?; + w.write_all(&[version]) +} + +#[inline] +pub fn read_magic_version(r: &mut R, magic: &[u8; 4]) -> io::Result { + let mut buf = [0u8; 4]; + r.read_exact(&mut buf)?; + if buf != *magic { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid proof magic", + )); + } + let mut v = [0u8; 1]; + r.read_exact(&mut v)?; + Ok(v[0]) +} + +#[inline] +pub fn write_varint_u64(w: &mut W, mut x: u64) -> io::Result<()> { + while x >= 0x80 { + w.write_all(&[((x as u8) & 0x7F) | 0x80])?; + x >>= 7; + } + w.write_all(&[x as u8]) +} + +#[inline] +pub fn varint_u64_len(mut x: u64) -> usize { + let mut n = 1usize; + while x >= 0x80 { + n += 1; + x >>= 7; + } + n +} + +#[inline] +pub fn read_varint_u64(r: &mut R) -> io::Result { + let mut x = 0u64; + let mut shift = 0u32; + for i in 0..VARINT_U64_MAX_BYTES { + let mut b = [0u8; 1]; + r.read_exact(&mut b)?; + let byte = b[0]; + let payload = (byte & 0x7F) as u64; + if shift == 63 && payload > 1 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "varint overflow", + )); + } + x |= payload << shift; + if (byte & 0x80) == 0 { + let bytes_read = i + 1; + if bytes_read != varint_u64_len(x) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "non-canonical varint", + )); + } + return Ok(x); + } + shift += 7; + } + Err(io::Error::new( + io::ErrorKind::InvalidData, + "varint overflow", + )) +} + +/// Reads a varint-prefixed section, enforcing a maximum payload length. +/// Returns a `Take` reader limited to exactly the declared payload length. +/// Callers should check `limited.limit() == 0` after reading to detect trailing bytes. +#[inline] +pub fn read_section(r: &mut R, max_len: u64) -> io::Result> { + let len = read_varint_u64(r)?; + if len > max_len { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "section too large", + )); + } + Ok(r.take(len)) +} + +#[inline] +pub fn read_section_bytes(r: &mut R, max_len: u64) -> io::Result> { + let mut limited = read_section(r, max_len)?; + let mut bytes = vec![0u8; limited.limit() as usize]; + limited.read_exact(&mut bytes)?; + Ok(bytes) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn varint_roundtrip_edge_cases() { + let cases: &[u64] = &[0, 1, 127, 128, 16383, 16384, u32::MAX as u64, u64::MAX]; + for &val in cases { + let mut buf = Vec::new(); + write_varint_u64(&mut buf, val).unwrap(); + let decoded = read_varint_u64(&mut buf.as_slice()).unwrap(); + assert_eq!(decoded, val, "roundtrip failed for {val}"); + } + } + + #[test] + fn varint_u64_len_matches_encoding() { + let cases: &[u64] = &[0, 1, 127, 128, 16383, 16384, u32::MAX as u64, u64::MAX]; + for &val in cases { + let mut buf = Vec::new(); + write_varint_u64(&mut buf, val).unwrap(); + assert_eq!( + buf.len(), + varint_u64_len(val), + "varint_u64_len mismatch for {val}" + ); + } + } + + #[test] + fn varint_overflow_rejected() { + // 11 continuation bytes — exceeds VARINT_U64_MAX_BYTES + let bad = vec![0x80u8; 11]; + let res = read_varint_u64(&mut bad.as_slice()); + assert!(res.is_err()); + } + + #[test] + fn varint_10th_byte_overflow_rejected() { + // 9 continuation bytes + a 10th byte with payload 2 (only 0 or 1 is valid at shift=63) + let mut bad = vec![0x80u8; 9]; + bad.push(0x02); + let res = read_varint_u64(&mut bad.as_slice()); + assert!(res.is_err()); + } + + #[test] + fn noncanonical_varint_rejected() { + let bad = [0x80u8, 0x00]; + let res = read_varint_u64(&mut bad.as_slice()); + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("non-canonical varint")); + } + + #[test] + fn magic_version_roundtrip() { + let magic = b"JOLT"; + let version = 1u8; + let mut buf = Vec::new(); + write_magic_version(&mut buf, magic, version).unwrap(); + assert_eq!(buf.len(), 5); + + let decoded_version = read_magic_version(&mut buf.as_slice(), magic).unwrap(); + assert_eq!(decoded_version, version); + } + + #[test] + fn wrong_magic_rejected() { + let mut buf = Vec::new(); + write_magic_version(&mut buf, b"JOLT", 1).unwrap(); + let res = read_magic_version(&mut buf.as_slice(), b"BAAD"); + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("invalid proof magic")); + } + + #[test] + fn wrong_version_readable() { + let mut buf = Vec::new(); + write_magic_version(&mut buf, b"JOLT", 2).unwrap(); + let version = read_magic_version(&mut buf.as_slice(), b"JOLT").unwrap(); + assert_eq!(version, 2); + } + + #[test] + fn read_section_enforces_cap() { + let mut buf = Vec::new(); + write_varint_u64(&mut buf, 1000).unwrap(); + buf.extend_from_slice(&[0u8; 1000]); + + let mut too_small = buf.as_slice(); + let res = read_section(&mut too_small, 999); + assert!(res.is_err()); + + let mut cursor = buf.as_slice(); + let limited = read_section(&mut cursor, 1000).unwrap(); + assert_eq!(limited.limit(), 1000); + } +}