diff --git a/Cargo.toml b/Cargo.toml index a6b75c5..4607ab4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "chalamet_pir" -version = "0.4.0" +version = "0.5.0" edition = "2024" resolver = "2" rust-version = "1.85.0" @@ -9,14 +9,22 @@ description = "Simple, Stateful, Single-Server Private Information Retrieval for readme = "README.md" repository = "https://github.com/itzmeanjan/ChalametPIR.git" license = "MPL-2.0" -keywords = ["priv-info-retrieval", "lwe-pir", "frodo-pir", "chalamet-pir"] -categories = ["cryptography", "data-structures"] +keywords = [ + "priv-info-retrieval", + "lwe-pir", + "frodo-pir", + "chalamet-pir", + "gpu", +] +categories = ["cryptography", "data-structures", "concurrency"] [dependencies] turboshake = "=0.4.1" rayon = "=1.10.0" rand = "=0.9.0" rand_chacha = "=0.9.0" +vulkano = { version = "=0.35.1", optional = true } +vulkano-shaders = { version = "=0.35.0", optional = true } [dev-dependencies] test-case = "=3.3.1" @@ -34,6 +42,7 @@ required-features = ["mutate_internal_client_state"] [features] mutate_internal_client_state = [] +gpu = ["dep:vulkano", "dep:vulkano-shaders"] [profile.optimized] inherits = "release" diff --git a/README.md b/README.md index 1d05e02..caf3b41 100644 --- a/README.md +++ b/README.md @@ -9,14 +9,12 @@ built on top of FrodoPIR - a practical, single-server, stateful LWE -based PIR s - Binary Fuse Filter was proposed in https://arxiv.org/pdf/2201.01174. - And ChalametPIR was proposed in https://ia.cr/2024/092. -ChalametPIR allows a client to retrieve a specific value from a key-value database on a server without revealing the requested key. -It uses Binary Fuse Filters to encode key-value pairs in form of a matrix. And then it applies FrodoPIR on the encoded database matrix -to actually retrieve values for requested keys. +ChalametPIR allows a client to retrieve a specific value from a key-value database, stored on a server, without revealing the requested key to the server. It uses Binary Fuse Filters to encode key-value pairs in form of a matrix. And then it applies FrodoPIR on the encoded database matrix to actually retrieve values for requested keys. The protocol has two participants: **Server:** -* **`setup`:** Initializes the server with a key-value database, generating a public matrix, a hint matrix, and a Binary Fuse Filter (3-wise XOR or 4-wise XOR, compile-time configurable). Returns serialized representations of the hint matrix and filter parameters. This phase can be completed in offline and it's completely client agnostic. +* **`setup`:** Initializes the server with a key-value database, generating a public matrix, a hint matrix, and a Binary Fuse Filter (3-wise XOR or 4-wise XOR, configurable at compile time). It returns serialized representations of the hint matrix and filter parameters. This phase can be completed offline and is completely client-agnostic. But it is very compute-intensive, which is why this library allows you to offload expensive matrix multiplication and transposition to a GPU, gated behind the opt-in `gpu` feature. For large key-value databases (e.g., with >= $2^{18}$ entries), I recommend enabling the `gpu` feature, as it can significantly reduce the cost of the server-setup phase. * **`respond`:** Processes a client's query and returns an encrypted response vector. **Client:** @@ -28,8 +26,8 @@ To paint a more practical picture, imagine, we have a database with $2^{20}$ (~1 Machine Type | Machine | Kernel | Compiler | Memory Read Speed --- | --- | --- | --- | --- -aarch64 server | AWS EC2 `m8g.8xlarge` | `Linux 6.8.0-1021-aws aarch64` | `rustc 1.84.1 (e71f9a9a9 2025-01-27)` | 28.25 GB/s -x86_64 server | AWS EC2 `m7i.8xlarge` | `Linux 6.8.0-1021-aws x86_64` | `rustc 1.84.1 (e71f9a9a9 2025-01-27)` | 10.33 GB/s +aarch64 server | AWS EC2 `m8g.8xlarge` | `Linux 6.8.0-1021-aws aarch64` | `rustc 1.85.1 (e71f9a9a9 2025-01-27)` | 28.25 GB/s +x86_64 server | AWS EC2 `m7i.8xlarge` | `Linux 6.8.0-1021-aws x86_64` | `rustc 1.85.1 (e71f9a9a9 2025-01-27)` | 10.33 GB/s and this implementation of ChalametPIR is compiled with specified compiler, in `optimized` profile. See [Cargo.toml](./Cargo.toml). @@ -44,22 +42,34 @@ Step | `(a)` Time Taken on `aarch64` server | `(b)` Time Taken on `x86_64` serve `server_respond` | 18.01 milliseconds | 32.16 milliseconds | 0.56 `client_process_response` | 11.73 microseconds | 16.75 microseconds | 0.7 -> [!NOTE] -> In above table, I show only the median timing measurements, while the DB is encoded using a 3 -wise XOR Binary Fuse Filter. For more results, with more database configurations, see benchmarking [section](#benchmarking) below. - So, the median bandwidth of the `server_respond` algorithm, which needs to traverse through the whole processed database, is - (a) For `aarch64` server: 53.82 GB/s - (b) For `x86_64` server: 30.12 GB/s +For demonstrating the effectiveness of offloading parts of the server-setup phase to a GPU, I benchmark it on AWS EC2 instance `g6e.8xlarge`, which features a NVIDIA L40S Tensor Core GPU and $3^{rd}$ generation AMD EPYC CPUs. + +Number of entries in DB | Key length | Value length | `(a)` Time taken to setup PIR server on CPU | `(b)` Time taken to setup PIR server, partially offloading to GPU | Ratio `a / b` +:-- | --: | --: | --: | --: | --: +$2^{16}$ | 32B | 1kB | 19.55 seconds | 19.39 seconds | 1.0 +$2^{18}$ | 32B | 1kB | 6.0 minutes | 2.23 minutes | 2.69 +$2^{20}$ | 32B | 1kB | 25.89 minutes | 25.58 seconds | 60.72 + +For small key-value databases, it is not worth offloading server-setup to the GPU, but for databases with entries >= $2^{18}$, it is recommended to enable `gpu` feature, when GPU is available. + +> [!NOTE] +> In both of above tables, I show only the median timing measurements, while the DB is encoded using a 3 -wise XOR Binary Fuse Filter. For more results, with more database configurations, see benchmarking [section](#benchmarking) below. + ## Prerequisites -Rust stable toolchain; see https://rustup.rs for installation guide. MSRV for this crate is 1.84.0. +Rust stable toolchain; see https://rustup.rs for installation guide. MSRV for this crate is 1.85.0. ```bash # While developing this library, I was using $ rustc --version -rustc 1.84.1 (e71f9a9a9 2025-01-27) +rustc 1.85.1 (e71f9a9a9 2025-01-27) ``` +If you plan to offload server-setup to GPU, you need to install Vulkan drivers and library for your target setup. I followed https://linux.how2shout.com/how-to-install-vulkan-on-ubuntu-24-04-or-22-04-lts-linux on Ubuntu 24.04 LTS, with Nvidia GPUs - it was easy to setup. + ## Testing The `chalamet_pir` library includes comprehensive tests to ensure functional correctness. @@ -69,8 +79,12 @@ The `chalamet_pir` library includes comprehensive tests to ensure functional cor To run the tests, go to the project's root directory and issue: ```bash -cargo test --profile test-release # Custom profile to make tests run faster! - # Default debug mode is too slow! +# Custom profile to make tests run faster! +# Default debug mode is too slow! +cargo test --profile test-release + +# For testing if offloading to GPU works as expected. +cargo test --features gpu --profile test-release ``` @@ -80,9 +94,12 @@ Performance benchmarks are included to evaluate the efficiency of the PIR scheme To run the benchmarks, execute the following command from the root of the project: ```bash -cargo bench --all-features --profile optimized # For benchmarking the online phase of the PIR, - # you need to enable feature `mutate_internal_client_state`, - # passing `--all-features` does that. +# For benchmarking the online phase of the PIR, +# you need to enable feature `mutate_internal_client_state`. +cargo bench --features mutate_internal_client_state --profile optimized + +# For benchmarking only the server-setup phase, offloaded to the GPU. +cargo bench --features gpu --profile optimized --bench offline_phase -q server_setup ``` > [!WARNING] @@ -101,7 +118,11 @@ First, add this library crate as a dependency in your Cargo.toml file. ```toml [dependencies] -chalamet_pir = "=0.4.0" +chalamet_pir = "=0.5.0" +# Or, if you want to offload server-setup to a GPU. +# chalamet_pir = { version = "=0.5.0", features = ["gpu"] } +rand = "=0.9.0" +rand_chacha = "=0.9.0" ``` Then, let's code a very simple keyword PIR scheme: diff --git a/shaders/mat_transpose.glsl b/shaders/mat_transpose.glsl new file mode 100644 index 0000000..48e1b4d --- /dev/null +++ b/shaders/mat_transpose.glsl @@ -0,0 +1,37 @@ +#version 460 +#pragma shader_stage(compute) + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +layout(set = 0, binding = 0) buffer readonly MatrixA { + uint rows; + uint cols; + uint[] elems; +} +matrix_a; + +layout(set = 0, binding = 1) buffer writeonly MatrixB { + uint rows; + uint cols; + uint[] elems; +} +matrix_b; + +void main() { + const uint row_idx = gl_GlobalInvocationID.x; + const uint col_idx = gl_GlobalInvocationID.y; + + if (row_idx >= matrix_a.rows || col_idx >= matrix_a.cols) { + return; + } + + if ((row_idx == 0) && (col_idx == 0)) { + matrix_b.rows = matrix_a.cols; + matrix_b.cols = matrix_a.rows; + } + + const uint src_index = row_idx * matrix_a.cols + col_idx; + const uint dst_index = col_idx * matrix_a.rows + row_idx; + + matrix_b.elems[dst_index] = matrix_a.elems[src_index]; +} diff --git a/shaders/mat_x_mat.glsl b/shaders/mat_x_mat.glsl new file mode 100644 index 0000000..8dae1fe --- /dev/null +++ b/shaders/mat_x_mat.glsl @@ -0,0 +1,47 @@ +#version 460 +#pragma shader_stage(compute) + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +layout(set = 0, binding = 0) buffer readonly MatrixA { + uint rows; + uint cols; + uint[] elems; +} +matrix_a; + +layout(set = 0, binding = 1) buffer readonly MatrixB { + uint rows; + uint cols; + uint[] elems; +} +matrix_b; + +layout(set = 0, binding = 2) buffer writeonly MatrixC { + uint rows; + uint cols; + uint[] elems; +} +matrix_c; + +void main() { + const uint row_idx = gl_GlobalInvocationID.x; + const uint col_idx = gl_GlobalInvocationID.y; + + if (row_idx >= matrix_a.rows || col_idx >= matrix_b.cols) { + return; + } + + if ((row_idx == 0) && (col_idx == 0)) { + matrix_c.rows = matrix_a.rows; + matrix_c.cols = matrix_b.cols; + } + + uint sum = 0; + for (uint i = 0; i < matrix_a.cols; i++) { + sum += matrix_a.elems[row_idx * matrix_a.cols + i] * + matrix_b.elems[i * matrix_b.cols + col_idx]; + } + + matrix_c.elems[row_idx * matrix_b.cols + col_idx] = sum; +} diff --git a/src/client.rs b/src/client.rs index 02db179..9db66e2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -42,7 +42,7 @@ impl Client { let filter = BinaryFuseFilter::from_bytes(filter_param_bytes)?; let pub_mat_a_num_rows = LWE_DIMENSION; - let pub_mat_a_num_cols = filter.num_fingerprints; + let pub_mat_a_num_cols = filter.num_fingerprints as u32; let pub_mat_a = Matrix::generate_from_seed(pub_mat_a_num_rows, pub_mat_a_num_cols, seed_μ)?; let hint_mat_m = Matrix::from_bytes(hint_bytes)?; @@ -225,7 +225,7 @@ impl Client { let hashed_key = binary_fuse_filter::hash_of_key(key); let hash = binary_fuse_filter::mix256(&hashed_key, &self.filter.seed); - let recovered_row = (0..response_vector.num_cols()) + let recovered_row = (0..response_vector.num_cols() as usize) .map(|idx| { let unscaled_res = response_vector[(0, idx)].wrapping_sub(secret_vec_c[(0, idx)]); diff --git a/src/lib.rs b/src/lib.rs index 24f3a66..251e16f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ //! * **Secure Private Information Retrieval:** Allows clients to retrieve value from a PIR server without disclosing corresponding key. Server learns neither the value nor the queried key. //! * **Error Handling:** Comprehensive error handling to catch and report issues during setup, query generation, and response processing. //! * **Flexibility:** Supports both 3-wise and 4-wise XOR Binary Fuse Filters, allowing a choice between trade-offs in client/server computation and communication costs. +//! * **Efficient:** It supports offloading parts of the server-setup phase to a GPU, using Vulkan Compute API, which can drastically reduce time taken to setup PIR server, for large key-value databases. //! //! ## Usage //! @@ -18,7 +19,9 @@ //! //! ```toml //! [dependencies] -//! chalametpir = "=0.4.0" +//! chalametpir = "=0.5.0" +//! # Or, if you want to offload server-setup to GPU. +//! # chalamet_pir = { version = "=0.5.0", features = ["gpu"] } //! rand = "=0.9.0" //! rand_chacha = "=0.9.0" //! ``` diff --git a/src/pir_internals/error.rs b/src/pir_internals/error.rs index 00de65c..19cfbaa 100644 --- a/src/pir_internals/error.rs +++ b/src/pir_internals/error.rs @@ -6,6 +6,21 @@ use std::{error::Error, fmt::Display}; /// It includes errors related to matrix operations, binary fuse filter operations, and PIR operations. #[derive(Debug, PartialEq)] pub enum ChalametPIRError { + // GPU + VulkanLibraryNotFound, + VulkanInstanceCreationFailed, + VulkanPhysicalDeviceNotFound, + VulkanDeviceCreationFailed, + VulkanBufferCreationFailed, + VulkanCommandBufferBuilderCreationFailed, + VulkanCommandBufferRecordingFailed, + VulkanCommandBufferBuildingFailed, + VulkanCommandBufferExecutionFailed, + VulkanReadingFromBufferFailed, + VulkanComputeShaderLoadingFailed, + VulkanComputePipelineCreationFailed, + VulkanDescriptorSetCreationFailed, + // Matrix InvalidMatrixDimension, IncompatibleDimensionForMatrixMultiplication, @@ -36,6 +51,20 @@ pub enum ChalametPIRError { impl Display for ChalametPIRError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + Self::VulkanLibraryNotFound => write!(f, "Failed to load the default Vulkan library for the system."), + Self::VulkanInstanceCreationFailed => write!(f, "Failed to create a new instance of Vulkan."), + Self::VulkanPhysicalDeviceNotFound => write!(f, "Failed to find a compatible Vulkan physical device."), + Self::VulkanDeviceCreationFailed => write!(f, "Failed to create a Vulkan device and associated queue."), + Self::VulkanBufferCreationFailed => write!(f, "Failed to create a Vulkan transfer source buffer."), + Self::VulkanCommandBufferBuilderCreationFailed => write!(f, "Failed to create a Vulkan command buffer builder."), + Self::VulkanCommandBufferRecordingFailed => write!(f, "Failed to record command in a Vulkan command buffer."), + Self::VulkanCommandBufferBuildingFailed => write!(f, "Failed to build a Vulkan command buffer."), + Self::VulkanCommandBufferExecutionFailed => write!(f, "Failed to execute the Vulkan command buffer."), + Self::VulkanReadingFromBufferFailed => write!(f, "Failed to read from Vulkan buuffer."), + Self::VulkanComputeShaderLoadingFailed => write!(f, "Failed to load Vulkan compute shader module."), + Self::VulkanComputePipelineCreationFailed => write!(f, "Failed to create Vulkan compute pipeline."), + Self::VulkanDescriptorSetCreationFailed => write!(f, "Failed to create descriptor set for Vulkan compute pipeline."), + Self::InvalidMatrixDimension => write!(f, "The number of rows and columns in the matrix must be non-zero."), Self::IncompatibleDimensionForMatrixMultiplication => write!(f, "The matrix dimensions do not allow multiplication."), Self::IncompatibleDimensionForMatrixAddition => write!(f, "The matrix dimensions do not allow addition."), diff --git a/src/pir_internals/gpu.rs b/src/pir_internals/gpu.rs new file mode 100644 index 0000000..d22abab --- /dev/null +++ b/src/pir_internals/gpu.rs @@ -0,0 +1,288 @@ +pub use std::sync::Arc; +pub use vulkano::{ + buffer::Subbuffer, + command_buffer::allocator::StandardCommandBufferAllocator, + device::{Device, Queue}, + memory::allocator::StandardMemoryAllocator, +}; + +use super::{mat_transpose_shader, mat_x_mat_shader, matrix::Matrix}; +use crate::ChalametPIRError; +use vulkano::{ + VulkanLibrary, + buffer::{Buffer, BufferCreateInfo, BufferUsage}, + command_buffer::{AutoCommandBufferBuilder, CommandBufferUsage, CopyBufferInfo, PrimaryCommandBufferAbstract}, + descriptor_set::{DescriptorSet, WriteDescriptorSet, allocator::StandardDescriptorSetAllocator}, + device::{DeviceCreateInfo, DeviceExtensions, QueueCreateInfo, QueueFlags, physical::PhysicalDeviceType}, + instance::{Instance, InstanceCreateFlags, InstanceCreateInfo}, + memory::allocator::{AllocationCreateInfo, MemoryTypeFilter}, + pipeline::{ + ComputePipeline, Pipeline, PipelineBindPoint, PipelineLayout, PipelineShaderStageCreateInfo, compute::ComputePipelineCreateInfo, + layout::PipelineDescriptorSetLayoutCreateInfo, + }, + sync::GpuFuture, +}; + +pub fn setup_gpu() -> Result<(Arc, Arc, Arc, Arc), ChalametPIRError> { + let library = VulkanLibrary::new().map_err(|_| ChalametPIRError::VulkanLibraryNotFound)?; + let instance = Instance::new( + library, + InstanceCreateInfo { + flags: InstanceCreateFlags::ENUMERATE_PORTABILITY, + ..Default::default() + }, + ) + .map_err(|_| ChalametPIRError::VulkanInstanceCreationFailed)?; + + let device_extensions = DeviceExtensions { + khr_storage_buffer_storage_class: true, + ..DeviceExtensions::empty() + }; + + let (physical_device, queue_family_index) = instance + .enumerate_physical_devices() + .map_err(|_| ChalametPIRError::VulkanPhysicalDeviceNotFound)? + .filter(|p| p.supported_extensions().contains(&device_extensions)) + .filter_map(|p| { + p.queue_family_properties() + .iter() + .position(|q| q.queue_flags.intersects(QueueFlags::COMPUTE | QueueFlags::TRANSFER)) + .map(|i| (p, i as u32)) + }) + .min_by_key(|(p, _)| match p.properties().device_type { + PhysicalDeviceType::DiscreteGpu => 0, + PhysicalDeviceType::IntegratedGpu => 1, + PhysicalDeviceType::VirtualGpu => 2, + PhysicalDeviceType::Cpu => 3, + PhysicalDeviceType::Other => 4, + _ => 5, + }) + .ok_or(ChalametPIRError::VulkanPhysicalDeviceNotFound)?; + + let (device, mut queues) = Device::new( + physical_device, + DeviceCreateInfo { + enabled_extensions: device_extensions, + queue_create_infos: vec![QueueCreateInfo { + queue_family_index, + ..Default::default() + }], + ..Default::default() + }, + ) + .map_err(|_| ChalametPIRError::VulkanDeviceCreationFailed)?; + let queue = queues.next().ok_or(ChalametPIRError::VulkanDeviceCreationFailed)?; + + let memory_allocator = Arc::new(StandardMemoryAllocator::new_default(device.clone())); + let command_buffer_allocator = Arc::new(StandardCommandBufferAllocator::new(device.clone(), Default::default())); + + Ok((device, queue, memory_allocator, command_buffer_allocator)) +} + +pub fn transfer_mat_to_device( + queue: Arc, + mem_alloc: Arc, + cmd_buf_alloc: Arc, + matrix: Matrix, +) -> Result, ChalametPIRError> { + let matrix_as_bytes = matrix.to_bytes(); + let matrix_byte_len = matrix_as_bytes.len() as u64; + + let src_buf = Buffer::from_iter( + mem_alloc.clone(), + BufferCreateInfo { + usage: BufferUsage::TRANSFER_SRC, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::HOST_SEQUENTIAL_WRITE | MemoryTypeFilter::PREFER_DEVICE, + ..Default::default() + }, + matrix_as_bytes, + ) + .map_err(|_| ChalametPIRError::VulkanBufferCreationFailed)?; + + let dst_buf = Buffer::new_slice::( + mem_alloc.clone(), + BufferCreateInfo { + usage: BufferUsage::STORAGE_BUFFER | BufferUsage::TRANSFER_DST, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::PREFER_DEVICE, + ..Default::default() + }, + matrix_byte_len, + ) + .map_err(|_| ChalametPIRError::VulkanBufferCreationFailed)?; + + let cmd_buf = { + let mut builder = AutoCommandBufferBuilder::primary(cmd_buf_alloc, queue.queue_family_index(), CommandBufferUsage::OneTimeSubmit) + .map_err(|_| ChalametPIRError::VulkanCommandBufferBuilderCreationFailed)?; + + builder + .copy_buffer(CopyBufferInfo::buffers(src_buf, dst_buf.clone())) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)?; + + builder.build().map_err(|_| ChalametPIRError::VulkanCommandBufferBuildingFailed)? + }; + + cmd_buf + .execute(queue) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .then_signal_fence_and_flush() + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .wait(None) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)?; + + Ok(dst_buf) +} + +pub fn get_empty_host_readable_buffer(memory_allocator: Arc, byte_len: u64) -> Result, ChalametPIRError> { + Buffer::new_slice::( + memory_allocator.clone(), + BufferCreateInfo { + usage: BufferUsage::STORAGE_BUFFER, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::HOST_SEQUENTIAL_WRITE | MemoryTypeFilter::PREFER_DEVICE, + ..Default::default() + }, + byte_len, + ) + .map_err(|_| ChalametPIRError::VulkanBufferCreationFailed) +} + +pub fn mat_x_mat( + device: Arc, + queue: Arc, + command_buffer_allocator: Arc, + left_mat: Subbuffer<[u8]>, + rhs_mat: Subbuffer<[u8]>, + res_mat: Subbuffer<[u8]>, + wg_count: [u32; 3], +) -> Result<(), ChalametPIRError> { + let pipeline = { + let cs = mat_x_mat_shader::load(device.clone()).map_err(|_| ChalametPIRError::VulkanComputeShaderLoadingFailed)?; + let cs_entry_point = cs.entry_point("main").ok_or(ChalametPIRError::VulkanComputeShaderLoadingFailed)?; + let compute_stage = PipelineShaderStageCreateInfo::new(cs_entry_point); + + let layout = PipelineLayout::new( + device.clone(), + PipelineDescriptorSetLayoutCreateInfo::from_stages([&compute_stage]) + .into_pipeline_layout_create_info(device.clone()) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)?, + ) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)?; + + ComputePipeline::new(device.clone(), None, ComputePipelineCreateInfo::stage_layout(compute_stage, layout.clone())) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)? + }; + + let descriptor_set_allocator = Arc::new(StandardDescriptorSetAllocator::new(device.clone(), Default::default())); + let descriptor_set_layout = pipeline.layout().set_layouts()[0].clone(); + let descriptor_set = DescriptorSet::new( + descriptor_set_allocator, + descriptor_set_layout, + [ + WriteDescriptorSet::buffer(0, left_mat), + WriteDescriptorSet::buffer(1, rhs_mat), + WriteDescriptorSet::buffer(2, res_mat), + ], + [], + ) + .map_err(|_| ChalametPIRError::VulkanDescriptorSetCreationFailed)?; + + let command_buffer = { + let mut command_buffer_builder = + AutoCommandBufferBuilder::primary(command_buffer_allocator, queue.queue_family_index(), CommandBufferUsage::OneTimeSubmit) + .map_err(|_| ChalametPIRError::VulkanCommandBufferBuilderCreationFailed)?; + + unsafe { + command_buffer_builder + .bind_pipeline_compute(pipeline.clone()) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)? + .bind_descriptor_sets(PipelineBindPoint::Compute, pipeline.layout().clone(), 0, descriptor_set) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)? + .dispatch(wg_count) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)?; + } + + command_buffer_builder + .build() + .map_err(|_| ChalametPIRError::VulkanCommandBufferBuildingFailed)? + }; + + command_buffer + .execute(queue.clone()) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .then_signal_fence_and_flush() + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .wait(None) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed) +} + +pub fn mat_transpose( + device: Arc, + queue: Arc, + command_buffer_allocator: Arc, + orig_mat: Subbuffer<[u8]>, + res_mat: Subbuffer<[u8]>, + wg_count: [u32; 3], +) -> Result<(), ChalametPIRError> { + let pipeline = { + let cs = mat_transpose_shader::load(device.clone()).map_err(|_| ChalametPIRError::VulkanComputeShaderLoadingFailed)?; + let cs_entry_point = cs.entry_point("main").ok_or(ChalametPIRError::VulkanComputeShaderLoadingFailed)?; + let compute_stage = PipelineShaderStageCreateInfo::new(cs_entry_point); + + let layout = PipelineLayout::new( + device.clone(), + PipelineDescriptorSetLayoutCreateInfo::from_stages([&compute_stage]) + .into_pipeline_layout_create_info(device.clone()) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)?, + ) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)?; + + ComputePipeline::new(device.clone(), None, ComputePipelineCreateInfo::stage_layout(compute_stage, layout.clone())) + .map_err(|_| ChalametPIRError::VulkanComputePipelineCreationFailed)? + }; + + let descriptor_set_allocator = Arc::new(StandardDescriptorSetAllocator::new(device.clone(), Default::default())); + let descriptor_set_layout = pipeline.layout().set_layouts()[0].clone(); + let descriptor_set = DescriptorSet::new( + descriptor_set_allocator, + descriptor_set_layout, + [WriteDescriptorSet::buffer(0, orig_mat), WriteDescriptorSet::buffer(1, res_mat)], + [], + ) + .map_err(|_| ChalametPIRError::VulkanDescriptorSetCreationFailed)?; + + let command_buffer = { + let mut command_buffer_builder = + AutoCommandBufferBuilder::primary(command_buffer_allocator, queue.queue_family_index(), CommandBufferUsage::OneTimeSubmit) + .map_err(|_| ChalametPIRError::VulkanCommandBufferBuilderCreationFailed)?; + + unsafe { + command_buffer_builder + .bind_pipeline_compute(pipeline.clone()) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)? + .bind_descriptor_sets(PipelineBindPoint::Compute, pipeline.layout().clone(), 0, descriptor_set) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)? + .dispatch(wg_count) + .map_err(|_| ChalametPIRError::VulkanCommandBufferRecordingFailed)?; + } + + command_buffer_builder + .build() + .map_err(|_| ChalametPIRError::VulkanCommandBufferBuildingFailed)? + }; + + command_buffer + .execute(queue.clone()) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .then_signal_fence_and_flush() + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed)? + .wait(None) + .map_err(|_| ChalametPIRError::VulkanCommandBufferExecutionFailed) +} diff --git a/src/pir_internals/mat_transpose_shader.rs b/src/pir_internals/mat_transpose_shader.rs new file mode 100644 index 0000000..d49b087 --- /dev/null +++ b/src/pir_internals/mat_transpose_shader.rs @@ -0,0 +1,5 @@ +vulkano_shaders::shader! { + ty: "compute", + path: "./shaders/mat_transpose.glsl", + vulkan_version: "1.2", +} diff --git a/src/pir_internals/mat_x_mat_shader.rs b/src/pir_internals/mat_x_mat_shader.rs new file mode 100644 index 0000000..eb2a928 --- /dev/null +++ b/src/pir_internals/mat_x_mat_shader.rs @@ -0,0 +1,5 @@ +vulkano_shaders::shader! { + ty: "compute", + path: "./shaders/mat_x_mat.glsl", + vulkan_version: "1.2", +} diff --git a/src/pir_internals/matrix.rs b/src/pir_internals/matrix.rs index 236457b..4f78317 100644 --- a/src/pir_internals/matrix.rs +++ b/src/pir_internals/matrix.rs @@ -19,8 +19,8 @@ use super::error::ChalametPIRError; #[derive(Clone, Debug, PartialEq)] pub struct Matrix { - rows: usize, - cols: usize, + rows: u32, + cols: u32, elems: Vec, } @@ -36,12 +36,12 @@ impl Matrix { /// /// * `Result` - A new matrix if the input is valid (rows and cols are positive). /// Returns an error if either rows or cols is zero. - pub fn new(rows: usize, cols: usize) -> Result { + pub fn new(rows: u32, cols: u32) -> Result { if branch_opt_util::likely((rows > 0) && (cols > 0)) { Ok(Matrix { rows, cols, - elems: vec![0; rows * cols], + elems: vec![0; (rows * cols) as usize], }) } else { Err(ChalametPIRError::InvalidMatrixDimension) @@ -60,9 +60,9 @@ impl Matrix { /// /// * `Result` - A new matrix if the input is valid (rows and cols are positive and the number of values matches the number of required elements). /// Returns an error if either rows or cols is zero, or if the number of values does not match the number of required elements. - pub fn from_values(rows: usize, cols: usize, values: Vec) -> Result { + pub fn from_values(rows: u32, cols: u32, values: Vec) -> Result { if branch_opt_util::likely((rows > 0) && (cols > 0)) { - if branch_opt_util::likely(rows * cols == values.len()) { + if branch_opt_util::likely((rows * cols) as usize == values.len()) { Ok(Matrix { rows, cols, elems: values }) } else { Err(ChalametPIRError::InvalidNumberOfElementsInMatrix) @@ -73,17 +73,21 @@ impl Matrix { } #[inline(always)] - pub const fn num_rows(&self) -> usize { + pub const fn num_rows(&self) -> u32 { self.rows } #[inline(always)] - pub const fn num_cols(&self) -> usize { + pub const fn num_cols(&self) -> u32 { self.cols } #[inline(always)] pub fn num_elems(&self) -> usize { self.elems.len() } + #[inline(always)] + pub fn num_bytes(&self) -> usize { + std::mem::size_of_val(&self.rows) + std::mem::size_of_val(&self.cols) + std::mem::size_of::() * (self.rows * self.cols) as usize + } /// Performs the multiplication of a row vector (1xN matrix) by the transpose of a matrix (MxN). /// @@ -103,13 +107,13 @@ impl Matrix { let res_num_rows = self.rows; let res_num_cols = rhs.rows; - let mut res_elems = vec![0u32; res_num_rows * res_num_cols]; + let mut res_elems = vec![0u32; (res_num_rows * res_num_cols) as usize]; res_elems.par_iter_mut().enumerate().for_each(|(lin_idx, v)| { let r_idx = 0; let c_idx = lin_idx; - *v = (0..self.cols).fold(0u32, |acc, k| acc.wrapping_add(self[(r_idx, k)].wrapping_mul(rhs[(c_idx, k)]))); + *v = (0..self.cols as usize).fold(0u32, |acc, k| acc.wrapping_add(self[(r_idx, k)].wrapping_mul(rhs[(c_idx, k)]))); }); Matrix::from_values(res_num_rows, res_num_cols, res_elems) @@ -126,14 +130,14 @@ impl Matrix { /// * `Result` - A new identity matrix if the input is valid (rows is positive). /// Returns an error if rows is zero. #[cfg(test)] - pub fn identity(rows: usize) -> Result { + pub fn identity(rows: u32) -> Result { if branch_opt_util::unlikely(rows == 0) { return Err(ChalametPIRError::InvalidMatrixDimension); } let mut mat = Matrix::new(rows, rows)?; - (0..rows).for_each(|idx| { + (0..mat.rows as usize).for_each(|idx| { mat[(idx, idx)] = 1; }); @@ -148,8 +152,8 @@ impl Matrix { pub fn transpose(&self) -> Matrix { let mut res = unsafe { Matrix::new(self.cols, self.rows).unwrap_unchecked() }; - (0..self.cols) - .flat_map(|ridx| (0..self.rows).map(move |cidx| (ridx, cidx))) + (0..self.cols as usize) + .flat_map(|ridx| (0..self.rows as usize).map(move |cidx| (ridx, cidx))) .for_each(|(ridx, cidx)| { res[(ridx, cidx)] = self[(cidx, ridx)]; }); @@ -169,12 +173,12 @@ impl Matrix { /// /// * `Result` - A new matrix if the input is valid (rows and cols are positive). /// Returns an error if either rows or cols is zero. - pub fn generate_from_seed(rows: usize, cols: usize, seed: &[u8; SEED_BYTE_LEN]) -> Result { + pub fn generate_from_seed(rows: u32, cols: u32, seed: &[u8; SEED_BYTE_LEN]) -> Result { let mut hasher = TurboShake128::default(); hasher.absorb(seed); hasher.finalize::<{ TurboShake128::DEFAULT_DOMAIN_SEPARATOR }>(); - let mut elems = vec![0u32; rows * cols]; + let mut elems = vec![0u32; (rows * cols) as usize]; let elems_byte_len = elems.len() * std::mem::size_of::(); unsafe { @@ -200,7 +204,7 @@ impl Matrix { /// /// * `Result` - A new row/ column vector if the input is valid (rows or cols is 1). /// Returns an error if neither rows nor cols is 1. - pub fn sample_from_uniform_ternary_dist(rows: usize, cols: usize) -> Result { + pub fn sample_from_uniform_ternary_dist(rows: u32, cols: u32) -> Result { if branch_opt_util::unlikely(!(rows == 1 || cols == 1)) { return Err(ChalametPIRError::InvalidDimensionForVector); } @@ -211,7 +215,7 @@ impl Matrix { let mut rng = ChaCha8Rng::from_os_rng(); let mut vec = Matrix::new(rows, cols)?; - let num_elems = rows * cols; + let num_elems = vec.num_elems(); let mut elem_idx = 0; while branch_opt_util::likely(elem_idx < num_elems) { @@ -318,8 +322,8 @@ impl Matrix { let max_value_byte_len = unsafe { db.values().map(|v| v.len()).max().unwrap_unchecked() }; let max_value_bit_len = max_value_byte_len * 8; - let rows = filter.num_fingerprints; - let cols: usize = (HASHED_KEY_BIT_LEN + max_value_bit_len + 8).div_ceil(mat_elem_bit_len); + let rows = filter.num_fingerprints as u32; + let cols = (HASHED_KEY_BIT_LEN + max_value_bit_len + 8).div_ceil(mat_elem_bit_len) as u32; let mut mat = Matrix::new(rows, cols)?; let mat_elem_mask = (1u32 << mat_elem_bit_len) - 1; @@ -346,13 +350,13 @@ impl Matrix { let mat_row_idx1 = h012[found + 1] as usize; let mat_row_idx2 = h012[found + 2] as usize; - let elems = (0..cols) + let elems = (0..cols as usize) .map(|elem_idx| { - let f1 = mat.elems[mat_row_idx1 * cols + elem_idx]; + let f1 = mat.elems[mat_row_idx1 * cols as usize + elem_idx]; (elem_idx, row[elem_idx].wrapping_sub(f1)) }) .map(|(elem_idx, elem)| { - let f2 = mat.elems[mat_row_idx2 * cols + elem_idx]; + let f2 = mat.elems[mat_row_idx2 * cols as usize + elem_idx]; (elem_idx, elem.wrapping_sub(f2) & mat_elem_mask) }) .map(|(elem_idx, elem)| { @@ -361,8 +365,8 @@ impl Matrix { }) .collect::>(); - let fingerprints_begin_at = mat_row_idx0 * cols; - let fingerprints_end_at = fingerprints_begin_at + cols; + let fingerprints_begin_at = mat_row_idx0 * cols as usize; + let fingerprints_end_at = fingerprints_begin_at + cols as usize; mat.elems[fingerprints_begin_at..fingerprints_end_at].copy_from_slice(&elems); } @@ -396,10 +400,10 @@ impl Matrix { let (h0, h1, h2) = binary_fuse_filter::hash_batch_for_3_wise_xor_filter(hash, filter.segment_length, filter.segment_count_length); - let recovered_row = (0..self.cols) - .map(|elem_idx| (elem_idx, self.elems[h0 as usize * self.cols + elem_idx])) - .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h1 as usize * self.cols + elem_idx]))) - .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h2 as usize * self.cols + elem_idx]))) + let recovered_row = (0..self.cols as usize) + .map(|elem_idx| (elem_idx, self.elems[h0 as usize * self.cols as usize + elem_idx])) + .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h1 as usize * self.cols as usize + elem_idx]))) + .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h2 as usize * self.cols as usize + elem_idx]))) .map(|(elem_idx, elem)| elem.wrapping_add((binary_fuse_filter::mix(hash, elem_idx as u64) as u32) & mat_elem_mask) & mat_elem_mask) .collect::>(); @@ -450,8 +454,8 @@ impl Matrix { let max_value_byte_len = unsafe { db.values().map(|v| v.len()).max().unwrap_unchecked() }; let max_value_bit_len = max_value_byte_len * 8; - let rows = filter.num_fingerprints; - let cols: usize = (HASHED_KEY_BIT_LEN + max_value_bit_len + 8).div_ceil(mat_elem_bit_len); + let rows = filter.num_fingerprints as u32; + let cols = (HASHED_KEY_BIT_LEN + max_value_bit_len + 8).div_ceil(mat_elem_bit_len) as u32; let mut mat = Matrix::new(rows, cols)?; let mat_elem_mask = (1u32 << mat_elem_bit_len) - 1; @@ -481,17 +485,17 @@ impl Matrix { let mat_row_idx2 = h0123[found + 2] as usize; let mat_row_idx3 = h0123[found + 3] as usize; - let elems = (0..cols) + let elems = (0..cols as usize) .map(|elem_idx| { - let f1 = mat.elems[mat_row_idx1 * cols + elem_idx]; + let f1 = mat.elems[mat_row_idx1 * cols as usize + elem_idx]; (elem_idx, row[elem_idx].wrapping_sub(f1)) }) .map(|(elem_idx, elem)| { - let f2 = mat.elems[mat_row_idx2 * cols + elem_idx]; + let f2 = mat.elems[mat_row_idx2 * cols as usize + elem_idx]; (elem_idx, elem.wrapping_sub(f2) & mat_elem_mask) }) .map(|(elem_idx, elem)| { - let f2 = mat.elems[mat_row_idx3 * cols + elem_idx]; + let f2 = mat.elems[mat_row_idx3 * cols as usize + elem_idx]; (elem_idx, elem.wrapping_sub(f2) & mat_elem_mask) }) .map(|(elem_idx, elem)| { @@ -500,8 +504,8 @@ impl Matrix { }) .collect::>(); - let fingerprints_begin_at = mat_row_idx0 * cols; - let fingerprints_end_at = fingerprints_begin_at + cols; + let fingerprints_begin_at = mat_row_idx0 * cols as usize; + let fingerprints_end_at = fingerprints_begin_at + cols as usize; mat.elems[fingerprints_begin_at..fingerprints_end_at].copy_from_slice(&elems); } @@ -535,11 +539,11 @@ impl Matrix { let (h0, h1, h2, h3) = binary_fuse_filter::hash_batch_for_4_wise_xor_filter(hash, filter.segment_length, filter.segment_count_length); - let recovered_row = (0..self.cols) - .map(|elem_idx| (elem_idx, self.elems[h0 as usize * self.cols + elem_idx])) - .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h1 as usize * self.cols + elem_idx]))) - .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h2 as usize * self.cols + elem_idx]))) - .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h3 as usize * self.cols + elem_idx]))) + let recovered_row = (0..self.cols as usize) + .map(|elem_idx| (elem_idx, self.elems[h0 as usize * self.cols as usize + elem_idx])) + .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h1 as usize * self.cols as usize + elem_idx]))) + .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h2 as usize * self.cols as usize + elem_idx]))) + .map(|(elem_idx, elem)| (elem_idx, elem.wrapping_add(self.elems[h3 as usize * self.cols as usize + elem_idx]))) .map(|(elem_idx, elem)| elem.wrapping_add((binary_fuse_filter::mix(hash, elem_idx as u64) as u32) & mat_elem_mask) & mat_elem_mask) .collect::>(); @@ -567,7 +571,7 @@ impl Matrix { } pub fn to_bytes(&self) -> Vec { - let encoded_elems_byte_len = std::mem::size_of::() * self.rows * self.cols; + let encoded_elems_byte_len = std::mem::size_of::() * (self.rows * self.cols) as usize; let offset0 = 0; let offset1 = offset0 + std::mem::size_of_val(&self.rows); @@ -594,8 +598,8 @@ impl Matrix { pub fn from_bytes(bytes: &[u8]) -> Result { const OFFSET0: usize = 0; - const OFFSET1: usize = OFFSET0 + std::mem::size_of::(); - const OFFSET2: usize = OFFSET1 + std::mem::size_of::(); + const OFFSET1: usize = OFFSET0 + std::mem::size_of::(); + const OFFSET2: usize = OFFSET1 + std::mem::size_of::(); if branch_opt_util::unlikely(bytes.len() <= OFFSET2) { return Err(ChalametPIRError::FailedToDeserializeMatrixFromBytes); @@ -603,11 +607,11 @@ impl Matrix { let (rows, cols) = unsafe { ( - usize::from_le_bytes(bytes.get_unchecked(OFFSET0..OFFSET1).try_into().unwrap()), - usize::from_le_bytes(bytes.get_unchecked(OFFSET1..OFFSET2).try_into().unwrap()), + u32::from_le_bytes(bytes.get_unchecked(OFFSET0..OFFSET1).try_into().unwrap()), + u32::from_le_bytes(bytes.get_unchecked(OFFSET1..OFFSET2).try_into().unwrap()), ) }; - let num_elems = rows * cols; + let num_elems = (rows * cols) as usize; if branch_opt_util::unlikely(num_elems == 0) { return Err(ChalametPIRError::FailedToDeserializeMatrixFromBytes); @@ -638,7 +642,7 @@ impl Index<(usize, usize)> for Matrix { #[inline(always)] fn index(&self, index: (usize, usize)) -> &Self::Output { let (ridx, cidx) = index; - unsafe { self.elems.get_unchecked(ridx * self.cols + cidx) } + unsafe { self.elems.get_unchecked(ridx * self.cols as usize + cidx) } } } @@ -646,7 +650,7 @@ impl IndexMut<(usize, usize)> for Matrix { #[inline(always)] fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output { let (ridx, cidx) = index; - unsafe { self.elems.get_unchecked_mut(ridx * self.cols + cidx) } + unsafe { self.elems.get_unchecked_mut(ridx * self.cols as usize + cidx) } } } @@ -667,13 +671,13 @@ impl<'b> Mul<&'b Matrix> for &Matrix { return Err(ChalametPIRError::IncompatibleDimensionForMatrixMultiplication); } - let mut res_elems = vec![0u32; self.rows * rhs.cols]; + let mut res_elems = vec![0u32; (self.rows * rhs.cols) as usize]; res_elems.par_iter_mut().enumerate().for_each(|(lin_idx, v)| { - let r_idx = lin_idx / rhs.cols; - let c_idx = lin_idx - r_idx * rhs.cols; + let r_idx = lin_idx / rhs.cols as usize; + let c_idx = lin_idx - r_idx * rhs.cols as usize; - *v = (0..self.cols).fold(0u32, |acc, k| acc.wrapping_add(self[(r_idx, k)].wrapping_mul(rhs[(k, c_idx)]))); + *v = (0..self.cols as usize).fold(0u32, |acc, k| acc.wrapping_add(self[(r_idx, k)].wrapping_mul(rhs[(k, c_idx)]))); }); Matrix::from_values(self.rows, rhs.cols, res_elems) @@ -697,7 +701,7 @@ impl<'b> Add<&'b Matrix> for &Matrix { return Err(ChalametPIRError::IncompatibleDimensionForMatrixAddition); } - let mut res_elems = vec![0u32; self.rows * rhs.cols]; + let mut res_elems = vec![0u32; (self.rows * rhs.cols) as usize]; res_elems.par_iter_mut().enumerate().for_each(|(lin_idx, v)| { *v = unsafe { self.elems.get_unchecked(lin_idx).wrapping_add(*rhs.elems.get_unchecked(lin_idx)) }; @@ -850,7 +854,7 @@ pub mod test { #[test_case(0, 1024 => matches Err(ChalametPIRError::InvalidMatrixDimension); "Number of rows must be greater than zero")] #[test_case(1024, 0 => matches Err(ChalametPIRError::InvalidMatrixDimension); "Number of columns must be greater than zero")] #[test_case(0, 0 => matches Err(ChalametPIRError::InvalidMatrixDimension); "Both number of rows and columns must be greater than zero")] - fn new_empty_matrix_constructor_api(num_rows: usize, num_cols: usize) -> Result { + fn new_empty_matrix_constructor_api(num_rows: u32, num_cols: u32) -> Result { Matrix::new(num_rows, num_cols) } @@ -859,13 +863,13 @@ pub mod test { #[test_case(1024, 0, vec![] => matches Err(ChalametPIRError::InvalidMatrixDimension); "Number of columns must be greater than zero")] #[test_case(0, 0, vec![] => matches Err(ChalametPIRError::InvalidMatrixDimension); "Both number of rows and columns must be greater than zero")] #[test_case(1024, 1024, vec![0u32; 1024 * 1024 -1] => matches Err(ChalametPIRError::InvalidNumberOfElementsInMatrix); "Number of elements must be equal to number of rows times number of columns")] - fn from_values_matrix_constructor_api(num_rows: usize, num_cols: usize, elems: Vec) -> Result { + fn from_values_matrix_constructor_api(num_rows: u32, num_cols: u32, elems: Vec) -> Result { Matrix::from_values(num_rows, num_cols, elems) } #[test_case((1024,1),(1,1024) => matches Ok(_); "Matrix multiplication should work for valid dimensions")] #[test_case((1024,1),(1024, 1) => matches Err(ChalametPIRError::IncompatibleDimensionForMatrixMultiplication); "Matrix multiplication should not work for incompatible dimensions")] - fn matrix_multiplication_failures(lhs_mat_dim: (usize, usize), rhs_mat_dim: (usize, usize)) -> Result { + fn matrix_multiplication_failures(lhs_mat_dim: (u32, u32), rhs_mat_dim: (u32, u32)) -> Result { let (lhs_mat_rows, lhs_mat_cols) = lhs_mat_dim; let lhs_mat = Matrix::new(lhs_mat_rows, lhs_mat_cols)?; @@ -877,7 +881,7 @@ pub mod test { #[test_case((1024,1),(1024, 1) => matches Ok(_); "Matrix addition should work for valid dimensions")] #[test_case((1024,1),(1, 1024) => matches Err(ChalametPIRError::IncompatibleDimensionForMatrixAddition); "Matrix addition should not work for incompatible dimensions")] - fn matrix_addition_failures(lhs_mat_dim: (usize, usize), rhs_mat_dim: (usize, usize)) -> Result { + fn matrix_addition_failures(lhs_mat_dim: (u32, u32), rhs_mat_dim: (u32, u32)) -> Result { let (lhs_mat_rows, lhs_mat_cols) = lhs_mat_dim; let lhs_mat = Matrix::new(lhs_mat_rows, lhs_mat_cols)?; @@ -890,8 +894,8 @@ pub mod test { #[test] fn matrix_multiplication_is_correct() { const NUM_ATTEMPT_MATRIX_MULTIPLICATIONS: usize = 100; - const MIN_MATRIX_DIM: usize = 1; - const MAX_MATRIX_DIM: usize = 1024; + const MIN_MATRIX_DIM: u32 = 1; + const MAX_MATRIX_DIM: u32 = 1024; let mut rng = ChaCha8Rng::from_os_rng(); @@ -920,8 +924,8 @@ pub mod test { #[test] fn row_vector_transposed_matrix_multiplication_works() { const NUM_ATTEMPT_VECTOR_MATRIX_MULTIPLICATIONS: usize = 100; - const MIN_ROW_VECTOR_DIM: usize = 1; - const MAX_ROW_VECTOR_DIM: usize = 1024; + const MIN_ROW_VECTOR_DIM: u32 = 1; + const MAX_ROW_VECTOR_DIM: u32 = 1024; let mut rng = ChaCha8Rng::from_os_rng(); @@ -934,9 +938,10 @@ pub mod test { let vec_num_cols = rng.random_range(MIN_ROW_VECTOR_DIM..=MAX_ROW_VECTOR_DIM); let mat_num_rows = vec_num_cols; let mat_num_cols = rng.random_range(MIN_ROW_VECTOR_DIM..=MAX_ROW_VECTOR_DIM); + let mat_num_elems = (mat_num_rows * mat_num_cols) as usize; let row_vector = Matrix::generate_from_seed(vec_num_rows, vec_num_cols, &seed).expect("Row vector must be generated from seed"); - let all_ones = Matrix::from_values(mat_num_rows, mat_num_cols, vec![1; mat_num_rows * mat_num_cols]).expect("Matrix of ones must be created"); + let all_ones = Matrix::from_values(mat_num_rows, mat_num_cols, vec![1; mat_num_elems]).expect("Matrix of ones must be created"); let transposed_all_ones = all_ones.transpose(); let res_row_vector = row_vector @@ -945,7 +950,9 @@ pub mod test { let expected_res_row_vector = { let sum_of_elems_in_row_vector = row_vector.elems.iter().fold(0u32, |acc, &cur| acc.wrapping_add(cur)); - Matrix::from_values(vec_num_rows, mat_num_cols, vec![sum_of_elems_in_row_vector; mat_num_cols]).expect("Expected row vector must be created") + let row_vec_elems = vec![sum_of_elems_in_row_vector; mat_num_cols as usize]; + + Matrix::from_values(vec_num_rows, mat_num_cols, row_vec_elems).expect("Expected row vector must be created") }; assert_eq!(expected_res_row_vector, res_row_vector); @@ -956,8 +963,8 @@ pub mod test { #[test] fn matrix_addition_is_correct() { const NUM_ATTEMPT_MATRIX_ADDITIONS: usize = 100; - const MIN_MATRIX_DIM: usize = 1; - const MAX_MATRIX_DIM: usize = 1024; + const MIN_MATRIX_DIM: u32 = 1; + const MAX_MATRIX_DIM: u32 = 1024; let mut rng = ChaCha8Rng::from_os_rng(); @@ -987,7 +994,7 @@ pub mod test { #[test_case(1024, 1024 => matches Err(ChalametPIRError::InvalidDimensionForVector); "Either number of rows or columns must be 1 in vector")] #[test_case(0, 1024 => matches Err(ChalametPIRError::InvalidDimensionForVector); "Number of rows in row vector must be 1")] #[test_case(1024, 0 => matches Err(ChalametPIRError::InvalidDimensionForVector); "Number of columns in column vector must be 1")] - fn sampling_from_uniform_ternary_dist_works(num_rows: usize, num_cols: usize) -> Result { + fn sampling_from_uniform_ternary_dist_works(num_rows: u32, num_cols: u32) -> Result { Matrix::sample_from_uniform_ternary_dist(num_rows, num_cols) } @@ -1012,8 +1019,8 @@ pub mod test { #[test] fn serialized_matrix_can_be_deserialized() { const NUM_ATTEMPT_MATRIX_SERIALIZATIONS: usize = 100; - const MIN_MATRIX_DIM: usize = 1; - const MAX_MATRIX_DIM: usize = 1024; + const MIN_MATRIX_DIM: u32 = 1; + const MAX_MATRIX_DIM: u32 = 1024; let mut rng = ChaCha8Rng::from_os_rng(); diff --git a/src/pir_internals/mod.rs b/src/pir_internals/mod.rs index 24566fd..edad805 100644 --- a/src/pir_internals/mod.rs +++ b/src/pir_internals/mod.rs @@ -4,3 +4,10 @@ pub mod error; pub mod matrix; pub mod params; pub mod serialization; + +#[cfg(feature = "gpu")] +pub mod gpu; +#[cfg(feature = "gpu")] +pub mod mat_transpose_shader; +#[cfg(feature = "gpu")] +pub mod mat_x_mat_shader; diff --git a/src/pir_internals/params.rs b/src/pir_internals/params.rs index f5087d9..f1c4514 100644 --- a/src/pir_internals/params.rs +++ b/src/pir_internals/params.rs @@ -1,5 +1,7 @@ +pub const LWE_DIMENSION: u32 = 1774; + pub const BIT_SECURITY_LEVEL: usize = 128; -pub const LWE_DIMENSION: usize = 1774; -pub const SEED_BYTE_LEN: usize = (2 * BIT_SECURITY_LEVEL) / 8; -pub const HASHED_KEY_BYTE_LEN: usize = (2 * BIT_SECURITY_LEVEL) / 8; +pub const SEED_BYTE_LEN: usize = (2 * BIT_SECURITY_LEVEL) / u8::BITS as usize; +pub const HASHED_KEY_BYTE_LEN: usize = (2 * BIT_SECURITY_LEVEL) / u8::BITS as usize; + pub const SERVER_SETUP_MAX_ATTEMPT_COUNT: usize = 100; diff --git a/src/pir_internals/serialization.rs b/src/pir_internals/serialization.rs index cd02aa9..2686fda 100644 --- a/src/pir_internals/serialization.rs +++ b/src/pir_internals/serialization.rs @@ -19,7 +19,7 @@ use turboshake::TurboShake128; /// /// A vector of 32-bit unsigned integers representing the encoded key-value pair. #[inline] -pub fn encode_kv_as_row(key: &[u8], value: &[u8], mat_elem_bit_len: usize, num_cols: usize) -> Vec { +pub fn encode_kv_as_row(key: &[u8], value: &[u8], mat_elem_bit_len: usize, num_cols: u32) -> Vec { let hashed_key = { let mut hasher = TurboShake128::default(); hasher.absorb(key); @@ -31,7 +31,7 @@ pub fn encode_kv_as_row(key: &[u8], value: &[u8], mat_elem_bit_len: usize, num_c hashed_key }; - let mut row = vec![0u32; num_cols]; + let mut row = vec![0u32; num_cols as usize]; let mut row_offset = 0; let mat_elem_mask = (1u64 << mat_elem_bit_len) - 1; @@ -268,8 +268,8 @@ mod test { hashed_key }; - let actual_encoded_kv_len = (hashed_key.len() * 8 + (value.len() + 1) * 8).div_ceil(mat_elem_bit_len); - let max_encoded_kv_len = (hashed_key.len() * 8 + (2 * value.len() + 1) * 8).div_ceil(mat_elem_bit_len); + let actual_encoded_kv_len = (hashed_key.len() * 8 + (value.len() + 1) * 8).div_ceil(mat_elem_bit_len) as u32; + let max_encoded_kv_len = (hashed_key.len() * 8 + (2 * value.len() + 1) * 8).div_ceil(mat_elem_bit_len) as u32; for encoded_kv_len in actual_encoded_kv_len..max_encoded_kv_len { let row = encode_kv_as_row(&key, &value, mat_elem_bit_len, encoded_kv_len); diff --git a/src/server.rs b/src/server.rs index faac70a..36b7205 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "gpu")] +use crate::pir_internals::gpu; use crate::{ ChalametPIRError, pir_internals::{ @@ -40,6 +42,7 @@ impl Server { /// # Returns /// /// A `Result` containing a tuple of the `Server` object, the serialized hint matrix bytes, and the serialized filter parameters bytes. Returns an error if any error occurs during setup. + #[cfg(not(feature = "gpu"))] pub fn setup(seed_μ: &[u8; SEED_BYTE_LEN], db: HashMap<&[u8], &[u8]>) -> Result<(Server, Vec, Vec), ChalametPIRError> { let db_num_kv_pairs = db.len(); if branch_opt_util::unlikely(db_num_kv_pairs == 0) { @@ -50,7 +53,7 @@ impl Server { let (parsed_db_mat_d, filter) = Matrix::from_kv_database::(db, mat_elem_bit_len, SERVER_SETUP_MAX_ATTEMPT_COUNT)?; let pub_mat_a_num_rows = LWE_DIMENSION; - let pub_mat_a_num_cols = filter.num_fingerprints; + let pub_mat_a_num_cols = filter.num_fingerprints as u32; let pub_mat_a = unsafe { Matrix::generate_from_seed(pub_mat_a_num_rows, pub_mat_a_num_cols, seed_μ).unwrap_unchecked() }; @@ -62,6 +65,88 @@ impl Server { Ok((Server { transposed_parsed_db_mat_d }, hint_bytes, filter_param_bytes)) } + /// Sets up the keyword **P**rivate **I**nformation **R**etrieval scheme's server with a given Key-Value database. + /// + /// This function takes a database as input and generates the necessary matrices and parameters for responding to client queries. + /// It involves several steps: + /// 1. **Database Validation:** The database must not be empty and should have at most 242 entries. Returns an error if validation fails. + /// 2. **Matrix Generation from Database:** Creates a `Matrix` (`parsed_db_mat_d`) representing the database. Uses the `Matrix::from_kv_database` function, which might involve multiple attempts (`SERVER_SETUP_MAX_ATTEMPT_COUNT`) to generate a suitable matrix. Returns an error if matrix generation fails. This also generates a `filter` object used in later stages of the PIR protocol. + /// 3. **Public Matrix Generation:** Generates a public matrix (`pub_mat_a`) using a provided seed (`seed_μ`). The dimensions of this matrix are determined by `LWE_DIMENSION` and the number of fingerprints in the `filter`. + /// 4. **Hint Matrix Calculation:** Computes the hint matrix (`hint_mat_m`) by multiplying the public matrix and the parsed database matrix. + /// 5. **Serialization:** Converts the hint matrix and filter parameters into byte vectors for storage and transmission. Returns an error if conversion fails. + /// 6. **Transposition:** Transposes the parsed database matrix (`parsed_db_mat_d`) to optimize memory access patterns during execution of the `respond` function. + /// + /// # Arguments + /// + /// * `seed_μ`: The seed used for generating the public matrix. + /// * `db`: The input database, represented as a hash map of key-value pairs. + /// + /// The constant parameter `ARITY` can be 3 or 4, denoting the use of a 3/4-wise XOR binary fuse filter. + /// This choice affects client/server computation and communication costs. + /// + /// # Returns + /// + /// A `Result` containing a tuple of the `Server` object, the serialized hint matrix bytes, and the serialized filter parameters bytes. Returns an error if any error occurs during setup. + #[cfg(feature = "gpu")] + pub fn setup(seed_μ: &[u8; SEED_BYTE_LEN], db: HashMap<&[u8], &[u8]>) -> Result<(Server, Vec, Vec), ChalametPIRError> { + let db_num_kv_pairs = db.len(); + if branch_opt_util::unlikely(db_num_kv_pairs == 0) { + return Err(ChalametPIRError::EmptyKVDatabase); + } + + let mat_elem_bit_len = Self::find_encoded_db_matrix_element_bit_length(db_num_kv_pairs)?; + let (parsed_db_mat_d, filter) = Matrix::from_kv_database::(db, mat_elem_bit_len, SERVER_SETUP_MAX_ATTEMPT_COUNT)?; + + let pub_mat_a_num_rows = LWE_DIMENSION; + let pub_mat_a_num_cols = filter.num_fingerprints as u32; + + let pub_mat_a = unsafe { Matrix::generate_from_seed(pub_mat_a_num_rows, pub_mat_a_num_cols, seed_μ).unwrap_unchecked() }; + + let (device, queue, mem_alloc, cmd_buf_alloc) = gpu::setup_gpu()?; + + let hint_mat_m_num_rows = pub_mat_a_num_rows; + let hint_mat_m_num_cols = parsed_db_mat_d.num_cols(); + let hint_mat_m_byte_len = (2 * std::mem::size_of::() + (hint_mat_m_num_rows * hint_mat_m_num_cols) as usize * std::mem::size_of::()) as u64; + let hint_mat_m_wg_count = [hint_mat_m_num_rows.div_ceil(8), hint_mat_m_num_cols.div_ceil(8), 1]; + + let parsed_db_mat_d_byte_len = parsed_db_mat_d.num_bytes() as u64; + let parsed_db_mat_d_wg_count = [parsed_db_mat_d.num_rows().div_ceil(8), parsed_db_mat_d.num_cols().div_ceil(8), 1]; + + let pub_mat_a_buf = gpu::transfer_mat_to_device(queue.clone(), mem_alloc.clone(), cmd_buf_alloc.clone(), pub_mat_a)?; + let parsed_db_mat_d_buf = gpu::transfer_mat_to_device(queue.clone(), mem_alloc.clone(), cmd_buf_alloc.clone(), parsed_db_mat_d.clone())?; + let hint_mat_m_buf = gpu::get_empty_host_readable_buffer(mem_alloc.clone(), hint_mat_m_byte_len)?; + let transposed_parsed_db_mat_d_buf = gpu::get_empty_host_readable_buffer(mem_alloc.clone(), parsed_db_mat_d_byte_len)?; + + gpu::mat_x_mat( + device.clone(), + queue.clone(), + cmd_buf_alloc.clone(), + pub_mat_a_buf, + parsed_db_mat_d_buf.clone(), + hint_mat_m_buf.clone(), + hint_mat_m_wg_count, + )?; + + gpu::mat_transpose( + device.clone(), + queue.clone(), + cmd_buf_alloc.clone(), + parsed_db_mat_d_buf, + transposed_parsed_db_mat_d_buf.clone(), + parsed_db_mat_d_wg_count, + )?; + + let transposed_parsed_db_mat_d = Matrix::from_bytes( + &transposed_parsed_db_mat_d_buf + .read() + .map_err(|_| ChalametPIRError::VulkanReadingFromBufferFailed)?, + )?; + let hint_bytes = hint_mat_m_buf.read().map_err(|_| ChalametPIRError::VulkanReadingFromBufferFailed)?.to_vec(); + let filter_param_bytes: Vec = filter.to_bytes(); + + Ok((Server { transposed_parsed_db_mat_d }, hint_bytes, filter_param_bytes)) + } + /// Responds to a client query. /// /// This function takes a client's query (in byte form) as input and uses the transposed database matrix to compute the response.