From 28a85da5b3b4589a50dd39d187ca5c67c109feef Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Wed, 19 Mar 2025 16:02:05 +0530 Subject: [PATCH 01/29] Add dependencies for new feature `gpu` Signed-off-by: Anjan Roy --- Cargo.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index a6b75c5..7c299bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,8 @@ turboshake = "=0.4.1" rayon = "=1.10.0" rand = "=0.9.0" rand_chacha = "=0.9.0" +vulkano = { version = "=0.35.1", optional = true } +vulkano-shaders = { version = "=0.35.0", optional = true } [dev-dependencies] test-case = "=3.3.1" @@ -34,6 +36,7 @@ required-features = ["mutate_internal_client_state"] [features] mutate_internal_client_state = [] +gpu = ["dep:vulkano", "dep:vulkano-shaders"] [profile.optimized] inherits = "release" From 11e5b774a98bc1509068387b08fe24631148b275 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Wed, 19 Mar 2025 16:41:09 +0530 Subject: [PATCH 02/29] Use `u32` for matrix dimensions Signed-off-by: Anjan Roy --- src/client.rs | 4 +- src/pir_internals/matrix.rs | 143 +++++++++++++++-------------- src/pir_internals/params.rs | 8 +- src/pir_internals/serialization.rs | 8 +- src/server.rs | 2 +- 5 files changed, 85 insertions(+), 80 deletions(-) diff --git a/src/client.rs b/src/client.rs index 02db179..9db66e2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -42,7 +42,7 @@ impl Client { let filter = BinaryFuseFilter::from_bytes(filter_param_bytes)?; let pub_mat_a_num_rows = LWE_DIMENSION; - let pub_mat_a_num_cols = filter.num_fingerprints; + let pub_mat_a_num_cols = filter.num_fingerprints as u32; let pub_mat_a = Matrix::generate_from_seed(pub_mat_a_num_rows, pub_mat_a_num_cols, seed_μ)?; let hint_mat_m = Matrix::from_bytes(hint_bytes)?; @@ -225,7 +225,7 @@ impl Client { let hashed_key = binary_fuse_filter::hash_of_key(key); let hash = binary_fuse_filter::mix256(&hashed_key, &self.filter.seed); - let recovered_row = (0..response_vector.num_cols()) + let recovered_row = (0..response_vector.num_cols() as usize) .map(|idx| { let unscaled_res = response_vector[(0, idx)].wrapping_sub(secret_vec_c[(0, idx)]); diff --git a/src/pir_internals/matrix.rs b/src/pir_internals/matrix.rs index 236457b..85e87f5 100644 --- a/src/pir_internals/matrix.rs +++ b/src/pir_internals/matrix.rs @@ -19,8 +19,8 @@ use super::error::ChalametPIRError; #[derive(Clone, Debug, PartialEq)] pub struct Matrix { - rows: usize, - cols: usize, + rows: u32, + cols: u32, elems: Vec, } @@ -36,12 +36,12 @@ impl Matrix { /// /// * `Result` - A new matrix if the input is valid (rows and cols are positive). /// Returns an error if either rows or cols is zero. - pub fn new(rows: usize, cols: usize) -> Result { + pub fn new(rows: u32, cols: u32) -> Result { if branch_opt_util::likely((rows > 0) && (cols > 0)) { Ok(Matrix { rows, cols, - elems: vec![0; rows * cols], + elems: vec![0; (rows * cols) as usize], }) } else { Err(ChalametPIRError::InvalidMatrixDimension) @@ -60,9 +60,9 @@ impl Matrix { /// /// * `Result` - A new matrix if the input is valid (rows and cols are positive and the number of values matches the number of required elements). /// Returns an error if either rows or cols is zero, or if the number of values does not match the number of required elements. - pub fn from_values(rows: usize, cols: usize, values: Vec) -> Result { + pub fn from_values(rows: u32, cols: u32, values: Vec) -> Result { if branch_opt_util::likely((rows > 0) && (cols > 0)) { - if branch_opt_util::likely(rows * cols == values.len()) { + if branch_opt_util::likely((rows * cols) as usize == values.len()) { Ok(Matrix { rows, cols, elems: values }) } else { Err(ChalametPIRError::InvalidNumberOfElementsInMatrix) @@ -73,11 +73,11 @@ impl Matrix { } #[inline(always)] - pub const fn num_rows(&self) -> usize { + pub const fn num_rows(&self) -> u32 { self.rows } #[inline(always)] - pub const fn num_cols(&self) -> usize { + pub const fn num_cols(&self) -> u32 { self.cols } #[inline(always)] @@ -103,13 +103,13 @@ impl Matrix { let res_num_rows = self.rows; let res_num_cols = rhs.rows; - let mut res_elems = vec![0u32; res_num_rows * res_num_cols]; + let mut res_elems = vec![0u32; (res_num_rows * res_num_cols) as usize]; res_elems.par_iter_mut().enumerate().for_each(|(lin_idx, v)| { let r_idx = 0; let c_idx = lin_idx; - *v = (0..self.cols).fold(0u32, |acc, k| acc.wrapping_add(self[(r_idx, k)].wrapping_mul(rhs[(c_idx, k)]))); + *v = (0..self.cols as usize).fold(0u32, |acc, k| acc.wrapping_add(self[(r_idx, k)].wrapping_mul(rhs[(c_idx, k)]))); }); Matrix::from_values(res_num_rows, res_num_cols, res_elems) @@ -126,14 +126,14 @@ impl Matrix { /// * `Result` - A new identity matrix if the input is valid (rows is positive). /// Returns an error if rows is zero. #[cfg(test)] - pub fn identity(rows: usize) -> Result { + pub fn identity(rows: u32) -> Result { if branch_opt_util::unlikely(rows == 0) { return Err(ChalametPIRError::InvalidMatrixDimension); } let mut mat = Matrix::new(rows, rows)?; - (0..rows).for_each(|idx| { + (0..mat.rows as usize).for_each(|idx| { mat[(idx, idx)] = 1; }); @@ -148,8 +148,8 @@ impl Matrix { pub fn transpose(&self) -> Matrix { let mut res = unsafe { Matrix::new(self.cols, self.rows).unwrap_unchecked() }; - (0..self.cols) - .flat_map(|ridx| (0..self.rows).map(move |cidx| (ridx, cidx))) + (0..self.cols as usize) + .flat_map(|ridx| (0..self.rows as usize).map(move |cidx| (ridx, cidx))) .for_each(|(ridx, cidx)| { res[(ridx, cidx)] = self[(cidx, ridx)]; }); @@ -169,12 +169,12 @@ impl Matrix { /// /// * `Result` - A new matrix if the input is valid (rows and cols are positive). /// Returns an error if either rows or cols is zero. - pub fn generate_from_seed(rows: usize, cols: usize, seed: &[u8; SEED_BYTE_LEN]) -> Result { + pub fn generate_from_seed(rows: u32, cols: u32, seed: &[u8; SEED_BYTE_LEN]) -> Result { let mut hasher = TurboShake128::default(); hasher.absorb(seed); hasher.finalize::<{ TurboShake128::DEFAULT_DOMAIN_SEPARATOR }>(); - let mut elems = vec![0u32; rows * cols]; + let mut elems = vec![0u32; (rows * cols) as usize]; let elems_byte_len = elems.len() * std::mem::size_of::(); unsafe { @@ -200,7 +200,7 @@ impl Matrix { /// /// * `Result` - A new row/ column vector if the input is valid (rows or cols is 1). /// Returns an error if neither rows nor cols is 1. - pub fn sample_from_uniform_ternary_dist(rows: usize, cols: usize) -> Result { + pub fn sample_from_uniform_ternary_dist(rows: u32, cols: u32) -> Result { if branch_opt_util::unlikely(!(rows == 1 || cols == 1)) { return Err(ChalametPIRError::InvalidDimensionForVector); } @@ -211,7 +211,7 @@ impl Matrix { let mut rng = ChaCha8Rng::from_os_rng(); let mut vec = Matrix::new(rows, cols)?; - let num_elems = rows * cols; + let num_elems = vec.num_elems(); let mut elem_idx = 0; while branch_opt_util::likely(elem_idx < num_elems) { @@ -318,8 +318,8 @@ impl Matrix { let max_value_byte_len = unsafe { db.values().map(|v| v.len()).max().unwrap_unchecked() }; let max_value_bit_len = max_value_byte_len * 8; - let rows = filter.num_fingerprints; - let cols: usize = (HASHED_KEY_BIT_LEN + max_value_bit_len + 8).div_ceil(mat_elem_bit_len); + let rows = filter.num_fingerprints as u32; + let cols = (HASHED_KEY_BIT_LEN + max_value_bit_len + 8).div_ceil(mat_elem_bit_len) as u32; let mut mat = Matrix::new(rows, cols)?; let mat_elem_mask = (1u32 << mat_elem_bit_len) - 1; @@ -346,13 +346,13 @@ impl Matrix { let mat_row_idx1 = h012[found + 1] as usize; let mat_row_idx2 = h012[found + 2] as usize; - let elems = (0..cols) + let elems = (0..cols as usize) .map(|elem_idx| { - let f1 = mat.elems[mat_row_idx1 * cols + elem_idx]; + let f1 = mat.elems[mat_row_idx1 * cols as usize + elem_idx]; (elem_idx, row[elem_idx].wrapping_sub(f1)) }) .map(|(elem_idx, elem)| { - let f2 = mat.elems[mat_row_idx2 * cols + elem_idx]; + let f2 = mat.elems[mat_row_idx2 * cols as usize + elem_idx]; (elem_idx, elem.wrapping_sub(f2) & mat_elem_mask) }) .map(|(elem_idx, elem)| { @@ -361,8 +361,8 @@ impl Matrix { }) .collect::>(); - let fingerprints_begin_at = mat_row_idx0 * cols; - let fingerprints_end_at = fingerprints_begin_at + cols; + let fingerprints_begin_at = mat_row_idx0 * cols as usize; + let fingerprints_end_at = fingerprints_begin_at + cols as usize; mat.elems[fingerprints_begin_at..fingerprints_end_at].copy_from_slice(&elems); } @@ -396,10 +396,10 @@ impl Matrix { let (h0, h1, h2) = binary_fuse_filter::hash_batch_for_3_wise_xor_filter(hash, filter.segment_length, filter.segment_count_length); - let recovered_row = (0..self.cols) - .map(|elem_idx| (elem_idx, self.elems[h0 as usize * self.cols + elem_idx])) - .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h1 as usize * self.cols + elem_idx]))) - .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h2 as usize * self.cols + elem_idx]))) + let recovered_row = (0..self.cols as usize) + .map(|elem_idx| (elem_idx, self.elems[h0 as usize * self.cols as usize + elem_idx])) + .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h1 as usize * self.cols as usize + elem_idx]))) + .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h2 as usize * self.cols as usize + elem_idx]))) .map(|(elem_idx, elem)| elem.wrapping_add((binary_fuse_filter::mix(hash, elem_idx as u64) as u32) & mat_elem_mask) & mat_elem_mask) .collect::>(); @@ -450,8 +450,8 @@ impl Matrix { let max_value_byte_len = unsafe { db.values().map(|v| v.len()).max().unwrap_unchecked() }; let max_value_bit_len = max_value_byte_len * 8; - let rows = filter.num_fingerprints; - let cols: usize = (HASHED_KEY_BIT_LEN + max_value_bit_len + 8).div_ceil(mat_elem_bit_len); + let rows = filter.num_fingerprints as u32; + let cols = (HASHED_KEY_BIT_LEN + max_value_bit_len + 8).div_ceil(mat_elem_bit_len) as u32; let mut mat = Matrix::new(rows, cols)?; let mat_elem_mask = (1u32 << mat_elem_bit_len) - 1; @@ -481,17 +481,17 @@ impl Matrix { let mat_row_idx2 = h0123[found + 2] as usize; let mat_row_idx3 = h0123[found + 3] as usize; - let elems = (0..cols) + let elems = (0..cols as usize) .map(|elem_idx| { - let f1 = mat.elems[mat_row_idx1 * cols + elem_idx]; + let f1 = mat.elems[mat_row_idx1 * cols as usize + elem_idx]; (elem_idx, row[elem_idx].wrapping_sub(f1)) }) .map(|(elem_idx, elem)| { - let f2 = mat.elems[mat_row_idx2 * cols + elem_idx]; + let f2 = mat.elems[mat_row_idx2 * cols as usize + elem_idx]; (elem_idx, elem.wrapping_sub(f2) & mat_elem_mask) }) .map(|(elem_idx, elem)| { - let f2 = mat.elems[mat_row_idx3 * cols + elem_idx]; + let f2 = mat.elems[mat_row_idx3 * cols as usize + elem_idx]; (elem_idx, elem.wrapping_sub(f2) & mat_elem_mask) }) .map(|(elem_idx, elem)| { @@ -500,8 +500,8 @@ impl Matrix { }) .collect::>(); - let fingerprints_begin_at = mat_row_idx0 * cols; - let fingerprints_end_at = fingerprints_begin_at + cols; + let fingerprints_begin_at = mat_row_idx0 * cols as usize; + let fingerprints_end_at = fingerprints_begin_at + cols as usize; mat.elems[fingerprints_begin_at..fingerprints_end_at].copy_from_slice(&elems); } @@ -535,11 +535,11 @@ impl Matrix { let (h0, h1, h2, h3) = binary_fuse_filter::hash_batch_for_4_wise_xor_filter(hash, filter.segment_length, filter.segment_count_length); - let recovered_row = (0..self.cols) - .map(|elem_idx| (elem_idx, self.elems[h0 as usize * self.cols + elem_idx])) - .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h1 as usize * self.cols + elem_idx]))) - .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h2 as usize * self.cols + elem_idx]))) - .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h3 as usize * self.cols + elem_idx]))) + let recovered_row = (0..self.cols as usize) + .map(|elem_idx| (elem_idx, self.elems[h0 as usize * self.cols as usize + elem_idx])) + .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h1 as usize * self.cols as usize + elem_idx]))) + .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h2 as usize * self.cols as usize + elem_idx]))) + .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h3 as usize * self.cols as usize + elem_idx]))) .map(|(elem_idx, elem)| elem.wrapping_add((binary_fuse_filter::mix(hash, elem_idx as u64) as u32) & mat_elem_mask) & mat_elem_mask) .collect::>(); @@ -567,7 +567,7 @@ impl Matrix { } pub fn to_bytes(&self) -> Vec { - let encoded_elems_byte_len = std::mem::size_of::() * self.rows * self.cols; + let encoded_elems_byte_len = std::mem::size_of::() * (self.rows * self.cols) as usize; let offset0 = 0; let offset1 = offset0 + std::mem::size_of_val(&self.rows); @@ -594,8 +594,8 @@ impl Matrix { pub fn from_bytes(bytes: &[u8]) -> Result { const OFFSET0: usize = 0; - const OFFSET1: usize = OFFSET0 + std::mem::size_of::(); - const OFFSET2: usize = OFFSET1 + std::mem::size_of::(); + const OFFSET1: usize = OFFSET0 + std::mem::size_of::(); + const OFFSET2: usize = OFFSET1 + std::mem::size_of::(); if branch_opt_util::unlikely(bytes.len() <= OFFSET2) { return Err(ChalametPIRError::FailedToDeserializeMatrixFromBytes); @@ -603,11 +603,11 @@ impl Matrix { let (rows, cols) = unsafe { ( - usize::from_le_bytes(bytes.get_unchecked(OFFSET0..OFFSET1).try_into().unwrap()), - usize::from_le_bytes(bytes.get_unchecked(OFFSET1..OFFSET2).try_into().unwrap()), + u32::from_le_bytes(bytes.get_unchecked(OFFSET0..OFFSET1).try_into().unwrap()), + u32::from_le_bytes(bytes.get_unchecked(OFFSET1..OFFSET2).try_into().unwrap()), ) }; - let num_elems = rows * cols; + let num_elems = (rows * cols) as usize; if branch_opt_util::unlikely(num_elems == 0) { return Err(ChalametPIRError::FailedToDeserializeMatrixFromBytes); @@ -638,7 +638,7 @@ impl Index<(usize, usize)> for Matrix { #[inline(always)] fn index(&self, index: (usize, usize)) -> &Self::Output { let (ridx, cidx) = index; - unsafe { self.elems.get_unchecked(ridx * self.cols + cidx) } + unsafe { self.elems.get_unchecked(ridx * self.cols as usize + cidx) } } } @@ -646,7 +646,7 @@ impl IndexMut<(usize, usize)> for Matrix { #[inline(always)] fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output { let (ridx, cidx) = index; - unsafe { self.elems.get_unchecked_mut(ridx * self.cols + cidx) } + unsafe { self.elems.get_unchecked_mut(ridx * self.cols as usize + cidx) } } } @@ -667,13 +667,13 @@ impl<'b> Mul<&'b Matrix> for &Matrix { return Err(ChalametPIRError::IncompatibleDimensionForMatrixMultiplication); } - let mut res_elems = vec![0u32; self.rows * rhs.cols]; + let mut res_elems = vec![0u32; (self.rows * rhs.cols) as usize]; res_elems.par_iter_mut().enumerate().for_each(|(lin_idx, v)| { - let r_idx = lin_idx / rhs.cols; - let c_idx = lin_idx - r_idx * rhs.cols; + let r_idx = lin_idx / rhs.cols as usize; + let c_idx = lin_idx - r_idx * rhs.cols as usize; - *v = (0..self.cols).fold(0u32, |acc, k| acc.wrapping_add(self[(r_idx, k)].wrapping_mul(rhs[(k, c_idx)]))); + *v = (0..self.cols as usize).fold(0u32, |acc, k| acc.wrapping_add(self[(r_idx, k)].wrapping_mul(rhs[(k, c_idx)]))); }); Matrix::from_values(self.rows, rhs.cols, res_elems) @@ -697,7 +697,7 @@ impl<'b> Add<&'b Matrix> for &Matrix { return Err(ChalametPIRError::IncompatibleDimensionForMatrixAddition); } - let mut res_elems = vec![0u32; self.rows * rhs.cols]; + let mut res_elems = vec![0u32; (self.rows * rhs.cols) as usize]; res_elems.par_iter_mut().enumerate().for_each(|(lin_idx, v)| { *v = unsafe { self.elems.get_unchecked(lin_idx).wrapping_add(*rhs.elems.get_unchecked(lin_idx)) }; @@ -850,7 +850,7 @@ pub mod test { #[test_case(0, 1024 => matches Err(ChalametPIRError::InvalidMatrixDimension); "Number of rows must be greater than zero")] #[test_case(1024, 0 => matches Err(ChalametPIRError::InvalidMatrixDimension); "Number of columns must be greater than zero")] #[test_case(0, 0 => matches Err(ChalametPIRError::InvalidMatrixDimension); "Both number of rows and columns must be greater than zero")] - fn new_empty_matrix_constructor_api(num_rows: usize, num_cols: usize) -> Result { + fn new_empty_matrix_constructor_api(num_rows: u32, num_cols: u32) -> Result { Matrix::new(num_rows, num_cols) } @@ -859,13 +859,13 @@ pub mod test { #[test_case(1024, 0, vec![] => matches Err(ChalametPIRError::InvalidMatrixDimension); "Number of columns must be greater than zero")] #[test_case(0, 0, vec![] => matches Err(ChalametPIRError::InvalidMatrixDimension); "Both number of rows and columns must be greater than zero")] #[test_case(1024, 1024, vec![0u32; 1024 * 1024 -1] => matches Err(ChalametPIRError::InvalidNumberOfElementsInMatrix); "Number of elements must be equal to number of rows times number of columns")] - fn from_values_matrix_constructor_api(num_rows: usize, num_cols: usize, elems: Vec) -> Result { + fn from_values_matrix_constructor_api(num_rows: u32, num_cols: u32, elems: Vec) -> Result { Matrix::from_values(num_rows, num_cols, elems) } #[test_case((1024,1),(1,1024) => matches Ok(_); "Matrix multiplication should work for valid dimensions")] #[test_case((1024,1),(1024, 1) => matches Err(ChalametPIRError::IncompatibleDimensionForMatrixMultiplication); "Matrix multiplication should not work for incompatible dimensions")] - fn matrix_multiplication_failures(lhs_mat_dim: (usize, usize), rhs_mat_dim: (usize, usize)) -> Result { + fn matrix_multiplication_failures(lhs_mat_dim: (u32, u32), rhs_mat_dim: (u32, u32)) -> Result { let (lhs_mat_rows, lhs_mat_cols) = lhs_mat_dim; let lhs_mat = Matrix::new(lhs_mat_rows, lhs_mat_cols)?; @@ -877,7 +877,7 @@ pub mod test { #[test_case((1024,1),(1024, 1) => matches Ok(_); "Matrix addition should work for valid dimensions")] #[test_case((1024,1),(1, 1024) => matches Err(ChalametPIRError::IncompatibleDimensionForMatrixAddition); "Matrix addition should not work for incompatible dimensions")] - fn matrix_addition_failures(lhs_mat_dim: (usize, usize), rhs_mat_dim: (usize, usize)) -> Result { + fn matrix_addition_failures(lhs_mat_dim: (u32, u32), rhs_mat_dim: (u32, u32)) -> Result { let (lhs_mat_rows, lhs_mat_cols) = lhs_mat_dim; let lhs_mat = Matrix::new(lhs_mat_rows, lhs_mat_cols)?; @@ -890,8 +890,8 @@ pub mod test { #[test] fn matrix_multiplication_is_correct() { const NUM_ATTEMPT_MATRIX_MULTIPLICATIONS: usize = 100; - const MIN_MATRIX_DIM: usize = 1; - const MAX_MATRIX_DIM: usize = 1024; + const MIN_MATRIX_DIM: u32 = 1; + const MAX_MATRIX_DIM: u32 = 1024; let mut rng = ChaCha8Rng::from_os_rng(); @@ -920,8 +920,8 @@ pub mod test { #[test] fn row_vector_transposed_matrix_multiplication_works() { const NUM_ATTEMPT_VECTOR_MATRIX_MULTIPLICATIONS: usize = 100; - const MIN_ROW_VECTOR_DIM: usize = 1; - const MAX_ROW_VECTOR_DIM: usize = 1024; + const MIN_ROW_VECTOR_DIM: u32 = 1; + const MAX_ROW_VECTOR_DIM: u32 = 1024; let mut rng = ChaCha8Rng::from_os_rng(); @@ -934,9 +934,10 @@ pub mod test { let vec_num_cols = rng.random_range(MIN_ROW_VECTOR_DIM..=MAX_ROW_VECTOR_DIM); let mat_num_rows = vec_num_cols; let mat_num_cols = rng.random_range(MIN_ROW_VECTOR_DIM..=MAX_ROW_VECTOR_DIM); + let mat_num_elems = (mat_num_rows * mat_num_cols) as usize; let row_vector = Matrix::generate_from_seed(vec_num_rows, vec_num_cols, &seed).expect("Row vector must be generated from seed"); - let all_ones = Matrix::from_values(mat_num_rows, mat_num_cols, vec![1; mat_num_rows * mat_num_cols]).expect("Matrix of ones must be created"); + let all_ones = Matrix::from_values(mat_num_rows, mat_num_cols, vec![1; mat_num_elems]).expect("Matrix of ones must be created"); let transposed_all_ones = all_ones.transpose(); let res_row_vector = row_vector @@ -945,7 +946,9 @@ pub mod test { let expected_res_row_vector = { let sum_of_elems_in_row_vector = row_vector.elems.iter().fold(0u32, |acc, &cur| acc.wrapping_add(cur)); - Matrix::from_values(vec_num_rows, mat_num_cols, vec![sum_of_elems_in_row_vector; mat_num_cols]).expect("Expected row vector must be created") + let row_vec_elems = vec![sum_of_elems_in_row_vector; mat_num_cols as usize]; + + Matrix::from_values(vec_num_rows, mat_num_cols, row_vec_elems).expect("Expected row vector must be created") }; assert_eq!(expected_res_row_vector, res_row_vector); @@ -956,8 +959,8 @@ pub mod test { #[test] fn matrix_addition_is_correct() { const NUM_ATTEMPT_MATRIX_ADDITIONS: usize = 100; - const MIN_MATRIX_DIM: usize = 1; - const MAX_MATRIX_DIM: usize = 1024; + const MIN_MATRIX_DIM: u32 = 1; + const MAX_MATRIX_DIM: u32 = 1024; let mut rng = ChaCha8Rng::from_os_rng(); @@ -987,7 +990,7 @@ pub mod test { #[test_case(1024, 1024 => matches Err(ChalametPIRError::InvalidDimensionForVector); "Either number of rows or columns must be 1 in vector")] #[test_case(0, 1024 => matches Err(ChalametPIRError::InvalidDimensionForVector); "Number of rows in row vector must be 1")] #[test_case(1024, 0 => matches Err(ChalametPIRError::InvalidDimensionForVector); "Number of columns in column vector must be 1")] - fn sampling_from_uniform_ternary_dist_works(num_rows: usize, num_cols: usize) -> Result { + fn sampling_from_uniform_ternary_dist_works(num_rows: u32, num_cols: u32) -> Result { Matrix::sample_from_uniform_ternary_dist(num_rows, num_cols) } @@ -1012,8 +1015,8 @@ pub mod test { #[test] fn serialized_matrix_can_be_deserialized() { const NUM_ATTEMPT_MATRIX_SERIALIZATIONS: usize = 100; - const MIN_MATRIX_DIM: usize = 1; - const MAX_MATRIX_DIM: usize = 1024; + const MIN_MATRIX_DIM: u32 = 1; + const MAX_MATRIX_DIM: u32 = 1024; let mut rng = ChaCha8Rng::from_os_rng(); diff --git a/src/pir_internals/params.rs b/src/pir_internals/params.rs index f5087d9..f1c4514 100644 --- a/src/pir_internals/params.rs +++ b/src/pir_internals/params.rs @@ -1,5 +1,7 @@ +pub const LWE_DIMENSION: u32 = 1774; + pub const BIT_SECURITY_LEVEL: usize = 128; -pub const LWE_DIMENSION: usize = 1774; -pub const SEED_BYTE_LEN: usize = (2 * BIT_SECURITY_LEVEL) / 8; -pub const HASHED_KEY_BYTE_LEN: usize = (2 * BIT_SECURITY_LEVEL) / 8; +pub const SEED_BYTE_LEN: usize = (2 * BIT_SECURITY_LEVEL) / u8::BITS as usize; +pub const HASHED_KEY_BYTE_LEN: usize = (2 * BIT_SECURITY_LEVEL) / u8::BITS as usize; + pub const SERVER_SETUP_MAX_ATTEMPT_COUNT: usize = 100; diff --git a/src/pir_internals/serialization.rs b/src/pir_internals/serialization.rs index cd02aa9..2686fda 100644 --- a/src/pir_internals/serialization.rs +++ b/src/pir_internals/serialization.rs @@ -19,7 +19,7 @@ use turboshake::TurboShake128; /// /// A vector of 32-bit unsigned integers representing the encoded key-value pair. #[inline] -pub fn encode_kv_as_row(key: &[u8], value: &[u8], mat_elem_bit_len: usize, num_cols: usize) -> Vec { +pub fn encode_kv_as_row(key: &[u8], value: &[u8], mat_elem_bit_len: usize, num_cols: u32) -> Vec { let hashed_key = { let mut hasher = TurboShake128::default(); hasher.absorb(key); @@ -31,7 +31,7 @@ pub fn encode_kv_as_row(key: &[u8], value: &[u8], mat_elem_bit_len: usize, num_c hashed_key }; - let mut row = vec![0u32; num_cols]; + let mut row = vec![0u32; num_cols as usize]; let mut row_offset = 0; let mat_elem_mask = (1u64 << mat_elem_bit_len) - 1; @@ -268,8 +268,8 @@ mod test { hashed_key }; - let actual_encoded_kv_len = (hashed_key.len() * 8 + (value.len() + 1) * 8).div_ceil(mat_elem_bit_len); - let max_encoded_kv_len = (hashed_key.len() * 8 + (2 * value.len() + 1) * 8).div_ceil(mat_elem_bit_len); + let actual_encoded_kv_len = (hashed_key.len() * 8 + (value.len() + 1) * 8).div_ceil(mat_elem_bit_len) as u32; + let max_encoded_kv_len = (hashed_key.len() * 8 + (2 * value.len() + 1) * 8).div_ceil(mat_elem_bit_len) as u32; for encoded_kv_len in actual_encoded_kv_len..max_encoded_kv_len { let row = encode_kv_as_row(&key, &value, mat_elem_bit_len, encoded_kv_len); diff --git a/src/server.rs b/src/server.rs index faac70a..d845360 100644 --- a/src/server.rs +++ b/src/server.rs @@ -50,7 +50,7 @@ impl Server { let (parsed_db_mat_d, filter) = Matrix::from_kv_database::(db, mat_elem_bit_len, SERVER_SETUP_MAX_ATTEMPT_COUNT)?; let pub_mat_a_num_rows = LWE_DIMENSION; - let pub_mat_a_num_cols = filter.num_fingerprints; + let pub_mat_a_num_cols = filter.num_fingerprints as u32; let pub_mat_a = unsafe { Matrix::generate_from_seed(pub_mat_a_num_rows, pub_mat_a_num_cols, seed_μ).unwrap_unchecked() }; From cfb91243b997638ba977950bc237c84b2e60f7fd Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Wed, 19 Mar 2025 17:02:11 +0530 Subject: [PATCH 03/29] Add compute shader for matrix-matrix multiplication Shader is taken from https://gist.github.com/itzmeanjan/84613bc7595372c5e6b6c22481d42f9a Signed-off-by: Anjan Roy --- shaders/mat_x_mat.glsl | 48 +++++++++++++++++++++++++++ src/pir_internals/mat_x_mat_shader.rs | 5 +++ src/pir_internals/mod.rs | 3 ++ 3 files changed, 56 insertions(+) create mode 100644 shaders/mat_x_mat.glsl create mode 100644 src/pir_internals/mat_x_mat_shader.rs diff --git a/shaders/mat_x_mat.glsl b/shaders/mat_x_mat.glsl new file mode 100644 index 0000000..98641bf --- /dev/null +++ b/shaders/mat_x_mat.glsl @@ -0,0 +1,48 @@ +#version 460 +#pragma shader_stage(compute) + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +layout(set = 0, binding = 0) buffer readonly MatrixA +{ + uint rows; + uint cols; + uint[] elems; +} matrix_a; + +layout(set = 0, binding = 1) buffer readonly MatrixB +{ + uint rows; + uint cols; + uint[] elems; +} matrix_b; + +layout(set = 0, binding = 2) buffer writeonly MatrixC +{ + uint rows; + uint cols; + uint[] elems; +} matrix_c; + +void +main() +{ + const uint row_idx = gl_GlobalInvocationID.x; + const uint col_idx = gl_GlobalInvocationID.y; + + if (row_idx >= matrix_a.rows || col_idx >= matrix_b.cols) { + return; + } + + if ((row_idx == 0) && (col_idx == 0)) { + matrix_c.rows = matrix_a.rows; + matrix_c.cols = matrix_b.cols; + } + + uint sum = 0; + for (uint i = 0; i < matrix_a.cols; i++) { + sum += matrix_a.elems[row_idx * matrix_a.cols + i] * matrix_b.elems[i * matrix_b.cols + col_idx]; + } + + matrix_c.elems[row_idx * matrix_b.cols + col_idx] = sum; +} diff --git a/src/pir_internals/mat_x_mat_shader.rs b/src/pir_internals/mat_x_mat_shader.rs new file mode 100644 index 0000000..eb2a928 --- /dev/null +++ b/src/pir_internals/mat_x_mat_shader.rs @@ -0,0 +1,5 @@ +vulkano_shaders::shader! { + ty: "compute", + path: "./shaders/mat_x_mat.glsl", + vulkan_version: "1.2", +} diff --git a/src/pir_internals/mod.rs b/src/pir_internals/mod.rs index 24566fd..db19967 100644 --- a/src/pir_internals/mod.rs +++ b/src/pir_internals/mod.rs @@ -4,3 +4,6 @@ pub mod error; pub mod matrix; pub mod params; pub mod serialization; + +#[cfg(feature = "gpu")] +pub mod mat_x_mat_shader; From 4cb8d97111873d68d143a2eef556c64870c2e44b Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Wed, 19 Mar 2025 17:50:57 +0530 Subject: [PATCH 04/29] Setup a Vulkan device and queue so that commands can be submitted to it Signed-off-by: Anjan Roy --- src/pir_internals/error.rs | 11 +++++++ src/pir_internals/gpu.rs | 63 ++++++++++++++++++++++++++++++++++++++ src/pir_internals/mod.rs | 2 ++ 3 files changed, 76 insertions(+) create mode 100644 src/pir_internals/gpu.rs diff --git a/src/pir_internals/error.rs b/src/pir_internals/error.rs index 00de65c..bd8c9fb 100644 --- a/src/pir_internals/error.rs +++ b/src/pir_internals/error.rs @@ -6,6 +6,12 @@ use std::{error::Error, fmt::Display}; /// It includes errors related to matrix operations, binary fuse filter operations, and PIR operations. #[derive(Debug, PartialEq)] pub enum ChalametPIRError { + // GPU + VulkanLibraryNotFound, + VulkanInstanceCreationFailed, + VulkanPhysicalDeviceNotFound, + VulkanDeviceCreationFailed, + // Matrix InvalidMatrixDimension, IncompatibleDimensionForMatrixMultiplication, @@ -36,6 +42,11 @@ pub enum ChalametPIRError { impl Display for ChalametPIRError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + Self::VulkanLibraryNotFound => write!(f, "Failed to load the default Vulkan library for the system."), + Self::VulkanInstanceCreationFailed => write!(f, "Failed to create a new instance of Vulkan."), + Self::VulkanPhysicalDeviceNotFound => write!(f, "Failed to find a compatible Vulkan physical device."), + Self::VulkanDeviceCreationFailed => write!(f, "Failed to create a Vulkan device and associated queue"), + Self::InvalidMatrixDimension => write!(f, "The number of rows and columns in the matrix must be non-zero."), Self::IncompatibleDimensionForMatrixMultiplication => write!(f, "The matrix dimensions do not allow multiplication."), Self::IncompatibleDimensionForMatrixAddition => write!(f, "The matrix dimensions do not allow addition."), diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs new file mode 100644 index 0000000..845b8cc --- /dev/null +++ b/src/pir_internals/gpu.rs @@ -0,0 +1,63 @@ +use crate::ChalametPIRError; +use std::sync::Arc; +use vulkano::{ + VulkanLibrary, + device::{Device, DeviceCreateInfo, DeviceExtensions, Queue, QueueCreateInfo, QueueFlags, physical::PhysicalDeviceType}, + instance::{Instance, InstanceCreateFlags, InstanceCreateInfo}, +}; + +pub fn setup_gpu() -> Result<(Arc, Arc), ChalametPIRError> { + let library = VulkanLibrary::new().map_err(|_| ChalametPIRError::VulkanLibraryNotFound)?; + let instance = Instance::new( + library, + InstanceCreateInfo { + flags: InstanceCreateFlags::ENUMERATE_PORTABILITY, + ..Default::default() + }, + ) + .map_err(|_| ChalametPIRError::VulkanInstanceCreationFailed)?; + + let device_extensions = DeviceExtensions { + khr_storage_buffer_storage_class: true, + ..DeviceExtensions::empty() + }; + + let (physical_device, queue_family_index) = instance + .enumerate_physical_devices() + .map_err(|_| ChalametPIRError::VulkanPhysicalDeviceNotFound)? + .filter(|p| p.supported_extensions().contains(&device_extensions)) + .filter_map(|p| { + p.queue_family_properties() + .iter() + .position(|q| q.queue_flags.intersects(QueueFlags::COMPUTE | QueueFlags::TRANSFER)) + .map(|i| (p, i as u32)) + }) + .min_by_key(|(p, _)| match p.properties().device_type { + PhysicalDeviceType::DiscreteGpu => 0, + PhysicalDeviceType::IntegratedGpu => 1, + PhysicalDeviceType::VirtualGpu => 2, + PhysicalDeviceType::Cpu => 3, + PhysicalDeviceType::Other => 4, + _ => 5, + }) + .ok_or(ChalametPIRError::VulkanPhysicalDeviceNotFound)?; + + let (device, queue) = { + let (device, mut queues) = Device::new( + physical_device, + DeviceCreateInfo { + enabled_extensions: device_extensions, + queue_create_infos: vec![QueueCreateInfo { + queue_family_index, + ..Default::default() + }], + ..Default::default() + }, + ) + .map_err(|_| ChalametPIRError::VulkanDeviceCreationFailed)?; + + (device, queues.next().ok_or(ChalametPIRError::VulkanDeviceCreationFailed)?) + }; + + Ok((device, queue)) +} diff --git a/src/pir_internals/mod.rs b/src/pir_internals/mod.rs index db19967..2e50f5a 100644 --- a/src/pir_internals/mod.rs +++ b/src/pir_internals/mod.rs @@ -5,5 +5,7 @@ pub mod matrix; pub mod params; pub mod serialization; +#[cfg(feature = "gpu")] +pub mod gpu; #[cfg(feature = "gpu")] pub mod mat_x_mat_shader; From 4b9bac8be3f503511cfbb5408f56d917fbe2bbe9 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Wed, 19 Mar 2025 21:47:04 +0530 Subject: [PATCH 05/29] Setup gpu returns a memory allocator and command buffer allocator too Signed-off-by: Anjan Roy --- src/pir_internals/gpu.rs | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index 845b8cc..db19b00 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -1,12 +1,16 @@ +use super::matrix::Matrix; use crate::ChalametPIRError; use std::sync::Arc; use vulkano::{ VulkanLibrary, + buffer::{Buffer, BufferCreateInfo, BufferUsage, Subbuffer}, + command_buffer::{AutoCommandBufferBuilder, CopyBufferInfo, PrimaryAutoCommandBuffer, allocator::StandardCommandBufferAllocator}, device::{Device, DeviceCreateInfo, DeviceExtensions, Queue, QueueCreateInfo, QueueFlags, physical::PhysicalDeviceType}, instance::{Instance, InstanceCreateFlags, InstanceCreateInfo}, + memory::allocator::{AllocationCreateInfo, MemoryTypeFilter, StandardMemoryAllocator}, }; -pub fn setup_gpu() -> Result<(Arc, Arc), ChalametPIRError> { +pub fn setup_gpu() -> Result<(Arc, Arc, Arc, Arc), ChalametPIRError> { let library = VulkanLibrary::new().map_err(|_| ChalametPIRError::VulkanLibraryNotFound)?; let instance = Instance::new( library, @@ -42,22 +46,24 @@ pub fn setup_gpu() -> Result<(Arc, Arc), ChalametPIRError> { }) .ok_or(ChalametPIRError::VulkanPhysicalDeviceNotFound)?; - let (device, queue) = { - let (device, mut queues) = Device::new( - physical_device, - DeviceCreateInfo { - enabled_extensions: device_extensions, - queue_create_infos: vec![QueueCreateInfo { - queue_family_index, - ..Default::default() - }], + let (device, mut queues) = Device::new( + physical_device, + DeviceCreateInfo { + enabled_extensions: device_extensions, + queue_create_infos: vec![QueueCreateInfo { + queue_family_index, ..Default::default() - }, - ) - .map_err(|_| ChalametPIRError::VulkanDeviceCreationFailed)?; + }], + ..Default::default() + }, + ) + .map_err(|_| ChalametPIRError::VulkanDeviceCreationFailed)?; + let queue = queues.next().ok_or(ChalametPIRError::VulkanDeviceCreationFailed)?; - (device, queues.next().ok_or(ChalametPIRError::VulkanDeviceCreationFailed)?) - }; + let memory_allocator = Arc::new(StandardMemoryAllocator::new_default(device.clone())); + let command_buffer_allocator = Arc::new(StandardCommandBufferAllocator::new(device.clone(), Default::default())); + + Ok((device, queue, memory_allocator, command_buffer_allocator)) +} - Ok((device, queue)) } From 9c2ba00ccdcff762b6ff68ebb43da190f7d39857 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Wed, 19 Mar 2025 21:48:54 +0530 Subject: [PATCH 06/29] Given a matrix, returns a buffer with transfer-src flag set Signed-off-by: Anjan Roy --- src/pir_internals/gpu.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index db19b00..2d356f4 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -66,4 +66,21 @@ pub fn setup_gpu() -> Result<(Arc, Arc, Arc, matrix: Matrix) -> Result, ChalametPIRError> { + let matrix_as_bytes = matrix.to_bytes(); + let buffer = Buffer::from_iter( + memory_allocator.clone(), + BufferCreateInfo { + usage: BufferUsage::TRANSFER_SRC, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::HOST_SEQUENTIAL_WRITE | MemoryTypeFilter::PREFER_DEVICE, + ..Default::default() + }, + matrix_as_bytes, + ) + .map_err(|_| ChalametPIRError::VulkanBufferCreationFailed)?; + Ok(buffer) +} } From dce7da2a455fb1974ce609f0c85db43b3dc9c54e Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Wed, 19 Mar 2025 21:50:57 +0530 Subject: [PATCH 07/29] Add error enum for vulkan buffer creation failure Signed-off-by: Anjan Roy --- src/pir_internals/error.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pir_internals/error.rs b/src/pir_internals/error.rs index bd8c9fb..e207452 100644 --- a/src/pir_internals/error.rs +++ b/src/pir_internals/error.rs @@ -11,6 +11,7 @@ pub enum ChalametPIRError { VulkanInstanceCreationFailed, VulkanPhysicalDeviceNotFound, VulkanDeviceCreationFailed, + VulkanBufferCreationFailed, // Matrix InvalidMatrixDimension, @@ -46,6 +47,7 @@ impl Display for ChalametPIRError { Self::VulkanInstanceCreationFailed => write!(f, "Failed to create a new instance of Vulkan."), Self::VulkanPhysicalDeviceNotFound => write!(f, "Failed to find a compatible Vulkan physical device."), Self::VulkanDeviceCreationFailed => write!(f, "Failed to create a Vulkan device and associated queue"), + Self::VulkanBufferCreationFailed => write!(f, "Failed to create a Vulkan buffer"), Self::InvalidMatrixDimension => write!(f, "The number of rows and columns in the matrix must be non-zero."), Self::IncompatibleDimensionForMatrixMultiplication => write!(f, "The matrix dimensions do not allow multiplication."), From 2b8d84c86d736a594419e4c561e36b32d4615c03 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Wed, 19 Mar 2025 22:00:53 +0530 Subject: [PATCH 08/29] Simplify return in matrix to transfer source buffer function Signed-off-by: Anjan Roy --- src/pir_internals/gpu.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index 2d356f4..49a71a6 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -66,9 +66,9 @@ pub fn setup_gpu() -> Result<(Arc, Arc, Arc, matrix: Matrix) -> Result, ChalametPIRError> { +pub fn matrix_to_src_buffer(memory_allocator: Arc, matrix: Matrix) -> Result, ChalametPIRError> { let matrix_as_bytes = matrix.to_bytes(); - let buffer = Buffer::from_iter( + Buffer::from_iter( memory_allocator.clone(), BufferCreateInfo { usage: BufferUsage::TRANSFER_SRC, @@ -80,7 +80,6 @@ pub fn make_src_buffer(memory_allocator: Arc, matrix: M }, matrix_as_bytes, ) - .map_err(|_| ChalametPIRError::VulkanBufferCreationFailed)?; - Ok(buffer) + .map_err(|_| ChalametPIRError::VulkanBufferCreationFailed) } } From 96552dc54dca033c0b3ec64a14cbec8b15d8e907 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Wed, 19 Mar 2025 22:02:24 +0530 Subject: [PATCH 09/29] Add function recording Vulkan buffer to buffer data transfer command Signed-off-by: Anjan Roy --- src/pir_internals/error.rs | 2 ++ src/pir_internals/gpu.rs | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/src/pir_internals/error.rs b/src/pir_internals/error.rs index e207452..a74c851 100644 --- a/src/pir_internals/error.rs +++ b/src/pir_internals/error.rs @@ -12,6 +12,7 @@ pub enum ChalametPIRError { VulkanPhysicalDeviceNotFound, VulkanDeviceCreationFailed, VulkanBufferCreationFailed, + VulkanTransferCommandRecordFailed, // Matrix InvalidMatrixDimension, @@ -48,6 +49,7 @@ impl Display for ChalametPIRError { Self::VulkanPhysicalDeviceNotFound => write!(f, "Failed to find a compatible Vulkan physical device."), Self::VulkanDeviceCreationFailed => write!(f, "Failed to create a Vulkan device and associated queue"), Self::VulkanBufferCreationFailed => write!(f, "Failed to create a Vulkan buffer"), + Self::VulkanTransferCommandRecordFailed => write!(f, "Failed to record Vulkan buffer to buffer data transfer command"), Self::InvalidMatrixDimension => write!(f, "The number of rows and columns in the matrix must be non-zero."), Self::IncompatibleDimensionForMatrixMultiplication => write!(f, "The matrix dimensions do not allow multiplication."), diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index 49a71a6..f8c93cc 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -82,4 +82,13 @@ pub fn matrix_to_src_buffer(memory_allocator: Arc, matr ) .map_err(|_| ChalametPIRError::VulkanBufferCreationFailed) } + +pub fn record_transfer( + cmd_buf_builder: &mut AutoCommandBufferBuilder, + src: Subbuffer<[u8]>, + dst: Subbuffer<[u8]>, +) -> Result<&mut AutoCommandBufferBuilder, ChalametPIRError> { + cmd_buf_builder + .copy_buffer(CopyBufferInfo::buffers(src, dst)) + .map_err(|_| ChalametPIRError::VulkanTransferCommandRecordFailed) } From 0e219342c58a8df26b1ef593f5b2085708cf1900 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Wed, 19 Mar 2025 22:03:37 +0530 Subject: [PATCH 10/29] Make error type more explicit Signed-off-by: Anjan Roy --- src/pir_internals/error.rs | 4 ++-- src/pir_internals/gpu.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pir_internals/error.rs b/src/pir_internals/error.rs index a74c851..acb7082 100644 --- a/src/pir_internals/error.rs +++ b/src/pir_internals/error.rs @@ -11,7 +11,7 @@ pub enum ChalametPIRError { VulkanInstanceCreationFailed, VulkanPhysicalDeviceNotFound, VulkanDeviceCreationFailed, - VulkanBufferCreationFailed, + VulkanSourceBufferCreationFailed, VulkanTransferCommandRecordFailed, // Matrix @@ -48,7 +48,7 @@ impl Display for ChalametPIRError { Self::VulkanInstanceCreationFailed => write!(f, "Failed to create a new instance of Vulkan."), Self::VulkanPhysicalDeviceNotFound => write!(f, "Failed to find a compatible Vulkan physical device."), Self::VulkanDeviceCreationFailed => write!(f, "Failed to create a Vulkan device and associated queue"), - Self::VulkanBufferCreationFailed => write!(f, "Failed to create a Vulkan buffer"), + Self::VulkanSourceBufferCreationFailed => write!(f, "Failed to create a Vulkan transfer source buffer"), Self::VulkanTransferCommandRecordFailed => write!(f, "Failed to record Vulkan buffer to buffer data transfer command"), Self::InvalidMatrixDimension => write!(f, "The number of rows and columns in the matrix must be non-zero."), diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index f8c93cc..b9c5097 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -80,7 +80,7 @@ pub fn matrix_to_src_buffer(memory_allocator: Arc, matr }, matrix_as_bytes, ) - .map_err(|_| ChalametPIRError::VulkanBufferCreationFailed) + .map_err(|_| ChalametPIRError::VulkanSourceBufferCreationFailed) } pub fn record_transfer( From 1b3c3bc53a8580456af62d67e4e288e335317aca Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Wed, 19 Mar 2025 22:12:35 +0530 Subject: [PATCH 11/29] Add function to create empty Vulkan storage buffer Signed-off-by: Anjan Roy --- src/pir_internals/error.rs | 2 ++ src/pir_internals/gpu.rs | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/src/pir_internals/error.rs b/src/pir_internals/error.rs index acb7082..c4ca323 100644 --- a/src/pir_internals/error.rs +++ b/src/pir_internals/error.rs @@ -12,6 +12,7 @@ pub enum ChalametPIRError { VulkanPhysicalDeviceNotFound, VulkanDeviceCreationFailed, VulkanSourceBufferCreationFailed, + VulkanEmptyBufferCreationFailed, VulkanTransferCommandRecordFailed, // Matrix @@ -49,6 +50,7 @@ impl Display for ChalametPIRError { Self::VulkanPhysicalDeviceNotFound => write!(f, "Failed to find a compatible Vulkan physical device."), Self::VulkanDeviceCreationFailed => write!(f, "Failed to create a Vulkan device and associated queue"), Self::VulkanSourceBufferCreationFailed => write!(f, "Failed to create a Vulkan transfer source buffer"), + Self::VulkanEmptyBufferCreationFailed => write!(f, "Failed to create an empty Vulkan storage buffer"), Self::VulkanTransferCommandRecordFailed => write!(f, "Failed to record Vulkan buffer to buffer data transfer command"), Self::InvalidMatrixDimension => write!(f, "The number of rows and columns in the matrix must be non-zero."), diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index b9c5097..9517881 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -83,6 +83,22 @@ pub fn matrix_to_src_buffer(memory_allocator: Arc, matr .map_err(|_| ChalametPIRError::VulkanSourceBufferCreationFailed) } +pub fn get_empty_storage_buffer(memory_allocator: Arc, byte_len: u64) -> Result, ChalametPIRError> { + Buffer::new_slice::( + memory_allocator.clone(), + BufferCreateInfo { + usage: BufferUsage::STORAGE_BUFFER | BufferUsage::TRANSFER_DST, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::PREFER_DEVICE, + ..Default::default() + }, + byte_len, + ) + .map_err(|_| ChalametPIRError::VulkanEmptyBufferCreationFailed) +} + pub fn record_transfer( cmd_buf_builder: &mut AutoCommandBufferBuilder, src: Subbuffer<[u8]>, From e526074d5d8bd8c45b509b2b55258d4d2efcf552 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Wed, 19 Mar 2025 23:49:10 +0530 Subject: [PATCH 12/29] Add function to submit transfer command buffer to queue and wait till it finishes Signed-off-by: Anjan Roy --- src/pir_internals/error.rs | 12 ++++++++---- src/pir_internals/gpu.rs | 17 ++++++++++++++++- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/pir_internals/error.rs b/src/pir_internals/error.rs index c4ca323..6f07049 100644 --- a/src/pir_internals/error.rs +++ b/src/pir_internals/error.rs @@ -14,6 +14,8 @@ pub enum ChalametPIRError { VulkanSourceBufferCreationFailed, VulkanEmptyBufferCreationFailed, VulkanTransferCommandRecordFailed, + VulkanCommandBufferBuildingFailed, + VulkanCommandBufferExecutionFailed, // Matrix InvalidMatrixDimension, @@ -48,10 +50,12 @@ impl Display for ChalametPIRError { Self::VulkanLibraryNotFound => write!(f, "Failed to load the default Vulkan library for the system."), Self::VulkanInstanceCreationFailed => write!(f, "Failed to create a new instance of Vulkan."), Self::VulkanPhysicalDeviceNotFound => write!(f, "Failed to find a compatible Vulkan physical device."), - Self::VulkanDeviceCreationFailed => write!(f, "Failed to create a Vulkan device and associated queue"), - Self::VulkanSourceBufferCreationFailed => write!(f, "Failed to create a Vulkan transfer source buffer"), - Self::VulkanEmptyBufferCreationFailed => write!(f, "Failed to create an empty Vulkan storage buffer"), - Self::VulkanTransferCommandRecordFailed => write!(f, "Failed to record Vulkan buffer to buffer data transfer command"), + Self::VulkanDeviceCreationFailed => write!(f, "Failed to create a Vulkan device and associated queue."), + Self::VulkanSourceBufferCreationFailed => write!(f, "Failed to create a Vulkan transfer source buffer."), + Self::VulkanEmptyBufferCreationFailed => write!(f, "Failed to create an empty Vulkan storage buffer."), + Self::VulkanTransferCommandRecordFailed => write!(f, "Failed to record Vulkan buffer to buffer data transfer command."), + Self::VulkanCommandBufferBuildingFailed => write!(f, "Failed to build a Vulkan command buffer."), + Self::VulkanCommandBufferExecutionFailed => write!(f, "Failed to execute the Vulkan command buffer"), Self::InvalidMatrixDimension => write!(f, "The number of rows and columns in the matrix must be non-zero."), Self::IncompatibleDimensionForMatrixMultiplication => write!(f, "The matrix dimensions do not allow multiplication."), diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index 9517881..e22e207 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -4,10 +4,13 @@ use std::sync::Arc; use vulkano::{ VulkanLibrary, buffer::{Buffer, BufferCreateInfo, BufferUsage, Subbuffer}, - command_buffer::{AutoCommandBufferBuilder, CopyBufferInfo, PrimaryAutoCommandBuffer, allocator::StandardCommandBufferAllocator}, + command_buffer::{ + AutoCommandBufferBuilder, CopyBufferInfo, PrimaryAutoCommandBuffer, PrimaryCommandBufferAbstract, allocator::StandardCommandBufferAllocator, + }, device::{Device, DeviceCreateInfo, DeviceExtensions, Queue, QueueCreateInfo, QueueFlags, physical::PhysicalDeviceType}, instance::{Instance, InstanceCreateFlags, InstanceCreateInfo}, memory::allocator::{AllocationCreateInfo, MemoryTypeFilter, StandardMemoryAllocator}, + sync::GpuFuture, }; pub fn setup_gpu() -> Result<(Arc, Arc, Arc, Arc), ChalametPIRError> { @@ -108,3 +111,15 @@ pub fn record_transfer( .copy_buffer(CopyBufferInfo::buffers(src, dst)) .map_err(|_| ChalametPIRError::VulkanTransferCommandRecordFailed) } + +pub fn finish_transfer(cmd_buf_builder: AutoCommandBufferBuilder, queue: Arc) -> Result<(), ChalametPIRError> { + cmd_buf_builder + .build() + .map_err(|_| ChalametPIRError::VulkanCommandBufferBuildingFailed)? + .execute(queue.clone()) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .then_signal_fence_and_flush() + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .wait(None) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed) +} From db5aca169cf26e17399b91bb71f3eca11233675b Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Wed, 19 Mar 2025 23:51:44 +0530 Subject: [PATCH 13/29] Rename error enum variant to be more generic Signed-off-by: Anjan Roy --- src/pir_internals/error.rs | 4 ++-- src/pir_internals/gpu.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pir_internals/error.rs b/src/pir_internals/error.rs index 6f07049..5d29661 100644 --- a/src/pir_internals/error.rs +++ b/src/pir_internals/error.rs @@ -13,7 +13,7 @@ pub enum ChalametPIRError { VulkanDeviceCreationFailed, VulkanSourceBufferCreationFailed, VulkanEmptyBufferCreationFailed, - VulkanTransferCommandRecordFailed, + VulkanCommandBufferRecordingFailed, VulkanCommandBufferBuildingFailed, VulkanCommandBufferExecutionFailed, @@ -53,7 +53,7 @@ impl Display for ChalametPIRError { Self::VulkanDeviceCreationFailed => write!(f, "Failed to create a Vulkan device and associated queue."), Self::VulkanSourceBufferCreationFailed => write!(f, "Failed to create a Vulkan transfer source buffer."), Self::VulkanEmptyBufferCreationFailed => write!(f, "Failed to create an empty Vulkan storage buffer."), - Self::VulkanTransferCommandRecordFailed => write!(f, "Failed to record Vulkan buffer to buffer data transfer command."), + Self::VulkanCommandBufferRecordingFailed => write!(f, "Failed to record command in a Vulkan command buffer."), Self::VulkanCommandBufferBuildingFailed => write!(f, "Failed to build a Vulkan command buffer."), Self::VulkanCommandBufferExecutionFailed => write!(f, "Failed to execute the Vulkan command buffer"), diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index e22e207..f80d5e4 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -109,7 +109,7 @@ pub fn record_transfer( ) -> Result<&mut AutoCommandBufferBuilder, ChalametPIRError> { cmd_buf_builder .copy_buffer(CopyBufferInfo::buffers(src, dst)) - .map_err(|_| ChalametPIRError::VulkanTransferCommandRecordFailed) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed) } pub fn finish_transfer(cmd_buf_builder: AutoCommandBufferBuilder, queue: Arc) -> Result<(), ChalametPIRError> { From 8ff39651c8879dde5287422e5b9afb1024bbbcf5 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Thu, 20 Mar 2025 00:04:01 +0530 Subject: [PATCH 14/29] Add function for computing number of bytes required to encode matrix Signed-off-by: Anjan Roy --- src/pir_internals/matrix.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pir_internals/matrix.rs b/src/pir_internals/matrix.rs index 85e87f5..4f78317 100644 --- a/src/pir_internals/matrix.rs +++ b/src/pir_internals/matrix.rs @@ -84,6 +84,10 @@ impl Matrix { pub fn num_elems(&self) -> usize { self.elems.len() } + #[inline(always)] + pub fn num_bytes(&self) -> usize { + std::mem::size_of_val(&self.rows) + std::mem::size_of_val(&self.cols) + std::mem::size_of::() * (self.rows * self.cols) as usize + } /// Performs the multiplication of a row vector (1xN matrix) by the transpose of a matrix (MxN). /// From 9f4e0eaf76e45ece6d04aae7e44244df6f0fdfc9 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Thu, 20 Mar 2025 13:02:20 +0530 Subject: [PATCH 15/29] Matrix-matrix multiplication command submission and execution on GPU queue Signed-off-by: Anjan Roy --- src/pir_internals/error.rs | 10 ++++- src/pir_internals/gpu.rs | 79 +++++++++++++++++++++++++++++++++++++- 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/src/pir_internals/error.rs b/src/pir_internals/error.rs index 5d29661..47b265b 100644 --- a/src/pir_internals/error.rs +++ b/src/pir_internals/error.rs @@ -13,9 +13,13 @@ pub enum ChalametPIRError { VulkanDeviceCreationFailed, VulkanSourceBufferCreationFailed, VulkanEmptyBufferCreationFailed, + VulkanCommandBufferBuilderCreationFailed, VulkanCommandBufferRecordingFailed, VulkanCommandBufferBuildingFailed, VulkanCommandBufferExecutionFailed, + VulkanComputeShaderLoadingFailed, + VulkanComputePipelineCreationFailed, + VulkanDescriptorSetCreationFailed, // Matrix InvalidMatrixDimension, @@ -53,9 +57,13 @@ impl Display for ChalametPIRError { Self::VulkanDeviceCreationFailed => write!(f, "Failed to create a Vulkan device and associated queue."), Self::VulkanSourceBufferCreationFailed => write!(f, "Failed to create a Vulkan transfer source buffer."), Self::VulkanEmptyBufferCreationFailed => write!(f, "Failed to create an empty Vulkan storage buffer."), + Self::VulkanCommandBufferBuilderCreationFailed => write!(f, "Failed to create a Vulkan command buffer builder."), Self::VulkanCommandBufferRecordingFailed => write!(f, "Failed to record command in a Vulkan command buffer."), Self::VulkanCommandBufferBuildingFailed => write!(f, "Failed to build a Vulkan command buffer."), - Self::VulkanCommandBufferExecutionFailed => write!(f, "Failed to execute the Vulkan command buffer"), + Self::VulkanCommandBufferExecutionFailed => write!(f, "Failed to execute the Vulkan command buffer."), + Self::VulkanComputeShaderLoadingFailed => write!(f, "Failed to load Vulkan compute shader module."), + Self::VulkanComputePipelineCreationFailed => write!(f, "Failed to create Vulkan compute pipeline."), + Self::VulkanDescriptorSetCreationFailed => write!(f, "Failed to create descriptor set for Vulkan compute pipeline."), Self::InvalidMatrixDimension => write!(f, "The number of rows and columns in the matrix must be non-zero."), Self::IncompatibleDimensionForMatrixMultiplication => write!(f, "The matrix dimensions do not allow multiplication."), diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index f80d5e4..127a477 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -1,15 +1,21 @@ -use super::matrix::Matrix; +use super::{mat_x_mat_shader, matrix::Matrix}; use crate::ChalametPIRError; use std::sync::Arc; use vulkano::{ VulkanLibrary, buffer::{Buffer, BufferCreateInfo, BufferUsage, Subbuffer}, command_buffer::{ - AutoCommandBufferBuilder, CopyBufferInfo, PrimaryAutoCommandBuffer, PrimaryCommandBufferAbstract, allocator::StandardCommandBufferAllocator, + AutoCommandBufferBuilder, CommandBufferUsage, CopyBufferInfo, PrimaryAutoCommandBuffer, PrimaryCommandBufferAbstract, + allocator::StandardCommandBufferAllocator, }, + descriptor_set::{DescriptorSet, WriteDescriptorSet, allocator::StandardDescriptorSetAllocator}, device::{Device, DeviceCreateInfo, DeviceExtensions, Queue, QueueCreateInfo, QueueFlags, physical::PhysicalDeviceType}, instance::{Instance, InstanceCreateFlags, InstanceCreateInfo}, memory::allocator::{AllocationCreateInfo, MemoryTypeFilter, StandardMemoryAllocator}, + pipeline::{ + ComputePipeline, Pipeline, PipelineBindPoint, PipelineLayout, PipelineShaderStageCreateInfo, compute::ComputePipelineCreateInfo, + layout::PipelineDescriptorSetLayoutCreateInfo, + }, sync::GpuFuture, }; @@ -123,3 +129,72 @@ pub fn finish_transfer(cmd_buf_builder: AutoCommandBufferBuilder, + queue: Arc, + command_buffer_allocator: Arc, + left_mat: Subbuffer<[u8]>, + rhs_mat: Subbuffer<[u8]>, + res_mat: Subbuffer<[u8]>, + wg_count: [u32; 3], +) -> Result<(), ChalametPIRError> { + let pipeline = { + let cs = mat_x_mat_shader::load(device.clone()).map_err(|_| ChalametPIRError::VulkanComputeShaderLoadingFailed)?; + let cs_entry_point = cs.entry_point("main").ok_or(ChalametPIRError::VulkanComputeShaderLoadingFailed)?; + let compute_stage = PipelineShaderStageCreateInfo::new(cs_entry_point); + + let layout = PipelineLayout::new( + device.clone(), + PipelineDescriptorSetLayoutCreateInfo::from_stages([&compute_stage]) + .into_pipeline_layout_create_info(device.clone()) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)?, + ) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)?; + + ComputePipeline::new(device.clone(), None, ComputePipelineCreateInfo::stage_layout(compute_stage, layout.clone())) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)? + }; + + let descriptor_set_allocator = Arc::new(StandardDescriptorSetAllocator::new(device.clone(), Default::default())); + let descriptor_set_layout = pipeline.layout().set_layouts()[0].clone(); + let descriptor_set = DescriptorSet::new( + descriptor_set_allocator, + descriptor_set_layout, + [ + WriteDescriptorSet::buffer(0, left_mat), + WriteDescriptorSet::buffer(1, rhs_mat), + WriteDescriptorSet::buffer(2, res_mat), + ], + [], + ) + .map_err(|_| ChalametPIRError::VulkanDescriptorSetCreationFailed)?; + + let command_buffer = { + let mut command_buffer_builder = + AutoCommandBufferBuilder::primary(command_buffer_allocator, queue.queue_family_index(), CommandBufferUsage::OneTimeSubmit) + .map_err(|_| ChalametPIRError::VulkanCommandBufferBuilderCreationFailed)?; + + unsafe { + command_buffer_builder + .bind_pipeline_compute(pipeline.clone()) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)? + .bind_descriptor_sets(PipelineBindPoint::Compute, pipeline.layout().clone(), 0, descriptor_set) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)? + .dispatch(wg_count) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)?; + } + + command_buffer_builder + .build() + .map_err(|_| ChalametPIRError::VulkanCommandBufferBuildingFailed)? + }; + + command_buffer + .execute(queue.clone()) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .then_signal_fence_and_flush() + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .wait(None) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed) +} From 3d5757b86d160f43e473ec08936aeb9f0f7ce1e0 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Thu, 20 Mar 2025 14:41:14 +0530 Subject: [PATCH 16/29] Reformat GLSL compute shader using clang-format Signed-off-by: Anjan Roy --- shaders/mat_x_mat.glsl | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/shaders/mat_x_mat.glsl b/shaders/mat_x_mat.glsl index 98641bf..8dae1fe 100644 --- a/shaders/mat_x_mat.glsl +++ b/shaders/mat_x_mat.glsl @@ -3,30 +3,28 @@ layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; -layout(set = 0, binding = 0) buffer readonly MatrixA -{ +layout(set = 0, binding = 0) buffer readonly MatrixA { uint rows; uint cols; uint[] elems; -} matrix_a; +} +matrix_a; -layout(set = 0, binding = 1) buffer readonly MatrixB -{ +layout(set = 0, binding = 1) buffer readonly MatrixB { uint rows; uint cols; uint[] elems; -} matrix_b; +} +matrix_b; -layout(set = 0, binding = 2) buffer writeonly MatrixC -{ +layout(set = 0, binding = 2) buffer writeonly MatrixC { uint rows; uint cols; uint[] elems; -} matrix_c; +} +matrix_c; -void -main() -{ +void main() { const uint row_idx = gl_GlobalInvocationID.x; const uint col_idx = gl_GlobalInvocationID.y; @@ -41,7 +39,8 @@ main() uint sum = 0; for (uint i = 0; i < matrix_a.cols; i++) { - sum += matrix_a.elems[row_idx * matrix_a.cols + i] * matrix_b.elems[i * matrix_b.cols + col_idx]; + sum += matrix_a.elems[row_idx * matrix_a.cols + i] * + matrix_b.elems[i * matrix_b.cols + col_idx]; } matrix_c.elems[row_idx * matrix_b.cols + col_idx] = sum; From 1cc480621f5c2bcbd862ca8df00c7e835ec2e029 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Thu, 20 Mar 2025 14:54:27 +0530 Subject: [PATCH 17/29] Add matrix transpose compute shader Signed-off-by: Anjan Roy --- shaders/mat_transpose.glsl | 35 +++++++++++++++++++++++ src/pir_internals/mat_transpose_shader.rs | 5 ++++ src/pir_internals/mod.rs | 2 ++ 3 files changed, 42 insertions(+) create mode 100644 shaders/mat_transpose.glsl create mode 100644 src/pir_internals/mat_transpose_shader.rs diff --git a/shaders/mat_transpose.glsl b/shaders/mat_transpose.glsl new file mode 100644 index 0000000..11bfc2e --- /dev/null +++ b/shaders/mat_transpose.glsl @@ -0,0 +1,35 @@ +#version 460 +#pragma shader_stage(compute) + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +layout(set = 0, binding = 0) buffer readonly MatrixA { + uint rows; + uint cols; + uint[] elems; +} +matrix_a; + +layout(set = 0, binding = 1) buffer writeonly MatrixB { + uint rows; + uint cols; + uint[] elems; +} +matrix_b; + +void main() { + const uint row_idx = gl_GlobalInvocationID.x; + const uint col_idx = gl_GlobalInvocationID.y; + + if (row_idx >= matrix_a.cols || col_idx >= matrix_a.rows) { + return; + } + + if ((row_idx == 0) && (col_idx == 0)) { + matrix_b.rows = matrix_a.cols; + matrix_b.cols = matrix_a.rows; + } + + matrix_b.elems[row_idx * matrix_a.rows + col_idx] = + matrix_a.elems[row_idx * matrix_a.cols + col_idx]; +} diff --git a/src/pir_internals/mat_transpose_shader.rs b/src/pir_internals/mat_transpose_shader.rs new file mode 100644 index 0000000..d49b087 --- /dev/null +++ b/src/pir_internals/mat_transpose_shader.rs @@ -0,0 +1,5 @@ +vulkano_shaders::shader! { + ty: "compute", + path: "./shaders/mat_transpose.glsl", + vulkan_version: "1.2", +} diff --git a/src/pir_internals/mod.rs b/src/pir_internals/mod.rs index 2e50f5a..dae466a 100644 --- a/src/pir_internals/mod.rs +++ b/src/pir_internals/mod.rs @@ -9,3 +9,5 @@ pub mod serialization; pub mod gpu; #[cfg(feature = "gpu")] pub mod mat_x_mat_shader; +#[cfg(feature = "gpu")] +pub mod mat_transpose_shader; From 98a074688f43b815b2f59014fda88fe08cbf7bf0 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Thu, 20 Mar 2025 15:27:18 +0530 Subject: [PATCH 18/29] Submit and wait for matrix transpose job to finish on GPU Signed-off-by: Anjan Roy --- src/pir_internals/gpu.rs | 66 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index 127a477..33fdb9d 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -1,4 +1,4 @@ -use super::{mat_x_mat_shader, matrix::Matrix}; +use super::{mat_transpose_shader, mat_x_mat_shader, matrix::Matrix}; use crate::ChalametPIRError; use std::sync::Arc; use vulkano::{ @@ -198,3 +198,67 @@ pub fn mat_x_mat( .wait(None) .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed) } + +pub fn mat_transpose( + device: Arc, + queue: Arc, + command_buffer_allocator: Arc, + orig_mat: Subbuffer<[u8]>, + res_mat: Subbuffer<[u8]>, + wg_count: [u32; 3], +) -> Result<(), ChalametPIRError> { + let pipeline = { + let cs = mat_transpose_shader::load(device.clone()).map_err(|_| ChalametPIRError::VulkanComputeShaderLoadingFailed)?; + let cs_entry_point = cs.entry_point("main").ok_or(ChalametPIRError::VulkanComputeShaderLoadingFailed)?; + let compute_stage = PipelineShaderStageCreateInfo::new(cs_entry_point); + + let layout = PipelineLayout::new( + device.clone(), + PipelineDescriptorSetLayoutCreateInfo::from_stages([&compute_stage]) + .into_pipeline_layout_create_info(device.clone()) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)?, + ) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)?; + + ComputePipeline::new(device.clone(), None, ComputePipelineCreateInfo::stage_layout(compute_stage, layout.clone())) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)? + }; + + let descriptor_set_allocator = Arc::new(StandardDescriptorSetAllocator::new(device.clone(), Default::default())); + let descriptor_set_layout = pipeline.layout().set_layouts()[0].clone(); + let descriptor_set = DescriptorSet::new( + descriptor_set_allocator, + descriptor_set_layout, + [WriteDescriptorSet::buffer(0, orig_mat), WriteDescriptorSet::buffer(1, res_mat)], + [], + ) + .map_err(|_| ChalametPIRError::VulkanDescriptorSetCreationFailed)?; + + let command_buffer = { + let mut command_buffer_builder = + AutoCommandBufferBuilder::primary(command_buffer_allocator, queue.queue_family_index(), CommandBufferUsage::OneTimeSubmit) + .map_err(|_| ChalametPIRError::VulkanCommandBufferBuilderCreationFailed)?; + + unsafe { + command_buffer_builder + .bind_pipeline_compute(pipeline.clone()) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)? + .bind_descriptor_sets(PipelineBindPoint::Compute, pipeline.layout().clone(), 0, descriptor_set) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)? + .dispatch(wg_count) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)?; + } + + command_buffer_builder + .build() + .map_err(|_| ChalametPIRError::VulkanCommandBufferBuildingFailed)? + }; + + command_buffer + .execute(queue.clone()) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .then_signal_fence_and_flush() + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .wait(None) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed) +} From 679bc17d8ba182f63cffc3f3c9290d5101d11266 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Thu, 20 Mar 2025 17:57:39 +0530 Subject: [PATCH 19/29] Fix matrix transpose shader Signed-off-by: Anjan Roy --- shaders/mat_transpose.glsl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/shaders/mat_transpose.glsl b/shaders/mat_transpose.glsl index 11bfc2e..48e1b4d 100644 --- a/shaders/mat_transpose.glsl +++ b/shaders/mat_transpose.glsl @@ -21,7 +21,7 @@ void main() { const uint row_idx = gl_GlobalInvocationID.x; const uint col_idx = gl_GlobalInvocationID.y; - if (row_idx >= matrix_a.cols || col_idx >= matrix_a.rows) { + if (row_idx >= matrix_a.rows || col_idx >= matrix_a.cols) { return; } @@ -30,6 +30,8 @@ void main() { matrix_b.cols = matrix_a.rows; } - matrix_b.elems[row_idx * matrix_a.rows + col_idx] = - matrix_a.elems[row_idx * matrix_a.cols + col_idx]; + const uint src_index = row_idx * matrix_a.cols + col_idx; + const uint dst_index = col_idx * matrix_a.rows + row_idx; + + matrix_b.elems[dst_index] = matrix_a.elems[src_index]; } From 3f33f81491080ac4794c01dab75b31ad7142550c Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Thu, 20 Mar 2025 18:47:41 +0530 Subject: [PATCH 20/29] Refactor function for transferring host matrix to device Signed-off-by: Anjan Roy --- src/pir_internals/gpu.rs | 72 +++++++++++++++++++++++++--------------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index 33fdb9d..92aaf1e 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -75,10 +75,17 @@ pub fn setup_gpu() -> Result<(Arc, Arc, Arc, matrix: Matrix) -> Result, ChalametPIRError> { +pub fn transfer_mat_to_device( + queue: Arc, + mem_alloc: Arc, + cmd_buf_alloc: Arc, + matrix: Matrix, +) -> Result, ChalametPIRError> { let matrix_as_bytes = matrix.to_bytes(); - Buffer::from_iter( - memory_allocator.clone(), + let matrix_byte_len = matrix_as_bytes.len() as u64; + + let src_buf = Buffer::from_iter( + mem_alloc.clone(), BufferCreateInfo { usage: BufferUsage::TRANSFER_SRC, ..Default::default() @@ -89,12 +96,10 @@ pub fn matrix_to_src_buffer(memory_allocator: Arc, matr }, matrix_as_bytes, ) - .map_err(|_| ChalametPIRError::VulkanSourceBufferCreationFailed) -} + .map_err(|_| ChalametPIRError::VulkanSourceBufferCreationFailed)?; -pub fn get_empty_storage_buffer(memory_allocator: Arc, byte_len: u64) -> Result, ChalametPIRError> { - Buffer::new_slice::( - memory_allocator.clone(), + let dst_buf = Buffer::new_slice::( + mem_alloc.clone(), BufferCreateInfo { usage: BufferUsage::STORAGE_BUFFER | BufferUsage::TRANSFER_DST, ..Default::default() @@ -103,31 +108,46 @@ pub fn get_empty_storage_buffer(memory_allocator: Arc, memory_type_filter: MemoryTypeFilter::PREFER_DEVICE, ..Default::default() }, - byte_len, + matrix_byte_len, ) - .map_err(|_| ChalametPIRError::VulkanEmptyBufferCreationFailed) -} + .map_err(|_| ChalametPIRError::VulkanEmptyBufferCreationFailed)?; -pub fn record_transfer( - cmd_buf_builder: &mut AutoCommandBufferBuilder, - src: Subbuffer<[u8]>, - dst: Subbuffer<[u8]>, -) -> Result<&mut AutoCommandBufferBuilder, ChalametPIRError> { - cmd_buf_builder - .copy_buffer(CopyBufferInfo::buffers(src, dst)) - .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed) -} + let cmd_buf = { + let mut builder = AutoCommandBufferBuilder::primary(cmd_buf_alloc, queue.queue_family_index(), CommandBufferUsage::OneTimeSubmit) + .map_err(|_| ChalametPIRError::VulkanCommandBufferBuilderCreationFailed)?; -pub fn finish_transfer(cmd_buf_builder: AutoCommandBufferBuilder, queue: Arc) -> Result<(), ChalametPIRError> { - cmd_buf_builder - .build() - .map_err(|_| ChalametPIRError::VulkanCommandBufferBuildingFailed)? - .execute(queue.clone()) + builder + .copy_buffer(CopyBufferInfo::buffers(src_buf, dst_buf.clone())) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)?; + + builder.build().map_err(|_| ChalametPIRError::VulkanCommandBufferBuildingFailed)? + }; + + cmd_buf + .execute(queue) .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? .then_signal_fence_and_flush() .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? .wait(None) - .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)?; + + Ok(dst_buf) +} + +pub fn get_empty_storage_buffer(memory_allocator: Arc, byte_len: u64) -> Result, ChalametPIRError> { + Buffer::new_slice::( + memory_allocator.clone(), + BufferCreateInfo { + usage: BufferUsage::STORAGE_BUFFER, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::HOST_RANDOM_ACCESS | MemoryTypeFilter::PREFER_DEVICE, + ..Default::default() + }, + byte_len, + ) + .map_err(|_| ChalametPIRError::VulkanEmptyBufferCreationFailed) } pub fn mat_x_mat( From 9b50f41c838b064e1b5453d80ae5677ce534fba4 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Thu, 20 Mar 2025 20:00:45 +0530 Subject: [PATCH 21/29] Maintain two different functions for host-accessible and device-local buffer creation Signed-off-by: Anjan Roy --- src/pir_internals/error.rs | 8 ++++---- src/pir_internals/gpu.rs | 31 ++++++++++++++++++++++--------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/pir_internals/error.rs b/src/pir_internals/error.rs index 47b265b..19cfbaa 100644 --- a/src/pir_internals/error.rs +++ b/src/pir_internals/error.rs @@ -11,12 +11,12 @@ pub enum ChalametPIRError { VulkanInstanceCreationFailed, VulkanPhysicalDeviceNotFound, VulkanDeviceCreationFailed, - VulkanSourceBufferCreationFailed, - VulkanEmptyBufferCreationFailed, + VulkanBufferCreationFailed, VulkanCommandBufferBuilderCreationFailed, VulkanCommandBufferRecordingFailed, VulkanCommandBufferBuildingFailed, VulkanCommandBufferExecutionFailed, + VulkanReadingFromBufferFailed, VulkanComputeShaderLoadingFailed, VulkanComputePipelineCreationFailed, VulkanDescriptorSetCreationFailed, @@ -55,12 +55,12 @@ impl Display for ChalametPIRError { Self::VulkanInstanceCreationFailed => write!(f, "Failed to create a new instance of Vulkan."), Self::VulkanPhysicalDeviceNotFound => write!(f, "Failed to find a compatible Vulkan physical device."), Self::VulkanDeviceCreationFailed => write!(f, "Failed to create a Vulkan device and associated queue."), - Self::VulkanSourceBufferCreationFailed => write!(f, "Failed to create a Vulkan transfer source buffer."), - Self::VulkanEmptyBufferCreationFailed => write!(f, "Failed to create an empty Vulkan storage buffer."), + Self::VulkanBufferCreationFailed => write!(f, "Failed to create a Vulkan transfer source buffer."), Self::VulkanCommandBufferBuilderCreationFailed => write!(f, "Failed to create a Vulkan command buffer builder."), Self::VulkanCommandBufferRecordingFailed => write!(f, "Failed to record command in a Vulkan command buffer."), Self::VulkanCommandBufferBuildingFailed => write!(f, "Failed to build a Vulkan command buffer."), Self::VulkanCommandBufferExecutionFailed => write!(f, "Failed to execute the Vulkan command buffer."), + Self::VulkanReadingFromBufferFailed => write!(f, "Failed to read from Vulkan buuffer."), Self::VulkanComputeShaderLoadingFailed => write!(f, "Failed to load Vulkan compute shader module."), Self::VulkanComputePipelineCreationFailed => write!(f, "Failed to create Vulkan compute pipeline."), Self::VulkanDescriptorSetCreationFailed => write!(f, "Failed to create descriptor set for Vulkan compute pipeline."), diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index 92aaf1e..15499e8 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -3,11 +3,8 @@ use crate::ChalametPIRError; use std::sync::Arc; use vulkano::{ VulkanLibrary, - buffer::{Buffer, BufferCreateInfo, BufferUsage, Subbuffer}, - command_buffer::{ - AutoCommandBufferBuilder, CommandBufferUsage, CopyBufferInfo, PrimaryAutoCommandBuffer, PrimaryCommandBufferAbstract, - allocator::StandardCommandBufferAllocator, - }, + buffer::{Buffer, BufferCreateFlags, BufferCreateInfo, BufferUsage, Subbuffer}, + command_buffer::{AutoCommandBufferBuilder, CommandBufferUsage, CopyBufferInfo, PrimaryCommandBufferAbstract, allocator::StandardCommandBufferAllocator}, descriptor_set::{DescriptorSet, WriteDescriptorSet, allocator::StandardDescriptorSetAllocator}, device::{Device, DeviceCreateInfo, DeviceExtensions, Queue, QueueCreateInfo, QueueFlags, physical::PhysicalDeviceType}, instance::{Instance, InstanceCreateFlags, InstanceCreateInfo}, @@ -96,7 +93,7 @@ pub fn transfer_mat_to_device( }, matrix_as_bytes, ) - .map_err(|_| ChalametPIRError::VulkanSourceBufferCreationFailed)?; + .map_err(|_| ChalametPIRError::VulkanBufferCreationFailed)?; let dst_buf = Buffer::new_slice::( mem_alloc.clone(), @@ -110,7 +107,7 @@ pub fn transfer_mat_to_device( }, matrix_byte_len, ) - .map_err(|_| ChalametPIRError::VulkanEmptyBufferCreationFailed)?; + .map_err(|_| ChalametPIRError::VulkanBufferCreationFailed)?; let cmd_buf = { let mut builder = AutoCommandBufferBuilder::primary(cmd_buf_alloc, queue.queue_family_index(), CommandBufferUsage::OneTimeSubmit) @@ -134,7 +131,7 @@ pub fn transfer_mat_to_device( Ok(dst_buf) } -pub fn get_empty_storage_buffer(memory_allocator: Arc, byte_len: u64) -> Result, ChalametPIRError> { +pub fn get_empty_host_readable_buffer(memory_allocator: Arc, byte_len: u64) -> Result, ChalametPIRError> { Buffer::new_slice::( memory_allocator.clone(), BufferCreateInfo { @@ -147,7 +144,23 @@ pub fn get_empty_storage_buffer(memory_allocator: Arc, }, byte_len, ) - .map_err(|_| ChalametPIRError::VulkanEmptyBufferCreationFailed) + .map_err(|_| ChalametPIRError::VulkanBufferCreationFailed) +} + +pub fn get_empty_device_local_buffer(memory_allocator: Arc, byte_len: u64) -> Result, ChalametPIRError> { + Buffer::new_slice::( + memory_allocator.clone(), + BufferCreateInfo { + usage: BufferUsage::STORAGE_BUFFER, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::PREFER_DEVICE, + ..Default::default() + }, + byte_len, + ) + .map_err(|_| ChalametPIRError::VulkanBufferCreationFailed) } pub fn mat_x_mat( From 450d7dc44bb09020796b42d0d418581d585cc585 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Thu, 20 Mar 2025 21:12:57 +0530 Subject: [PATCH 22/29] Implementation server-setup phase for `gpu` feature Signed-off-by: Anjan Roy --- src/pir_internals/gpu.rs | 2 +- src/pir_internals/mod.rs | 4 +- src/server.rs | 83 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 3 deletions(-) diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index 15499e8..d8d8bcf 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -3,7 +3,7 @@ use crate::ChalametPIRError; use std::sync::Arc; use vulkano::{ VulkanLibrary, - buffer::{Buffer, BufferCreateFlags, BufferCreateInfo, BufferUsage, Subbuffer}, + buffer::{Buffer, BufferCreateInfo, BufferUsage, Subbuffer}, command_buffer::{AutoCommandBufferBuilder, CommandBufferUsage, CopyBufferInfo, PrimaryCommandBufferAbstract, allocator::StandardCommandBufferAllocator}, descriptor_set::{DescriptorSet, WriteDescriptorSet, allocator::StandardDescriptorSetAllocator}, device::{Device, DeviceCreateInfo, DeviceExtensions, Queue, QueueCreateInfo, QueueFlags, physical::PhysicalDeviceType}, diff --git a/src/pir_internals/mod.rs b/src/pir_internals/mod.rs index dae466a..edad805 100644 --- a/src/pir_internals/mod.rs +++ b/src/pir_internals/mod.rs @@ -8,6 +8,6 @@ pub mod serialization; #[cfg(feature = "gpu")] pub mod gpu; #[cfg(feature = "gpu")] -pub mod mat_x_mat_shader; -#[cfg(feature = "gpu")] pub mod mat_transpose_shader; +#[cfg(feature = "gpu")] +pub mod mat_x_mat_shader; diff --git a/src/server.rs b/src/server.rs index d845360..1d15781 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "gpu")] +use crate::pir_internals::gpu; use crate::{ ChalametPIRError, pir_internals::{ @@ -40,6 +42,7 @@ impl Server { /// # Returns /// /// A `Result` containing a tuple of the `Server` object, the serialized hint matrix bytes, and the serialized filter parameters bytes. Returns an error if any error occurs during setup. + #[cfg(not(feature = "gpu"))] pub fn setup(seed_μ: &[u8; SEED_BYTE_LEN], db: HashMap<&[u8], &[u8]>) -> Result<(Server, Vec, Vec), ChalametPIRError> { let db_num_kv_pairs = db.len(); if branch_opt_util::unlikely(db_num_kv_pairs == 0) { @@ -62,6 +65,86 @@ impl Server { Ok((Server { transposed_parsed_db_mat_d }, hint_bytes, filter_param_bytes)) } + /// TODO: Update following documentation before publishing. + /// + /// Sets up the keyword **P**rivate **I**nformation **R**etrieval scheme's server with a given Key-Value database. + /// + /// This function takes a database as input and generates the necessary matrices and parameters for responding to client queries. + /// It involves several steps: + /// 1. **Database Validation:** The database must not be empty and should have at most 242 entries. Returns an error if validation fails. + /// 2. **Matrix Generation from Database:** Creates a `Matrix` (`parsed_db_mat_d`) representing the database. Uses the `Matrix::from_kv_database` function, which might involve multiple attempts (`SERVER_SETUP_MAX_ATTEMPT_COUNT`) to generate a suitable matrix. Returns an error if matrix generation fails. This also generates a `filter` object used in later stages of the PIR protocol. + /// 3. **Public Matrix Generation:** Generates a public matrix (`pub_mat_a`) using a provided seed (`seed_μ`). The dimensions of this matrix are determined by `LWE_DIMENSION` and the number of fingerprints in the `filter`. + /// 4. **Hint Matrix Calculation:** Computes the hint matrix (`hint_mat_m`) by multiplying the public matrix and the parsed database matrix. + /// 5. **Serialization:** Converts the hint matrix and filter parameters into byte vectors for storage and transmission. Returns an error if conversion fails. + /// 6. **Transposition:** Transposes the parsed database matrix (`parsed_db_mat_d`) to optimize memory access patterns during execution of the `respond` function. + /// + /// # Arguments + /// + /// * `seed_μ`: The seed used for generating the public matrix. + /// * `db`: The input database, represented as a hash map of key-value pairs. + /// + /// The constant parameter `ARITY` can be 3 or 4, denoting the use of a 3/4-wise XOR binary fuse filter. + /// This choice affects client/server computation and communication costs. + /// + /// # Returns + /// + /// A `Result` containing a tuple of the `Server` object, the serialized hint matrix bytes, and the serialized filter parameters bytes. Returns an error if any error occurs during setup. + #[cfg(feature = "gpu")] + pub fn setup(seed_μ: &[u8; SEED_BYTE_LEN], db: HashMap<&[u8], &[u8]>) -> Result<(Server, Vec, Vec), ChalametPIRError> { + let db_num_kv_pairs = db.len(); + if branch_opt_util::unlikely(db_num_kv_pairs == 0) { + return Err(ChalametPIRError::EmptyKVDatabase); + } + + let mat_elem_bit_len = Self::find_encoded_db_matrix_element_bit_length(db_num_kv_pairs)?; + let (parsed_db_mat_d, filter) = Matrix::from_kv_database::(db, mat_elem_bit_len, SERVER_SETUP_MAX_ATTEMPT_COUNT)?; + + let pub_mat_a_num_rows = LWE_DIMENSION; + let pub_mat_a_num_cols = filter.num_fingerprints as u32; + + let pub_mat_a = unsafe { Matrix::generate_from_seed(pub_mat_a_num_rows, pub_mat_a_num_cols, seed_μ).unwrap_unchecked() }; + + let (device, queue, mem_alloc, cmd_buf_alloc) = gpu::setup_gpu()?; + + let hint_mat_m_num_rows = pub_mat_a_num_rows; + let hint_mat_m_num_cols = parsed_db_mat_d.num_cols(); + let hint_mat_m_byte_len = (2 * std::mem::size_of::() + (hint_mat_m_num_rows * hint_mat_m_num_cols) as usize * std::mem::size_of::()) as u64; + let hint_mat_m_wg_count = [hint_mat_m_num_rows.div_ceil(8), hint_mat_m_num_cols.div_ceil(8), 1]; + + let parsed_db_mat_d_byte_len = parsed_db_mat_d.num_bytes() as u64; + let parsed_db_mat_d_wg_count = [parsed_db_mat_d.num_rows().div_ceil(8), parsed_db_mat_d.num_cols().div_ceil(8), 1]; + + let pub_mat_a_buf = gpu::transfer_mat_to_device(queue.clone(), mem_alloc.clone(), cmd_buf_alloc.clone(), pub_mat_a)?; + let parsed_db_mat_d_buf = gpu::transfer_mat_to_device(queue.clone(), mem_alloc.clone(), cmd_buf_alloc.clone(), parsed_db_mat_d.clone())?; + let hint_mat_m_buf = gpu::get_empty_host_readable_buffer(mem_alloc.clone(), hint_mat_m_byte_len)?; + let transposed_parsed_db_mat_d_buf = gpu::get_empty_device_local_buffer(mem_alloc.clone(), parsed_db_mat_d_byte_len)?; + + gpu::mat_x_mat( + device.clone(), + queue.clone(), + cmd_buf_alloc.clone(), + pub_mat_a_buf, + parsed_db_mat_d_buf.clone(), + hint_mat_m_buf.clone(), + hint_mat_m_wg_count, + )?; + + gpu::mat_transpose( + device.clone(), + queue.clone(), + cmd_buf_alloc.clone(), + parsed_db_mat_d_buf, + transposed_parsed_db_mat_d_buf.clone(), + parsed_db_mat_d_wg_count, + )?; + + let hint_bytes = hint_mat_m_buf.read().map_err(|_| ChalametPIRError::VulkanReadingFromBufferFailed)?.to_vec(); + let filter_param_bytes: Vec = filter.to_bytes(); + let transposed_parsed_db_mat_d = parsed_db_mat_d.transpose(); + + Ok((Server { transposed_parsed_db_mat_d }, hint_bytes, filter_param_bytes)) + } + /// Responds to a client query. /// /// This function takes a client's query (in byte form) as input and uses the transposed database matrix to compute the response. From 3be9c225621437cc7e0c9340ec6df3b20d2f8acf Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Tue, 1 Apr 2025 22:51:15 +0530 Subject: [PATCH 23/29] Add row-vector transposed matrix multiplication compute shader Signed-off-by: Anjan Roy --- shaders/vec_x_mat.glsl | 47 +++++++++++++++++++++++++++ src/pir_internals/mod.rs | 2 ++ src/pir_internals/vec_x_mat_shader.rs | 5 +++ 3 files changed, 54 insertions(+) create mode 100644 shaders/vec_x_mat.glsl create mode 100644 src/pir_internals/vec_x_mat_shader.rs diff --git a/shaders/vec_x_mat.glsl b/shaders/vec_x_mat.glsl new file mode 100644 index 0000000..6e29e1d --- /dev/null +++ b/shaders/vec_x_mat.glsl @@ -0,0 +1,47 @@ +#version 460 +#pragma shader_stage(compute) + +layout(local_size_x = 1, local_size_y = 32, local_size_z = 1) in; + +layout(set = 0, binding = 0) buffer readonly MatrixA { + uint rows; + uint cols; + uint[] elems; +} +lhs_vec; + +layout(set = 0, binding = 1) buffer readonly MatrixB { + uint rows; + uint cols; + uint[] elems; +} +rhs_trans_mat; + +layout(set = 0, binding = 2) buffer writeonly MatrixC { + uint rows; + uint cols; + uint[] elems; +} +res_vec; + +void main() { + const uint row_idx = gl_GlobalInvocationID.x; + const uint col_idx = gl_GlobalInvocationID.y; + + if (row_idx >= lhs_vec.rows || col_idx >= rhs_trans_mat.rows) { + return; + } + + if ((row_idx == 0) && (col_idx == 0)) { + res_vec.rows = lhs_vec.rows; + res_vec.cols = rhs_trans_mat.rows; + } + + uint sum = 0; + for (uint i = 0; i < lhs_vec.cols; i++) { + sum += lhs_vec.elems[i] * + rhs_trans_mat.elems[col_idx * rhs_trans_mat.cols + i]; + } + + res_vec.elems[col_idx] = sum; +} diff --git a/src/pir_internals/mod.rs b/src/pir_internals/mod.rs index edad805..e016859 100644 --- a/src/pir_internals/mod.rs +++ b/src/pir_internals/mod.rs @@ -11,3 +11,5 @@ pub mod gpu; pub mod mat_transpose_shader; #[cfg(feature = "gpu")] pub mod mat_x_mat_shader; +#[cfg(feature = "gpu")] +pub mod vec_x_mat_shader; diff --git a/src/pir_internals/vec_x_mat_shader.rs b/src/pir_internals/vec_x_mat_shader.rs new file mode 100644 index 0000000..97f3269 --- /dev/null +++ b/src/pir_internals/vec_x_mat_shader.rs @@ -0,0 +1,5 @@ +vulkano_shaders::shader! { + ty: "compute", + path: "./shaders/vec_x_mat.glsl", + vulkan_version: "1.2", +} From 40ba459ad2c146a05cc7daf0c4c97c57b6cdbede Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Tue, 1 Apr 2025 23:14:48 +0530 Subject: [PATCH 24/29] Implement server-respond function, using `gpu` feature Signed-off-by: Anjan Roy --- src/pir_internals/gpu.rs | 86 +++++++++++++++++++++++++++++++++++++--- src/server.rs | 63 +++++++++++++++++++++++++++-- 2 files changed, 141 insertions(+), 8 deletions(-) diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index d8d8bcf..09ec3af 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -1,14 +1,21 @@ -use super::{mat_transpose_shader, mat_x_mat_shader, matrix::Matrix}; +pub use vulkano::{ + buffer::Subbuffer, + command_buffer::allocator::StandardCommandBufferAllocator, + device::{Device, Queue}, + memory::allocator::StandardMemoryAllocator, +}; + +use super::{mat_transpose_shader, mat_x_mat_shader, matrix::Matrix, vec_x_mat_shader}; use crate::ChalametPIRError; use std::sync::Arc; use vulkano::{ VulkanLibrary, - buffer::{Buffer, BufferCreateInfo, BufferUsage, Subbuffer}, - command_buffer::{AutoCommandBufferBuilder, CommandBufferUsage, CopyBufferInfo, PrimaryCommandBufferAbstract, allocator::StandardCommandBufferAllocator}, + buffer::{Buffer, BufferCreateInfo, BufferUsage}, + command_buffer::{AutoCommandBufferBuilder, CommandBufferUsage, CopyBufferInfo, PrimaryCommandBufferAbstract}, descriptor_set::{DescriptorSet, WriteDescriptorSet, allocator::StandardDescriptorSetAllocator}, - device::{Device, DeviceCreateInfo, DeviceExtensions, Queue, QueueCreateInfo, QueueFlags, physical::PhysicalDeviceType}, + device::{DeviceCreateInfo, DeviceExtensions, QueueCreateInfo, QueueFlags, physical::PhysicalDeviceType}, instance::{Instance, InstanceCreateFlags, InstanceCreateInfo}, - memory::allocator::{AllocationCreateInfo, MemoryTypeFilter, StandardMemoryAllocator}, + memory::allocator::{AllocationCreateInfo, MemoryTypeFilter}, pipeline::{ ComputePipeline, Pipeline, PipelineBindPoint, PipelineLayout, PipelineShaderStageCreateInfo, compute::ComputePipelineCreateInfo, layout::PipelineDescriptorSetLayoutCreateInfo, @@ -295,3 +302,72 @@ pub fn mat_transpose( .wait(None) .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed) } + +pub fn vec_x_mat( + device: Arc, + queue: Arc, + command_buffer_allocator: Arc, + left_vec: Subbuffer<[u8]>, + rhs_transposed_mat: Subbuffer<[u8]>, + res_vec: Subbuffer<[u8]>, + wg_count: [u32; 3], +) -> Result<(), ChalametPIRError> { + let pipeline = { + let cs = vec_x_mat_shader::load(device.clone()).map_err(|_| ChalametPIRError::VulkanComputeShaderLoadingFailed)?; + let cs_entry_point = cs.entry_point("main").ok_or(ChalametPIRError::VulkanComputeShaderLoadingFailed)?; + let compute_stage = PipelineShaderStageCreateInfo::new(cs_entry_point); + + let layout = PipelineLayout::new( + device.clone(), + PipelineDescriptorSetLayoutCreateInfo::from_stages([&compute_stage]) + .into_pipeline_layout_create_info(device.clone()) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)?, + ) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)?; + + ComputePipeline::new(device.clone(), None, ComputePipelineCreateInfo::stage_layout(compute_stage, layout.clone())) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)? + }; + + let descriptor_set_allocator = Arc::new(StandardDescriptorSetAllocator::new(device.clone(), Default::default())); + let descriptor_set_layout = pipeline.layout().set_layouts()[0].clone(); + let descriptor_set = DescriptorSet::new( + descriptor_set_allocator, + descriptor_set_layout, + [ + WriteDescriptorSet::buffer(0, left_vec), + WriteDescriptorSet::buffer(1, rhs_transposed_mat), + WriteDescriptorSet::buffer(2, res_vec), + ], + [], + ) + .map_err(|_| ChalametPIRError::VulkanDescriptorSetCreationFailed)?; + + let command_buffer = { + let mut command_buffer_builder = + AutoCommandBufferBuilder::primary(command_buffer_allocator, queue.queue_family_index(), CommandBufferUsage::OneTimeSubmit) + .map_err(|_| ChalametPIRError::VulkanCommandBufferBuilderCreationFailed)?; + + unsafe { + command_buffer_builder + .bind_pipeline_compute(pipeline.clone()) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)? + .bind_descriptor_sets(PipelineBindPoint::Compute, pipeline.layout().clone(), 0, descriptor_set) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)? + .dispatch(wg_count) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)?; + } + + command_buffer_builder + .build() + .map_err(|_| ChalametPIRError::VulkanCommandBufferBuildingFailed)? + }; + + command_buffer + .execute(queue.clone()) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .then_signal_fence_and_flush() + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .wait(None) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed) +} diff --git a/src/server.rs b/src/server.rs index 1d15781..dc68eff 100644 --- a/src/server.rs +++ b/src/server.rs @@ -8,7 +8,7 @@ use crate::{ params::{LWE_DIMENSION, SEED_BYTE_LEN, SERVER_SETUP_MAX_ATTEMPT_COUNT}, }, }; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; /// Represents the server in the Keyword Private Information Retrieval (PIR) scheme ChalametPIR. /// @@ -16,7 +16,23 @@ use std::collections::HashMap; #[derive(Clone)] pub struct Server { /// This matrix is kept in transposed form to optimize memory access pattern in vector matrix multiplication of server-respond function. + #[cfg(not(feature = "gpu"))] transposed_parsed_db_mat_d: Matrix, + + #[cfg(feature = "gpu")] + device: Arc, + #[cfg(feature = "gpu")] + queue: Arc, + #[cfg(feature = "gpu")] + mem_alloc: Arc, + #[cfg(feature = "gpu")] + cmd_buf_alloc: Arc, + #[cfg(feature = "gpu")] + transposed_parsed_db_mat_d_num_rows: u32, + #[cfg(feature = "gpu")] + transposed_parsed_db_mat_d_num_cols: u32, + #[cfg(feature = "gpu")] + transposed_parsed_db_mat_d_buf: gpu::Subbuffer<[u8]>, } impl Server { @@ -140,9 +156,20 @@ impl Server { let hint_bytes = hint_mat_m_buf.read().map_err(|_| ChalametPIRError::VulkanReadingFromBufferFailed)?.to_vec(); let filter_param_bytes: Vec = filter.to_bytes(); - let transposed_parsed_db_mat_d = parsed_db_mat_d.transpose(); - Ok((Server { transposed_parsed_db_mat_d }, hint_bytes, filter_param_bytes)) + Ok(( + Server { + device, + queue, + mem_alloc, + cmd_buf_alloc, + transposed_parsed_db_mat_d_num_rows: parsed_db_mat_d.num_cols(), + transposed_parsed_db_mat_d_num_cols: parsed_db_mat_d.num_rows(), + transposed_parsed_db_mat_d_buf, + }, + hint_bytes, + filter_param_bytes, + )) } /// Responds to a client query. @@ -160,6 +187,7 @@ impl Server { /// # Returns /// /// A `Result` containing the response as a byte vector. Returns an error if any error occurs during response computation or serialization. + #[cfg(not(feature = "gpu"))] pub fn respond(&self, query: &[u8]) -> Result, ChalametPIRError> { let query_vector = Matrix::from_bytes(query)?; let response_vector = query_vector.row_vector_x_transposed_matrix(&self.transposed_parsed_db_mat_d)?; @@ -167,6 +195,35 @@ impl Server { Ok(response_vector.to_bytes()) } + #[cfg(feature = "gpu")] + pub fn respond(&self, query: &[u8]) -> Result, ChalametPIRError> { + let query_vector = Matrix::from_bytes(query)?; + if branch_opt_util::unlikely(!(query_vector.num_rows() == 1 && query_vector.num_cols() == self.transposed_parsed_db_mat_d_num_cols)) { + return Err(ChalametPIRError::IncompatibleDimensionForRowVectorTransposedMatrixMultiplication); + } + + let response_vec_byte_len = (2 * std::mem::size_of::() + + (query_vector.num_rows() * self.transposed_parsed_db_mat_d_num_rows) as usize * std::mem::size_of::()) + as u64; + let response_vec_wg_count = [1, self.transposed_parsed_db_mat_d_num_rows.div_ceil(32), 1]; + + let query_vec_buf = gpu::transfer_mat_to_device(self.queue.clone(), self.mem_alloc.clone(), self.cmd_buf_alloc.clone(), query_vector)?; + let response_vec_buf = gpu::get_empty_host_readable_buffer(self.mem_alloc.clone(), response_vec_byte_len)?; + + gpu::vec_x_mat( + self.device.clone(), + self.queue.clone(), + self.cmd_buf_alloc.clone(), + query_vec_buf, + self.transposed_parsed_db_mat_d_buf.clone(), + response_vec_buf.clone(), + response_vec_wg_count, + )?; + + let response_bytes = response_vec_buf.read().map_err(|_| ChalametPIRError::VulkanReadingFromBufferFailed)?.to_vec(); + Ok(response_bytes) + } + /// This is required to ensure that LWE PIR protocol is correct. See eq. 8 in section 5.1 of the FrodoPIR paper @ https://ia.cr/2022/981. fn find_encoded_db_matrix_element_bit_length(db_entry_count: usize) -> Result { const MIN_MAT_ELEM_BIT_LEN: usize = 4; From ec4a802016f7d34d91d1b41e5397537e7427aa32 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Fri, 4 Apr 2025 18:48:40 +0530 Subject: [PATCH 25/29] Change work-group size for vector-matrix multiplication shader invocation Signed-off-by: Anjan Roy --- shaders/vec_x_mat.glsl | 10 ++++++---- src/pir_internals/gpu.rs | 2 +- src/server.rs | 13 +++++++------ 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/shaders/vec_x_mat.glsl b/shaders/vec_x_mat.glsl index 6e29e1d..f179dab 100644 --- a/shaders/vec_x_mat.glsl +++ b/shaders/vec_x_mat.glsl @@ -1,7 +1,7 @@ #version 460 #pragma shader_stage(compute) -layout(local_size_x = 1, local_size_y = 32, local_size_z = 1) in; +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; layout(set = 0, binding = 0) buffer readonly MatrixA { uint rows; @@ -27,8 +27,10 @@ res_vec; void main() { const uint row_idx = gl_GlobalInvocationID.x; const uint col_idx = gl_GlobalInvocationID.y; + const uint res_vec_num_cols_sqrt = uint(sqrt(rhs_trans_mat.rows)) + 1; + const uint lin_idx = row_idx * res_vec_num_cols_sqrt + col_idx; - if (row_idx >= lhs_vec.rows || col_idx >= rhs_trans_mat.rows) { + if (lin_idx >= rhs_trans_mat.rows) { return; } @@ -40,8 +42,8 @@ void main() { uint sum = 0; for (uint i = 0; i < lhs_vec.cols; i++) { sum += lhs_vec.elems[i] * - rhs_trans_mat.elems[col_idx * rhs_trans_mat.cols + i]; + rhs_trans_mat.elems[lin_idx * rhs_trans_mat.cols + i]; } - res_vec.elems[col_idx] = sum; + res_vec.elems[lin_idx] = sum; } diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index 09ec3af..3c28625 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -1,3 +1,4 @@ +pub use std::sync::Arc; pub use vulkano::{ buffer::Subbuffer, command_buffer::allocator::StandardCommandBufferAllocator, @@ -7,7 +8,6 @@ pub use vulkano::{ use super::{mat_transpose_shader, mat_x_mat_shader, matrix::Matrix, vec_x_mat_shader}; use crate::ChalametPIRError; -use std::sync::Arc; use vulkano::{ VulkanLibrary, buffer::{Buffer, BufferCreateInfo, BufferUsage}, diff --git a/src/server.rs b/src/server.rs index dc68eff..f8fbe08 100644 --- a/src/server.rs +++ b/src/server.rs @@ -8,7 +8,7 @@ use crate::{ params::{LWE_DIMENSION, SEED_BYTE_LEN, SERVER_SETUP_MAX_ATTEMPT_COUNT}, }, }; -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; /// Represents the server in the Keyword Private Information Retrieval (PIR) scheme ChalametPIR. /// @@ -20,13 +20,13 @@ pub struct Server { transposed_parsed_db_mat_d: Matrix, #[cfg(feature = "gpu")] - device: Arc, + device: gpu::Arc, #[cfg(feature = "gpu")] - queue: Arc, + queue: gpu::Arc, #[cfg(feature = "gpu")] - mem_alloc: Arc, + mem_alloc: gpu::Arc, #[cfg(feature = "gpu")] - cmd_buf_alloc: Arc, + cmd_buf_alloc: gpu::Arc, #[cfg(feature = "gpu")] transposed_parsed_db_mat_d_num_rows: u32, #[cfg(feature = "gpu")] @@ -205,7 +205,8 @@ impl Server { let response_vec_byte_len = (2 * std::mem::size_of::() + (query_vector.num_rows() * self.transposed_parsed_db_mat_d_num_rows) as usize * std::mem::size_of::()) as u64; - let response_vec_wg_count = [1, self.transposed_parsed_db_mat_d_num_rows.div_ceil(32), 1]; + let response_vec_len_sqrt = (self.transposed_parsed_db_mat_d_num_rows as f32).sqrt() as u32 + 1; + let response_vec_wg_count = [response_vec_len_sqrt.div_ceil(8), response_vec_len_sqrt.div_ceil(8), 1]; let query_vec_buf = gpu::transfer_mat_to_device(self.queue.clone(), self.mem_alloc.clone(), self.cmd_buf_alloc.clone(), query_vector)?; let response_vec_buf = gpu::get_empty_host_readable_buffer(self.mem_alloc.clone(), response_vec_byte_len)?; From 1d2ed91578b560d31047befccf697695fb033dc9 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Fri, 4 Apr 2025 19:51:15 +0530 Subject: [PATCH 26/29] Duplicate comment for `gpu` feature-gated version of `server-respond` function Signed-off-by: Anjan Roy --- src/server.rs | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/server.rs b/src/server.rs index f8fbe08..fdfcda7 100644 --- a/src/server.rs +++ b/src/server.rs @@ -81,8 +81,6 @@ impl Server { Ok((Server { transposed_parsed_db_mat_d }, hint_bytes, filter_param_bytes)) } - /// TODO: Update following documentation before publishing. - /// /// Sets up the keyword **P**rivate **I**nformation **R**etrieval scheme's server with a given Key-Value database. /// /// This function takes a database as input and generates the necessary matrices and parameters for responding to client queries. @@ -195,6 +193,21 @@ impl Server { Ok(response_vector.to_bytes()) } + /// Responds to a client query. + /// + /// This function takes a client's query (in byte form) as input and uses the transposed database matrix to compute the response. + /// The process involves: + /// 1. **Query Vectorization:** Converts the query bytes into a row vector. Returns an error if conversion fails. + /// 2. **Vector-Matrix Multiplication:** Performs a row vector-transposed matrix multiplication of the query vector and the server's transposed database matrix. This is optimized for efficiency due to the transposition performed during server setup. Returns an error if multiplication fails. + /// 3. **Response Serialization:** Converts the resulting response vector into a byte vector for transmission to the client. Returns an error if conversion fails. + /// + /// # Arguments + /// + /// * `query`: The client's query, represented as a byte slice. + /// + /// # Returns + /// + /// A `Result` containing the response as a byte vector. Returns an error if any error occurs during response computation or serialization. #[cfg(feature = "gpu")] pub fn respond(&self, query: &[u8]) -> Result, ChalametPIRError> { let query_vector = Matrix::from_bytes(query)?; From fe5ce498ba516132b639103b996b7f08d0df7b88 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Sat, 5 Apr 2025 08:48:35 +0530 Subject: [PATCH 27/29] Avoid computing vector-matrix multiplication on GPU during `server-respond` Signed-off-by: Anjan Roy --- shaders/vec_x_mat.glsl | 49 --------------- src/pir_internals/gpu.rs | 89 +-------------------------- src/pir_internals/mod.rs | 2 - src/pir_internals/vec_x_mat_shader.rs | 5 -- src/server.rs | 83 +++---------------------- 5 files changed, 9 insertions(+), 219 deletions(-) delete mode 100644 shaders/vec_x_mat.glsl delete mode 100644 src/pir_internals/vec_x_mat_shader.rs diff --git a/shaders/vec_x_mat.glsl b/shaders/vec_x_mat.glsl deleted file mode 100644 index f179dab..0000000 --- a/shaders/vec_x_mat.glsl +++ /dev/null @@ -1,49 +0,0 @@ -#version 460 -#pragma shader_stage(compute) - -layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; - -layout(set = 0, binding = 0) buffer readonly MatrixA { - uint rows; - uint cols; - uint[] elems; -} -lhs_vec; - -layout(set = 0, binding = 1) buffer readonly MatrixB { - uint rows; - uint cols; - uint[] elems; -} -rhs_trans_mat; - -layout(set = 0, binding = 2) buffer writeonly MatrixC { - uint rows; - uint cols; - uint[] elems; -} -res_vec; - -void main() { - const uint row_idx = gl_GlobalInvocationID.x; - const uint col_idx = gl_GlobalInvocationID.y; - const uint res_vec_num_cols_sqrt = uint(sqrt(rhs_trans_mat.rows)) + 1; - const uint lin_idx = row_idx * res_vec_num_cols_sqrt + col_idx; - - if (lin_idx >= rhs_trans_mat.rows) { - return; - } - - if ((row_idx == 0) && (col_idx == 0)) { - res_vec.rows = lhs_vec.rows; - res_vec.cols = rhs_trans_mat.rows; - } - - uint sum = 0; - for (uint i = 0; i < lhs_vec.cols; i++) { - sum += lhs_vec.elems[i] * - rhs_trans_mat.elems[lin_idx * rhs_trans_mat.cols + i]; - } - - res_vec.elems[lin_idx] = sum; -} diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs index 3c28625..d22abab 100644 --- a/src/pir_internals/gpu.rs +++ b/src/pir_internals/gpu.rs @@ -6,7 +6,7 @@ pub use vulkano::{ memory::allocator::StandardMemoryAllocator, }; -use super::{mat_transpose_shader, mat_x_mat_shader, matrix::Matrix, vec_x_mat_shader}; +use super::{mat_transpose_shader, mat_x_mat_shader, matrix::Matrix}; use crate::ChalametPIRError; use vulkano::{ VulkanLibrary, @@ -146,23 +146,7 @@ pub fn get_empty_host_readable_buffer(memory_allocator: Arc, byte_len: u64) -> Result, ChalametPIRError> { - Buffer::new_slice::( - memory_allocator.clone(), - BufferCreateInfo { - usage: BufferUsage::STORAGE_BUFFER, - ..Default::default() - }, - AllocationCreateInfo { - memory_type_filter: MemoryTypeFilter::PREFER_DEVICE, + memory_type_filter: MemoryTypeFilter::HOST_SEQUENTIAL_WRITE | MemoryTypeFilter::PREFER_DEVICE, ..Default::default() }, byte_len, @@ -302,72 +286,3 @@ pub fn mat_transpose( .wait(None) .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed) } - -pub fn vec_x_mat( - device: Arc, - queue: Arc, - command_buffer_allocator: Arc, - left_vec: Subbuffer<[u8]>, - rhs_transposed_mat: Subbuffer<[u8]>, - res_vec: Subbuffer<[u8]>, - wg_count: [u32; 3], -) -> Result<(), ChalametPIRError> { - let pipeline = { - let cs = vec_x_mat_shader::load(device.clone()).map_err(|_| ChalametPIRError::VulkanComputeShaderLoadingFailed)?; - let cs_entry_point = cs.entry_point("main").ok_or(ChalametPIRError::VulkanComputeShaderLoadingFailed)?; - let compute_stage = PipelineShaderStageCreateInfo::new(cs_entry_point); - - let layout = PipelineLayout::new( - device.clone(), - PipelineDescriptorSetLayoutCreateInfo::from_stages([&compute_stage]) - .into_pipeline_layout_create_info(device.clone()) - .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)?, - ) - .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)?; - - ComputePipeline::new(device.clone(), None, ComputePipelineCreateInfo::stage_layout(compute_stage, layout.clone())) - .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)? - }; - - let descriptor_set_allocator = Arc::new(StandardDescriptorSetAllocator::new(device.clone(), Default::default())); - let descriptor_set_layout = pipeline.layout().set_layouts()[0].clone(); - let descriptor_set = DescriptorSet::new( - descriptor_set_allocator, - descriptor_set_layout, - [ - WriteDescriptorSet::buffer(0, left_vec), - WriteDescriptorSet::buffer(1, rhs_transposed_mat), - WriteDescriptorSet::buffer(2, res_vec), - ], - [], - ) - .map_err(|_| ChalametPIRError::VulkanDescriptorSetCreationFailed)?; - - let command_buffer = { - let mut command_buffer_builder = - AutoCommandBufferBuilder::primary(command_buffer_allocator, queue.queue_family_index(), CommandBufferUsage::OneTimeSubmit) - .map_err(|_| ChalametPIRError::VulkanCommandBufferBuilderCreationFailed)?; - - unsafe { - command_buffer_builder - .bind_pipeline_compute(pipeline.clone()) - .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)? - .bind_descriptor_sets(PipelineBindPoint::Compute, pipeline.layout().clone(), 0, descriptor_set) - .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)? - .dispatch(wg_count) - .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)?; - } - - command_buffer_builder - .build() - .map_err(|_| ChalametPIRError::VulkanCommandBufferBuildingFailed)? - }; - - command_buffer - .execute(queue.clone()) - .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? - .then_signal_fence_and_flush() - .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? - .wait(None) - .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed) -} diff --git a/src/pir_internals/mod.rs b/src/pir_internals/mod.rs index e016859..edad805 100644 --- a/src/pir_internals/mod.rs +++ b/src/pir_internals/mod.rs @@ -11,5 +11,3 @@ pub mod gpu; pub mod mat_transpose_shader; #[cfg(feature = "gpu")] pub mod mat_x_mat_shader; -#[cfg(feature = "gpu")] -pub mod vec_x_mat_shader; diff --git a/src/pir_internals/vec_x_mat_shader.rs b/src/pir_internals/vec_x_mat_shader.rs deleted file mode 100644 index 97f3269..0000000 --- a/src/pir_internals/vec_x_mat_shader.rs +++ /dev/null @@ -1,5 +0,0 @@ -vulkano_shaders::shader! { - ty: "compute", - path: "./shaders/vec_x_mat.glsl", - vulkan_version: "1.2", -} diff --git a/src/server.rs b/src/server.rs index fdfcda7..36b7205 100644 --- a/src/server.rs +++ b/src/server.rs @@ -16,23 +16,7 @@ use std::collections::HashMap; #[derive(Clone)] pub struct Server { /// This matrix is kept in transposed form to optimize memory access pattern in vector matrix multiplication of server-respond function. - #[cfg(not(feature = "gpu"))] transposed_parsed_db_mat_d: Matrix, - - #[cfg(feature = "gpu")] - device: gpu::Arc, - #[cfg(feature = "gpu")] - queue: gpu::Arc, - #[cfg(feature = "gpu")] - mem_alloc: gpu::Arc, - #[cfg(feature = "gpu")] - cmd_buf_alloc: gpu::Arc, - #[cfg(feature = "gpu")] - transposed_parsed_db_mat_d_num_rows: u32, - #[cfg(feature = "gpu")] - transposed_parsed_db_mat_d_num_cols: u32, - #[cfg(feature = "gpu")] - transposed_parsed_db_mat_d_buf: gpu::Subbuffer<[u8]>, } impl Server { @@ -131,7 +115,7 @@ impl Server { let pub_mat_a_buf = gpu::transfer_mat_to_device(queue.clone(), mem_alloc.clone(), cmd_buf_alloc.clone(), pub_mat_a)?; let parsed_db_mat_d_buf = gpu::transfer_mat_to_device(queue.clone(), mem_alloc.clone(), cmd_buf_alloc.clone(), parsed_db_mat_d.clone())?; let hint_mat_m_buf = gpu::get_empty_host_readable_buffer(mem_alloc.clone(), hint_mat_m_byte_len)?; - let transposed_parsed_db_mat_d_buf = gpu::get_empty_device_local_buffer(mem_alloc.clone(), parsed_db_mat_d_byte_len)?; + let transposed_parsed_db_mat_d_buf = gpu::get_empty_host_readable_buffer(mem_alloc.clone(), parsed_db_mat_d_byte_len)?; gpu::mat_x_mat( device.clone(), @@ -152,22 +136,15 @@ impl Server { parsed_db_mat_d_wg_count, )?; + let transposed_parsed_db_mat_d = Matrix::from_bytes( + &transposed_parsed_db_mat_d_buf + .read() + .map_err(|_| ChalametPIRError::VulkanReadingFromBufferFailed)?, + )?; let hint_bytes = hint_mat_m_buf.read().map_err(|_| ChalametPIRError::VulkanReadingFromBufferFailed)?.to_vec(); let filter_param_bytes: Vec = filter.to_bytes(); - Ok(( - Server { - device, - queue, - mem_alloc, - cmd_buf_alloc, - transposed_parsed_db_mat_d_num_rows: parsed_db_mat_d.num_cols(), - transposed_parsed_db_mat_d_num_cols: parsed_db_mat_d.num_rows(), - transposed_parsed_db_mat_d_buf, - }, - hint_bytes, - filter_param_bytes, - )) + Ok((Server { transposed_parsed_db_mat_d }, hint_bytes, filter_param_bytes)) } /// Responds to a client query. @@ -185,7 +162,6 @@ impl Server { /// # Returns /// /// A `Result` containing the response as a byte vector. Returns an error if any error occurs during response computation or serialization. - #[cfg(not(feature = "gpu"))] pub fn respond(&self, query: &[u8]) -> Result, ChalametPIRError> { let query_vector = Matrix::from_bytes(query)?; let response_vector = query_vector.row_vector_x_transposed_matrix(&self.transposed_parsed_db_mat_d)?; @@ -193,51 +169,6 @@ impl Server { Ok(response_vector.to_bytes()) } - /// Responds to a client query. - /// - /// This function takes a client's query (in byte form) as input and uses the transposed database matrix to compute the response. - /// The process involves: - /// 1. **Query Vectorization:** Converts the query bytes into a row vector. Returns an error if conversion fails. - /// 2. **Vector-Matrix Multiplication:** Performs a row vector-transposed matrix multiplication of the query vector and the server's transposed database matrix. This is optimized for efficiency due to the transposition performed during server setup. Returns an error if multiplication fails. - /// 3. **Response Serialization:** Converts the resulting response vector into a byte vector for transmission to the client. Returns an error if conversion fails. - /// - /// # Arguments - /// - /// * `query`: The client's query, represented as a byte slice. - /// - /// # Returns - /// - /// A `Result` containing the response as a byte vector. Returns an error if any error occurs during response computation or serialization. - #[cfg(feature = "gpu")] - pub fn respond(&self, query: &[u8]) -> Result, ChalametPIRError> { - let query_vector = Matrix::from_bytes(query)?; - if branch_opt_util::unlikely(!(query_vector.num_rows() == 1 && query_vector.num_cols() == self.transposed_parsed_db_mat_d_num_cols)) { - return Err(ChalametPIRError::IncompatibleDimensionForRowVectorTransposedMatrixMultiplication); - } - - let response_vec_byte_len = (2 * std::mem::size_of::() - + (query_vector.num_rows() * self.transposed_parsed_db_mat_d_num_rows) as usize * std::mem::size_of::()) - as u64; - let response_vec_len_sqrt = (self.transposed_parsed_db_mat_d_num_rows as f32).sqrt() as u32 + 1; - let response_vec_wg_count = [response_vec_len_sqrt.div_ceil(8), response_vec_len_sqrt.div_ceil(8), 1]; - - let query_vec_buf = gpu::transfer_mat_to_device(self.queue.clone(), self.mem_alloc.clone(), self.cmd_buf_alloc.clone(), query_vector)?; - let response_vec_buf = gpu::get_empty_host_readable_buffer(self.mem_alloc.clone(), response_vec_byte_len)?; - - gpu::vec_x_mat( - self.device.clone(), - self.queue.clone(), - self.cmd_buf_alloc.clone(), - query_vec_buf, - self.transposed_parsed_db_mat_d_buf.clone(), - response_vec_buf.clone(), - response_vec_wg_count, - )?; - - let response_bytes = response_vec_buf.read().map_err(|_| ChalametPIRError::VulkanReadingFromBufferFailed)?.to_vec(); - Ok(response_bytes) - } - /// This is required to ensure that LWE PIR protocol is correct. See eq. 8 in section 5.1 of the FrodoPIR paper @ https://ia.cr/2022/981. fn find_encoded_db_matrix_element_bit_length(db_entry_count: usize) -> Result { const MIN_MAT_ELEM_BIT_LEN: usize = 4; From 1e391ad11d21090ff13e128ee5c3f12dd8ed000e Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Sun, 6 Apr 2025 11:50:43 +0530 Subject: [PATCH 28/29] Update project documentation mentioning about the `gpu` feature gate Signed-off-by: Anjan Roy --- README.md | 53 +++++++++++++++++++++++++++++++++++++---------------- src/lib.rs | 3 +++ 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 1d05e02..9aef8e0 100644 --- a/README.md +++ b/README.md @@ -9,14 +9,12 @@ built on top of FrodoPIR - a practical, single-server, stateful LWE -based PIR s - Binary Fuse Filter was proposed in https://arxiv.org/pdf/2201.01174. - And ChalametPIR was proposed in https://ia.cr/2024/092. -ChalametPIR allows a client to retrieve a specific value from a key-value database on a server without revealing the requested key. -It uses Binary Fuse Filters to encode key-value pairs in form of a matrix. And then it applies FrodoPIR on the encoded database matrix -to actually retrieve values for requested keys. +ChalametPIR allows a client to retrieve a specific value from a key-value database, stored on a server, without revealing the requested key to the server. It uses Binary Fuse Filters to encode key-value pairs in form of a matrix. And then it applies FrodoPIR on the encoded database matrix to actually retrieve values for requested keys. The protocol has two participants: **Server:** -* **`setup`:** Initializes the server with a key-value database, generating a public matrix, a hint matrix, and a Binary Fuse Filter (3-wise XOR or 4-wise XOR, compile-time configurable). Returns serialized representations of the hint matrix and filter parameters. This phase can be completed in offline and it's completely client agnostic. +* **`setup`:** Initializes the server with a key-value database, generating a public matrix, a hint matrix, and a Binary Fuse Filter (3-wise XOR or 4-wise XOR, configurable at compile time). It returns serialized representations of the hint matrix and filter parameters. This phase can be completed offline and is completely client-agnostic. But it is very compute-intensive, which is why this library allows you to offload expensive matrix multiplication and transposition to a GPU, gated behind the opt-in `gpu` feature. For large key-value databases (e.g., with >= $2^{18}$ entries), I recommend enabling the `gpu` feature, as it can significantly reduce the cost of the server-setup phase. * **`respond`:** Processes a client's query and returns an encrypted response vector. **Client:** @@ -28,8 +26,8 @@ To paint a more practical picture, imagine, we have a database with $2^{20}$ (~1 Machine Type | Machine | Kernel | Compiler | Memory Read Speed --- | --- | --- | --- | --- -aarch64 server | AWS EC2 `m8g.8xlarge` | `Linux 6.8.0-1021-aws aarch64` | `rustc 1.84.1 (e71f9a9a9 2025-01-27)` | 28.25 GB/s -x86_64 server | AWS EC2 `m7i.8xlarge` | `Linux 6.8.0-1021-aws x86_64` | `rustc 1.84.1 (e71f9a9a9 2025-01-27)` | 10.33 GB/s +aarch64 server | AWS EC2 `m8g.8xlarge` | `Linux 6.8.0-1021-aws aarch64` | `rustc 1.85.1 (e71f9a9a9 2025-01-27)` | 28.25 GB/s +x86_64 server | AWS EC2 `m7i.8xlarge` | `Linux 6.8.0-1021-aws x86_64` | `rustc 1.85.1 (e71f9a9a9 2025-01-27)` | 10.33 GB/s and this implementation of ChalametPIR is compiled with specified compiler, in `optimized` profile. See [Cargo.toml](./Cargo.toml). @@ -44,22 +42,34 @@ Step | `(a)` Time Taken on `aarch64` server | `(b)` Time Taken on `x86_64` serve `server_respond` | 18.01 milliseconds | 32.16 milliseconds | 0.56 `client_process_response` | 11.73 microseconds | 16.75 microseconds | 0.7 -> [!NOTE] -> In above table, I show only the median timing measurements, while the DB is encoded using a 3 -wise XOR Binary Fuse Filter. For more results, with more database configurations, see benchmarking [section](#benchmarking) below. - So, the median bandwidth of the `server_respond` algorithm, which needs to traverse through the whole processed database, is - (a) For `aarch64` server: 53.82 GB/s - (b) For `x86_64` server: 30.12 GB/s +For demonstrating the effectiveness of offloading parts of the server-setup phase to a GPU, I benchmark it on AWS EC2 instance `g6e.8xlarge`, which features a NVIDIA L40S Tensor Core GPU and $3^{rd}$ generation AMD EPYC CPUs. + +Number of entries in DB | Key length | Value length | `(a)` Time taken to setup PIR server on CPU | `(b)` Time taken to setup PIR server, partially offloading to GPU | Ratio `a / b` +:-- | --: | --: | --: | --: | --: +$2^{16}$ | 32B | 1kB | 19.55 seconds | 19.39 seconds | 1.0 +$2^{18}$ | 32B | 1kB | 6.0 minutes | 2.23 minutes | 2.69 +$2^{20}$ | 32B | 1kB | 25.89 minutes | 25.58 seconds | 60.72 + +For small key-value databases, it is not worth offloading server-setup to the GPU, but for databases with entries >= $2^{18}$, it is recommended to enable `gpu` feature, when GPU is available. + +> [!NOTE] +> In both of above tables, I show only the median timing measurements, while the DB is encoded using a 3 -wise XOR Binary Fuse Filter. For more results, with more database configurations, see benchmarking [section](#benchmarking) below. + ## Prerequisites -Rust stable toolchain; see https://rustup.rs for installation guide. MSRV for this crate is 1.84.0. +Rust stable toolchain; see https://rustup.rs for installation guide. MSRV for this crate is 1.85.0. ```bash # While developing this library, I was using $ rustc --version -rustc 1.84.1 (e71f9a9a9 2025-01-27) +rustc 1.85.1 (e71f9a9a9 2025-01-27) ``` +If you plan to offload server-setup to GPU, you need to install Vulkan drivers and library for your target setup. I followed https://linux.how2shout.com/how-to-install-vulkan-on-ubuntu-24-04-or-22-04-lts-linux on Ubuntu 24.04 LTS, with Nvidia GPUs - it was easy to setup. + ## Testing The `chalamet_pir` library includes comprehensive tests to ensure functional correctness. @@ -69,8 +79,12 @@ The `chalamet_pir` library includes comprehensive tests to ensure functional cor To run the tests, go to the project's root directory and issue: ```bash -cargo test --profile test-release # Custom profile to make tests run faster! - # Default debug mode is too slow! +# Custom profile to make tests run faster! +# Default debug mode is too slow! +cargo test --profile test-release + +# For testing if offloading to GPU works as expected. +cargo test --features gpu --profile test-release ``` @@ -80,9 +94,12 @@ Performance benchmarks are included to evaluate the efficiency of the PIR scheme To run the benchmarks, execute the following command from the root of the project: ```bash -cargo bench --all-features --profile optimized # For benchmarking the online phase of the PIR, - # you need to enable feature `mutate_internal_client_state`, - # passing `--all-features` does that. +# For benchmarking the online phase of the PIR, +# you need to enable feature `mutate_internal_client_state`. +cargo bench --features mutate_internal_client_state --profile optimized + +# For benchmarking only the server-setup phase, offloaded to the GPU. +cargo bench --features gpu --profile optimized --bench offline_phase -q server_setup ``` > [!WARNING] @@ -102,6 +119,10 @@ First, add this library crate as a dependency in your Cargo.toml file. ```toml [dependencies] chalamet_pir = "=0.4.0" +# Or, if you want to offload server-setup to a GPU. +# chalamet_pir = { version = "=0.4.0", features = ["gpu"] } +rand = "=0.9.0" +rand_chacha = "=0.9.0" ``` Then, let's code a very simple keyword PIR scheme: diff --git a/src/lib.rs b/src/lib.rs index 24f3a66..fd231c9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ //! * **Secure Private Information Retrieval:** Allows clients to retrieve value from a PIR server without disclosing corresponding key. Server learns neither the value nor the queried key. //! * **Error Handling:** Comprehensive error handling to catch and report issues during setup, query generation, and response processing. //! * **Flexibility:** Supports both 3-wise and 4-wise XOR Binary Fuse Filters, allowing a choice between trade-offs in client/server computation and communication costs. +//! * **Efficient:** It supports offloading parts of the server-setup phase to a GPU, using Vulkan Compute API, which can drastically reduce time taken to setup PIR server, for large key-value databases. //! //! ## Usage //! @@ -19,6 +20,8 @@ //! ```toml //! [dependencies] //! chalametpir = "=0.4.0" +//! # Or, if you want to offload server-setup to GPU. +//! # chalamet_pir = { version = "=0.4.0", features = ["gpu"] } //! rand = "=0.9.0" //! rand_chacha = "=0.9.0" //! ``` From 42a67364ab34e56440bdc05eaad6570534087ca2 Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Sun, 6 Apr 2025 11:55:50 +0530 Subject: [PATCH 29/29] Prepare for release v0.5.0 Signed-off-by: Anjan Roy --- Cargo.toml | 12 +++++++++--- README.md | 4 ++-- src/lib.rs | 4 ++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7c299bc..4607ab4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "chalamet_pir" -version = "0.4.0" +version = "0.5.0" edition = "2024" resolver = "2" rust-version = "1.85.0" @@ -9,8 +9,14 @@ description = "Simple, Stateful, Single-Server Private Information Retrieval for readme = "README.md" repository = "https://github.com/itzmeanjan/ChalametPIR.git" license = "MPL-2.0" -keywords = ["priv-info-retrieval", "lwe-pir", "frodo-pir", "chalamet-pir"] -categories = ["cryptography", "data-structures"] +keywords = [ + "priv-info-retrieval", + "lwe-pir", + "frodo-pir", + "chalamet-pir", + "gpu", +] +categories = ["cryptography", "data-structures", "concurrency"] [dependencies] turboshake = "=0.4.1" diff --git a/README.md b/README.md index 9aef8e0..caf3b41 100644 --- a/README.md +++ b/README.md @@ -118,9 +118,9 @@ First, add this library crate as a dependency in your Cargo.toml file. ```toml [dependencies] -chalamet_pir = "=0.4.0" +chalamet_pir = "=0.5.0" # Or, if you want to offload server-setup to a GPU. -# chalamet_pir = { version = "=0.4.0", features = ["gpu"] } +# chalamet_pir = { version = "=0.5.0", features = ["gpu"] } rand = "=0.9.0" rand_chacha = "=0.9.0" ``` diff --git a/src/lib.rs b/src/lib.rs index fd231c9..251e16f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,9 +19,9 @@ //! //! ```toml //! [dependencies] -//! chalametpir = "=0.4.0" +//! chalametpir = "=0.5.0" //! # Or, if you want to offload server-setup to GPU. -//! # chalamet_pir = { version = "=0.4.0", features = ["gpu"] } +//! # chalamet_pir = { version = "=0.5.0", features = ["gpu"] } //! rand = "=0.9.0" //! rand_chacha = "=0.9.0" //! ```