diff --git a/willow/src/zk/rlwe_relation.rs b/willow/src/zk/rlwe_relation.rs index c8dbc45..3aa2cef 100644 --- a/willow/src/zk/rlwe_relation.rs +++ b/willow/src/zk/rlwe_relation.rs @@ -266,15 +266,15 @@ fn create_public_vec( fn update_public_vec_for_range_proof( public_vec: &mut Vec, result: &mut Scalar, - R_r: &Vec, - R_e: &Vec, - R_vw: &Vec, - z_r: &Vec, - z_e: &Vec, - z_vw: &Vec, - psi_r: Scalar, - psi_e: Scalar, - psi_vw: Scalar, + R_r: &[Scalar], + R_e: &[Scalar], + R_vw: &[Scalar], + z_r: &[Scalar], + z_e: &[Scalar], + z_vw: &[Scalar], + psi_r: &[Scalar], + psi_e: &[Scalar], + psi_vw: &[Scalar], n: usize, range_comm_offset: usize, samples_required: usize, @@ -298,20 +298,14 @@ fn update_public_vec_for_range_proof( // The range proofs equation also involves length 128 innerproducts involving the relevant // psi these are included in the last 3*128 entries of the inner product vectors. - let mut phi_psi_r_pow = phi; - let mut phi2_psi_e_pow = phi2; - let mut phi3_psi_vw_pow = phi3; for i in 0..128 { - public_vec[i + range_comm_offset] = phi_psi_r_pow; - public_vec[i + range_comm_offset + 128] = phi2_psi_e_pow; - public_vec[i + range_comm_offset + 256] = phi3_psi_vw_pow; + public_vec[i + range_comm_offset] = phi * Scalar::from(psi_r[i]); + public_vec[i + range_comm_offset + 128] = phi2 * Scalar::from(psi_e[i]); + public_vec[i + range_comm_offset + 256] = phi3 * Scalar::from(psi_vw[i]); // Add contributions of the range proofs to the overall inner product result. - *result += z_r[i] * phi_psi_r_pow; - *result += z_e[i] * phi2_psi_e_pow; - *result += z_vw[i] * phi3_psi_vw_pow; - phi_psi_r_pow *= psi_r; - phi2_psi_e_pow *= psi_e; - phi3_psi_vw_pow *= psi_vw; + *result += z_r[i] * public_vec[i + range_comm_offset]; + *result += z_e[i] * public_vec[i + range_comm_offset + 128]; + *result += z_vw[i] * public_vec[i + range_comm_offset + 256]; } } @@ -364,33 +358,38 @@ pub fn flatten_challenge_matrix( R1: Vec, R2: Vec, challenge_label: &'static [u8], -) -> Result<(Vec, Scalar), status::StatusError> { +) -> Result<(Vec, Vec), status::StatusError> { let n = R1.len(); if n != R2.len() { return Err(status::failed_precondition("R1 and R2 have different lengths".to_string())); } - let mut buf = [0u8; 64]; - transcript.challenge_bytes(challenge_label, &mut buf); - let psi = Scalar::from_bytes_mod_order_wide(&buf); - - let mut R = vec![Scalar::from(0 as u64); n]; - let mut psi_powers = [Scalar::from(1 as u64); 128]; - for j in 1..128 { - psi_powers[j] = psi_powers[j - 1] * psi; + let mut Rplus = vec![0u128; n]; + let mut Rminus = vec![0u128; n]; + let mut Rscalar = vec![Scalar::from(0u64); n]; + + let mut psi = [0u128; 128]; + let mut psi_scalar = vec![Scalar::from(0 as u64); 128]; + let mut buf = [0u8; 16]; + for j in 0..128 { + transcript.challenge_bytes(challenge_label, &mut buf); + // We only take challenges up to 2^121 so that the sum of 128 of them will fit in a u128. + psi[j] = u128::from_le_bytes(buf) >> 7; + psi_scalar[j] = Scalar::from(psi[j]); } for i in 0..n { for j in 0..128 { if R1[i] & (1u128 << j) != 0 { - R[i] += psi_powers[j]; + Rplus[i] += psi[j]; } if R2[i] & (1u128 << j) != 0 { - R[i] -= psi_powers[j]; + Rminus[i] += psi[j]; } } + Rscalar[i] = Scalar::from(Rplus[i]) - Scalar::from(Rminus[i]); } - Ok((R, psi)) + Ok((Rscalar, psi_scalar)) } // Check that loose_bound = bound*2500*sqrt(v.len()+1) fits within an i128. @@ -407,6 +406,16 @@ fn check_loose_bound_will_not_overflow(bound: u128, n: usize) -> Result<(), stat Ok(()) } +// Struct to hold the results of the generate_range_product function. +struct RangeProductMetadata { + R: Vec, + comm_y: RistrettoPoint, + y: Vec, + delta_y: Scalar, + psi: Vec, + z: Vec, +} + // Return the inner product that needs to be checked for the range proof, the commitment to y that // the verifier will need to verify it and the blinding information required for the proof. // @@ -429,10 +438,7 @@ fn generate_range_product( start: usize, transcript: &mut (impl Transcript + Clone), challenge_label: &'static [u8], -) -> Result< - (Vec, RistrettoPoint, Vec, Scalar, Scalar, Vec), - status::StatusError, -> { +) -> Result { // Check that computing loose bound does not result in an overflow. check_loose_bound_will_not_overflow(bound, v.len())?; @@ -512,17 +518,19 @@ fn generate_range_product( }) .collect(); - Ok((R, comm_y, scalar_y, delta_y, psi, scalar_z)) + Ok(RangeProductMetadata { R, comm_y, y: scalar_y, delta_y, psi, z: scalar_z }) } +// Verifies the z bound and returns the linear combination of the 128 rows of the range proof +// projection matrix R and a vector psi of the coefficients used in that linear combination. fn generate_range_product_for_verification_and_verify_z_bound( n: usize, bound: u128, comm_y: RistrettoPoint, - z: &Vec, + z: &[Scalar], transcript: &mut impl Transcript, challenge_label: &'static [u8], -) -> Result<(Vec, Scalar), status::StatusError> { +) -> Result<(Vec, Vec), status::StatusError> { // Check that computing loose bound does not result in an overflow. check_loose_bound_will_not_overflow(bound, n)?; @@ -762,7 +770,7 @@ impl<'a> ZeroKnowledgeProver, RlweRelationProofWi // Get inner products to prove for range proofs. We then need to check // + = mod P etc. // This is explained in more detail in the comment above generate_range_product. - let (R_r, comm_y_r, y_r, delta_y_r, psi_r, z_r) = generate_range_product( + let range_product_r = generate_range_product( &signed_r, bound_r, &self.prover, @@ -770,7 +778,7 @@ impl<'a> ZeroKnowledgeProver, RlweRelationProofWi transcript, b"range matrix r", )?; - let (R_e, comm_y_e, y_e, delta_y_e, psi_e, z_e) = generate_range_product( + let range_product_e = generate_range_product( &signed_e, bound_e, &self.prover, @@ -778,7 +786,7 @@ impl<'a> ZeroKnowledgeProver, RlweRelationProofWi transcript, b"range matrix e", )?; - let (R_vw, comm_y_vw, y_vw, delta_y_vw, psi_vw, z_vw) = generate_range_product( + let range_product_vw = generate_range_product( &signed_vw, q * (n as u128), &self.prover, @@ -792,15 +800,15 @@ impl<'a> ZeroKnowledgeProver, RlweRelationProofWi update_public_vec_for_range_proof( &mut public_vec, &mut result, - &R_r, - &R_e, - &R_vw, - &z_r, - &z_e, - &z_vw, - psi_r, - psi_e, - psi_vw, + &range_product_r.R, + &range_product_e.R, + &range_product_vw.R, + &range_product_r.z, + &range_product_e.z, + &range_product_vw.z, + &range_product_r.psi, + &range_product_e.psi, + &range_product_vw.psi, n, range_comm_offset, samples_required, @@ -818,13 +826,21 @@ impl<'a> ZeroKnowledgeProver, RlweRelationProofWi private_vec[i + n + n + n] = scalar_wrho_vec[i]; } for i in 0..128 { - private_vec[i + range_comm_offset] = y_r[i]; - private_vec[i + range_comm_offset + 128] = y_e[i]; - private_vec[i + range_comm_offset + 256] = y_vw[i]; + private_vec[i + range_comm_offset] = range_product_r.y[i]; + private_vec[i + range_comm_offset + 128] = range_product_e.y[i]; + private_vec[i + range_comm_offset + 256] = range_product_vw.y[i]; } - let private_vec_comm = comm_rev + comm_wrho + comm_y_r + comm_y_e + comm_y_vw; - let blinding_factor = delta_rev + delta_w + delta_y_r + delta_y_e + delta_y_vw; + let private_vec_comm = comm_rev + + comm_wrho + + range_product_r.comm_y + + range_product_e.comm_y + + range_product_vw.comm_y; + let blinding_factor = delta_rev + + delta_w + + range_product_r.delta_y + + range_product_e.delta_y + + range_product_vw.delta_y; // Set up linear product statement and prove it let lip_statement = LinearInnerProductProofStatement { @@ -841,12 +857,12 @@ impl<'a> ZeroKnowledgeProver, RlweRelationProofWi Ok(RlweRelationProof { comm_rev: comm_rev.compress(), comm_wrho: comm_wrho.compress(), - comm_y_r: comm_y_r.compress(), - comm_y_e: comm_y_e.compress(), - comm_y_vw: comm_y_vw.compress(), - z_r: z_r, - z_e: z_e, - z_vw: z_vw, + comm_y_r: range_product_r.comm_y.compress(), + comm_y_e: range_product_e.comm_y.compress(), + comm_y_vw: range_product_vw.comm_y.compress(), + z_r: range_product_r.z, + z_e: range_product_e.z, + z_vw: range_product_vw.z, lip_proof: lip_proof, }) } @@ -977,9 +993,9 @@ impl<'a> ZeroKnowledgeVerifier, RlweRelationProof &proof.z_r, &proof.z_e, &proof.z_vw, - psi_r, - psi_e, - psi_vw, + &psi_r, + &psi_e, + &psi_vw, n, range_comm_offset, samples_required, @@ -1247,33 +1263,31 @@ mod tests { let v = [1, -2, 3, -4]; let prover = LinearInnerProductProver::new(b"42", 132); let mut transcript = MerlinTranscript::new(b"42"); - let (R, comm_y, y, delta_y, psi, z) = + let result = generate_range_product(&v, bound, &prover, 4, &mut transcript, b"test vector")?; let mut private_vec = [Scalar::from(0u128); 132]; for i in 0..4 { private_vec[i] = Scalar::from((v[i] + (bound as i128)) as u128) - Scalar::from(bound); } - for i in 4..132 { - private_vec[i] = y[i - 4]; + for i in 0..128 { + private_vec[i + 4] = result.y[i]; } let mut public_vec = [Scalar::from(0u128); 132]; for i in 0..4 { - public_vec[i] = R[i]; + public_vec[i] = result.R[i]; } - let mut psi_pow = Scalar::from(1u128); - let mut result = Scalar::from(0u128); - for i in 4..132 { - public_vec[i] = psi_pow; - result += z[i - 4] * psi_pow; - psi_pow *= psi; + let mut inner_product = Scalar::from(0u128); + for i in 0..128 { + public_vec[i + 4] = result.psi[i]; + inner_product += result.z[i] * result.psi[i]; } let mut expected_result = Scalar::from(0u128); for j in 0..132 { expected_result += public_vec[j] * private_vec[j]; } - assert_eq!(result, expected_result); - let expected_comm_y = prover.commit_partial(&y, delta_y, 4, 132)?; - assert_eq!(comm_y, expected_comm_y); + assert_eq!(inner_product, expected_result); + let expected_comm_y = prover.commit_partial(&result.y, result.delta_y, 4, 132)?; + assert_eq!(result.comm_y, expected_comm_y); Ok(()) }