diff --git a/willow/src/zk/rlwe_relation.rs b/willow/src/zk/rlwe_relation.rs index 3aa2cef..a77cde1 100644 --- a/willow/src/zk/rlwe_relation.rs +++ b/willow/src/zk/rlwe_relation.rs @@ -327,28 +327,41 @@ pub fn generate_challenge_matrix( result } -// Multiplies a 128 by n matrix m and a length n vector v. -// m is a binary matrix each column of which has entries given by the bits of a single entry in the -// input vector m. -// Both the output and v are vectors of 128 bit signed integers. -pub fn multiply_by_challenge_matrix( +// Applies a challenge matrix R1-R2 to a vector v and checks if the result satisfies the conditions +// for not needing to be rejected. An internal error is returned in the event of rejection otherwise +// the resulting vector z is returned. +// +// To understand the rejection conditions see the comment for generate_range_product. +pub fn try_matrices_and_compute_z( v: &[i128], - m: &[u128], + R1: &[u128], + R2: &[u128], + y: &[i128], + half_loose_bound: i128, ) -> Result, status::StatusError> { let n = v.len(); - if m.len() != n { - return Err(status::failed_precondition("m and v have different lengths".to_string())); + if n != R1.len() || n != R2.len() { + return Err(status::failed_precondition( + "R1, R2, and v must have the same length".to_string(), + )); } - - let mut result = vec![0 as i128; 128]; - for i in 0..n { - for j in 0..128 { - if m[i] & (1u128 << j) != 0 { - result[j] += v[i]; + let mut z = vec![0 as i128; 128]; + for j in 0..128 { + let mut u = 0i128; + for i in 0..n { + if R1[i] & (1u128 << j) != 0 { + u += v[i]; } + if R2[i] & (1u128 << j) != 0 { + u -= v[i]; + } + } + z[j] = u + y[j]; + if u.abs() > half_loose_bound / 128 || z[j].abs() > half_loose_bound { + return Err(status::internal("Sample Rejected")); } } - Ok(result) + Ok(z) } // Linearly combines the 128 vector challenges of a challenge matrix into a single vector challenge @@ -459,7 +472,6 @@ fn generate_range_product( let mut z = vec![0 as i128; 128]; let mut attempts = 0; loop { - let mut done = true; attempts += 1; y = (0..128).map(|_| (rng.gen_range(0..possible_y) as i128)).collect(); for i in 0..128 { @@ -474,21 +486,9 @@ fn generate_range_product( // subtracting the other we get a challenge matrix with the correct distribution. R1 = generate_challenge_matrix(transcript, challenge_label, v.len()); R2 = generate_challenge_matrix(transcript, challenge_label, v.len()); - let u1 = multiply_by_challenge_matrix(v, &R1)?; - let u2 = multiply_by_challenge_matrix(v, &R2)?; - for i in 0..128 { - let u = u1[i] - u2[i]; - if u.abs() > half_loose_bound / 128 { - done = false; - break; - } - z[i] = u + y[i]; - if z[i].abs() > half_loose_bound { - done = false; - break; - } - } - if done { + let z_or_error = try_matrices_and_compute_z(v, &R1, &R2, &y, half_loose_bound); + if z_or_error.is_ok() { + z = z_or_error.unwrap(); break; } if attempts > 1000 { @@ -1244,16 +1244,44 @@ mod tests { } #[test] - fn test_multiply_by_challenge_matrix_basic_case() -> googletest::Result<()> { - let v = &[10i128, 20i128]; - let m = &[(1u128 << 0) | (1u128 << 2), (1u128 << 1) | (1u128 << 2)]; + fn test_try_matrices_and_compute_z_valid() -> googletest::Result<()> { + let v = [1, -2, 3, -4]; + let R1 = [1, 2, 3, 4]; + let R2 = [4, 3, 2, 1]; + let y = [1; 128]; + let half_loose_bound = 10000; + let result = try_matrices_and_compute_z(&v, &R1, &R2, &y, half_loose_bound)?; + let mut expected_z = vec![1; 128]; + expected_z[0] += 10; + expected_z[1] += 0; + expected_z[2] += -5; + verify_eq!(result, expected_z)?; + Ok(()) + } - let mut expected_result = vec![0i128; 128]; - expected_result[0] = 10; - expected_result[1] = 20; - expected_result[2] = 30; + #[test] + fn test_try_matrices_and_compute_z_mismatched_lengths() -> googletest::Result<()> { + let v = [1, -2, 3, -4]; + let R1 = [1, 2, 3]; + let R2 = [4, 3, 2, 1]; + let y = [1; 128]; + let half_loose_bound = 1000; + let result = try_matrices_and_compute_z(&v, &R1, &R2, &y, half_loose_bound); + assert!(result.is_err()); + verify_eq!(result.unwrap_err().message(), "R1, R2, and v must have the same length")?; + Ok(()) + } - assert_eq!(multiply_by_challenge_matrix(v, m).unwrap(), expected_result); + #[test] + fn test_try_matrices_and_compute_z_sample_rejected() -> googletest::Result<()> { + let v = [1000, -2000, 3000, -4000]; + let R1 = [1, 2, 3, 4]; + let R2 = [4, 3, 2, 1]; + let y = [1; 128]; + let half_loose_bound = 100000; + let result = try_matrices_and_compute_z(&v, &R1, &R2, &y, half_loose_bound); + assert!(result.is_err()); + verify_eq!(result.unwrap_err().message(), "Sample Rejected")?; Ok(()) }