diff --git a/.github/workflows/lint-fmt.yml b/.github/workflows/lint-fmt.yml index 94ed9b3f..46a9dbc7 100644 --- a/.github/workflows/lint-fmt.yml +++ b/.github/workflows/lint-fmt.yml @@ -20,7 +20,7 @@ jobs: - uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: nightly + toolchain: stable override: true components: rustfmt - uses: actions-rs/cargo@v1 @@ -35,13 +35,15 @@ jobs: CARGO_TERM_COLOR: always runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - name: Install protoc + run: sudo apt-get update && sudo apt-get install -y protobuf-compiler libprotobuf-dev - uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: stable override: true components: clippy + - uses: actions/checkout@v6 - uses: actions-rs/cargo@v1 with: command: clippy diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index b1612a1f..02bab990 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -21,12 +21,14 @@ jobs: CARGO_TERM_COLOR: always runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - name: Install protoc + run: sudo apt-get update && sudo apt-get install -y protobuf-compiler libprotobuf-dev - uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: stable override: true + - uses: actions/checkout@v6 - uses: actions-rs/cargo@v1 with: command: check @@ -38,12 +40,14 @@ jobs: CARGO_TERM_COLOR: always runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - name: Install protoc + run: sudo apt-get update && sudo apt-get install -y protobuf-compiler libprotobuf-dev - uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: stable override: true + - uses: actions/checkout@v6 - uses: actions-rs/cargo@v1 with: command: test diff --git a/Cargo.lock b/Cargo.lock index 20abe4bc..fda2f342 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -409,6 +409,12 @@ dependencies = [ "syn", ] +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + [[package]] name = "errno" version = "0.3.13" @@ -433,7 +439,7 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "fhe" -version = "0.1.1" +version = "0.2.0" dependencies = [ "clap", "console", @@ -450,6 +456,7 @@ dependencies = [ "num-bigint", "num-traits", "prost", + "prost-build", "rand", "rand_chacha", "thiserror", @@ -459,7 +466,7 @@ dependencies = [ [[package]] name = "fhe-math" -version = "0.1.1" +version = "0.2.0" dependencies = [ "criterion", "ethnum", @@ -469,10 +476,10 @@ dependencies = [ "ndarray", "num-bigint", "num-bigint-dig", - "num-complex", "num-traits", "proptest", "prost", + "prost-build", "pulp", "rand", "rand_chacha", @@ -506,12 +513,24 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "fnv" version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "generic-array" version = "0.14.7" @@ -544,12 +563,37 @@ dependencies = [ "crunchy", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + [[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", +] + [[package]] name = "indicatif" version = "0.18.3" @@ -667,6 +711,12 @@ version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" +[[package]] +name = "multimap" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" + [[package]] name = "ndarray" version = "0.17.2" @@ -745,7 +795,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", - "libm", ] [[package]] @@ -782,6 +831,17 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "petgraph" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" +dependencies = [ + "fixedbitset", + "hashbrown 0.15.5", + "indexmap", +] + [[package]] name = "plotters" version = "0.3.7" @@ -834,6 +894,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -872,6 +942,25 @@ dependencies = [ "prost-derive", ] +[[package]] +name = "prost-build" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" +dependencies = [ + "heck", + "itertools 0.14.0", + "log", + "multimap", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "regex", + "syn", + "tempfile", +] + [[package]] name = "prost-derive" version = "0.14.3" @@ -885,6 +974,15 @@ dependencies = [ "syn", ] +[[package]] +name = "prost-types" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" +dependencies = [ + "prost", +] + [[package]] name = "pulp" version = "0.22.2" diff --git a/Cargo.toml b/Cargo.toml index 3ff9cec7..19083dae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,22 +21,22 @@ tfhe-ntt = "^0.7.0" console = "^0.16.2" criterion = "^0.8.1" doc-comment = "^0.3.4" -env_logger = "^0.11.3" -ethnum = "^1.5.0" +env_logger = "^0.11.8" +ethnum = "^1.5.2" indicatif = "^0.18.3" itertools = "^0.14.0" log = "^0.4.29" ndarray = "^0.17.2" -num-bigint = "^0.4.4" +num-bigint = "^0.4.6" num-bigint-dig = "^0.9.1" -num-traits = "^0.2.18" -num-complex = { version = "^0.4.6", features = ["libm"] } +num-traits = "^0.2.19" proptest = "^1.9.0" prost = "^0.14.3" +prost-build = "^0.14.3" pulp = "^0.22.2" rand = "^0.9.2" rand_chacha = "^0.9.0" -sha2 = "^0.10.8" +sha2 = "^0.10.9" thiserror = "^2.0.18" zeroize = "^1.8.2" zeroize_derive = "^1.4.3" diff --git a/README.md b/README.md index 25937e00..82a63685 100644 --- a/README.md +++ b/README.md @@ -28,8 +28,8 @@ To install, add the following to your project's `Cargo.toml` file: ```toml [dependencies] -fhe = "0.1.1" -fhe-traits = "0.1.1" +fhe = "0.2" +fhe-traits = "0.2" ``` ## Minimum supported version / toolchain diff --git a/crates/fhe-math/Cargo.toml b/crates/fhe-math/Cargo.toml index c09f0752..097d3c35 100644 --- a/crates/fhe-math/Cargo.toml +++ b/crates/fhe-math/Cargo.toml @@ -6,7 +6,7 @@ edition.workspace = true license.workspace = true repository.workspace = true rust-version.workspace = true -version = "0.1.1" +version = "0.2.0" [lints] workspace = true @@ -21,8 +21,8 @@ tfhe-ntt = [] fhe-traits = { version = "=0.1.1", path = "../fhe-traits" } fhe-util = { version = "=0.1.1", path = "../fhe-util" } -tfhe-ntt.workspace = true ethnum.workspace = true +tfhe-ntt.workspace = true itertools.workspace = true ndarray.workspace = true num-bigint.workspace = true @@ -35,12 +35,14 @@ rand_chacha.workspace = true thiserror.workspace = true zeroize.workspace = true sha2.workspace = true -num-complex.workspace = true [dev-dependencies] criterion.workspace = true proptest.workspace = true +[build-dependencies] +prost-build.workspace = true + [[bench]] name = "zq" harness = false diff --git a/crates/fhe-math/README.md b/crates/fhe-math/README.md index c5759b94..650cce58 100644 --- a/crates/fhe-math/README.md +++ b/crates/fhe-math/README.md @@ -15,7 +15,7 @@ Add the following to your `Cargo.toml`: ```toml [dependencies] -fhe-math = "0.1.1" +fhe-math = "0.2.0" ``` ## Testing diff --git a/crates/fhe-math/build.rs b/crates/fhe-math/build.rs new file mode 100644 index 00000000..7d5b8333 --- /dev/null +++ b/crates/fhe-math/build.rs @@ -0,0 +1,12 @@ +#![allow(missing_docs)] + +fn main() -> Result<(), Box> { + let proto_path = "src/proto/rq.proto"; + let proto_dir = "src/proto"; + + println!("cargo:rerun-if-changed={proto_path}"); + + let mut config = prost_build::Config::new(); + config.compile_protos(&[proto_path], &[proto_dir])?; + Ok(()) +} diff --git a/crates/fhe-math/src/proto/rq.rs b/crates/fhe-math/src/proto/rq.rs index ec32b85f..a0d07d8b 100644 --- a/crates/fhe-math/src/proto/rq.rs +++ b/crates/fhe-math/src/proto/rq.rs @@ -1,49 +1 @@ -#[expect( - clippy::derive_partial_eq_without_eq, - reason = "prost-generated types do not need Eq" -)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Rq { - #[prost(enumeration = "Representation", tag = "1")] - pub representation: i32, - #[prost(uint32, tag = "2")] - pub degree: u32, - #[prost(bytes = "vec", tag = "3")] - pub coefficients: ::prost::alloc::vec::Vec, - #[prost(bool, tag = "4")] - pub allow_variable_time: bool, -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -#[non_exhaustive] -pub enum Representation { - Unknown = 0, - Powerbasis = 1, - Ntt = 2, - Nttshoup = 3, -} -impl Representation { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic - /// use. - pub fn as_str_name(&self) -> &'static str { - match self { - Representation::Unknown => "UNKNOWN", - Representation::Powerbasis => "POWERBASIS", - Representation::Ntt => "NTT", - Representation::Nttshoup => "NTTSHOUP", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "UNKNOWN" => Some(Self::Unknown), - "POWERBASIS" => Some(Self::Powerbasis), - "NTT" => Some(Self::Ntt), - "NTTSHOUP" => Some(Self::Nttshoup), - _ => None, - } - } -} +include!(concat!(env!("OUT_DIR"), "/fhers.rq.rs")); diff --git a/crates/fhe-math/src/rq/serialize.rs b/crates/fhe-math/src/rq/serialize.rs index 698d577f..444110ab 100644 --- a/crates/fhe-math/src/rq/serialize.rs +++ b/crates/fhe-math/src/rq/serialize.rs @@ -30,7 +30,9 @@ mod tests { use fhe_traits::{DeserializeWithContext, Serialize}; use rand::rng; - use crate::rq::{Context, Poly, Representation}; + use crate::proto::rq::{Representation as RepresentationProto, Rq}; + use crate::rq::{Context, Poly, Representation, traits::TryConvertFrom}; + use prost::Message; const Q: &[u64; 3] = &[ 4611686018282684417, @@ -62,4 +64,71 @@ mod tests { Ok(()) } + + #[test] + fn deserialize_unknown_representation_rejected() -> Result<(), Box> { + let mut rng = rng(); + let ctx = Arc::new(Context::new(Q, 16)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let mut proto = Rq::from(&p); + proto.representation = RepresentationProto::Unknown as i32; + let bytes = proto.encode_to_vec(); + let err = Poly::from_bytes(&bytes, &ctx).unwrap_err(); + assert!(err.to_string().contains("Unknown representation")); + Ok(()) + } + + #[test] + fn deserialize_invalid_degree_rejected() -> Result<(), Box> { + let mut rng = rng(); + let ctx = Arc::new(Context::new(Q, 16)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let mut proto = Rq::from(&p); + proto.degree = 6; + let bytes = proto.encode_to_vec(); + let err = Poly::from_bytes(&bytes, &ctx).unwrap_err(); + assert!(err.to_string().contains("Invalid degree")); + Ok(()) + } + + #[test] + fn deserialize_invalid_coefficients_rejected() -> Result<(), Box> { + let mut rng = rng(); + let ctx = Arc::new(Context::new(Q, 16)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let mut proto = Rq::from(&p); + proto.coefficients.clear(); + let bytes = proto.encode_to_vec(); + let err = Poly::from_bytes(&bytes, &ctx).unwrap_err(); + assert!(err.to_string().contains("Invalid coefficients")); + Ok(()) + } + + #[test] + fn deserialize_representation_mismatch_rejected() -> Result<(), Box> { + let mut rng = rng(); + let ctx = Arc::new(Context::new(Q, 16)?); + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let proto = Rq::from(&p); + let err = + Poly::try_convert_from(&proto, &ctx, false, Representation::PowerBasis).unwrap_err(); + assert!( + err.to_string() + .contains("representation asked for does not match") + ); + Ok(()) + } + + #[test] + fn deserialize_variable_time_flag_propagates() -> Result<(), Box> { + let mut rng = rng(); + let ctx = Arc::new(Context::new(Q, 16)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let mut proto = Rq::from(&p); + proto.allow_variable_time = true; + let bytes = proto.encode_to_vec(); + let decoded = Poly::from_bytes(&bytes, &ctx)?; + assert!(decoded.allow_variable_time_computations); + Ok(()) + } } diff --git a/crates/fhe/Cargo.toml b/crates/fhe/Cargo.toml index aac1087d..c98461ca 100644 --- a/crates/fhe/Cargo.toml +++ b/crates/fhe/Cargo.toml @@ -6,7 +6,7 @@ edition.workspace = true license.workspace = true repository.workspace = true rust-version.workspace = true -version = "0.1.1" +version = "0.2.0" [lints] workspace = true @@ -18,7 +18,7 @@ bench = false # Disable default bench (we use criterion) tfhe-ntt = ["fhe-math/tfhe-ntt"] [dependencies] -fhe-math = { version = "=0.1.1", path = "../fhe-math" } +fhe-math = { version = "=0.2.0", path = "../fhe-math" } fhe-traits = { version = "=0.1.1", path = "../fhe-traits" } fhe-util = { version = "=0.1.1", path = "../fhe-util" } @@ -45,6 +45,9 @@ log.workspace = true ndarray.workspace = true rand.workspace = true +[build-dependencies] +prost-build.workspace = true + [[bench]] name = "bfv" harness = false diff --git a/crates/fhe/README.md b/crates/fhe/README.md index cd1a50ac..50a250a7 100644 --- a/crates/fhe/README.md +++ b/crates/fhe/README.md @@ -15,7 +15,7 @@ Add the following to your `Cargo.toml`: ```toml [dependencies] -fhe = "0.1.1" +fhe = "0.2.0" ``` ## Example diff --git a/crates/fhe/build.rs b/crates/fhe/build.rs new file mode 100644 index 00000000..c63f472b --- /dev/null +++ b/crates/fhe/build.rs @@ -0,0 +1,12 @@ +#![allow(missing_docs)] + +fn main() -> Result<(), Box> { + let proto_path = "src/proto/bfv.proto"; + let proto_dir = "src/proto"; + + println!("cargo:rerun-if-changed={proto_path}"); + + let mut config = prost_build::Config::new(); + config.compile_protos(&[proto_path], &[proto_dir])?; + Ok(()) +} diff --git a/crates/fhe/src/bfv/ciphertext.rs b/crates/fhe/src/bfv/ciphertext.rs index fc755036..3987f6a8 100644 --- a/crates/fhe/src/bfv/ciphertext.rs +++ b/crates/fhe/src/bfv/ciphertext.rs @@ -267,7 +267,9 @@ mod tests { BfvParameters::default_arc(6, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); - let v = params.plaintext.random_vec(params.degree(), &mut rng); + let v = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; let ct = sk.try_encrypt(&pt, &mut rng)?; let ct_proto = CiphertextProto::from(&ct); @@ -288,7 +290,9 @@ mod tests { BfvParameters::default_arc(6, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); - let v = params.plaintext.random_vec(params.degree(), &mut rng); + let v = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; let ct_bytes = ct.to_bytes(); @@ -305,7 +309,9 @@ mod tests { BfvParameters::default_arc(6, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); - let v = params.plaintext.random_vec(params.degree(), &mut rng); + let v = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; let mut ct3 = &ct * &ct; @@ -343,7 +349,9 @@ mod tests { BfvParameters::default_arc(6, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); - let v = params.plaintext.random_vec(params.degree(), &mut rng); + let v = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; let mut ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; @@ -364,7 +372,9 @@ mod tests { let mut rng = rng(); let params = BfvParameters::default_arc(2, 16); let sk = SecretKey::random(¶ms, &mut rng); - let v = params.plaintext.random_vec(params.degree(), &mut rng); + let v = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; let mut ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; diff --git a/crates/fhe/src/bfv/context/cipher_plain_context.rs b/crates/fhe/src/bfv/context/cipher_plain_context.rs index e805ba07..681ffd59 100644 --- a/crates/fhe/src/bfv/context/cipher_plain_context.rs +++ b/crates/fhe/src/bfv/context/cipher_plain_context.rs @@ -1,4 +1,5 @@ use fhe_math::rq::{Context, Poly, scaler::Scaler}; +use num_bigint::BigUint; use std::sync::Arc; /// Stores pre-computed values relating a ciphertext and plaintext context pair. @@ -11,10 +12,10 @@ pub struct CipherPlainContext { pub(crate) delta: Poly, /// Q modulo the plaintext modulus - pub(crate) q_mod_t: u64, + pub(crate) q_mod_t: BigUint, /// Threshold for centered reduction (plaintext_modulus + 1) / 2 - pub(crate) plain_threshold: u64, + pub(crate) plain_threshold: BigUint, /// Scaler to map a ciphertext polynomial to the plaintext context pub(crate) scaler: Scaler, @@ -33,8 +34,8 @@ impl CipherPlainContext { plaintext_context: &Arc, ciphertext_context: &Arc, delta: Poly, - q_mod_t: u64, - plain_threshold: u64, + q_mod_t: BigUint, + plain_threshold: BigUint, scaler: Scaler, ) -> Arc { Arc::new(CipherPlainContext { diff --git a/crates/fhe/src/bfv/keys/evaluation_key.rs b/crates/fhe/src/bfv/keys/evaluation_key.rs index c1f225eb..25741e33 100644 --- a/crates/fhe/src/bfv/keys/evaluation_key.rs +++ b/crates/fhe/src/bfv/keys/evaluation_key.rs @@ -551,9 +551,11 @@ mod tests { .enable_inner_sum()? .build(&mut rng)?; - let v = params.plaintext.random_vec(params.degree(), &mut rng); - let expected = params - .plaintext + let v = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); + let expected = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() .reduce_u128(v.iter().map(|vi| *vi as u128).sum()); let pt = Plaintext::try_encode( @@ -595,7 +597,9 @@ mod tests { .enable_row_rotation()? .build(&mut rng)?; - let v = params.plaintext.random_vec(params.degree(), &mut rng); + let v = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); let row_size = params.degree() >> 1; let mut expected = vec![0u64; params.degree()]; expected[..row_size].copy_from_slice(&v[row_size..]); @@ -642,7 +646,9 @@ mod tests { .enable_column_rotation(i)? .build(&mut rng)?; - let v = params.plaintext.random_vec(params.degree(), &mut rng); + let v = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); let row_size = params.degree() >> 1; let mut expected = vec![0u64; params.degree()]; expected[..row_size - i].copy_from_slice(&v[i..row_size]); @@ -699,7 +705,9 @@ mod tests { assert!(ek.supports_expansion(i)); assert!(!ek.supports_expansion(i + 1)); - let v = params.plaintext.random_vec(1 << i, &mut rng); + let v = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(1 << i, &mut rng); let pt = Plaintext::try_encode( &v, Encoding::poly_at_level(ciphertext_level), @@ -711,7 +719,9 @@ mod tests { assert_eq!(ct2.len(), 1 << i); for (vi, ct2i) in izip!(&v, &ct2) { let mut expected = vec![0u64; params.degree()]; - expected[0] = params.plaintext.mul(*vi, (1 << i) as u64); + expected[0] = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .mul(*vi, (1 << i) as u64); let pt = sk.try_decrypt(ct2i)?; assert_eq!( expected, diff --git a/crates/fhe/src/bfv/keys/galois_key.rs b/crates/fhe/src/bfv/keys/galois_key.rs index 1d96bd04..b0e24df7 100644 --- a/crates/fhe/src/bfv/keys/galois_key.rs +++ b/crates/fhe/src/bfv/keys/galois_key.rs @@ -173,7 +173,9 @@ mod tests { ] { for _ in 0..30 { let sk = SecretKey::random(¶ms, &mut rng); - let v = params.plaintext.random_vec(params.degree(), &mut rng); + let v = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); let row_size = params.degree() >> 1; let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; diff --git a/crates/fhe/src/bfv/keys/public_key.rs b/crates/fhe/src/bfv/keys/public_key.rs index 34d54dc0..893151e3 100644 --- a/crates/fhe/src/bfv/keys/public_key.rs +++ b/crates/fhe/src/bfv/keys/public_key.rs @@ -188,7 +188,9 @@ mod tests { let pk = PublicKey::new(&sk, &mut rng); let pt = Plaintext::try_encode( - ¶ms.plaintext.random_vec(params.degree(), &mut rng), + &fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng), Encoding::poly_at_level(level), ¶ms, )?; diff --git a/crates/fhe/src/bfv/keys/secret_key.rs b/crates/fhe/src/bfv/keys/secret_key.rs index 753c2717..f491a312 100644 --- a/crates/fhe/src/bfv/keys/secret_key.rs +++ b/crates/fhe/src/bfv/keys/secret_key.rs @@ -1,6 +1,8 @@ //! Secret keys for the BFV encryption scheme -use crate::bfv::{BfvParameters, Ciphertext, Plaintext}; +use crate::bfv::{ + BfvParameters, Ciphertext, Plaintext, parameters::PlaintextModulus, plaintext::PlaintextValues, +}; use crate::proto::bfv::SecretKey as SecretKeyProto; use crate::{Error, Result, SerializationError}; use fhe_math::{ @@ -237,25 +239,55 @@ impl FheDecrypter for SecretKey { let ctx_lvl = self.par.context_level_at(ct.level).unwrap(); let d = Zeroizing::new(c.scale(&ctx_lvl.cipher_plain_context.scaler)?); - // TODO: Can we handle plaintext moduli that are BigUint? - let v = Zeroizing::new( - Vec::::try_from(d.as_ref())? - .iter_mut() - .map(|vi| *vi + *self.par.plaintext) - .collect_vec(), - ); - let mut w = v[..self.par.degree()].to_vec(); - let q = Modulus::new(self.par.moduli[0]).map_err(Error::MathError)?; - q.reduce_vec(&mut w); - self.par.plaintext.reduce_vec(&mut w); - - let mut poly = - Poly::try_convert_from(&w, ct[0].ctx(), false, Representation::PowerBasis)?; + let value = match self.par.plaintext { + PlaintextModulus::Small { .. } => { + let mut v = Vec::::try_from(d.as_ref())?; + let plaintext_modulus = self.par.plaintext(); + v.iter_mut().for_each(|vi| *vi += plaintext_modulus); + let mut w = v[..self.par.degree()].to_vec(); + + let q = Modulus::new(self.par.moduli[0]).map_err(Error::MathError)?; + q.reduce_vec(&mut w); + if let PlaintextModulus::Small { modulus: m, .. } = &self.par.plaintext { + m.reduce_vec(&mut w); + } + PlaintextValues::Small(w.into_boxed_slice()) + } + PlaintextModulus::Large(_) => { + let v: Vec = Vec::::from(d.as_ref()) + .into_iter() + .map(|vi| vi + self.par.plaintext_big()) + .collect_vec(); + + let mut w = v[..self.par.degree()].to_vec(); + let q_poly = d.as_ref().ctx().modulus(); + w.iter_mut().for_each(|wi| *wi %= q_poly); + + self.par.plaintext.reduce_vec(&mut w); + PlaintextValues::Large(w.into_boxed_slice()) + } + }; + + let mut poly = match &value { + PlaintextValues::Small(v) => Poly::try_convert_from( + v.as_ref(), + ct[0].ctx(), + false, + Representation::PowerBasis, + )?, + PlaintextValues::Large(v) => Poly::try_convert_from( + v.as_ref(), + ct[0].ctx(), + false, + Representation::PowerBasis, + )?, + }; + poly.change_representation(Representation::Ntt); let pt = Plaintext { par: self.par.clone(), - value: w.into_boxed_slice(), + value, encoding: None, poly_ntt: poly, level: ct.level, @@ -299,9 +331,10 @@ mod tests { for level in 0..params.max_level() { for _ in 0..20 { let sk = SecretKey::random(¶ms, &mut rng); + let q = fhe_math::zq::Modulus::new(params.plaintext()).unwrap(); let pt = Plaintext::try_encode( - ¶ms.plaintext.random_vec(params.degree(), &mut rng), + &q.random_vec(params.degree(), &mut rng), Encoding::poly_at_level(level), ¶ms, )?; diff --git a/crates/fhe/src/bfv/mod.rs b/crates/fhe/src/bfv/mod.rs index fec4755b..798b9900 100644 --- a/crates/fhe/src/bfv/mod.rs +++ b/crates/fhe/src/bfv/mod.rs @@ -1,4 +1,4 @@ -#![warn(missing_docs, unused_imports)] +#![warn(missing_docs)] // Expect indexing in BFV cryptographic operations for performance #![expect( clippy::indexing_slicing, @@ -24,7 +24,9 @@ pub use encoding::Encoding; pub(crate) use keys::KeySwitchingKey; pub use keys::{EvaluationKey, EvaluationKeyBuilder, PublicKey, RelinearizationKey, SecretKey}; pub use ops::{Multiplicator, dot_product_scalar}; +pub(crate) use parameters::PlaintextModulus; pub use parameters::{BfvParameters, BfvParametersBuilder}; pub use plaintext::Plaintext; +pub(crate) use plaintext::PlaintextValues; pub use plaintext_vec::PlaintextVec; pub use rgsw_ciphertext::RGSWCiphertext; diff --git a/crates/fhe/src/bfv/ops/dot_product.rs b/crates/fhe/src/bfv/ops/dot_product.rs index f51aed13..50f29404 100644 --- a/crates/fhe/src/bfv/ops/dot_product.rs +++ b/crates/fhe/src/bfv/ops/dot_product.rs @@ -180,14 +180,18 @@ mod tests { for size in 1..128 { let ct = (0..size) .map(|_| { - let v = params.plaintext.random_vec(params.degree(), &mut rng); + let v = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms).unwrap(); sk.try_encrypt(&pt, &mut rng).unwrap() }) .collect_vec(); let pt = (0..size) .map(|_| { - let v = params.plaintext.random_vec(params.degree(), &mut rng); + let v = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); Plaintext::try_encode(&v, Encoding::simd(), ¶ms).unwrap() }) .collect_vec(); diff --git a/crates/fhe/src/bfv/ops/mod.rs b/crates/fhe/src/bfv/ops/mod.rs index c4602623..95197282 100644 --- a/crates/fhe/src/bfv/ops/mod.rs +++ b/crates/fhe/src/bfv/ops/mod.rs @@ -370,11 +370,12 @@ mod tests { BfvParameters::default_arc(6, 16), ] { let zero = Ciphertext::zero(¶ms); + let q = fhe_math::zq::Modulus::new(params.plaintext()).unwrap(); for _ in 0..50 { - let a = params.plaintext.random_vec(params.degree(), &mut rng); - let b = params.plaintext.random_vec(params.degree(), &mut rng); + let a = q.random_vec(params.degree(), &mut rng); + let b = q.random_vec(params.degree(), &mut rng); let mut c = a.clone(); - params.plaintext.add_vec(&mut c, &b); + q.add_vec(&mut c, &b); let sk = SecretKey::random(¶ms, &mut rng); @@ -410,11 +411,12 @@ mod tests { BfvParameters::default_arc(1, 16), BfvParameters::default_arc(6, 16), ] { + let q = fhe_math::zq::Modulus::new(params.plaintext()).unwrap(); for _ in 0..50 { - let a = params.plaintext.random_vec(params.degree(), &mut rng); - let b = params.plaintext.random_vec(params.degree(), &mut rng); + let a = q.random_vec(params.degree(), &mut rng); + let b = q.random_vec(params.degree(), &mut rng); let mut c = a.clone(); - params.plaintext.add_vec(&mut c, &b); + q.add_vec(&mut c, &b); let sk = SecretKey::random(¶ms, &mut rng); @@ -462,13 +464,14 @@ mod tests { BfvParameters::default_arc(6, 16), ] { let zero = Ciphertext::zero(¶ms); + let q = fhe_math::zq::Modulus::new(params.plaintext()).unwrap(); for _ in 0..50 { - let a = params.plaintext.random_vec(params.degree(), &mut rng); + let a = q.random_vec(params.degree(), &mut rng); let mut a_neg = a.clone(); - params.plaintext.neg_vec(&mut a_neg); - let b = params.plaintext.random_vec(params.degree(), &mut rng); + q.neg_vec(&mut a_neg); + let b = q.random_vec(params.degree(), &mut rng); let mut c = a.clone(); - params.plaintext.sub_vec(&mut c, &b); + q.sub_vec(&mut c, &b); let sk = SecretKey::random(¶ms, &mut rng); @@ -509,13 +512,14 @@ mod tests { BfvParameters::default_arc(1, 16), BfvParameters::default_arc(6, 16), ] { + let q = fhe_math::zq::Modulus::new(params.plaintext()).unwrap(); for _ in 0..50 { - let a = params.plaintext.random_vec(params.degree(), &mut rng); + let a = q.random_vec(params.degree(), &mut rng); let mut a_neg = a.clone(); - params.plaintext.neg_vec(&mut a_neg); - let b = params.plaintext.random_vec(params.degree(), &mut rng); + q.neg_vec(&mut a_neg); + let b = q.random_vec(params.degree(), &mut rng); let mut c = a.clone(); - params.plaintext.sub_vec(&mut c, &b); + q.sub_vec(&mut c, &b); let sk = SecretKey::random(¶ms, &mut rng); @@ -562,10 +566,11 @@ mod tests { BfvParameters::default_arc(1, 16), BfvParameters::default_arc(6, 16), ] { + let q = fhe_math::zq::Modulus::new(params.plaintext()).unwrap(); for _ in 0..50 { - let a = params.plaintext.random_vec(params.degree(), &mut rng); + let a = q.random_vec(params.degree(), &mut rng); let mut c = a.clone(); - params.plaintext.neg_vec(&mut c); + q.neg_vec(&mut c); let sk = SecretKey::random(¶ms, &mut rng); for encoding in [Encoding::poly(), Encoding::simd()] { @@ -595,9 +600,10 @@ mod tests { BfvParameters::default_arc(1, 16), BfvParameters::default_arc(6, 16), ] { + let q = fhe_math::zq::Modulus::new(params.plaintext()).unwrap(); for _ in 0..50 { - let a = params.plaintext.random_vec(params.degree(), &mut rng); - let b = params.plaintext.random_vec(params.degree(), &mut rng); + let a = q.random_vec(params.degree(), &mut rng); + let b = q.random_vec(params.degree(), &mut rng); let sk = SecretKey::random(¶ms, &mut rng); for encoding in [Encoding::poly(), Encoding::simd()] { @@ -607,21 +613,17 @@ mod tests { for i in 0..params.degree() { for j in 0..params.degree() { if i + j >= params.degree() { - c[(i + j) % params.degree()] = params.plaintext.sub( - c[(i + j) % params.degree()], - params.plaintext.mul(a[i], b[j]), - ); + c[(i + j) % params.degree()] = + q.sub(c[(i + j) % params.degree()], q.mul(a[i], b[j])); } else { - c[i + j] = params - .plaintext - .add(c[i + j], params.plaintext.mul(a[i], b[j])); + c[i + j] = q.add(c[i + j], q.mul(a[i], b[j])); } } } } EncodingEnum::Simd => { c.clone_from(&a); - params.plaintext.mul_vec(&mut c, &b); + q.mul_vec(&mut c, &b); } } @@ -652,13 +654,14 @@ mod tests { BfvParameters::default_arc(2, 16), BfvParameters::default_arc(8, 16), ] { + let q = fhe_math::zq::Modulus::new(par.plaintext()).unwrap(); for _ in 0..1 { // We will encode `values` in an Simd format, and check that the product is // computed correctly. - let v1 = par.plaintext.random_vec(par.degree(), &mut rng); - let v2 = par.plaintext.random_vec(par.degree(), &mut rng); + let v1 = q.random_vec(par.degree(), &mut rng); + let v2 = q.random_vec(par.degree(), &mut rng); let mut expected = v1.clone(); - par.plaintext.mul_vec(&mut expected, &v2); + q.mul_vec(&mut expected, &v2); let sk = SecretKey::random(&par, &mut rng); let pt1 = Plaintext::try_encode(&v1, Encoding::simd(), &par)?; @@ -674,7 +677,7 @@ mod tests { assert_eq!(Vec::::try_decode(&pt, Encoding::simd())?, expected); let e = expected.clone(); - par.plaintext.mul_vec(&mut expected, &e); + q.mul_vec(&mut expected, &e); println!("Noise: {}", unsafe { sk.measure_noise(&ct4)? }); let pt = sk.try_decrypt(&ct4)?; assert_eq!(Vec::::try_decode(&pt, Encoding::simd())?, expected); @@ -687,12 +690,13 @@ mod tests { fn square() -> Result<(), Box> { let mut rng = rng(); let par = BfvParameters::default_arc(6, 16); + let q = fhe_math::zq::Modulus::new(par.plaintext()).unwrap(); for _ in 0..20 { // We will encode `values` in an Simd format, and check that the product is // computed correctly. - let v = par.plaintext.random_vec(par.degree(), &mut rng); + let v = q.random_vec(par.degree(), &mut rng); let mut expected = v.clone(); - par.plaintext.mul_vec(&mut expected, &v); + q.mul_vec(&mut expected, &v); let sk = SecretKey::random(&par, &mut rng); let pt = Plaintext::try_encode(&v, Encoding::simd(), &par)?; diff --git a/crates/fhe/src/bfv/ops/mul.rs b/crates/fhe/src/bfv/ops/mul.rs index 993aedb9..bb84dadb 100644 --- a/crates/fhe/src/bfv/ops/mul.rs +++ b/crates/fhe/src/bfv/ops/mul.rs @@ -5,7 +5,6 @@ use fhe_math::{ rq::{Context, Representation, scaler::Scaler}, zq::primes::generate_prime, }; -use num_bigint::BigUint; use crate::{ Error, Result, @@ -121,7 +120,7 @@ impl Multiplicator { ScalingFactor::one(), ScalingFactor::one(), &extended_basis, - ScalingFactor::new(&BigUint::from(*rk.ksk.par.plaintext), ctx.modulus()), + ScalingFactor::new(rk.ksk.par.plaintext_big(), ctx.modulus()), rk.ksk.ciphertext_level, &rk.ksk.par, )?; @@ -254,12 +253,13 @@ mod tests { fn mul() -> Result<(), Box> { let mut rng = rng(); let par = BfvParameters::default_arc(3, 16); + let q = fhe_math::zq::Modulus::new(par.plaintext()).unwrap(); for _ in 0..30 { // We will encode `values` in an Simd format, and check that the product is // computed correctly. - let values = par.plaintext.random_vec(par.degree(), &mut rng); + let values = q.random_vec(par.degree(), &mut rng); let mut expected = values.clone(); - par.plaintext.mul_vec(&mut expected, &values); + q.mul_vec(&mut expected, &values); let sk = SecretKey::random(&par, &mut rng); let rk = RelinearizationKey::new(&sk, &mut rng)?; @@ -287,11 +287,12 @@ mod tests { fn mul_at_level() -> Result<(), Box> { let mut rng = rng(); let par = BfvParameters::default_arc(3, 16); + let q = fhe_math::zq::Modulus::new(par.plaintext()).unwrap(); for _ in 0..15 { for level in 0..2 { - let values = par.plaintext.random_vec(par.degree(), &mut rng); + let values = q.random_vec(par.degree(), &mut rng); let mut expected = values.clone(); - par.plaintext.mul_vec(&mut expected, &values); + q.mul_vec(&mut expected, &values); let sk = SecretKey::random(&par, &mut rng); let rk = RelinearizationKey::new_leveled(&sk, level, level, &mut rng)?; @@ -322,12 +323,13 @@ mod tests { fn mul_no_relin() -> Result<(), Box> { let mut rng = rng(); let par = BfvParameters::default_arc(6, 16); + let q = fhe_math::zq::Modulus::new(par.plaintext()).unwrap(); for _ in 0..30 { // We will encode `values` in an Simd format, and check that the product is // computed correctly. - let values = par.plaintext.random_vec(par.degree(), &mut rng); + let values = q.random_vec(par.degree(), &mut rng); let mut expected = values.clone(); - par.plaintext.mul_vec(&mut expected, &values); + q.mul_vec(&mut expected, &values); let sk = SecretKey::random(&par, &mut rng); let rk = RelinearizationKey::new(&sk, &mut rng)?; @@ -359,6 +361,7 @@ mod tests { let mut rng = rng(); let par = BfvParameters::default_arc(3, 16); + let q = fhe_math::zq::Modulus::new(par.plaintext()).unwrap(); let mut extended_basis = par.moduli().to_vec(); extended_basis .push(generate_prime(62, 2 * par.degree() as u64, extended_basis[2]).unwrap()); @@ -371,9 +374,9 @@ mod tests { for _ in 0..30 { // We will encode `values` in an Simd format, and check that the product is // computed correctly. - let values = par.plaintext.random_vec(par.degree(), &mut rng); + let values = q.random_vec(par.degree(), &mut rng); let mut expected = values.clone(); - par.plaintext.mul_vec(&mut expected, &values); + q.mul_vec(&mut expected, &values); let sk = SecretKey::random(&par, &mut rng); let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?; diff --git a/crates/fhe/src/bfv/parameters.rs b/crates/fhe/src/bfv/parameters.rs index 4ca6b9c6..41ca62e6 100644 --- a/crates/fhe/src/bfv/parameters.rs +++ b/crates/fhe/src/bfv/parameters.rs @@ -1,7 +1,7 @@ //! Create parameters for the BFV encryption scheme use crate::bfv::{context::CipherPlainContext, context::ContextLevel}; -use crate::proto::bfv::Parameters; +use crate::proto::bfv::{Parameters, parameters::PlaintextModulus as PlaintextModulusProto}; use crate::{Error, ParametersError, Result, SerializationError}; use fhe_math::{ ntt::NttOperator, @@ -18,6 +18,56 @@ use std::collections::HashMap; use std::fmt::Debug; use std::sync::Arc; +/// Enum to support both small (u64) and large (BigUint) plaintext moduli. +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) enum PlaintextModulus { + Small { + modulus: Modulus, + modulus_big: BigUint, + }, + Large(BigUint), +} + +impl PlaintextModulus { + pub fn as_biguint(&self) -> &BigUint { + match self { + Self::Small { modulus_big, .. } => modulus_big, + Self::Large(m) => m, + } + } + + pub fn as_u64(&self) -> Option { + match self { + Self::Small { modulus, .. } => Some(**modulus), + Self::Large(_) => None, + } + } + + pub fn reduce_vec(&self, v: &mut [BigUint]) { + match self { + Self::Small { modulus_big, .. } => { + v.iter_mut().for_each(|vi| *vi %= modulus_big); + } + Self::Large(m) => v.iter_mut().for_each(|vi| *vi %= m), + } + } + + // Helper to reduce BigUint vector to i64 (centered), returning as Vec + // or similar? The previous implementation used center_vec_vt returning + // Vec. If modulus is large, we can't fit in i64. + + // We need a scalar multiplication for Plaintext::to_poly + pub fn scalar_mul_vec(&self, a: &mut [BigUint], b: &BigUint) { + match self { + Self::Small { modulus_big, .. } => { + a.iter_mut() + .for_each(|ai| *ai = (ai as &BigUint * b) % modulus_big); + } + Self::Large(m) => a.iter_mut().for_each(|ai| *ai = (ai as &BigUint * b) % m), + } + } +} + /// Parameters for the BFV encryption scheme. /// /// This struct consolidates all parameter-specific data and pre-computed values @@ -28,9 +78,6 @@ pub struct BfvParameters { /// Number of coefficients in a polynomial. polynomial_degree: usize, - /// Modulus of the plaintext. - plaintext_modulus: u64, - /// Vector of coprime moduli q_i for the ciphertext. pub(crate) moduli: Box<[u64]>, @@ -46,8 +93,8 @@ pub struct BfvParameters { /// NTT operator for SIMD plaintext operations, if possible pub(crate) ntt_operator: Option>, - /// Plaintext Modulus as a Modulus type - pub(crate) plaintext: Modulus, + /// Plaintext Modulus as a Modulus type or BigUint + pub(crate) plaintext: PlaintextModulus, pub(crate) matrix_reps_index_map: Box<[usize]>, } @@ -56,14 +103,8 @@ impl Debug for BfvParameters { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("BfvParameters") .field("polynomial_degree", &self.polynomial_degree) - .field("plaintext_modulus", &self.plaintext_modulus) + .field("plaintext_modulus", &self.plaintext.as_biguint()) .field("moduli", &self.moduli) - // .field("moduli_sizes", &self.moduli_sizes) - // .field("variance", &self.variance) - // .field("ctx", &self.ctx) - // .field("op", &self.op) - // .field("plaintext", &self.plaintext) - // .field("matrix_reps_index_map", &self.matrix_reps_index_map) .finish() } } @@ -91,10 +132,17 @@ impl BfvParameters { &self.moduli_sizes } - /// Returns the plaintext modulus + /// Returns the plaintext modulus if it fits in u64. + /// Panics if the modulus is too large. + #[must_use] + pub fn plaintext(&self) -> u64 { + self.plaintext.as_u64().unwrap() + } + + /// Returns the plaintext modulus as BigUint #[must_use] - pub const fn plaintext(&self) -> u64 { - self.plaintext_modulus + pub fn plaintext_big(&self) -> &BigUint { + self.plaintext.as_biguint() } /// Returns the maximum level allowed by these parameters. @@ -266,7 +314,7 @@ impl BfvParameters { #[derive(Debug)] pub struct BfvParametersBuilder { degree: usize, - plaintext: u64, + plaintext: BigUint, variance: usize, ciphertext_moduli: Vec, ciphertext_moduli_sizes: Vec, @@ -296,9 +344,14 @@ impl BfvParametersBuilder { self } - /// Sets the plaintext modulus. Returns an error if the plaintext is not - /// between 2 and 2^62 - 1. + /// Sets the plaintext modulus. pub fn set_plaintext_modulus(&mut self, plaintext: u64) -> &mut Self { + self.plaintext = BigUint::from(plaintext); + self + } + + /// Sets the plaintext modulus as BigUint. + pub fn set_plaintext_modulus_biguint(&mut self, plaintext: BigUint) -> &mut Self { self.plaintext = plaintext; self } @@ -383,14 +436,20 @@ impl BfvParametersBuilder { )); } - // This checks that the plaintext modulus is valid. - // TODO: Check bound on the plaintext modulus. - let plaintext_modulus = Modulus::new(self.plaintext).map_err(|e| { - Error::ParametersError(ParametersError::InvalidPlaintextModulus { - modulus: self.plaintext, - reason: e.to_string(), - }) - })?; + let plaintext_modulus_struct = if let Some(p) = self.plaintext.to_u64() { + PlaintextModulus::Small { + modulus: Modulus::new(p).map_err(|e| { + Error::ParametersError(ParametersError::InvalidPlaintextModulus { + modulus: p, + reason: e.to_string(), + }) + })?, + modulus_big: BigUint::from(p), + } + } else { + PlaintextModulus::Large(self.plaintext.clone()) + }; + let plaintext_big = plaintext_modulus_struct.as_biguint(); // Check that one of `ciphertext_moduli` and `ciphertext_moduli_sizes` is // specified. @@ -416,11 +475,32 @@ impl BfvParametersBuilder { .map(|m| 64 - m.leading_zeros() as usize) .collect_vec(); - // Create plaintext context using the first ciphertext modulus - let plaintext_context = Context::new_arc(&moduli[..1], self.degree)?; + // Determine how many moduli needed for plaintext context + // We need product of moduli > plaintext modulus. + let t_bits = plaintext_big.bits(); + let mut accumulated_bits = 0; + let mut plaintext_moduli_count = 0; + for size in &moduli_sizes { + accumulated_bits += size; + plaintext_moduli_count += 1; + if accumulated_bits as u64 >= t_bits + 60 { + break; + } + } + plaintext_moduli_count = std::cmp::max(plaintext_moduli_count, 1); + plaintext_moduli_count = std::cmp::min(plaintext_moduli_count, moduli.len()); + + // Create plaintext context using sufficient moduli + let plaintext_context = Context::new_arc(&moduli[..plaintext_moduli_count], self.degree)?; // Create NTT operator for SIMD operations if possible - let ntt_operator = NttOperator::new(&plaintext_modulus, self.degree).map(Arc::new); + // Only if plaintext modulus fits in u64 for now + let ntt_operator = match &plaintext_modulus_struct { + PlaintextModulus::Small { modulus, .. } => { + NttOperator::new(modulus, self.degree).map(Arc::new) + } + PlaintextModulus::Large(_) => None, + }; // Create cipher-plain bridge contexts let mut cipher_plain_contexts = Vec::with_capacity(moduli.len()); @@ -433,7 +513,15 @@ impl BfvParametersBuilder { let mut delta_rests = vec![]; for m in level_moduli { let q = Modulus::new(*m)?; - delta_rests.push(q.inv(q.neg(*plaintext_modulus)).unwrap()) + let t_mod_q = (plaintext_big % *m).to_u64().unwrap(); + let neg_t_mod_q = q.neg(t_mod_q); + if let Some(inv) = q.inv(neg_t_mod_q) { + delta_rests.push(inv); + } else { + Err(Error::MathError(fhe_math::Error::Default( + "Inverse failed".to_string(), + )))?; + } } // Use RnsContext to lift the delta values and create the scaling polynomial @@ -447,16 +535,19 @@ impl BfvParametersBuilder { delta.change_representation(Representation::NttShoup); // Compute q_mod_t - let q_mod_t = (rns.modulus() % *plaintext_modulus).to_u64().unwrap(); + let q_mod_t = rns.modulus() % plaintext_big; // Compute plain_threshold - let plain_threshold = self.plaintext.div_ceil(2); + let plain_threshold = match &plaintext_modulus_struct { + PlaintextModulus::Small { modulus, .. } => BigUint::from((**modulus + 1) >> 1), + PlaintextModulus::Large(m) => (m + 1u32) >> 1, + }; // Scaler from ciphertext to plaintext context let scaler = Scaler::new( &cipher_ctx, &plaintext_context, - ScalingFactor::new(&BigUint::from(*plaintext_modulus), rns.modulus()), + ScalingFactor::new(plaintext_big, rns.modulus()), )?; let cipher_plain_ctx = CipherPlainContext::new_arc( @@ -516,10 +607,7 @@ impl BfvParametersBuilder { &node.poly_context, &mul_1_ctx, ScalingFactor::one(), - ScalingFactor::new( - &BigUint::from(*plaintext_modulus), - node.poly_context.modulus(), - ), + ScalingFactor::new(plaintext_big, node.poly_context.modulus()), )?; node.mul_params.set(mp).unwrap(); } @@ -543,13 +631,12 @@ impl BfvParametersBuilder { Ok(BfvParameters { polynomial_degree: self.degree, - plaintext_modulus: self.plaintext, moduli: moduli.into(), moduli_sizes: moduli_sizes.into(), variance: self.variance, context_chain, ntt_operator, - plaintext: plaintext_modulus, + plaintext: plaintext_modulus_struct, matrix_reps_index_map: matrix_reps_index_map.into(), }) } @@ -557,11 +644,19 @@ impl BfvParametersBuilder { impl Serialize for BfvParameters { fn to_bytes(&self) -> Vec { + let plaintext_modulus = if let Some(plaintext_u64) = self.plaintext.as_u64() { + Some(PlaintextModulusProto::Plaintext(plaintext_u64)) + } else { + Some(PlaintextModulusProto::PlaintextBig( + self.plaintext.as_biguint().to_bytes_le(), + )) + }; + Parameters { degree: self.polynomial_degree as u32, - plaintext: self.plaintext_modulus, moduli: self.moduli.to_vec(), variance: self.variance as u32, + plaintext_modulus, } .encode_to_vec() } @@ -574,9 +669,22 @@ impl Deserialize for BfvParameters { message: "Parameters decode".into(), }) })?; + + let plaintext_modulus = match params.plaintext_modulus { + Some(PlaintextModulusProto::Plaintext(value)) => BigUint::from(value), + Some(PlaintextModulusProto::PlaintextBig(bytes)) => BigUint::from_bytes_le(&bytes), + None => { + return Err(Error::SerializationError( + SerializationError::MissingField { + field_name: "Parameters.plaintext_modulus".into(), + }, + )); + } + }; + BfvParametersBuilder::new() .set_degree(params.degree as usize) - .set_plaintext_modulus(params.plaintext) + .set_plaintext_modulus_biguint(plaintext_modulus) .set_moduli(¶ms.moduli) .set_variance(params.variance as usize) .build() @@ -612,7 +720,10 @@ impl MultiplicationParameters { #[cfg(test)] mod tests { use super::{BfvParameters, BfvParametersBuilder}; + use crate::proto::bfv::{Parameters, parameters::PlaintextModulus as PlaintextModulusProto}; use fhe_traits::{Deserialize, Serialize}; + use num_bigint::BigUint; + use prost::Message; use std::error::Error; #[test] @@ -662,6 +773,20 @@ mod tests { Ok(()) } + #[test] + fn big_plaintext_modulus() -> Result<(), Box> { + // Use a 128-bit prime + let p = BigUint::parse_bytes(b"340282366920938463463374607431768211507", 10).unwrap(); + let params = BfvParametersBuilder::new() + .set_degree(16) + .set_plaintext_modulus_biguint(p.clone()) + .set_moduli_sizes(&[62, 62, 62, 62, 62]) // Large enough for product > p + .build()?; + + assert_eq!(params.plaintext_big(), &p); + Ok(()) + } + #[test] fn serialize() -> Result<(), Box> { let params = BfvParametersBuilder::new() @@ -671,10 +796,51 @@ mod tests { .set_variance(4) .build()?; let bytes = params.to_bytes(); + let proto = Parameters::decode(bytes.as_slice())?; + assert!(matches!( + proto.plaintext_modulus, + Some(PlaintextModulusProto::Plaintext(2)) + )); assert_eq!(BfvParameters::try_deserialize(&bytes)?, params); + + // Test with big plaintext + let p = BigUint::parse_bytes(b"340282366920938463463374607431768211507", 10).unwrap(); + let params = BfvParametersBuilder::new() + .set_degree(16) + .set_plaintext_modulus_biguint(p) + .set_moduli_sizes(&[62, 62, 62, 62, 62]) + .set_variance(4) + .build()?; + let bytes = params.to_bytes(); + let proto = Parameters::decode(bytes.as_slice())?; + let proto_plaintext_bytes = match &proto.plaintext_modulus { + Some(PlaintextModulusProto::PlaintextBig(bytes)) => bytes.as_slice(), + _ => return Err("expected plaintext_big variant".into()), + }; + assert_eq!( + proto_plaintext_bytes, + params.plaintext_big().to_bytes_le().as_slice() + ); + let decoded = BfvParameters::try_deserialize(&bytes)?; + assert_eq!(decoded, params); + assert_eq!(decoded.plaintext_big(), params.plaintext_big()); + Ok(()) } + #[test] + fn deserialize_missing_plaintext_modulus() { + let proto = Parameters { + degree: 16, + moduli: vec![4611686018427387617, 4611686018427387329], + variance: 4, + plaintext_modulus: None, + }; + let bytes = proto.encode_to_vec(); + let err = BfvParameters::try_deserialize(&bytes).unwrap_err(); + assert!(format!("{err}").contains("Missing required field")); + } + #[test] fn matrix_reps_index_map_is_permutation() -> Result<(), Box> { let params = BfvParametersBuilder::new() diff --git a/crates/fhe/src/bfv/plaintext.rs b/crates/fhe/src/bfv/plaintext.rs index 2390ee66..036e0202 100644 --- a/crates/fhe/src/bfv/plaintext.rs +++ b/crates/fhe/src/bfv/plaintext.rs @@ -1,22 +1,43 @@ //! Plaintext type in the BFV encryption scheme. use crate::{ Error, Result, - bfv::{BfvParameters, Encoding, PlaintextVec}, + bfv::{BfvParameters, Encoding, PlaintextVec, parameters::PlaintextModulus}, }; use fhe_math::rq::{Context, Poly, Representation, traits::TryConvertFrom}; use fhe_traits::{FheDecoder, FheEncoder, FheParametrized, FhePlaintext}; +use num_bigint::{BigInt, BigUint, Sign}; +use num_traits::{ToPrimitive, Zero}; use std::sync::Arc; use zeroize::{Zeroize, Zeroizing}; use super::encoding::EncodingEnum; +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum PlaintextValues { + Small(Box<[u64]>), + Large(Box<[BigUint]>), +} + +impl Zeroize for PlaintextValues { + fn zeroize(&mut self) { + match self { + Self::Small(v) => v.zeroize(), + Self::Large(v) => { + for x in v.iter_mut() { + *x = BigUint::zero(); + } + } + } + } +} + /// A plaintext object, that encodes a vector according to a specific encoding. #[derive(Debug, Clone, Eq)] pub struct Plaintext { /// The parameters of the underlying BFV encryption scheme. pub(crate) par: Arc, /// The value after encoding. - pub(crate) value: Box<[u64]>, + pub(crate) value: PlaintextValues, /// The encoding of the plaintext, if known pub(crate) encoding: Option, /// The plaintext as a polynomial. @@ -27,7 +48,6 @@ pub struct Plaintext { impl Zeroize for Plaintext { fn zeroize(&mut self) { - // Only zeroize the sensitive value and polynomial fields self.value.zeroize(); self.poly_ntt.zeroize(); } @@ -49,14 +69,31 @@ impl FhePlaintext for Plaintext { impl Plaintext { pub(crate) fn to_poly(&self) -> Poly { - let mut m_v = Zeroizing::new(self.value.clone()); let ctx_lvl = self.par.context_level_at(self.level).unwrap(); - self.par - .plaintext - .scalar_mul_vec(&mut m_v, ctx_lvl.cipher_plain_context.q_mod_t); let ctx = &ctx_lvl.poly_context; - let mut m = - Poly::try_convert_from(m_v.as_ref(), ctx, false, Representation::PowerBasis).unwrap(); + + let mut m = match &self.value { + PlaintextValues::Small(v) => { + let mut m_v = Zeroizing::new(v.clone()); + if let PlaintextModulus::Small { modulus, .. } = &self.par.plaintext { + let q_mod_t = ctx_lvl.cipher_plain_context.q_mod_t.to_u64().unwrap(); + modulus.scalar_mul_vec(&mut m_v, q_mod_t); + } else { + unreachable!("PlaintextValues::Small but PlaintextModulus::Large"); + } + Poly::try_convert_from(m_v.as_ref(), ctx, false, Representation::PowerBasis) + .unwrap() + } + PlaintextValues::Large(v) => { + let mut m_v = v.clone(); + self.par + .plaintext + .scalar_mul_vec(&mut m_v, &ctx_lvl.cipher_plain_context.q_mod_t); + Poly::try_convert_from(m_v.as_ref(), ctx, false, Representation::PowerBasis) + .unwrap() + } + }; + m.change_representation(Representation::Ntt); m *= &ctx_lvl.cipher_plain_context.delta; m @@ -66,11 +103,18 @@ impl Plaintext { pub fn zero(encoding: Encoding, par: &Arc) -> Result { let level = encoding.level; let ctx = par.context_at_level(level)?; - let value = vec![0u64; par.degree()]; + let value = match par.plaintext { + PlaintextModulus::Small { .. } => { + PlaintextValues::Small(vec![0u64; par.degree()].into_boxed_slice()) + } + PlaintextModulus::Large(_) => { + PlaintextValues::Large(vec![BigUint::zero(); par.degree()].into_boxed_slice()) + } + }; let poly_ntt = Poly::zero(ctx, Representation::Ntt); Ok(Self { par: par.clone(), - value: value.into_boxed_slice(), + value, encoding: Some(encoding), poly_ntt, level, @@ -86,8 +130,6 @@ impl Plaintext { unsafe impl Send for Plaintext {} -// Implement the equality manually; we want to say that two plaintexts are equal -// even if one of them doesn't store its encoding information. impl PartialEq for Plaintext { fn eq(&self, other: &Self) -> bool { let Self { @@ -109,8 +151,6 @@ impl PartialEq for Plaintext { eq &= value == other_value; eq &= poly_ntt == other_poly_ntt; eq &= level == other_level; - // Compare encoding only if both plaintexts have encoding information. - // This allows comparing plaintexts even when one doesn't store encoding. if encoding.is_some() && other_encoding.is_some() { eq &= encoding == other_encoding; } @@ -139,12 +179,20 @@ impl TryConvertFrom<&Plaintext> for Poly { "Incompatible contexts".to_string(), )) } else { - Poly::try_convert_from( - pt.value.as_ref(), - ctx, - variable_time, - Representation::PowerBasis, - ) + match &pt.value { + PlaintextValues::Small(v) => Poly::try_convert_from( + v.as_ref(), + ctx, + variable_time, + Representation::PowerBasis, + ), + PlaintextValues::Large(v) => Poly::try_convert_from( + v.as_ref(), + ctx, + variable_time, + Representation::PowerBasis, + ), + } } } } @@ -171,6 +219,25 @@ where } } +impl<'a> FheEncoder<&'a [BigUint]> for Plaintext { + type Error = Error; + fn try_encode( + value: &'a [BigUint], + encoding: Encoding, + par: &Arc, + ) -> Result { + if value.len() > par.degree() { + return Err(Error::TooManyValues { + actual: value.len(), + limit: par.degree(), + }); + } + + let v = PlaintextVec::try_encode(value, encoding, par)?; + Ok(v[0].clone()) + } +} + impl<'a> FheEncoder<&'a [u64]> for Plaintext { type Error = Error; fn try_encode(value: &'a [u64], encoding: Encoding, par: &Arc) -> Result { @@ -188,16 +255,42 @@ impl<'a> FheEncoder<&'a [u64]> for Plaintext { impl<'a> FheEncoder<&'a [i64]> for Plaintext { type Error = Error; fn try_encode(value: &'a [i64], encoding: Encoding, par: &Arc) -> Result { - let w = Zeroizing::new(par.plaintext.reduce_vec_i64(value)); - Plaintext::try_encode(w.as_ref() as &[u64], encoding, par) + match &par.plaintext { + PlaintextModulus::Small { modulus: m, .. } => { + let w = Zeroizing::new(m.reduce_vec_i64(value)); + Plaintext::try_encode(w.as_ref() as &[u64], encoding, par) + } + PlaintextModulus::Large(m) => { + let modulus_int = BigInt::from_biguint(Sign::Plus, m.clone()); + let v: Vec = value + .iter() + .map(|&x| { + let mut x_int = BigInt::from(x); + x_int %= &modulus_int; + if x_int < BigInt::zero() { + x_int += &modulus_int; + } + x_int.to_biguint().unwrap() + }) + .collect(); + Plaintext::try_encode(v.as_slice(), encoding, par) + } + } } } -impl FheDecoder for Vec<u64> { - fn try_decode<O>(pt: &Plaintext, encoding: O) -> Result<Vec<u64>> +impl FheDecoder<Plaintext> for Vec<BigUint> { + fn try_decode<O>(pt: &Plaintext, encoding: O) -> Result<Vec<BigUint>> where O: Into<Option<Encoding>>, { + // First convert to Vec<BigUint> regardless of internal storage + let w = match &pt.value { + PlaintextValues::Small(v) => v.iter().map(|&x| BigUint::from(x)).collect::<Vec<_>>(), + PlaintextValues::Large(v) => v.to_vec(), + }; + + // Standard decoding logic (e.g. check encoding match) let encoding = encoding.into(); let enc: Encoding; if pt.encoding.is_none() && encoding.is_none() { @@ -226,19 +319,21 @@ impl FheDecoder<Plaintext> for Vec<u64> { } } - let mut w = pt.value.to_vec(); - match enc.encoding { EncodingEnum::Poly => Ok(w), EncodingEnum::Simd => { if let Some(op) = &pt.par.ntt_operator { - op.forward(&mut w); - let mut w_reordered = w.clone(); + // NTT operator works on u64. + // If ntt_operator exists, it means we are in Small modulus case. + let mut w_u64: Vec<u64> = w.iter().map(|x| x.to_u64().unwrap()).collect(); + op.forward(&mut w_u64); + let mut w_reordered = w_u64.clone(); for i in 0..pt.par.degree() { - w_reordered[i] = w[pt.par.matrix_reps_index_map[i]] + w_reordered[i] = w_u64[pt.par.matrix_reps_index_map[i]] } - w.zeroize(); - Ok(w_reordered) + w_u64.zeroize(); + + Ok(w_reordered.into_iter().map(BigUint::from).collect()) } else { Err(Error::EncodingNotSupported { encoding: EncodingEnum::Simd.to_string(), @@ -248,6 +343,80 @@ impl FheDecoder<Plaintext> for Vec<u64> { } } } + type Error = Error; +} + +impl FheDecoder<Plaintext> for Vec<u64> { + fn try_decode<O>(pt: &Plaintext, encoding: O) -> Result<Vec<u64>> + where + O: Into<Option<Encoding>>, + { + // Optimized path for Small values + match &pt.value { + PlaintextValues::Small(v) => { + // Copied logic for validation + let encoding = encoding.into(); + let enc: Encoding; + if pt.encoding.is_none() && encoding.is_none() { + return Err(Error::InvalidPlaintext { + reason: "No encoding specified".into(), + }); + } else if pt.encoding.is_some() { + enc = pt.encoding.as_ref().unwrap().clone(); + if let Some(arg_enc) = encoding + && arg_enc != enc + { + return Err(Error::EncodingMismatch { + found: arg_enc.into(), + expected: enc.into(), + }); + } + } else { + enc = encoding.unwrap(); + if let Some(pt_enc) = pt.encoding.as_ref() + && pt_enc != &enc + { + return Err(Error::EncodingMismatch { + found: pt_enc.into(), + expected: enc.into(), + }); + } + } + + let mut w = v.to_vec(); + + match enc.encoding { + EncodingEnum::Poly => Ok(w), + EncodingEnum::Simd => { + if let Some(op) = &pt.par.ntt_operator { + op.forward(&mut w); + let mut w_reordered = w.clone(); + for i in 0..pt.par.degree() { + w_reordered[i] = w[pt.par.matrix_reps_index_map[i]] + } + w.zeroize(); + Ok(w_reordered) + } else { + Err(Error::EncodingNotSupported { + encoding: EncodingEnum::Simd.to_string(), + reason: "NTT operator not available".into(), + }) + } + } + } + } + PlaintextValues::Large(_) => { + let v = Vec::<BigUint>::try_decode(pt, encoding)?; + v.iter() + .map(|x| { + x.to_u64().ok_or(Error::DefaultError( + "Plaintext value too large for u64".to_string(), + )) + }) + .collect() + } + } + } type Error = Error; } @@ -257,8 +426,33 @@ impl FheDecoder<Plaintext> for Vec<i64> { where E: Into<Option<Encoding>>, { - let v = Vec::<u64>::try_decode(pt, encoding)?; - Ok(pt.par.plaintext.center_vec(&v)) + match &pt.value { + PlaintextValues::Small(_) => { + let v = Vec::<u64>::try_decode(pt, encoding)?; + if let PlaintextModulus::Small { modulus: m, .. } = &pt.par.plaintext { + Ok(m.center_vec(&v)) + } else { + unreachable!() + } + } + PlaintextValues::Large(_) => { + let v = Vec::<BigUint>::try_decode(pt, encoding)?; + let modulus_big = pt.par.plaintext_big(); + let modulus_int = BigInt::from_biguint(Sign::Plus, modulus_big.clone()); + let half_modulus = modulus_big / 2u32; + + Ok(v.iter() + .map(|x| { + if x >= &half_modulus { + let x_int = BigInt::from_biguint(Sign::Plus, x.clone()); + (x_int - &modulus_int).to_i64().unwrap() + } else { + x.to_i64().unwrap() + } + }) + .collect()) + } + } } type Error = Error; @@ -268,8 +462,11 @@ impl FheDecoder<Plaintext> for Vec<i64> { mod tests { use super::{Encoding, Plaintext}; use crate::bfv::parameters::{BfvParameters, BfvParametersBuilder}; + use crate::bfv::plaintext::PlaintextValues; use fhe_math::rq::{Poly, Representation}; use fhe_traits::{FheDecoder, FheEncoder}; + use num_bigint::BigUint; + use num_traits::Zero; use rand::rng; use std::error::Error; use zeroize::Zeroize; @@ -279,15 +476,23 @@ mod tests { let mut rng = rng(); // The default test parameters support both Poly and Simd encodings let params = BfvParameters::default_arc(1, 16); - let a = params.plaintext.random_vec(params.degree(), &mut rng); + // random_vec returns Vec<u64> + let a = params.plaintext(); + // use modulus directly to generate random u64s + let q = fhe_math::zq::Modulus::new(a).unwrap(); + let a_vec = q.random_vec(params.degree(), &mut rng); let plaintext = Plaintext::try_encode(&[0u64; 17], Encoding::poly(), &params); assert!(plaintext.is_err()); - let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params); + let plaintext = Plaintext::try_encode(&a_vec, Encoding::poly(), &params); assert!(plaintext.is_ok()); + // Verify it used Small variant + if let PlaintextValues::Large(_) = plaintext.unwrap().value { + println!("Expected Small variant"); + } - let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params); + let plaintext = Plaintext::try_encode(&a_vec, Encoding::simd(), &params); assert!(plaintext.is_ok()); let plaintext = Plaintext::try_encode(&[1u64], Encoding::poly(), &params); @@ -300,38 +505,77 @@ mod tests { .set_moduli(&[4611686018326724609]) .build_arc()?; - let a = params.plaintext.random_vec(params.degree(), &mut rng); + let a = 2u64; + let q = fhe_math::zq::Modulus::new(a).unwrap(); + let a_vec = q.random_vec(params.degree(), &mut rng); - let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params); + let plaintext = Plaintext::try_encode(&a_vec, Encoding::poly(), &params); assert!(plaintext.is_ok()); - let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params); + let plaintext = Plaintext::try_encode(&a_vec, Encoding::simd(), &params); assert!(plaintext.is_err()); Ok(()) } + #[test] + fn try_encode_big() -> Result<(), Box<dyn Error>> { + // Test with big plaintext + let p_val = BigUint::parse_bytes(b"340282366920938463463374607431768211507", 10).unwrap(); + let params = BfvParametersBuilder::new() + .set_degree(16) + .set_plaintext_modulus_biguint(p_val.clone()) + .set_moduli_sizes(&[62, 62, 62, 62, 62]) + .build_arc()?; + + let vals = vec![p_val.clone() - 1u32, BigUint::from(123u32)]; + let plaintext = Plaintext::try_encode(&vals, Encoding::poly(), &params)?; + + // Verify it used Large variant + if let PlaintextValues::Small(_) = plaintext.value { + println!("Expected Large variant"); + } + + let decoded: Vec<BigUint> = Vec::<BigUint>::try_decode(&plaintext, Encoding::poly())?; + assert_eq!(decoded[0], p_val - 1u32); + assert_eq!(decoded[1], BigUint::from(123u32)); + assert_eq!(decoded[2], BigUint::zero()); + + Ok(()) + } + #[test] fn encode_decode() -> Result<(), Box<dyn Error>> { let mut rng = rng(); let params = BfvParameters::default_arc(1, 16); - let a = params.plaintext.random_vec(params.degree(), &mut rng); + let a = params.plaintext(); + let q = fhe_math::zq::Modulus::new(a).unwrap(); + let a_vec = q.random_vec(params.degree(), &mut rng); - let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params); + let plaintext = Plaintext::try_encode(&a_vec, Encoding::simd(), &params); assert!(plaintext.is_ok()); let b = Vec::<u64>::try_decode(&plaintext?, Encoding::simd())?; - assert_eq!(b, a); + assert_eq!(b, a_vec); + + // center_vec replacement logic for test + let mut a_signed = vec![]; + for x in &a_vec { + if *x >= a / 2 { + a_signed.push((*x as i64) - (a as i64)); + } else { + a_signed.push(*x as i64); + } + } - let a = params.plaintext.center_vec(&a); - let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params); + let plaintext = Plaintext::try_encode(&a_signed, Encoding::poly(), &params); assert!(plaintext.is_ok()); let b = Vec::<i64>::try_decode(&plaintext?, Encoding::poly())?; - assert_eq!(b, a); + assert_eq!(b, a_signed); - let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params); + let plaintext = Plaintext::try_encode(&a_signed, Encoding::simd(), &params); assert!(plaintext.is_ok()); let b = Vec::<i64>::try_decode(&plaintext?, Encoding::simd())?; - assert_eq!(b, a); + assert_eq!(b, a_signed); Ok(()) } @@ -340,10 +584,12 @@ mod tests { fn partial_eq() -> Result<(), Box<dyn Error>> { let mut rng = rng(); let params = BfvParameters::default_arc(1, 16); - let a = params.plaintext.random_vec(params.degree(), &mut rng); + let a = params.plaintext(); + let q = fhe_math::zq::Modulus::new(a).unwrap(); + let a_vec = q.random_vec(params.degree(), &mut rng); - let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?; - let mut same_plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?; + let plaintext = Plaintext::try_encode(&a_vec, Encoding::poly(), &params)?; + let mut same_plaintext = Plaintext::try_encode(&a_vec, Encoding::poly(), &params)?; assert_eq!(plaintext, same_plaintext); // Equality also holds when there is no encoding specified. In this test, we use @@ -360,9 +606,11 @@ mod tests { fn try_decode_errors() -> Result<(), Box<dyn Error>> { let mut rng = rng(); let params = BfvParameters::default_arc(1, 16); - let a = params.plaintext.random_vec(params.degree(), &mut rng); + let a = params.plaintext(); + let q = fhe_math::zq::Modulus::new(a).unwrap(); + let a_vec = q.random_vec(params.degree(), &mut rng); - let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?; + let mut plaintext = Plaintext::try_encode(&a_vec, Encoding::poly(), &params)?; assert!(Vec::<u64>::try_decode(&plaintext, None).is_ok()); let e = Vec::<u64>::try_decode(&plaintext, Encoding::simd()); @@ -402,7 +650,10 @@ mod tests { let params = BfvParameters::default_arc(1, 16); let plaintext = Plaintext::zero(Encoding::poly(), &params)?; - assert_eq!(plaintext.value, Box::<[u64]>::from([0u64; 16])); + assert_eq!( + plaintext.value, + PlaintextValues::Small(vec![0u64; 16].into_boxed_slice()) + ); assert_eq!( plaintext.poly_ntt, Poly::zero(params.context_at_level(0)?, Representation::Ntt) @@ -415,8 +666,10 @@ mod tests { fn zeroize() -> Result<(), Box<dyn Error>> { let mut rng = rng(); let params = BfvParameters::default_arc(1, 16); - let a = params.plaintext.random_vec(params.degree(), &mut rng); - let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?; + let a = params.plaintext(); + let q = fhe_math::zq::Modulus::new(a).unwrap(); + let a_vec = q.random_vec(params.degree(), &mut rng); + let mut plaintext = Plaintext::try_encode(&a_vec, Encoding::poly(), &params)?; plaintext.zeroize(); @@ -430,12 +683,14 @@ mod tests { let mut rng = rng(); // The default test parameters support both Poly and Simd encodings let params = BfvParameters::default_arc(10, 16); - let a = params.plaintext.random_vec(params.degree(), &mut rng); + let a = params.plaintext(); + let q = fhe_math::zq::Modulus::new(a).unwrap(); + let a_vec = q.random_vec(params.degree(), &mut rng); for level in 0..10 { - let plaintext = Plaintext::try_encode(&a, Encoding::poly_at_level(level), &params)?; + let plaintext = Plaintext::try_encode(&a_vec, Encoding::poly_at_level(level), &params)?; assert_eq!(plaintext.level(), level); - let plaintext = Plaintext::try_encode(&a, Encoding::simd_at_level(level), &params)?; + let plaintext = Plaintext::try_encode(&a_vec, Encoding::simd_at_level(level), &params)?; assert_eq!(plaintext.level(), level); } diff --git a/crates/fhe/src/bfv/plaintext_vec.rs b/crates/fhe/src/bfv/plaintext_vec.rs index 15f47baa..e4ebe06b 100644 --- a/crates/fhe/src/bfv/plaintext_vec.rs +++ b/crates/fhe/src/bfv/plaintext_vec.rs @@ -2,11 +2,13 @@ use std::{cmp::min, ops::Deref, sync::Arc}; use fhe_math::rq::{Poly, Representation, traits::TryConvertFrom}; use fhe_traits::{FheEncoder, FheEncoderVariableTime, FheParametrized, FhePlaintext}; +use num_bigint::BigUint; +use num_traits::{ToPrimitive, Zero}; use zeroize_derive::{Zeroize, ZeroizeOnDrop}; use crate::{ Error, Result, - bfv::{BfvParameters, Encoding, Plaintext}, + bfv::{BfvParameters, Encoding, Plaintext, PlaintextValues}, }; use super::encoding::EncodingEnum; @@ -75,9 +77,97 @@ impl FheEncoderVariableTime<&[u64]> for PlaintextVec { Poly::try_convert_from(&v, ctx, true, Representation::PowerBasis)?; poly.change_representation(Representation::Ntt); + let value_enum = match par.plaintext { + crate::bfv::PlaintextModulus::Small { .. } => { + PlaintextValues::Small(v.into_boxed_slice()) + } + crate::bfv::PlaintextModulus::Large(_) => PlaintextValues::Large( + v.iter() + .map(|&x| BigUint::from(x)) + .collect::<Vec<_>>() + .into_boxed_slice(), + ), + }; + + Ok(Plaintext { + par: par.clone(), + value: value_enum, + encoding: Some(encoding.clone()), + poly_ntt: poly, + level: encoding.level, + }) + }) + .collect::<Result<Vec<Plaintext>>>()?, + )) + } +} + +impl FheEncoder<&[BigUint]> for PlaintextVec { + type Error = Error; + fn try_encode(value: &[BigUint], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> { + if value.is_empty() { + return Ok(PlaintextVec(vec![Plaintext::zero(encoding, par)?])); + } + if encoding.encoding == EncodingEnum::Simd && par.ntt_operator.is_none() { + return Err(Error::EncodingNotSupported { + encoding: EncodingEnum::Simd.to_string(), + reason: "NTT operator not available".into(), + }); + } + let ctx = par.context_at_level(encoding.level)?; + let num_plaintexts = value.len().div_ceil(par.degree()); + + Ok(PlaintextVec( + (0..num_plaintexts) + .map(|i| { + let slice = &value[i * par.degree()..min(value.len(), (i + 1) * par.degree())]; + let mut v = vec![BigUint::zero(); par.degree()]; + match encoding.encoding { + EncodingEnum::Poly => v[..slice.len()].clone_from_slice(slice), + EncodingEnum::Simd => { + let mut v_u64 = vec![0u64; par.degree()]; + for i in 0..slice.len() { + v_u64[par.matrix_reps_index_map[i]] = + slice[i].to_u64().ok_or(Error::DefaultError( + "Value too large for SIMD encoding".to_string(), + ))?; + } + par.ntt_operator + .as_ref() + .ok_or(Error::InvalidPlaintext { + reason: "No Ntt operator".into(), + })? + .backward(&mut v_u64); + + v = v_u64.into_iter().map(BigUint::from).collect(); + } + }; + + let mut poly = Poly::try_convert_from( + v.as_slice(), + ctx, + false, + Representation::PowerBasis, + )?; + poly.change_representation(Representation::Ntt); + + let value_enum = match &par.plaintext { + crate::bfv::PlaintextModulus::Small { modulus_big, .. } => { + PlaintextValues::Small( + v.iter() + .map(|x| (x % modulus_big).to_u64().unwrap()) + .collect::<Vec<_>>() + .into_boxed_slice(), + ) + } + crate::bfv::PlaintextModulus::Large(_) => { + PlaintextValues::Large(v.into_boxed_slice()) + } + }; + Ok(Plaintext { par: par.clone(), - value: v.into(), + value: value_enum, encoding: Some(encoding.clone()), poly_ntt: poly, level: encoding.level, @@ -127,9 +217,21 @@ impl FheEncoder<&[u64]> for PlaintextVec { Poly::try_convert_from(&v, ctx, false, Representation::PowerBasis)?; poly.change_representation(Representation::Ntt); + let value_enum = match par.plaintext { + crate::bfv::PlaintextModulus::Small { .. } => { + PlaintextValues::Small(v.into_boxed_slice()) + } + crate::bfv::PlaintextModulus::Large(_) => PlaintextValues::Large( + v.iter() + .map(|&x| BigUint::from(x)) + .collect::<Vec<_>>() + .into_boxed_slice(), + ), + }; + Ok(Plaintext { par: par.clone(), - value: v.into(), + value: value_enum, encoding: Some(encoding.clone()), poly_ntt: poly, level: encoding.level, @@ -153,18 +255,28 @@ mod tests { for _ in 0..20 { for i in 1..5 { let params = BfvParameters::default_arc(1, 16); - let a = params.plaintext.random_vec(params.degree() * i, &mut rng); + let a = params.plaintext(); + let q = fhe_math::zq::Modulus::new(a).unwrap(); + let a_vec = q.random_vec(params.degree() * i, &mut rng); - let plaintexts = PlaintextVec::try_encode(&a, Encoding::poly_at_level(0), &params)?; + let plaintexts = PlaintextVec::try_encode( + a_vec.as_slice(), + Encoding::poly_at_level(0), + &params, + )?; assert_eq!(plaintexts.0.len(), i); for j in 0..i { let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::poly_at_level(0))?; - assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]); + assert_eq!(b, &a_vec[j * params.degree()..(j + 1) * params.degree()]); } let plaintexts_vt = unsafe { - PlaintextVec::try_encode_vt(&a, Encoding::poly_at_level(0), &params)? + PlaintextVec::try_encode_vt( + a_vec.as_slice(), + Encoding::poly_at_level(0), + &params, + )? }; assert_eq!(plaintexts_vt.0.len(), i); for (pt, pt_vt) in plaintexts.0.iter().zip(plaintexts_vt.0.iter()) { @@ -174,19 +286,21 @@ mod tests { for j in 0..i { let b = Vec::<u64>::try_decode(&plaintexts_vt.0[j], Encoding::poly_at_level(0))?; - assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]); + assert_eq!(b, &a_vec[j * params.degree()..(j + 1) * params.degree()]); } - let plaintexts = PlaintextVec::try_encode(&a, Encoding::simd(), &params)?; + let plaintexts = + PlaintextVec::try_encode(a_vec.as_slice(), Encoding::simd(), &params)?; assert_eq!(plaintexts.0.len(), i); for j in 0..i { let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::simd())?; - assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]); + assert_eq!(b, &a_vec[j * params.degree()..(j + 1) * params.degree()]); } - let plaintexts_vt = - unsafe { PlaintextVec::try_encode_vt(&a, Encoding::simd(), &params)? }; + let plaintexts_vt = unsafe { + PlaintextVec::try_encode_vt(a_vec.as_slice(), Encoding::simd(), &params)? + }; assert_eq!(plaintexts_vt.0.len(), i); for (pt, pt_vt) in plaintexts.0.iter().zip(plaintexts_vt.0.iter()) { assert_eq!(pt.value, pt_vt.value); @@ -194,7 +308,7 @@ mod tests { for j in 0..i { let b = Vec::<u64>::try_decode(&plaintexts_vt.0[j], Encoding::simd())?; - assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]); + assert_eq!(b, &a_vec[j * params.degree()..(j + 1) * params.degree()]); } } } @@ -205,11 +319,11 @@ mod tests { .build_arc()?; let a = vec![1u64]; assert!(matches!( - PlaintextVec::try_encode(&a, Encoding::simd(), &params), + PlaintextVec::try_encode(a.as_slice(), Encoding::simd(), &params), Err(crate::Error::EncodingNotSupported { .. }) )); assert!(matches!( - unsafe { PlaintextVec::try_encode_vt(&a, Encoding::simd(), &params) }, + unsafe { PlaintextVec::try_encode_vt(a.as_slice(), Encoding::simd(), &params) }, Err(crate::Error::EncodingNotSupported { .. }) )); Ok(()) diff --git a/crates/fhe/src/bfv/rgsw_ciphertext.rs b/crates/fhe/src/bfv/rgsw_ciphertext.rs index 8e9e4581..374fc06d 100644 --- a/crates/fhe/src/bfv/rgsw_ciphertext.rs +++ b/crates/fhe/src/bfv/rgsw_ciphertext.rs @@ -187,8 +187,12 @@ mod tests { BfvParameters::default_arc(8, 16), ] { let sk = SecretKey::random(&params, &mut rng); - let v1 = params.plaintext.random_vec(params.degree(), &mut rng); - let v2 = params.plaintext.random_vec(params.degree(), &mut rng); + let v1 = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); + let v2 = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); let pt1 = Plaintext::try_encode(&v1, Encoding::simd(), &params)?; let pt2 = Plaintext::try_encode(&v2, Encoding::simd(), &params)?; @@ -219,7 +223,9 @@ mod tests { BfvParameters::default_arc(5, 16), ] { let sk = SecretKey::random(&params, &mut rng); - let v = params.plaintext.random_vec(params.degree(), &mut rng); + let v = fhe_math::zq::Modulus::new(params.plaintext()) + .unwrap() + .random_vec(params.degree(), &mut rng); let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?; let ct: RGSWCiphertext = sk.try_encrypt(&pt, &mut rng)?; diff --git a/crates/fhe/src/mbfv/public_key_gen.rs b/crates/fhe/src/mbfv/public_key_gen.rs index 0da3c59c..3915aa80 100644 --- a/crates/fhe/src/mbfv/public_key_gen.rs +++ b/crates/fhe/src/mbfv/public_key_gen.rs @@ -124,7 +124,9 @@ mod tests { // Use it to encrypt a random polynomial let pt = Plaintext::try_encode( - &par.plaintext.random_vec(par.degree(), &mut rng), + &fhe_math::zq::Modulus::new(par.plaintext()) + .unwrap() + .random_vec(par.degree(), &mut rng), Encoding::poly_at_level(level), &par, ) diff --git a/crates/fhe/src/mbfv/public_key_switch.rs b/crates/fhe/src/mbfv/public_key_switch.rs index 4c3081da..f7c711b4 100644 --- a/crates/fhe/src/mbfv/public_key_switch.rs +++ b/crates/fhe/src/mbfv/public_key_switch.rs @@ -160,7 +160,9 @@ mod tests { // Use it to encrypt a random polynomial ct1 let pt1 = Plaintext::try_encode( - &par.plaintext.random_vec(par.degree(), &mut rng), + &fhe_math::zq::Modulus::new(par.plaintext()) + .unwrap() + .random_vec(par.degree(), &mut rng), Encoding::poly_at_level(level), &par, ) diff --git a/crates/fhe/src/mbfv/relin_key_gen.rs b/crates/fhe/src/mbfv/relin_key_gen.rs index ee36b1c9..b5c60455 100644 --- a/crates/fhe/src/mbfv/relin_key_gen.rs +++ b/crates/fhe/src/mbfv/relin_key_gen.rs @@ -440,8 +440,12 @@ mod tests { .unwrap(); // Create a couple random encrypted polynomials - let v1 = par.plaintext.random_vec(par.degree(), &mut rng); - let v2 = par.plaintext.random_vec(par.degree(), &mut rng); + let v1 = fhe_math::zq::Modulus::new(par.plaintext()) + .unwrap() + .random_vec(par.degree(), &mut rng); + let v2 = fhe_math::zq::Modulus::new(par.plaintext()) + .unwrap() + .random_vec(par.degree(), &mut rng); let pt1 = Plaintext::try_encode(&v1, Encoding::simd_at_level(level), &par).unwrap(); let pt2 = Plaintext::try_encode(&v2, Encoding::simd_at_level(level), &par).unwrap(); let ct1 = public_key.try_encrypt(&pt1, &mut rng).unwrap(); @@ -463,7 +467,9 @@ mod tests { .unwrap(); let mut expected = v1.clone(); - par.plaintext.mul_vec(&mut expected, &v2); + fhe_math::zq::Modulus::new(par.plaintext()) + .unwrap() + .mul_vec(&mut expected, &v2); assert_eq!( Vec::<u64>::try_decode(&pt, Encoding::simd_at_level(pt.level)).unwrap(), expected diff --git a/crates/fhe/src/mbfv/secret_key_switch.rs b/crates/fhe/src/mbfv/secret_key_switch.rs index 1064926b..14c83c83 100644 --- a/crates/fhe/src/mbfv/secret_key_switch.rs +++ b/crates/fhe/src/mbfv/secret_key_switch.rs @@ -1,14 +1,13 @@ use std::sync::Arc; -use fhe_math::{ - rq::{Poly, Representation, traits::TryConvertFrom}, - zq::Modulus, -}; +use fhe_math::rq::{Poly, Representation, traits::TryConvertFrom}; use itertools::Itertools; +use num_bigint::BigUint; +use num_traits::ToPrimitive; use rand::{CryptoRng, RngCore}; use zeroize::Zeroizing; -use crate::bfv::{BfvParameters, Ciphertext, Plaintext, SecretKey}; +use crate::bfv::{BfvParameters, Ciphertext, Plaintext, PlaintextValues, SecretKey}; use crate::{Error, Result}; use super::Aggregate; @@ -157,23 +156,35 @@ impl Aggregate<DecryptionShare> for Plaintext { // The true decryption part is done during SKS; all that is left is to scale let ctx_lvl = ct.par.context_level_at(ct.level)?; let d = Zeroizing::new(c.scale(&ctx_lvl.cipher_plain_context.scaler)?); - let v = Zeroizing::new( - Vec::<u64>::try_from(d.as_ref())? - .iter_mut() - .map(|vi| *vi + *ct.par.plaintext) - .collect_vec(), - ); + + let v: Vec<BigUint> = Vec::<BigUint>::from(d.as_ref()) + .into_iter() + .map(|vi| vi + ct.par.plaintext_big()) + .collect_vec(); + let mut w = v[..ct.par.degree()].to_vec(); - let q = Modulus::new(ct.par.moduli[0]).map_err(Error::MathError)?; - q.reduce_vec(&mut w); + let q_poly = d.as_ref().ctx().modulus(); + w.iter_mut().for_each(|wi| *wi %= q_poly); + ct.par.plaintext.reduce_vec(&mut w); - let mut poly = Poly::try_convert_from(&w, ct[0].ctx(), false, Representation::PowerBasis)?; + let mut poly = + Poly::try_convert_from(w.as_slice(), ct[0].ctx(), false, Representation::PowerBasis)?; poly.change_representation(Representation::Ntt); + let value = match ct.par.plaintext { + crate::bfv::PlaintextModulus::Small { .. } => PlaintextValues::Small( + w.iter() + .map(|x| x.to_u64().unwrap()) + .collect::<Vec<_>>() + .into_boxed_slice(), + ), + crate::bfv::PlaintextModulus::Large(_) => PlaintextValues::Large(w.into_boxed_slice()), + }; + let pt = Plaintext { par: ct.par.clone(), - value: w.into_boxed_slice(), + value, encoding: None, poly_ntt: poly, level: ct.level, @@ -232,8 +243,9 @@ mod tests { .unwrap(); // Use it to encrypt a random polynomial + let q = fhe_math::zq::Modulus::new(par.plaintext()).unwrap(); let pt1 = Plaintext::try_encode( - &par.plaintext.random_vec(par.degree(), &mut rng), + &q.random_vec(par.degree(), &mut rng), Encoding::poly_at_level(level), &par, ) @@ -276,8 +288,9 @@ mod tests { PublicKey::from_shares(parties.iter().map(|p| p.pk_share.clone())).unwrap(); // Use it to encrypt a random polynomial ct1 + let q = fhe_math::zq::Modulus::new(par.plaintext()).unwrap(); let pt1 = Plaintext::try_encode( - &par.plaintext.random_vec(par.degree(), &mut rng), + &q.random_vec(par.degree(), &mut rng), Encoding::poly_at_level(level), &par, ) @@ -347,10 +360,11 @@ mod tests { .unwrap(); // Parties encrypt two plaintexts - let a = par.plaintext.random_vec(par.degree(), &mut rng); - let b = par.plaintext.random_vec(par.degree(), &mut rng); + let q = fhe_math::zq::Modulus::new(par.plaintext()).unwrap(); + let a = q.random_vec(par.degree(), &mut rng); + let b = q.random_vec(par.degree(), &mut rng); let mut expected = a.clone(); - par.plaintext.add_vec(&mut expected, &b); + q.add_vec(&mut expected, &b); let pt_a = Plaintext::try_encode(&a, Encoding::poly_at_level(level), &par).unwrap(); diff --git a/crates/fhe/src/proto/bfv.proto b/crates/fhe/src/proto/bfv.proto index b8511232..ce6e6ba7 100644 --- a/crates/fhe/src/proto/bfv.proto +++ b/crates/fhe/src/proto/bfv.proto @@ -40,7 +40,10 @@ message EvaluationKey { message Parameters { uint32 degree = 1; repeated uint64 moduli = 2; - uint64 plaintext = 3; + oneof plaintext_modulus { + uint64 plaintext = 3; + bytes plaintext_big = 5; + } uint32 variance = 4; } diff --git a/crates/fhe/src/proto/bfv.rs b/crates/fhe/src/proto/bfv.rs index 0c439d95..7dbfb51e 100644 --- a/crates/fhe/src/proto/bfv.rs +++ b/crates/fhe/src/proto/bfv.rs @@ -1,75 +1,3 @@ #![expect(missing_docs, reason = "prost-generated types omit docs")] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Ciphertext { - #[prost(bytes = "vec", repeated, tag = "1")] - pub c: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec<u8>>, - #[prost(bytes = "vec", tag = "2")] - pub seed: ::prost::alloc::vec::Vec<u8>, - #[prost(uint32, tag = "3")] - pub level: u32, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct RgswCiphertext { - #[prost(message, optional, tag = "1")] - pub ksk0: ::core::option::Option<KeySwitchingKey>, - #[prost(message, optional, tag = "2")] - pub ksk1: ::core::option::Option<KeySwitchingKey>, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct KeySwitchingKey { - #[prost(bytes = "vec", repeated, tag = "1")] - pub c0: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec<u8>>, - #[prost(bytes = "vec", repeated, tag = "2")] - pub c1: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec<u8>>, - #[prost(bytes = "vec", tag = "3")] - pub seed: ::prost::alloc::vec::Vec<u8>, - #[prost(uint32, tag = "4")] - pub ciphertext_level: u32, - #[prost(uint32, tag = "5")] - pub ksk_level: u32, - #[prost(uint32, tag = "6")] - pub log_base: u32, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct RelinearizationKey { - #[prost(message, optional, tag = "1")] - pub ksk: ::core::option::Option<KeySwitchingKey>, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct GaloisKey { - #[prost(message, optional, tag = "1")] - pub ksk: ::core::option::Option<KeySwitchingKey>, - #[prost(uint32, tag = "2")] - pub exponent: u32, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct EvaluationKey { - #[prost(message, repeated, tag = "2")] - pub gk: ::prost::alloc::vec::Vec<GaloisKey>, - #[prost(uint32, tag = "3")] - pub ciphertext_level: u32, - #[prost(uint32, tag = "4")] - pub evaluation_key_level: u32, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Parameters { - #[prost(uint32, tag = "1")] - pub degree: u32, - #[prost(uint64, repeated, tag = "2")] - pub moduli: ::prost::alloc::vec::Vec<u64>, - #[prost(uint64, tag = "3")] - pub plaintext: u64, - #[prost(uint32, tag = "4")] - pub variance: u32, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PublicKey { - #[prost(message, optional, tag = "1")] - pub c: ::core::option::Option<Ciphertext>, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SecretKey { - #[prost(sint64, repeated, tag = "1")] - pub coeffs: ::prost::alloc::vec::Vec<i64>, -} +include!(concat!(env!("OUT_DIR"), "/fhers.bfv.rs")); diff --git a/crates/fhe/tests/biguint.rs b/crates/fhe/tests/biguint.rs new file mode 100644 index 00000000..580e090d --- /dev/null +++ b/crates/fhe/tests/biguint.rs @@ -0,0 +1,206 @@ +#![allow(missing_docs, clippy::indexing_slicing)] +use fhe::bfv::{ + BfvParameters, BfvParametersBuilder, Ciphertext, Encoding, Plaintext, RelinearizationKey, + SecretKey, +}; +use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder as _, FheEncrypter}; +use num_bigint::BigUint; +use rand::rng; +use std::{error::Error, sync::Arc}; + +fn parameters() -> Arc<BfvParameters> { + // Choose a large plaintext modulus: 2^127 - 1 (Mersenne prime M127) + // 170141183460469231731687303715884105727 + let p_str = "170141183460469231731687303715884105727"; + let p = BigUint::parse_bytes(p_str.as_bytes(), 10).unwrap(); + + // Create parameters + BfvParametersBuilder::new() + .set_degree(16) + .set_plaintext_modulus_biguint(p.clone()) + .set_moduli_sizes(&[60, 60, 60, 60, 60]) + .build_arc() + .unwrap() +} + +#[test] +fn test_biguint_plaintext_encryption_decryption() -> Result<(), Box<dyn Error>> { + let mut rng = rng(); + + let params = parameters(); + let sk = SecretKey::random(&params, &mut rng); + + // Create a vector of BigUint values + let mut values = vec![BigUint::from(0u32); params.degree()]; + values[0] = BigUint::from(123456789u64); + values[1] = params.plaintext_big() - 1u32; // -1 + values[2] = params.plaintext_big() / 2u32; + + let pt = Plaintext::try_encode(values.as_slice(), Encoding::poly(), &params)?; + + let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; + + let decrypted_pt = sk.try_decrypt(&ct)?; + + // Decode + let decrypted_values: Vec<BigUint> = + Vec::<BigUint>::try_decode(&decrypted_pt, Encoding::poly())?; + + assert_eq!(decrypted_values, values); + + Ok(()) +} + +#[test] +fn test_biguint_homomorphic_addition() -> Result<(), Box<dyn Error>> { + let mut rng = rng(); + + let params = parameters(); + let sk = SecretKey::random(&params, &mut rng); + + let val1 = BigUint::from(10u32); + let val2 = params.plaintext_big() - 50u32; // -50 + + let mut vec1 = vec![BigUint::from(0u32); params.degree()]; + vec1[0] = val1.clone(); + + let mut vec2 = vec![BigUint::from(0u32); params.degree()]; + vec2[0] = val2.clone(); + + let pt1 = Plaintext::try_encode(vec1.as_slice(), Encoding::poly(), &params)?; + let pt2 = Plaintext::try_encode(vec2.as_slice(), Encoding::poly(), &params)?; + + let ct1: Ciphertext = sk.try_encrypt(&pt1, &mut rng)?; + let ct2: Ciphertext = sk.try_encrypt(&pt2, &mut rng)?; + + let ct_res = &ct1 + &ct2; + + let decrypted_pt = sk.try_decrypt(&ct_res)?; + let decrypted_values: Vec<BigUint> = + Vec::<BigUint>::try_decode(&decrypted_pt, Encoding::poly())?; + + // 10 + (-50) = -40 + assert_eq!( + decrypted_values[0], + params.plaintext_big() - BigUint::from(40u32) + ); + + Ok(()) +} + +#[test] +fn test_biguint_multiplication_without_relin() -> Result<(), Box<dyn Error>> { + let mut rng = rng(); + + let params = parameters(); + let sk = SecretKey::random(&params, &mut rng); + + let val1 = BigUint::from(10u32); + let val2 = params.plaintext_big() - BigUint::from(20u32); + + let mut vec1 = vec![BigUint::from(0u32); params.degree()]; + vec1[0] = val1.clone(); + + let mut vec2 = vec![BigUint::from(0u32); params.degree()]; + vec2[0] = val2.clone(); + + let pt1 = Plaintext::try_encode(vec1.as_slice(), Encoding::poly(), &params)?; + let pt2 = Plaintext::try_encode(vec2.as_slice(), Encoding::poly(), &params)?; + + let ct1: Ciphertext = sk.try_encrypt(&pt1, &mut rng)?; + let ct2: Ciphertext = sk.try_encrypt(&pt2, &mut rng)?; + + let ct_res = &ct1 * &ct2; + + assert_eq!(ct_res.len(), 3); // Degree increases + + let decrypted_pt = sk.try_decrypt(&ct_res)?; + let decrypted_values: Vec<BigUint> = + Vec::<BigUint>::try_decode(&decrypted_pt, Encoding::poly())?; + + // 10 * (-20) = -200 + assert_eq!( + decrypted_values[0], + params.plaintext_big() - BigUint::from(200u32) + ); + + Ok(()) +} + +#[test] +fn test_biguint_multiplication_with_relin() -> Result<(), Box<dyn Error>> { + let mut rng = rng(); + + // Use default parameters with biguint + let params = BfvParametersBuilder::new() + .set_degree(16) + .set_plaintext_modulus(1153) + .set_moduli_sizes(&[62usize; 3]) + .build_arc() + .unwrap(); + let sk = SecretKey::random(&params, &mut rng); + let rk = RelinearizationKey::new(&sk, &mut rng)?; + + let val1 = BigUint::from(10u32); + let val2 = params.plaintext_big() - BigUint::from(20u32); + + let mut vec1 = vec![BigUint::from(0u32); params.degree()]; + vec1[0] = val1.clone(); + + let mut vec2 = vec![BigUint::from(0u32); params.degree()]; + vec2[0] = val2.clone(); + + let pt1 = Plaintext::try_encode(vec1.as_slice(), Encoding::poly(), &params)?; + let pt2 = Plaintext::try_encode(vec2.as_slice(), Encoding::poly(), &params)?; + + let ct1: Ciphertext = sk.try_encrypt(&pt1, &mut rng)?; + let ct2: Ciphertext = sk.try_encrypt(&pt2, &mut rng)?; + + let mut ct_res = &ct1 * &ct2; + rk.relinearizes(&mut ct_res)?; + + assert_eq!(ct_res.len(), 2); // Degree reduced + + let decrypted_pt = sk.try_decrypt(&ct_res)?; + let decrypted_values: Vec<BigUint> = + Vec::<BigUint>::try_decode(&decrypted_pt, Encoding::poly())?; + + // 10 * (-20) = -200 + assert_eq!( + decrypted_values[0], + params.plaintext_big() - BigUint::from(200u32) + ); + + Ok(()) +} + +#[test] +fn test_small_modulus_with_biguint_input() -> Result<(), Box<dyn Error>> { + let mut rng = rng(); + // Standard small modulus parameters + let params = BfvParametersBuilder::new() + .set_degree(16) + .set_plaintext_modulus(1153) + .set_moduli_sizes(&[62usize; 1]) + .build_arc() + .unwrap(); + let sk = SecretKey::random(&params, &mut rng); + + // Let's just pick a value larger than t, but small enough to verify reduction. + // t = 1153 (default for default_arc(1, 16) in parameters.rs) + let t = params.plaintext(); + let val = BigUint::from(t) + 5u32; // Should reduce to 5 + + let mut values = vec![BigUint::from(0u32); params.degree()]; + values[0] = val.clone(); + + let pt = Plaintext::try_encode(values.as_slice(), Encoding::poly(), &params)?; + let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; + let decrypted_pt = sk.try_decrypt(&ct)?; + + let decrypted_values: Vec<u64> = Vec::<u64>::try_decode(&decrypted_pt, Encoding::poly())?; + + assert_eq!(decrypted_values[0], 5); + + Ok(()) +}