diff --git a/api/CCParams.cpp b/api/CCParams.cpp index 95b00779..aee90286 100644 --- a/api/CCParams.cpp +++ b/api/CCParams.cpp @@ -87,6 +87,13 @@ void CCParams::SetSecurityLevel(SecurityLevel level) { params.SetSecurityLevel(sl_openfhe); } +void CCParams::SetCKKSDataType(CKKSDataType ckksdt) { + auto& params = std::any_cast&>(cpu); + auto dt_openfhe = static_cast(ckksdt); + assert((int)dt_openfhe == (int)ckksdt); + params.SetCKKSDataType(dt_openfhe); +} + // ---- Getters ---- SecretKeyDist CCParams::GetSecretKeyDist() const { diff --git a/api/CCParams.hpp b/api/CCParams.hpp index a78559ce..ed4f9a87 100644 --- a/api/CCParams.hpp +++ b/api/CCParams.hpp @@ -49,6 +49,7 @@ template <> class CCParams { void SetDevices(std::vector&& devices); void SetPlaintextAutoload(bool autoload); void SetCiphertextAutoload(bool autoload); + void SetCKKSDataType(CKKSDataType ckksdt); // ---- Getters ---- SecretKeyDist GetSecretKeyDist() const; diff --git a/api/CryptoContext.cpp b/api/CryptoContext.cpp index 0d31263c..8d5aad06 100644 --- a/api/CryptoContext.cpp +++ b/api/CryptoContext.cpp @@ -747,6 +747,32 @@ Ciphertext CryptoContextImpl::EvalAdd(double scalar, const C return EvalAdd(ct, scalar); } +Ciphertext CryptoContextImpl::EvalAdd(const Ciphertext& ct, std::complex scalar) { + + // Fall back to CPU. + if (this->devices.empty()) { + auto& context = std::any_cast&>(this->cpu); + auto& ctImpl = std::any_cast&>(ct->cpu); + auto ct = context->EvalAdd(ctImpl, scalar); + Ciphertext ciphertext = std::make_shared>(this->self_reference.lock()); + ciphertext->cpu = std::make_any>(ct); + return ciphertext; + } + + // GPU path. + this->LoadCiphertext(const_cast&>(ct)); + + Ciphertext result = std::make_shared>(*ct); + auto res_gpu = std::static_pointer_cast(this->GetDeviceCiphertext(result->gpu)); + res_gpu->addScalar(scalar); + + return result; +} + +Ciphertext CryptoContextImpl::EvalAdd(std::complex scalar, const Ciphertext& ct) { + return EvalAdd(ct, scalar); +} + void CryptoContextImpl::EvalAddInPlace(Ciphertext& ct1, const Ciphertext& ct2) { // Fall back to CPU. @@ -815,6 +841,28 @@ void CryptoContextImpl::EvalAddInPlace(double scalar, Ciphertext::EvalAddInPlace(Ciphertext& ct1, std::complex scalar) { + + // Fall back to CPU. + if (this->devices.empty()) { + + auto& context = std::any_cast&>(this->cpu); + auto& ct1Impl = std::any_cast&>(ct1->cpu); + context->EvalAddInPlace(ct1Impl, scalar); + return; + } + + // GPU path. + this->LoadCiphertext(ct1); + + auto res_gpu = std::static_pointer_cast(this->GetDeviceCiphertext(ct1->gpu)); + res_gpu->addScalar(scalar); +} + +void CryptoContextImpl::EvalAddInPlace(std::complex scalar, Ciphertext& ct1) { + EvalAddInPlace(ct1, scalar); +} + Ciphertext CryptoContextImpl::EvalAddMutable(Ciphertext& ct1, Ciphertext& ct2) { return EvalAdd(ct1, ct2); } @@ -1198,6 +1246,33 @@ Ciphertext CryptoContextImpl::EvalMult(double scalar, const return EvalMult(ct1, scalar); } +Ciphertext CryptoContextImpl::EvalMult(const Ciphertext& ct1, std::complex scalar) { + + // Fall back to CPU. + if (this->devices.empty()) { + + auto& context = std::any_cast&>(this->cpu); + auto& ct1Impl = std::any_cast&>(ct1->cpu); + auto ct = context->EvalMult(ct1Impl, scalar); + Ciphertext ciphertext = std::make_shared>(this->self_reference.lock()); + ciphertext->cpu = std::make_any>(ct); + return ciphertext; + } + + // GPU path. + this->LoadCiphertext(const_cast&>(ct1)); + + Ciphertext result = std::make_shared>(*ct1); + auto res_gpu = std::static_pointer_cast(this->GetDeviceCiphertext(result->gpu)); + res_gpu->multScalar(scalar); + + return result; +} + +Ciphertext CryptoContextImpl::EvalMult(std::complex scalar, const Ciphertext& ct1) { + return EvalMult(ct1, scalar); +} + void CryptoContextImpl::EvalMultInPlace(Ciphertext& ct1, Plaintext& pt) { if (this->devices.empty()) { @@ -1241,6 +1316,28 @@ void CryptoContextImpl::EvalMultInPlace(double scalar, Ciphertext::EvalMultInPlace(Ciphertext& ct1, std::complex scalar) { + + // Fall back to CPU. + if (this->devices.empty()) { + + auto& context = std::any_cast&>(this->cpu); + auto& ct1Impl = std::any_cast&>(ct1->cpu); + context->EvalMultInPlace(ct1Impl, scalar); + return; + } + + // GPU path. + this->LoadCiphertext(ct1); + + auto res_gpu = std::static_pointer_cast(this->GetDeviceCiphertext(ct1->gpu)); + res_gpu->multScalar(scalar); +} + +void CryptoContextImpl::EvalMultInPlace(std::complex scalar, Ciphertext& ct1) { + EvalMultInPlace(ct1, scalar); +} + Ciphertext CryptoContextImpl::EvalMultMutable(Ciphertext& ct1, Ciphertext& ct2) { return EvalMult(ct1, ct2); } diff --git a/api/CryptoContext.hpp b/api/CryptoContext.hpp index 5f8f6723..51910bf6 100644 --- a/api/CryptoContext.hpp +++ b/api/CryptoContext.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "CCParams.hpp" #include "Ciphertext.hpp" @@ -115,11 +116,15 @@ template <> class CryptoContextImpl { Ciphertext EvalAdd(Plaintext& pt, const Ciphertext& ct); Ciphertext EvalAdd(const Ciphertext& ct, double scalar); Ciphertext EvalAdd(double scalar, const Ciphertext& ct); + Ciphertext EvalAdd(const Ciphertext& ct, std::complex scalar); + Ciphertext EvalAdd(std::complex scalar, const Ciphertext& ct); void EvalAddInPlace(Ciphertext& ct1, const Ciphertext& ct2); void EvalAddInPlace(Ciphertext& ct1, Plaintext& pt); void EvalAddInPlace(Plaintext& pt, Ciphertext& ct1); void EvalAddInPlace(Ciphertext& ct1, double scalar); void EvalAddInPlace(double scalar, Ciphertext& ct1); + void EvalAddInPlace(Ciphertext& ct1, std::complex scalar); + void EvalAddInPlace(std::complex scalar, Ciphertext& ct1); Ciphertext EvalAddMutable(Ciphertext& ct1, Ciphertext& ct2); Ciphertext EvalAddMutable(Ciphertext& ct, Plaintext& pt); Ciphertext EvalAddMutable(Plaintext& pt, Ciphertext& ct); @@ -146,9 +151,13 @@ template <> class CryptoContextImpl { Ciphertext EvalMult(Plaintext& pt, const Ciphertext& ct1); Ciphertext EvalMult(const Ciphertext& ct1, double scalar); Ciphertext EvalMult(double scalar, const Ciphertext& ct1); + Ciphertext EvalMult(const Ciphertext& ct1, std::complex scalar); + Ciphertext EvalMult(std::complex scalar, const Ciphertext& ct1); void EvalMultInPlace(Ciphertext& ct1, Plaintext& pt); void EvalMultInPlace(Ciphertext& ct1, double scalar); void EvalMultInPlace(double scalar, Ciphertext& ct1); + void EvalMultInPlace(Ciphertext& ct1, std::complex scalar); + void EvalMultInPlace(std::complex scalar, Ciphertext& ct1); Ciphertext EvalMultMutable(Ciphertext& ct1, Ciphertext& ct2); Ciphertext EvalMultMutable(Ciphertext& ct1, Plaintext& pt); Ciphertext EvalMultMutable(Plaintext& pt, Ciphertext& ct1); diff --git a/api/Definitions.hpp b/api/Definitions.hpp index 7b308a01..5878280b 100644 --- a/api/Definitions.hpp +++ b/api/Definitions.hpp @@ -86,6 +86,11 @@ enum SecurityLevel { HEStd_NotSet, }; +enum CKKSDataType { + REAL = 0, + COMPLEX, +}; + } // namespace fideslib #endif \ No newline at end of file diff --git a/other/GPUTest.cpp b/other/GPUTest.cpp index 41359b21..2679db29 100644 --- a/other/GPUTest.cpp +++ b/other/GPUTest.cpp @@ -1,5 +1,6 @@ #include #include +#include int main() { int deviceCount, device; diff --git a/src/CKKS/Ciphertext.cpp b/src/CKKS/Ciphertext.cpp index 9355881e..fb705baf 100644 --- a/src/CKKS/Ciphertext.cpp +++ b/src/CKKS/Ciphertext.cpp @@ -732,6 +732,97 @@ void Ciphertext::addScalar(const double c) { //} } +void Ciphertext::addScalar(const std::complex c) { + CudaNvtxRange r(std::string{std::source_location::current().function_name()}.substr(23 + strlen(loc))); + CKKS::SetCurrentContext(cc_); + op_count[OPS::ADDSCALAR]++; + + auto elem1 = cc.ElemForEvalAddOrSub(c0.getLevel(), std::fabs(c.real()), this->NoiseLevel); + auto elem2 = cc.ElemForEvalAddOrSub(c0.getLevel(), std::fabs(c.imag()), this->NoiseLevel); + + uint32_t sizeQl = c0.getLevel() + 1; + + std::vector moduli(sizeQl); + for (int i = 0; i < sizeQl; ++i) + moduli[i] = cc.prime[i].p; + + std::vector> data; + for (uint32_t i = 0; i < sizeQl; i++) { + std::vector vec(cc.N, 0); + // Code works for positive values only, doesn't work for negative values. + vec[0] = (c.real() > 0) ? elem1[i] % (moduli[i]) : (-elem1[i] % (moduli[i]) + moduli[i]) % (moduli[i]); + vec[cc.N / 2] = (c.imag() > 0) ? elem2[i] % (moduli[i]) : (-elem2[i] % (moduli[i]) + moduli[i]) % (moduli[i]); + // End code that only works for positive values. + data.push_back(vec); + } + + RNSPoly elemComplex(cc, data); + elemComplex.NTT(cc.batch, true); + c0.add(elemComplex); +} + +void Ciphertext::multScalar(const std::complex c, bool rescale) { + CudaNvtxRange r(std::string{std::source_location::current().function_name()}.substr(23 + strlen(loc))); + CKKS::SetCurrentContext(cc_); + op_count[OPS::MULTSCALAR]++; + + if (cc.rescaleTechnique == FLEXIBLEAUTO || cc.rescaleTechnique == FLEXIBLEAUTOEXT || + cc.rescaleTechnique == FIXEDAUTO) { + if (NoiseLevel == 2) + this->rescale(); + } + assert(this->NoiseLevel == 1); + + auto elem1 = cc.ElemForEvalMult(c0.getLevel(), c.real()); + auto elem2 = cc.ElemForEvalMult(c0.getLevel(), c.imag()); + + auto level0 = c0.getLevel(); + auto level1 = c1.getLevel(); + RNSPoly real0(cc, level0), imag0(cc, level0), real1(cc, level1), imag1(cc, level1); + real0.copy(c0); + imag0.copy(c0); + real1.copy(c1); + imag1.copy(c1); + + real0.multScalar(elem1); + imag0.multScalar(elem2); + real1.multScalar(elem1); + imag1.multScalar(elem2); + + uint32_t N = cc.N; + uint32_t M = 2 * N; + + std::vector> poly; + uint32_t power = M / 4; + uint32_t powerReduced = power % M; + uint32_t index = power % N; + for (int i = 0; i <= c0.getLevel(); ++i) { + std::vector vec(N, 0); + vec[index] = powerReduced < N ? 1 : cc.prime[0].p - 1; + poly.push_back(vec); + } + RNSPoly monomial(cc, poly); + monomial.NTT(cc.batch, true); + + imag0.multElement(monomial); + imag1.multElement(monomial); + c0.add(real0, imag0); + c1.add(real1, imag1); + + // Manage metadata + NoiseLevel += 1; + NoiseFactor *= cc.param.ScalingFactorReal.at(c0.getLevel()); + if (rescale) { + NoiseFactor /= cc.param.ModReduceFactor.at(c0.getLevel()); + NoiseLevel -= 1; + } + + if (rescale) { + c0.rescale(); + c1.rescale(); + } +} + void Ciphertext::automorph(const int index, const int br) { CudaNvtxRange r(std::string{ sc::current().function_name() }.substr()); CKKS::SetCurrentContext(cc_); diff --git a/src/CKKS/Ciphertext.cuh b/src/CKKS/Ciphertext.cuh index d895fb52..aa5bed4f 100644 --- a/src/CKKS/Ciphertext.cuh +++ b/src/CKKS/Ciphertext.cuh @@ -6,6 +6,7 @@ #define FIDESLIB_CKKS_CIPHERTEXT_CUH #include +#include #include "RNSPoly.cuh" #include "forwardDefs.cuh" #include "openfhe-interface/RawCiphertext.cuh" @@ -210,6 +211,18 @@ class Ciphertext { */ void addScalar(const double c); + /** + * @brief Adds a scalar (complex) to both polynomial components. + * + * The scalar is first converted to the appropriate modulus representation + * via `ElemForEvalAddOrSub`. The sign of the scalar is handled by + * converting the element to its complement when `c < 0`. No metadata + * updates are performed. + * + * @param c Scalar value to add. + */ + void addScalar(const std::complex c); + /** * @brief Multiplies *this* ciphertext by a plaintext. * @@ -314,6 +327,20 @@ class Ciphertext { */ void multScalar(const Ciphertext& b, const double c, bool rescale = false); + /** + * @brief Multiplies both polynomial components by a complex scalar, performing necessary checks. + * + * For flexible rescaling techniques the method may rescale the ciphertext + * when the noise level is 2. The scalar is converted using + * `ElemForEvalMult` and the multiplication is applied to both components. + * Metadata (noise level/factor) is updated accordingly; if `rescale` + * is true and the technique is `FIXEDMANUAL`, a rescaling step follows. + * + * @param c Scalar multiplier. + * @param rescale Perform rescaling after multiplication if true. + */ + void multScalar(const std::complex c, bool rescale = false); + /** * @brief Squares the ciphertext (i.e., multiplies it by itself). * diff --git a/src/CKKS/Context.cu b/src/CKKS/Context.cu index ef34bd77..3c64856d 100644 --- a/src/CKKS/Context.cu +++ b/src/CKKS/Context.cu @@ -755,7 +755,7 @@ void ContextData::PrepareNCCLCommunication() { NCCLCHECK( ncclCommRegister(GPUrank[g], top_limb_buffer2[i], sizeof(uint64_t) * N, &top_limb_buffer2_handle[i])); #else - cudaMalloc((void**)&top_limb_buffer[i], sizeof(uint63_t) * N); + cudaMalloc((void**)&top_limb_buffer[i], sizeof(uint64_t) * N); cudaMalloc((void**)&top_limb_buffer2[i], sizeof(uint64_t) * N); #endif } else {