diff --git a/Cargo.toml b/Cargo.toml index 22fc81f..7fb2031 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,8 @@ repository = "https://github.com/fairmath/openfhe-rs" [dependencies] cxx = "1.0" +num-bigint = { version = "0.4", default-features = false } +num-traits = "0.2" [build-dependencies] cxx-build = "1.0" diff --git a/examples/dcrt_poly.rs b/examples/dcrt_poly.rs new file mode 100644 index 0000000..0accea5 --- /dev/null +++ b/examples/dcrt_poly.rs @@ -0,0 +1,109 @@ +use num_bigint::BigUint; +use num_traits::Num; +use openfhe::ffi::{self, GetMatrixElement}; + +fn main() { + let val = String::from("123456789099999"); + // Parameters based on https://github.com/openfheorg/openfhe-development/blob/7b8346f4eac27121543e36c17237b919e03ec058/src/core/unittest/UnitTestTrapdoor.cpp#L314 + let n: u32 = 16; + let size: usize = 4; // Number of CRT + let k_res: usize = 51; + + let const_poly = ffi::DCRTPolyGenFromConst(n, size, k_res, &val); + // print the const_poly + println!("const_poly: {:?}", const_poly); + + let modulus = ffi::GenModulus(n, size, k_res); + println!("modulus: {:?}", modulus); + + let const_poly_2 = ffi::DCRTPolyGenFromConst(n, size, k_res, &val); + // print the const_poly_2 + println!("const_poly_2: {:?}", const_poly_2); + + // assert that the two const_poly are equal + assert_eq!(const_poly, const_poly_2); + + let const_poly_one = ffi::DCRTPolyGenFromConst(n, size, k_res, &String::from("1")); + let negated_poly_one = const_poly_one.Negate(); + println!("negated_poly_one: {:?}", negated_poly_one); + + let coeffs = vec![ + String::from("123456789099999"), + String::from("1234567842539099999"), + String::from("31232189328123893128912983"), + String::from("24535423544252452453"), + ]; + + let poly = ffi::DCRTPolyGenFromVec(n, size, k_res, &coeffs); + let poly_2 = ffi::DCRTPolyGenFromVec(n, size, k_res, &coeffs); + + // assert that the two poly are equal + assert_eq!(poly, poly_2); + + // perform polynomial addition + let poly_add = ffi::DCRTPolyAdd(&poly, &poly_2); + println!("poly_add: {:?}", poly_add); + + // perform polynomial multiplication + let poly_mul = ffi::DCRTPolyMul(&poly, &poly_2); + println!("poly_mul: {:?}", poly_mul); + + // get the coefficients of the polynomials + let coeffs_poly = poly.GetCoefficients(); + println!("coeffs_poly: {:?}", coeffs_poly); + let coeffs_poly_2 = poly_2.GetCoefficients(); + println!("coeffs_poly_2: {:?}", coeffs_poly_2); + let coeffs_poly_add = poly_add.GetCoefficients(); + println!("coeffs_poly_add: {:?}", coeffs_poly_add); + + let poly_modulus = poly.GetModulus(); + assert_eq!(poly_modulus, modulus); + + let sigma = 4.57825; + let base = 2; + let modulus_big_uint = BigUint::from_str_radix(&modulus, 10).unwrap(); + let k = modulus_big_uint.bits() as u32; + + // ** gen trapdoor ** + let trapdoor_output = ffi::DCRTTrapdoorGen(n, size, k_res, sigma, base, false); + let trapdoor = trapdoor_output.GetTrapdoorPair(); + let public_matrix = trapdoor_output.GetPublicMatrix(); + + // sample a target polynomial + let u = ffi::DCRTPolyGenFromDug(n, size, k_res); + + // generate a preimage such that public_matrix * preimage = target_polynomial + let _preimage = ffi::DCRTTrapdoorGaussSamp(n, k, &public_matrix, &trapdoor, &u, base, sigma); + + // ** gen trapdoor for a square matrix target of size 2x2 ** + let d = 2; + let trapdoor_output_square = + ffi::DCRTSquareMatTrapdoorGen(n, size, k_res, d, sigma, base, false); + + let trapdoor_square = trapdoor_output_square.GetTrapdoorPair(); + let public_matrix_square = trapdoor_output_square.GetPublicMatrix(); + + // build the target matrix by sampling a random polynomial for each element + let mut target_matrix = ffi::MatrixGen(n, size, k_res, d, d); + for i in 0..d { + for j in 0..d { + let poly = ffi::DCRTPolyGenFromDug(n, size, k_res); + ffi::SetMatrixElement(target_matrix.as_mut().unwrap(), i, j, &poly); + } + } + + // generate a preimage such that public_matrix_square * preimage = target_matrix + let _preimage_square = ffi::DCRTSquareMatTrapdoorGaussSamp( + n, + k, + &public_matrix_square, + &trapdoor_square, + &target_matrix, + base, + sigma, + ); + + let dummy_poly = ffi::DCRTPolyGenFromDug(n, size, k_res); + let decomposed_poly = dummy_poly.Decompose(); + let poly_0_0 = GetMatrixElement(&decomposed_poly, 0, 0); +} diff --git a/examples/trapdoor.rs b/examples/trapdoor.rs deleted file mode 100644 index bbfb06a..0000000 --- a/examples/trapdoor.rs +++ /dev/null @@ -1,19 +0,0 @@ -use openfhe::ffi; - -fn main() { - // Parameters based on https://github.com/openfheorg/openfhe-development/blob/7b8346f4eac27121543e36c17237b919e03ec058/src/core/unittest/UnitTestTrapdoor.cpp#L314 - let n: u32 = 16; - let size: u32 = 4; - let k_res: u32 = 51; - let base: i64 = 8; - - let params = ffi::GenILDCRTParamsByOrderSizeBits(2 * n, size, k_res); - - let u = ffi::DCRTPolyGenFromDug(¶ms); - - let trapdoor = ffi::DCRTPolyTrapdoorGen(¶ms, base, false); - - let k = 68; // to calculate - - let _res = ffi::DCRTPolyGaussSamp(n.try_into().unwrap(), k, &trapdoor, &u, base); -} diff --git a/src/DCRTPoly.cc b/src/DCRTPoly.cc index 79fbbde..474d34d 100644 --- a/src/DCRTPoly.cc +++ b/src/DCRTPoly.cc @@ -1,5 +1,5 @@ #include "DCRTPoly.h" -#include "openfhe/src/lib.rs.h" +#include namespace openfhe { @@ -8,6 +8,157 @@ DCRTPoly::DCRTPoly(lbcrypto::DCRTPoly&& poly) noexcept : m_poly(std::move(poly)) { } +const lbcrypto::DCRTPoly& DCRTPoly::GetPoly() const noexcept +{ + return m_poly; +} + + +rust::String DCRTPoly::GetString() const +{ + std::stringstream stream; + stream << m_poly; + return rust::String(stream.str()); +} + +rust::String DCRTPoly::GetModulus() const +{ + return m_poly.GetModulus().ToString(); +} + +bool DCRTPoly::IsEqual(const DCRTPoly& other) const noexcept +{ + return m_poly == other.m_poly; +} + +rust::Vec DCRTPoly::GetCoefficients() const +{ + auto tempPoly = m_poly; + tempPoly.SetFormat(Format::COEFFICIENT); + + lbcrypto::DCRTPoly::PolyLargeType polyLarge = tempPoly.CRTInterpolate(); + + const lbcrypto::BigVector &coeffs = polyLarge.GetValues(); + + rust::Vec result; + for (size_t i = 0; i < coeffs.GetLength(); ++i) + { + result.push_back(rust::String(coeffs[i].ToString())); + } + + return result; +} + +std::unique_ptr DCRTPoly::Negate() const +{ + return std::make_unique(-m_poly); +} + +std::unique_ptr DCRTPoly::Decompose() const +{ + std::vector decomposed = m_poly.CRTDecompose(1); + + auto zero_alloc = lbcrypto::DCRTPoly::Allocator(m_poly.GetParams(), Format::COEFFICIENT); + lbcrypto::Matrix decomposedMatrix(zero_alloc, 1, decomposed.size()); + + for (size_t i = 0; i < decomposed.size(); i++) { + decomposedMatrix(0, i) = decomposed[i]; + } + + return std::make_unique(std::move(decomposedMatrix)); +} + +// Arithmetic +std::unique_ptr DCRTPolyAdd(const DCRTPoly& rhs, const DCRTPoly& lhs) +{ + return std::make_unique(rhs.GetPoly() + lhs.GetPoly()); +} + +std::unique_ptr DCRTPolyMul(const DCRTPoly& rhs, const DCRTPoly& lhs) +{ + return std::make_unique(rhs.GetPoly() * lhs.GetPoly()); +} + +// Generator functions +std::unique_ptr DCRTPolyGenFromConst( + usint n, + size_t size, + size_t kRes, + const rust::String& value) +{ + // Create params + auto params = std::make_shared>(2 * n, size, kRes); + + // Create a BigVector + lbcrypto::BigVector bigVec(params -> GetRingDimension(), params ->GetModulus()); + bigVec[0] = lbcrypto::BigInteger(std::string(value)); + + // Create a Poly that supports BigInteger coefficients) + lbcrypto::PolyImpl polyLarge(params, Format::COEFFICIENT); + polyLarge.SetValues(bigVec, Format::COEFFICIENT); + + // Convert polyLarge to a DCRTPoly + lbcrypto::DCRTPoly dcrtPoly(polyLarge, params); + + // switch dcrtPoly to EVALUATION format + dcrtPoly.SetFormat(Format::EVALUATION); + + return std::make_unique(std::move(dcrtPoly)); +} + +std::unique_ptr DCRTPolyGenFromVec( + usint n, + size_t size, + size_t kRes, + const rust::Vec& values) +{ + // Create params + auto params = std::make_shared>(2 * n, size, kRes); + + // Create a BigVector + lbcrypto::BigVector bigVec(params->GetRingDimension(), params->GetModulus()); + for (size_t i = 0; i < values.size() && i < params->GetRingDimension(); i++) { + bigVec[i] = lbcrypto::BigInteger(std::string(values[i])); + } + + // Create a Poly that supports BigInteger coefficients + lbcrypto::PolyImpl polyLarge(params, Format::COEFFICIENT); + polyLarge.SetValues(bigVec, Format::COEFFICIENT); + + // Convert polyLarge to a DCRTPoly + lbcrypto::DCRTPoly dcrtPoly(polyLarge, params); + + // switch dcrtPoly to EVALUATION format + dcrtPoly.SetFormat(Format::EVALUATION); + + return std::make_unique(std::move(dcrtPoly)); +} + +std::unique_ptr DCRTPolyGenFromBug(usint n, size_t size, size_t kRes) +{ + auto params = std::make_shared>(2 * n, size, kRes); + typename lbcrypto::DCRTPoly::BugType bug; + auto poly = lbcrypto::DCRTPoly(bug, params, Format::EVALUATION); + return std::make_unique(std::move(poly)); +} + +std::unique_ptr DCRTPolyGenFromDug(usint n, size_t size, size_t kRes) +{ + auto params = std::make_shared>(2 * n, size, kRes); + typename lbcrypto::DCRTPoly::DugType dug; + auto poly = lbcrypto::DCRTPoly(dug, params, Format::EVALUATION); + return std::make_unique(std::move(poly)); +} + +std::unique_ptr DCRTPolyGenFromDgg(usint n, size_t size, size_t kRes, double sigma) +{ + auto params = std::make_shared>(2 * n, size, kRes); + typename lbcrypto::DCRTPoly::DggType dgg(sigma); + auto poly = lbcrypto::DCRTPoly(dgg, params, Format::EVALUATION); + return std::make_unique(std::move(poly)); +} + + DCRTPolyParams::DCRTPolyParams(const std::shared_ptr& params) noexcept : m_params(params) { } @@ -22,11 +173,36 @@ std::unique_ptr DCRTPolyGenNullParams() return std::make_unique(); } -std::unique_ptr DCRTPolyGenFromDug(const ILDCRTParams& params) +// Matrix functions +std::unique_ptr MatrixGen( + usint n, + size_t size, + size_t kRes, + size_t nrow, + size_t ncol) +{ + auto params = std::make_shared>(2 * n, size, kRes); + auto zero_alloc = lbcrypto::DCRTPoly::Allocator(params, Format::EVALUATION); + Matrix matrix(zero_alloc, nrow, ncol); + return std::make_unique(std::move(matrix)); +} + +void SetMatrixElement( + Matrix& matrix, + size_t row, + size_t col, + const DCRTPoly& element) +{ + matrix(row, col) = element.GetPoly(); +} + +std::unique_ptr GetMatrixElement( + const Matrix& matrix, + size_t row, + size_t col) { - std::shared_ptr params_ptr = std::make_shared(params); - typename DCRTPolyImpl::DugType dug; - return std::make_unique(dug, params_ptr, Format::EVALUATION); + lbcrypto::DCRTPoly copy = matrix(row, col); + return std::make_unique(std::move(copy)); } } // openfhe diff --git a/src/DCRTPoly.h b/src/DCRTPoly.h index 4923b18..13ad703 100644 --- a/src/DCRTPoly.h +++ b/src/DCRTPoly.h @@ -1,13 +1,14 @@ #pragma once #include "openfhe/core/lattice/hal/lat-backend.h" -#include "openfhe/core/lattice/hal/default/dcrtpoly.h" -#include "openfhe/core/math/math-hal.h" -#include "Params.h" +#include "rust/cxx.h" +#include "openfhe/core/math/matrix.h" namespace openfhe { +using Matrix = lbcrypto::Matrix; + class DCRTPoly final { lbcrypto::DCRTPoly m_poly; @@ -17,8 +18,39 @@ class DCRTPoly final DCRTPoly(DCRTPoly&&) = delete; DCRTPoly& operator=(const DCRTPoly&) = delete; DCRTPoly& operator=(DCRTPoly&&) = delete; + + [[nodiscard]] const lbcrypto::DCRTPoly& GetPoly() const noexcept; + [[nodiscard]] rust::String GetString() const; + [[nodiscard]] bool IsEqual(const DCRTPoly& other) const noexcept; + [[nodiscard]] rust::Vec GetCoefficients() const; + [[nodiscard]] rust::String GetModulus() const; + [[nodiscard]] std::unique_ptr Negate() const; + [[nodiscard]] std::unique_ptr Decompose() const; }; +// Generator functions +[[nodiscard]] std::unique_ptr DCRTPolyGenFromConst( + usint n, + size_t size, + size_t kRes, + const rust::String& value +); + +[[nodiscard]] std::unique_ptr DCRTPolyGenFromVec( + usint n, + size_t size, + size_t kRes, + const rust::Vec& values +); + +[[nodiscard]] std::unique_ptr DCRTPolyGenFromBug(usint n, size_t size, size_t kRes); +[[nodiscard]] std::unique_ptr DCRTPolyGenFromDug(usint n, size_t size, size_t kRes); +[[nodiscard]] std::unique_ptr DCRTPolyGenFromDgg(usint n, size_t size, size_t kRes, double sigma); + +// Arithmetic +[[nodiscard]] std::unique_ptr DCRTPolyAdd(const DCRTPoly& rhs, const DCRTPoly& lhs); +[[nodiscard]] std::unique_ptr DCRTPolyMul(const DCRTPoly& rhs, const DCRTPoly& lhs); + class DCRTPolyParams final { std::shared_ptr m_params; @@ -36,9 +68,22 @@ class DCRTPolyParams final // Generator functions [[nodiscard]] std::unique_ptr DCRTPolyGenNullParams(); -using DCRTPolyImpl = lbcrypto::DCRTPolyImpl; +// Matrix functions +[[nodiscard]] std::unique_ptr MatrixGen( + usint n, + size_t size, + size_t kRes, + size_t nrow, + size_t ncol); -// Generator functions -[[nodiscard]] std::unique_ptr DCRTPolyGenFromDug(const ILDCRTParams& params); +void SetMatrixElement( + Matrix& matrix, + size_t row, + size_t col, + const DCRTPoly& element); +[[nodiscard]] std::unique_ptr GetMatrixElement( + const Matrix& matrix, + size_t row, + size_t col); } // openfhe diff --git a/src/Params.cc b/src/Params.cc index a8061b9..25e4665 100644 --- a/src/Params.cc +++ b/src/Params.cc @@ -1,4 +1,5 @@ #include "Params.h" +#include "DCRTPoly.h" namespace openfhe { @@ -36,9 +37,10 @@ std::unique_ptr GenParamsCKKSRNSbyVectorOfString( { return std::make_unique(vals); } -std::unique_ptr GenILDCRTParamsByOrderSizeBits( - uint32_t corder, uint32_t depth, uint32_t bits) -{ - return std::make_unique(corder, depth, bits); +rust::String GenModulus( + usint n, size_t size, size_t kRes) +{ + auto params = std::make_shared>(2 * n, size, kRes); + return rust::String(params->GetModulus().ToString()); } } // openfhe diff --git a/src/Params.h b/src/Params.h index f642947..432655f 100644 --- a/src/Params.h +++ b/src/Params.h @@ -5,8 +5,8 @@ #include "openfhe/pke/scheme/bfvrns/gen-cryptocontext-bfvrns-params.h" #include "openfhe/pke/scheme/bgvrns/gen-cryptocontext-bgvrns-params.h" #include "openfhe/pke/scheme/ckksrns/gen-cryptocontext-ckksrns-params.h" -#include "openfhe/core/lattice/hal/default/ildcrtparams.h" -#include "openfhe/core/math/math-hal.h" + +#include "rust/cxx.h" #include @@ -30,7 +30,6 @@ using Params = lbcrypto::Params; using ParamsBFVRNS = lbcrypto::CCParams; using ParamsBGVRNS = lbcrypto::CCParams; using ParamsCKKSRNS = lbcrypto::CCParams; -using ILDCRTParams = lbcrypto::ILDCRTParams; // Generator functions [[nodiscard]] std::unique_ptr GenParamsBFVRNS(); @@ -45,6 +44,6 @@ using ILDCRTParams = lbcrypto::ILDCRTParams; [[nodiscard]] std::unique_ptr GenParamsCKKSRNS(); [[nodiscard]] std::unique_ptr GenParamsCKKSRNSbyVectorOfString( const std::vector& vals); -[[nodiscard]] std::unique_ptr GenILDCRTParamsByOrderSizeBits( - uint32_t corder, uint32_t depth, uint32_t bits); +[[nodiscard]] rust::String GenModulus( + usint n, size_t size, size_t kRes); } // openfhe diff --git a/src/Trapdoor.cc b/src/Trapdoor.cc index a1512bb..07948a7 100644 --- a/src/Trapdoor.cc +++ b/src/Trapdoor.cc @@ -1,45 +1,119 @@ #include "Trapdoor.h" -#include "openfhe/core/lattice/trapdoor.h" -#include "openfhe/core/lattice/dgsampling.h" +#include "Params.h" namespace openfhe { -std::unique_ptr DCRTPolyTrapdoorGen( - const ILDCRTParams& params, +DCRTTrapdoor::DCRTTrapdoor(Matrix&& publicMatrix, RLWETrapdoorPair&& trapdoorPair) noexcept + : m_publicMatrix(std::move(publicMatrix)), m_trapdoorPair(std::move(trapdoorPair)) +{ } + +std::unique_ptr DCRTTrapdoor::GetTrapdoorPair() const +{ + return std::make_unique(m_trapdoorPair); +} + +std::unique_ptr DCRTTrapdoor::GetPublicMatrix() const +{ + return std::make_unique(m_publicMatrix); +} + +std::unique_ptr DCRTTrapdoor::GetPublicMatrixElement(size_t row, size_t col) const +{ + if (row >= m_publicMatrix.GetRows() || col >= m_publicMatrix.GetCols()) { + return nullptr; + } + + lbcrypto::DCRTPoly copy = m_publicMatrix(row, col); + return std::make_unique(std::move(copy)); +} + +// Generator functions +std::unique_ptr DCRTTrapdoorGen( + usint n, + size_t size, + size_t kRes, + double sigma, int64_t base, bool balanced) { + auto params = std::make_shared>(2 * n, size, kRes); + + auto trapdoor = lbcrypto::RLWETrapdoorUtility::TrapdoorGen( + params, + sigma, + base, + balanced + ); - std::shared_ptr params_ptr = std::make_shared(params); + return std::make_unique( + std::move(trapdoor.first), + std::move(trapdoor.second) + ); +} - auto result = lbcrypto::RLWETrapdoorUtility::TrapdoorGen( - params_ptr, - lbcrypto::SIGMA, +std::unique_ptr DCRTSquareMatTrapdoorGen( + usint n, + size_t size, + size_t kRes, + size_t d, + double sigma, + int64_t base, + bool balanced) +{ + auto params = std::make_shared>(2 * n, size, kRes); + + auto trapdoor = lbcrypto::RLWETrapdoorUtility::TrapdoorGenSquareMat( + params, + sigma, + d, base, balanced ); - - return std::make_unique(TrapdoorOutput{ - std::move(result.first), - std::move(result.second) - }); + + return std::make_unique( + std::move(trapdoor.first), + std::move(trapdoor.second) + ); } -std::unique_ptr DCRTPolyGaussSamp(size_t n, size_t k, const TrapdoorOutput& trapdoor, const DCRTPolyImpl& u, int64_t base) +// Gauss sample functions +std::unique_ptr DCRTTrapdoorGaussSamp(usint n, usint k, const Matrix& publicMatrix, const RLWETrapdoorPair& trapdoor, const DCRTPoly& u, int64_t base, double sigma) { - DCRTPolyImpl::DggType dgg(lbcrypto::SIGMA); + lbcrypto::DCRTPoly::DggType dgg(sigma); - double c = (base + 1) * lbcrypto::SIGMA; + double c = (base + 1) * sigma; double s = lbcrypto::SPECTRAL_BOUND(n, k, base); - DCRTPolyImpl::DggType dggLargeSigma(sqrt(s * s - c * c)); + lbcrypto::DCRTPoly::DggType dggLargeSigma(sqrt(s * s - c * c)); auto result = lbcrypto::RLWETrapdoorUtility::GaussSamp( n, k, - trapdoor.m, - trapdoor.tp, - u, + publicMatrix, + trapdoor, + u.GetPoly(), + dgg, + dggLargeSigma, + base + ); + + return std::make_unique(std::move(result)); +} + +std::unique_ptr DCRTSquareMatTrapdoorGaussSamp(usint n, usint k, const Matrix& publicMatrix, const RLWETrapdoorPair& trapdoor, const Matrix& U, int64_t base, double sigma) +{ + lbcrypto::DCRTPoly::DggType dgg(sigma); + + double c = (base + 1) * sigma; + double s = lbcrypto::SPECTRAL_BOUND(n, k, base); + lbcrypto::DCRTPoly::DggType dggLargeSigma(sqrt(s * s - c * c)); + + auto result = lbcrypto::RLWETrapdoorUtility::GaussSampSquareMat( + n, + k, + publicMatrix, + trapdoor, + U, dgg, dggLargeSigma, base @@ -47,4 +121,4 @@ std::unique_ptr DCRTPolyGaussSamp(size_t n, size_t k, const TrapdoorOutp return std::make_unique(std::move(result)); } -} // openfhe \ No newline at end of file +} // openfhe \ No newline at end of file diff --git a/src/Trapdoor.h b/src/Trapdoor.h index 9201856..2a50fea 100644 --- a/src/Trapdoor.h +++ b/src/Trapdoor.h @@ -1,29 +1,63 @@ #pragma once -#include "Params.h" -#include "openfhe/core/lattice/hal/lat-backend.h" -#include "openfhe/core/math/matrix.h" #include "openfhe/core/lattice/trapdoor.h" #include "DCRTPoly.h" namespace openfhe { -using Matrix = lbcrypto::Matrix; using RLWETrapdoorPair = lbcrypto::RLWETrapdoorPair; -struct TrapdoorOutput +class DCRTTrapdoor final { - Matrix m; - RLWETrapdoorPair tp; -}; + Matrix m_publicMatrix; + RLWETrapdoorPair m_trapdoorPair; +public: + DCRTTrapdoor() = default; + DCRTTrapdoor(Matrix&& publicMatrix, RLWETrapdoorPair&& trapdoorPair) noexcept; + DCRTTrapdoor(const DCRTTrapdoor&) = delete; + DCRTTrapdoor(DCRTTrapdoor&&) = delete; + DCRTTrapdoor& operator=(const DCRTTrapdoor&) = delete; + DCRTTrapdoor& operator=(DCRTTrapdoor&&) = delete; + [[nodiscard]] std::unique_ptr GetTrapdoorPair() const; + [[nodiscard]] std::unique_ptr GetPublicMatrix() const; + [[nodiscard]] std::unique_ptr GetPublicMatrixElement(size_t row, size_t col) const; +}; // Generator functions -[[nodiscard]] std::unique_ptr DCRTPolyTrapdoorGen( - const ILDCRTParams& params, +[[nodiscard]] std::unique_ptr DCRTTrapdoorGen( + usint n, + size_t size, + size_t kRes, + double sigma, + int64_t base, + bool balanced); + +[[nodiscard]] std::unique_ptr DCRTSquareMatTrapdoorGen( + usint n, + size_t size, + size_t kRes, + size_t d, + double sigma, int64_t base, bool balanced); -[[nodiscard]] std::unique_ptr DCRTPolyGaussSamp(size_t n, size_t k, const TrapdoorOutput& trapdoor, const DCRTPolyImpl& u, int64_t base); - +// Gauss sample functions +[[nodiscard]] std::unique_ptr DCRTTrapdoorGaussSamp( + usint n, + usint k, + const Matrix& publicMatrix, + const RLWETrapdoorPair& trapdoor, + const DCRTPoly& u, + int64_t base, + double sigma); + +[[nodiscard]] std::unique_ptr DCRTSquareMatTrapdoorGaussSamp( + usint n, + usint k, + const Matrix& publicMatrix, + const RLWETrapdoorPair& trapdoor, + const Matrix& U, + int64_t base, + double sigma); } // openfhe \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 03cbf81..88abd6e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -198,13 +198,12 @@ pub mod ffi type CryptoContextDCRTPoly; type CryptoParametersBaseDCRTPoly; type DCRTPoly; - type DCRTPolyImpl; type DCRTPolyParams; + type DCRTTrapdoor; type DecryptResult; type EncodingParams; type EvalKeyDCRTPoly; type KeyPairDCRTPoly; - type ILDCRTParams; type LWEPrivateKey; type MapFromIndexToEvalKey; type MapFromStringToMapFromIndexToEvalKey; @@ -217,6 +216,7 @@ pub mod ffi type Plaintext; type PrivateKeyDCRTPoly; type PublicKeyDCRTPoly; + type RLWETrapdoorPair; type SchemeBaseDCRTPoly; type SetOfUints; type UnorderedMapFromIndexToDCRTPoly; @@ -226,7 +226,6 @@ pub mod ffi type VectorOfLWECiphertexts; type VectorOfPrivateKeys; type VectorOfVectorOfCiphertexts; - type TrapdoorOutput; } // CiphertextDCRTPoly @@ -724,6 +723,30 @@ pub mod ffi fn DCRTPolyGenNullCryptoContext() -> UniquePtr; } + // DCRTPoly + unsafe extern "C++" + { + fn GetString(self: &DCRTPoly) -> String; + fn IsEqual(self: &DCRTPoly, other: &DCRTPoly) -> bool; + fn GetCoefficients(self: &DCRTPoly) -> Vec; + fn GetModulus(self: &DCRTPoly) -> String; + fn Negate(self: &DCRTPoly) -> UniquePtr; + fn Decompose(self: &DCRTPoly) -> UniquePtr; + + // Generator functions + fn DCRTPolyGenFromConst(n: u32, size: usize, k_res: usize, value: &String) + -> UniquePtr; + fn DCRTPolyGenFromVec(n: u32, size: usize, k_res: usize, values: &Vec) + -> UniquePtr; + fn DCRTPolyGenFromBug(n: u32, size: usize, k_res: usize) -> UniquePtr; + fn DCRTPolyGenFromDug(n: u32, size: usize, k_res: usize) -> UniquePtr; + fn DCRTPolyGenFromDgg(n: u32, size: usize, k_res: usize, sigma: f64) -> UniquePtr; + + // Arithmetic + fn DCRTPolyAdd(rhs: &DCRTPoly, lhs: &DCRTPoly) -> UniquePtr; + fn DCRTPolyMul(rhs: &DCRTPoly, lhs: &DCRTPoly) -> UniquePtr; + } + // DCRTPolyParams unsafe extern "C++" { @@ -731,10 +754,12 @@ pub mod ffi fn DCRTPolyGenNullParams() -> UniquePtr; } - // DCRTPolyImpl + // Matrix unsafe extern "C++" { - fn DCRTPolyGenFromDug(params: &ILDCRTParams) -> UniquePtr; + fn MatrixGen(n: u32, size: usize, k_res: usize, nrow: usize, ncol: usize) -> UniquePtr; + fn SetMatrixElement(matrix: Pin<&mut Matrix>, row: usize, col: usize, element: &DCRTPoly); + fn GetMatrixElement(matrix: &Matrix, row: usize, col: usize) -> UniquePtr; } // KeyPairDCRTPoly @@ -744,12 +769,6 @@ pub mod ffi fn GetPublicKey(self: &KeyPairDCRTPoly) -> UniquePtr; } - // ILDCRTParams - unsafe extern "C++" - { - fn GenILDCRTParamsByOrderSizeBits(corder: u32, depth: u32, bits: u32) -> UniquePtr; - } - // Params unsafe extern "C++" { @@ -818,6 +837,7 @@ pub mod ffi // Generator functions fn GenParamsByScheme(scheme: SCHEME) -> UniquePtr; fn GenParamsByVectorOfString(vals: &CxxVector) -> UniquePtr; + fn GenModulus(n: u32, size: usize, k_res: usize) -> String; } // ParamsBFVRNS @@ -1145,11 +1165,71 @@ pub mod ffi } // Trapdoor - unsafe extern "C++" - { + unsafe extern "C++" { + fn GetPublicMatrix(self: &DCRTTrapdoor) -> UniquePtr; + fn GetPublicMatrixElement( + self: &DCRTTrapdoor, + row: usize, + col: usize, + ) -> UniquePtr; + fn GetTrapdoorPair(self: &DCRTTrapdoor) -> UniquePtr; + // Generator functions - fn DCRTPolyTrapdoorGen(params: &ILDCRTParams, base: i64, balanced: bool) -> UniquePtr; - fn DCRTPolyGaussSamp(n: usize, k: usize, trapdoor: &TrapdoorOutput, u: &DCRTPolyImpl, base: i64) -> UniquePtr; + fn DCRTTrapdoorGen( + n: u32, + size: usize, + k_res: usize, + sigma: f64, + base: i64, + balanced: bool, + ) -> UniquePtr; + + fn DCRTSquareMatTrapdoorGen( + n: u32, + size: usize, + k_res: usize, + d: usize, + sigma: f64, + base: i64, + balanced: bool, + ) -> UniquePtr; + + // Gauss sample functions + fn DCRTTrapdoorGaussSamp( + n: u32, + k: u32, + public_matrix: &Matrix, + trapdoor: &RLWETrapdoorPair, + u: &DCRTPoly, + base: i64, + sigma: f64, + ) -> UniquePtr; + + fn DCRTSquareMatTrapdoorGaussSamp( + n: u32, + k: u32, + public_matrix: &Matrix, + trapdoor: &RLWETrapdoorPair, + u: &Matrix, + base: i64, + sigma: f64, + ) -> UniquePtr; + } +} + + +use crate::ffi::DCRTPoly; +use std::fmt; + +impl fmt::Debug for DCRTPoly { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.GetString()) + } +} + +impl PartialEq for DCRTPoly { + fn eq(&self, other: &Self) -> bool { + self.IsEqual(other) } }