Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions api/CCParams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ void CCParams<CryptoContextCKKSRNS>::SetSecurityLevel(SecurityLevel level) {
params.SetSecurityLevel(sl_openfhe);
}

void CCParams<CryptoContextCKKSRNS>::SetCKKSDataType(CKKSDataType ckksdt) {
auto& params = std::any_cast<lbcrypto::CCParams<lbcrypto::CryptoContextCKKSRNS>&>(cpu);
auto dt_openfhe = static_cast<lbcrypto::CKKSDataType>(ckksdt);
assert((int)dt_openfhe == (int)ckksdt);
params.SetCKKSDataType(dt_openfhe);
}

// ---- Getters ----

SecretKeyDist CCParams<CryptoContextCKKSRNS>::GetSecretKeyDist() const {
Expand Down
1 change: 1 addition & 0 deletions api/CCParams.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ template <> class CCParams<CryptoContextCKKSRNS> {
void SetDevices(std::vector<int>&& devices);
void SetPlaintextAutoload(bool autoload);
void SetCiphertextAutoload(bool autoload);
void SetCKKSDataType(CKKSDataType ckksdt);

// ---- Getters ----
SecretKeyDist GetSecretKeyDist() const;
Expand Down
97 changes: 97 additions & 0 deletions api/CryptoContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,32 @@ Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalAdd(double scalar, const C
return EvalAdd(ct, scalar);
}

Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalAdd(const Ciphertext<DCRTPoly>& ct, std::complex<double> scalar) {

// Fall back to CPU.
if (this->devices.empty()) {
auto& context = std::any_cast<const lbcrypto::CryptoContext<lbcrypto::DCRTPoly>&>(this->cpu);
auto& ctImpl = std::any_cast<const lbcrypto::Ciphertext<lbcrypto::DCRTPoly>&>(ct->cpu);
auto ct = context->EvalAdd(ctImpl, scalar);
Ciphertext<DCRTPoly> ciphertext = std::make_shared<CiphertextImpl<DCRTPoly>>(this->self_reference.lock());
ciphertext->cpu = std::make_any<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>>(ct);
return ciphertext;
}

// GPU path.
this->LoadCiphertext(const_cast<Ciphertext<DCRTPoly>&>(ct));

Ciphertext<DCRTPoly> result = std::make_shared<CiphertextImpl<DCRTPoly>>(*ct);
auto res_gpu = std::static_pointer_cast<FIDESlib::CKKS::Ciphertext>(this->GetDeviceCiphertext(result->gpu));
res_gpu->addScalar(scalar);

return result;
}

Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalAdd(std::complex<double> scalar, const Ciphertext<DCRTPoly>& ct) {
return EvalAdd(ct, scalar);
}

void CryptoContextImpl<DCRTPoly>::EvalAddInPlace(Ciphertext<DCRTPoly>& ct1, const Ciphertext<DCRTPoly>& ct2) {

// Fall back to CPU.
Expand Down Expand Up @@ -815,6 +841,28 @@ void CryptoContextImpl<DCRTPoly>::EvalAddInPlace(double scalar, Ciphertext<DCRTP
EvalAddInPlace(ct1, scalar);
}

void CryptoContextImpl<DCRTPoly>::EvalAddInPlace(Ciphertext<DCRTPoly>& ct1, std::complex<double> scalar) {

// Fall back to CPU.
if (this->devices.empty()) {

auto& context = std::any_cast<const lbcrypto::CryptoContext<lbcrypto::DCRTPoly>&>(this->cpu);
auto& ct1Impl = std::any_cast<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>&>(ct1->cpu);
context->EvalAddInPlace(ct1Impl, scalar);
return;
}

// GPU path.
this->LoadCiphertext(ct1);

auto res_gpu = std::static_pointer_cast<FIDESlib::CKKS::Ciphertext>(this->GetDeviceCiphertext(ct1->gpu));
res_gpu->addScalar(scalar);
}

void CryptoContextImpl<DCRTPoly>::EvalAddInPlace(std::complex<double> scalar, Ciphertext<DCRTPoly>& ct1) {
EvalAddInPlace(ct1, scalar);
}

Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalAddMutable(Ciphertext<DCRTPoly>& ct1, Ciphertext<DCRTPoly>& ct2) {
return EvalAdd(ct1, ct2);
}
Expand Down Expand Up @@ -1198,6 +1246,33 @@ Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalMult(double scalar, const
return EvalMult(ct1, scalar);
}

Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalMult(const Ciphertext<DCRTPoly>& ct1, std::complex<double> scalar) {

// Fall back to CPU.
if (this->devices.empty()) {

auto& context = std::any_cast<const lbcrypto::CryptoContext<lbcrypto::DCRTPoly>&>(this->cpu);
auto& ct1Impl = std::any_cast<const lbcrypto::Ciphertext<lbcrypto::DCRTPoly>&>(ct1->cpu);
auto ct = context->EvalMult(ct1Impl, scalar);
Ciphertext<DCRTPoly> ciphertext = std::make_shared<CiphertextImpl<DCRTPoly>>(this->self_reference.lock());
ciphertext->cpu = std::make_any<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>>(ct);
return ciphertext;
}

// GPU path.
this->LoadCiphertext(const_cast<Ciphertext<DCRTPoly>&>(ct1));

Ciphertext<DCRTPoly> result = std::make_shared<CiphertextImpl<DCRTPoly>>(*ct1);
auto res_gpu = std::static_pointer_cast<FIDESlib::CKKS::Ciphertext>(this->GetDeviceCiphertext(result->gpu));
res_gpu->multScalar(scalar);

return result;
}

Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalMult(std::complex<double> scalar, const Ciphertext<DCRTPoly>& ct1) {
return EvalMult(ct1, scalar);
}

void CryptoContextImpl<DCRTPoly>::EvalMultInPlace(Ciphertext<DCRTPoly>& ct1, Plaintext& pt) {

if (this->devices.empty()) {
Expand Down Expand Up @@ -1241,6 +1316,28 @@ void CryptoContextImpl<DCRTPoly>::EvalMultInPlace(double scalar, Ciphertext<DCRT
EvalMultInPlace(ct1, scalar);
}

void CryptoContextImpl<DCRTPoly>::EvalMultInPlace(Ciphertext<DCRTPoly>& ct1, std::complex<double> scalar) {

// Fall back to CPU.
if (this->devices.empty()) {

auto& context = std::any_cast<const lbcrypto::CryptoContext<lbcrypto::DCRTPoly>&>(this->cpu);
auto& ct1Impl = std::any_cast<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>&>(ct1->cpu);
context->EvalMultInPlace(ct1Impl, scalar);
return;
}

// GPU path.
this->LoadCiphertext(ct1);

auto res_gpu = std::static_pointer_cast<FIDESlib::CKKS::Ciphertext>(this->GetDeviceCiphertext(ct1->gpu));
res_gpu->multScalar(scalar);
}

void CryptoContextImpl<DCRTPoly>::EvalMultInPlace(std::complex<double> scalar, Ciphertext<DCRTPoly>& ct1) {
EvalMultInPlace(ct1, scalar);
}

Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalMultMutable(Ciphertext<DCRTPoly>& ct1, Ciphertext<DCRTPoly>& ct2) {
return EvalMult(ct1, ct2);
}
Expand Down
9 changes: 9 additions & 0 deletions api/CryptoContext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <shared_mutex>
#include <unordered_map>
#include <vector>
#include <complex>

#include "CCParams.hpp"
#include "Ciphertext.hpp"
Expand Down Expand Up @@ -115,11 +116,15 @@ template <> class CryptoContextImpl<DCRTPoly> {
Ciphertext<DCRTPoly> EvalAdd(Plaintext& pt, const Ciphertext<DCRTPoly>& ct);
Ciphertext<DCRTPoly> EvalAdd(const Ciphertext<DCRTPoly>& ct, double scalar);
Ciphertext<DCRTPoly> EvalAdd(double scalar, const Ciphertext<DCRTPoly>& ct);
Ciphertext<DCRTPoly> EvalAdd(const Ciphertext<DCRTPoly>& ct, std::complex<double> scalar);
Ciphertext<DCRTPoly> EvalAdd(std::complex<double> scalar, const Ciphertext<DCRTPoly>& ct);
void EvalAddInPlace(Ciphertext<DCRTPoly>& ct1, const Ciphertext<DCRTPoly>& ct2);
void EvalAddInPlace(Ciphertext<DCRTPoly>& ct1, Plaintext& pt);
void EvalAddInPlace(Plaintext& pt, Ciphertext<DCRTPoly>& ct1);
void EvalAddInPlace(Ciphertext<DCRTPoly>& ct1, double scalar);
void EvalAddInPlace(double scalar, Ciphertext<DCRTPoly>& ct1);
void EvalAddInPlace(Ciphertext<DCRTPoly>& ct1, std::complex<double> scalar);
void EvalAddInPlace(std::complex<double> scalar, Ciphertext<DCRTPoly>& ct1);
Ciphertext<DCRTPoly> EvalAddMutable(Ciphertext<DCRTPoly>& ct1, Ciphertext<DCRTPoly>& ct2);
Ciphertext<DCRTPoly> EvalAddMutable(Ciphertext<DCRTPoly>& ct, Plaintext& pt);
Ciphertext<DCRTPoly> EvalAddMutable(Plaintext& pt, Ciphertext<DCRTPoly>& ct);
Expand All @@ -146,9 +151,13 @@ template <> class CryptoContextImpl<DCRTPoly> {
Ciphertext<DCRTPoly> EvalMult(Plaintext& pt, const Ciphertext<DCRTPoly>& ct1);
Ciphertext<DCRTPoly> EvalMult(const Ciphertext<DCRTPoly>& ct1, double scalar);
Ciphertext<DCRTPoly> EvalMult(double scalar, const Ciphertext<DCRTPoly>& ct1);
Ciphertext<DCRTPoly> EvalMult(const Ciphertext<DCRTPoly>& ct1, std::complex<double> scalar);
Ciphertext<DCRTPoly> EvalMult(std::complex<double> scalar, const Ciphertext<DCRTPoly>& ct1);
void EvalMultInPlace(Ciphertext<DCRTPoly>& ct1, Plaintext& pt);
void EvalMultInPlace(Ciphertext<DCRTPoly>& ct1, double scalar);
void EvalMultInPlace(double scalar, Ciphertext<DCRTPoly>& ct1);
void EvalMultInPlace(Ciphertext<DCRTPoly>& ct1, std::complex<double> scalar);
void EvalMultInPlace(std::complex<double> scalar, Ciphertext<DCRTPoly>& ct1);
Ciphertext<DCRTPoly> EvalMultMutable(Ciphertext<DCRTPoly>& ct1, Ciphertext<DCRTPoly>& ct2);
Ciphertext<DCRTPoly> EvalMultMutable(Ciphertext<DCRTPoly>& ct1, Plaintext& pt);
Ciphertext<DCRTPoly> EvalMultMutable(Plaintext& pt, Ciphertext<DCRTPoly>& ct1);
Expand Down
5 changes: 5 additions & 0 deletions api/Definitions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ enum SecurityLevel {
HEStd_NotSet,
};

enum CKKSDataType {
REAL = 0,
COMPLEX,
};

} // namespace fideslib

#endif
1 change: 1 addition & 0 deletions other/GPUTest.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include <iostream>
#include <tuple>

int main() {
int deviceCount, device;
Expand Down
91 changes: 91 additions & 0 deletions src/CKKS/Ciphertext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,97 @@ void Ciphertext::addScalar(const double c) {
//}
}

void Ciphertext::addScalar(const std::complex<double> c) {
CudaNvtxRange r(std::string{std::source_location::current().function_name()}.substr(23 + strlen(loc)));
CKKS::SetCurrentContext(cc_);
op_count[OPS::ADDSCALAR]++;

auto elem1 = cc.ElemForEvalAddOrSub(c0.getLevel(), std::fabs(c.real()), this->NoiseLevel);
auto elem2 = cc.ElemForEvalAddOrSub(c0.getLevel(), std::fabs(c.imag()), this->NoiseLevel);

uint32_t sizeQl = c0.getLevel() + 1;

std::vector<uint64_t> moduli(sizeQl);
for (int i = 0; i < sizeQl; ++i)
moduli[i] = cc.prime[i].p;

std::vector<std::vector<uint64_t>> data;
for (uint32_t i = 0; i < sizeQl; i++) {
std::vector<uint64_t> vec(cc.N, 0);
// Code works for positive values only, doesn't work for negative values.
vec[0] = (c.real() > 0) ? elem1[i] % (moduli[i]) : (-elem1[i] % (moduli[i]) + moduli[i]) % (moduli[i]);
vec[cc.N / 2] = (c.imag() > 0) ? elem2[i] % (moduli[i]) : (-elem2[i] % (moduli[i]) + moduli[i]) % (moduli[i]);
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a bug here on lines 753 and 754 for addition of negative complex numbers, whether the real or imaginary part is negative. Apologies, I didn't realize I tested it further just now. The multiplication doesn't have the same problems.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea of what causes this?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the cause is that those two lines don't quite match up to what OpenFHE has in these two lines. Specifically, the mod.ModSub(elemsRe[i], mod) in OpenFHE (which is used when the real and likewise imaginary part are negative) doesn't match with (-elem1[i] % (moduli[i]) + moduli[i]) % (moduli[i]). It's possible that the positive part may have a subtle bug as well.

// End code that only works for positive values.
data.push_back(vec);
}

RNSPoly elemComplex(cc, data);
elemComplex.NTT(cc.batch, true);
c0.add(elemComplex);
}

void Ciphertext::multScalar(const std::complex<double> c, bool rescale) {
CudaNvtxRange r(std::string{std::source_location::current().function_name()}.substr(23 + strlen(loc)));
CKKS::SetCurrentContext(cc_);
op_count[OPS::MULTSCALAR]++;

if (cc.rescaleTechnique == FLEXIBLEAUTO || cc.rescaleTechnique == FLEXIBLEAUTOEXT ||
cc.rescaleTechnique == FIXEDAUTO) {
if (NoiseLevel == 2)
this->rescale();
}
assert(this->NoiseLevel == 1);

auto elem1 = cc.ElemForEvalMult(c0.getLevel(), c.real());
auto elem2 = cc.ElemForEvalMult(c0.getLevel(), c.imag());

auto level0 = c0.getLevel();
auto level1 = c1.getLevel();
RNSPoly real0(cc, level0), imag0(cc, level0), real1(cc, level1), imag1(cc, level1);
real0.copy(c0);
imag0.copy(c0);
real1.copy(c1);
imag1.copy(c1);

real0.multScalar(elem1);
imag0.multScalar(elem2);
real1.multScalar(elem1);
imag1.multScalar(elem2);

uint32_t N = cc.N;
uint32_t M = 2 * N;

std::vector<std::vector<uint64_t>> poly;
uint32_t power = M / 4;
uint32_t powerReduced = power % M;
uint32_t index = power % N;
for (int i = 0; i <= c0.getLevel(); ++i) {
std::vector<uint64_t> vec(N, 0);
vec[index] = powerReduced < N ? 1 : cc.prime[0].p - 1;
poly.push_back(vec);
}
RNSPoly monomial(cc, poly);
monomial.NTT(cc.batch, true);

imag0.multElement(monomial);
imag1.multElement(monomial);
c0.add(real0, imag0);
c1.add(real1, imag1);

// Manage metadata
NoiseLevel += 1;
NoiseFactor *= cc.param.ScalingFactorReal.at(c0.getLevel());
if (rescale) {
NoiseFactor /= cc.param.ModReduceFactor.at(c0.getLevel());
NoiseLevel -= 1;
}

if (rescale) {
c0.rescale();
c1.rescale();
}
}

void Ciphertext::automorph(const int index, const int br) {
CudaNvtxRange r(std::string{ sc::current().function_name() }.substr());
CKKS::SetCurrentContext(cc_);
Expand Down
27 changes: 27 additions & 0 deletions src/CKKS/Ciphertext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define FIDESLIB_CKKS_CIPHERTEXT_CUH

#include <source_location>
#include <complex>
#include "RNSPoly.cuh"
#include "forwardDefs.cuh"
#include "openfhe-interface/RawCiphertext.cuh"
Expand Down Expand Up @@ -210,6 +211,18 @@ class Ciphertext {
*/
void addScalar(const double c);

/**
* @brief Adds a scalar (complex) to both polynomial components.
*
* The scalar is first converted to the appropriate modulus representation
* via `ElemForEvalAddOrSub`. The sign of the scalar is handled by
* converting the element to its complement when `c < 0`. No metadata
* updates are performed.
*
* @param c Scalar value to add.
*/
void addScalar(const std::complex<double> c);

/**
* @brief Multiplies *this* ciphertext by a plaintext.
*
Expand Down Expand Up @@ -314,6 +327,20 @@ class Ciphertext {
*/
void multScalar(const Ciphertext& b, const double c, bool rescale = false);

/**
* @brief Multiplies both polynomial components by a complex scalar, performing necessary checks.
*
* For flexible rescaling techniques the method may rescale the ciphertext
* when the noise level is 2. The scalar is converted using
* `ElemForEvalMult` and the multiplication is applied to both components.
* Metadata (noise level/factor) is updated accordingly; if `rescale`
* is true and the technique is `FIXEDMANUAL`, a rescaling step follows.
*
* @param c Scalar multiplier.
* @param rescale Perform rescaling after multiplication if true.
*/
void multScalar(const std::complex<double> c, bool rescale = false);

/**
* @brief Squares the ciphertext (i.e., multiplies it by itself).
*
Expand Down
2 changes: 1 addition & 1 deletion src/CKKS/Context.cu
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ void ContextData::PrepareNCCLCommunication() {
NCCLCHECK(
ncclCommRegister(GPUrank[g], top_limb_buffer2[i], sizeof(uint64_t) * N, &top_limb_buffer2_handle[i]));
#else
cudaMalloc((void**)&top_limb_buffer[i], sizeof(uint63_t) * N);
cudaMalloc((void**)&top_limb_buffer[i], sizeof(uint64_t) * N);
cudaMalloc((void**)&top_limb_buffer2[i], sizeof(uint64_t) * N);
#endif
} else {
Expand Down