diff --git a/README.md b/README.md index 82a63685..b6fa0131 100644 --- a/README.md +++ b/README.md @@ -28,8 +28,8 @@ To install, add the following to your project's `Cargo.toml` file: ```toml [dependencies] -fhe = "0.2" -fhe-traits = "0.2" +fhe = "0.2.0" +fhe-traits = "0.1.1" ``` ## Minimum supported version / toolchain diff --git a/crates/fhe-math/benches/rq.rs b/crates/fhe-math/benches/rq.rs index 7f716d55..f606b9a7 100644 --- a/crates/fhe-math/benches/rq.rs +++ b/crates/fhe-math/benches/rq.rs @@ -5,14 +5,10 @@ )] use criterion::measurement::WallTime; use criterion::{BenchmarkGroup, BenchmarkId, Criterion, criterion_group, criterion_main}; -use fhe_math::rq::{traits::TryConvertFrom, *}; +use fhe_math::rq::{Context, Ntt, NttShoup, Poly, PowerBasis, dot_product, traits::TryConvertFrom}; use itertools::{Itertools, izip}; use rand::rng; -use std::{ - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, - sync::Arc, - time::Duration, -}; +use std::{sync::Arc, time::Duration}; static MODULI: &[u64; 4] = &[ 562949954093057, @@ -42,8 +38,8 @@ macro_rules! bench_op { for degree in DEGREE { let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap()); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p = Poly::::random(&ctx, &mut rng); + let mut q = Poly::::random(&ctx, &mut rng); if $vt { unsafe { q.allow_variable_time_computations() } } @@ -70,8 +66,8 @@ macro_rules! bench_op_unary { for degree in DEGREE { let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap()); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p = Poly::::random(&ctx, &mut rng); + let mut q = Poly::::random(&ctx, &mut rng); if $vt { unsafe { q.allow_variable_time_computations() } } @@ -98,8 +94,8 @@ macro_rules! bench_op_assign { for degree in DEGREE { let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap()); - let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); + let mut p = Poly::::random(&ctx, &mut rng); + let mut q = Poly::::random(&ctx, &mut rng); if $vt { unsafe { q.allow_variable_time_computations() } } @@ -116,13 +112,28 @@ macro_rules! bench_op_assign { pub fn rq_op_benchmark(c: &mut Criterion) { for vt in [false, true] { - bench_op!(c, "rq_add", <&Poly>::add, vt); - bench_op_assign!(c, "rq_add_assign", Poly::add_assign, vt); - bench_op!(c, "rq_sub", <&Poly>::sub, vt); - bench_op_assign!(c, "rq_sub_assign", Poly::sub_assign, vt); - bench_op!(c, "rq_mul", <&Poly>::mul, vt); - bench_op_assign!(c, "rq_mul_assign", Poly::mul_assign, vt); - bench_op_unary!(c, "rq_neg", <&Poly>::neg, vt); + bench_op!(c, "rq_add", |p, q| p + q, vt); + bench_op_assign!( + c, + "rq_add_assign", + |p: &mut Poly, q: &Poly| *p += q, + vt + ); + bench_op!(c, "rq_sub", |p, q| p - q, vt); + bench_op_assign!( + c, + "rq_sub_assign", + |p: &mut Poly, q: &Poly| *p -= q, + vt + ); + bench_op!(c, "rq_mul", |p, q| p * q, vt); + bench_op_assign!( + c, + "rq_mul_assign", + |p: &mut Poly, q: &Poly| *p *= q, + vt + ); + bench_op_unary!(c, "rq_neg", |p: &Poly<_>| -p, vt); } } @@ -133,12 +144,12 @@ pub fn rq_dot_product(c: &mut Criterion) { for i in [1, 4] { let ctx = Arc::new(Context::new(&MODULI[..i], *degree).unwrap()); let p_vec = (0..256) - .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) + .map(|_| Poly::::random(&ctx, &mut rng)) .collect_vec(); - let mut q_vec = (0..256) - .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) + let q_vec = (0..256) + .map(|_| Poly::::random(&ctx, &mut rng)) .collect_vec(); - let mut out = Poly::zero(&ctx, Representation::Ntt); + let mut out = Poly::::zero(&ctx); group.bench_function( BenchmarkId::from_parameter(format!("naive/{}/{}", degree, ctx.modulus().bits())), @@ -149,9 +160,11 @@ pub fn rq_dot_product(c: &mut Criterion) { }, ); - q_vec - .iter_mut() - .for_each(|qi| qi.change_representation(Representation::NttShoup)); + let q_vec_shoup = q_vec + .iter() + .cloned() + .map(Poly::::into_ntt_shoup) + .collect_vec(); group.bench_function( BenchmarkId::from_parameter(format!( "naive_shoup/{}/{}", @@ -160,14 +173,12 @@ pub fn rq_dot_product(c: &mut Criterion) { )), |b| { b.iter(|| { - izip!(p_vec.iter(), q_vec.iter()).for_each(|(pi, qi)| out += &(pi * qi)) + izip!(p_vec.iter(), q_vec_shoup.iter()) + .for_each(|(pi, qi)| out += &(pi * qi)) }); }, ); - q_vec - .iter_mut() - .for_each(|qi| qi.change_representation(Representation::Ntt)); group.bench_function( BenchmarkId::from_parameter(format!("opt/{}/{}", degree, ctx.modulus().bits())), |b| { @@ -190,9 +201,8 @@ pub fn rq_benchmark(c: &mut Criterion) { continue; } let ctx = Arc::new(Context::new(&MODULI[..nmoduli], *degree).unwrap()); - let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); - q.change_representation(Representation::NttShoup); + let mut p = Poly::::random(&ctx, &mut rng); + let q = Poly::::random(&ctx, &mut rng); group.bench_function( BenchmarkId::new("mul_shoup", format!("{}/{}", degree, ctx.modulus().bits())), @@ -211,6 +221,7 @@ pub fn rq_benchmark(c: &mut Criterion) { }, ); + let p_pb = Poly::::random(&ctx, &mut rng); group.bench_function( BenchmarkId::new( "change_representation/PowerBasis_to_Ntt", @@ -218,14 +229,12 @@ pub fn rq_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - unsafe { - p.override_representation(Representation::PowerBasis); - } - p.change_representation(Representation::Ntt) + let _ = p_pb.clone().into_ntt(); }); }, ); + let p_ntt = Poly::::random(&ctx, &mut rng); group.bench_function( BenchmarkId::new( "change_representation/Ntt_to_PowerBasis", @@ -233,20 +242,16 @@ pub fn rq_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - unsafe { - p.override_representation(Representation::Ntt); - } - p.change_representation(Representation::PowerBasis) + let _ = p_ntt.clone().into_power_basis(); }); }, ); - p.change_representation(Representation::Ntt); - q.change_representation(Representation::Ntt); - unsafe { - q.allow_variable_time_computations(); - q.change_representation(Representation::NttShoup); + let mut q_vt = q.clone(); + q_vt.allow_variable_time_computations(); + let mut p_vt = p.clone(); + p_vt.allow_variable_time_computations(); group.bench_function( BenchmarkId::new( @@ -254,11 +259,12 @@ pub fn rq_benchmark(c: &mut Criterion) { format!("{}/{}", degree, ctx.modulus().bits()), ), |b| { - b.iter(|| p *= &q); + b.iter(|| p_vt *= &q_vt); }, ); - p.allow_variable_time_computations(); + let mut p_pb_vt = Poly::::random(&ctx, &mut rng); + p_pb_vt.allow_variable_time_computations(); group.bench_function( BenchmarkId::new( @@ -267,12 +273,13 @@ pub fn rq_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - p.override_representation(Representation::PowerBasis); - p.change_representation(Representation::Ntt) + let _ = p_pb_vt.clone().into_ntt(); }); }, ); + let mut p_ntt_vt = Poly::::random(&ctx, &mut rng); + p_ntt_vt.allow_variable_time_computations(); group.bench_function( BenchmarkId::new( "change_representation/Ntt_to_PowerBasis_vt", @@ -280,8 +287,7 @@ pub fn rq_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - p.override_representation(Representation::Ntt); - p.change_representation(Representation::PowerBasis) + let _ = p_ntt_vt.clone().into_power_basis(); }); }, ); @@ -306,10 +312,7 @@ pub fn rq_convert_benchmark(c: &mut Criterion) { group.bench_function( BenchmarkId::new("try_convert_from_slice", format!("{}/{}", degree, nmoduli)), |b| { - b.iter(|| { - Poly::try_convert_from(slice, &ctx, false, Representation::PowerBasis) - .unwrap() - }); + b.iter(|| Poly::::try_convert_from(slice, &ctx, false).unwrap()); }, ); } diff --git a/crates/fhe-math/src/rq/convert.rs b/crates/fhe-math/src/rq/convert.rs index 738d21cd..bf94490c 100644 --- a/crates/fhe-math/src/rq/convert.rs +++ b/crates/fhe-math/src/rq/convert.rs @@ -1,6 +1,9 @@ //! Implementation of conversions from and to polynomials. -use super::{Context, Poly, Representation, traits::TryConvertFrom}; +use super::{ + Context, Ntt, NttShoup, Poly, PowerBasis, Representation, RepresentationTag, + traits::TryConvertFrom, +}; use crate::{ Error, Result, proto::rq::{Representation as RepresentationProto, Rq}, @@ -8,34 +11,25 @@ use crate::{ use itertools::{Itertools, izip}; use ndarray::{Array2, ArrayView, Axis}; use num_bigint::BigUint; -use std::borrow::Cow; use std::sync::Arc; use zeroize::{Zeroize, Zeroizing}; -impl From<&Poly> for Rq { - fn from(p: &Poly) -> Self { +impl From<&Poly> for Rq { + fn from(p: &Poly) -> Self { assert!(!p.has_lazy_coefficients); - - let needs_transform = p.representation != Representation::PowerBasis; - let q: Cow<'_, Poly> = if needs_transform { - let mut owned = p.clone(); - owned.change_representation(Representation::PowerBasis); - Cow::Owned(owned) - } else { - Cow::Borrowed(p) + let q: Poly = match R::REPRESENTATION { + Representation::PowerBasis => Poly::::from_parts(p.clone()), + Representation::Ntt => Poly::::from_parts(p.clone()).into_power_basis(), + Representation::NttShoup => Poly::::from_parts(p.clone()).into_power_basis(), }; let mut proto = Rq::default(); - match p.representation { + match R::REPRESENTATION { Representation::PowerBasis => { - proto.representation = RepresentationProto::Powerbasis as i32; - } - Representation::Ntt => { - proto.representation = RepresentationProto::Ntt as i32; - } - Representation::NttShoup => { - proto.representation = RepresentationProto::Nttshoup as i32; + proto.representation = RepresentationProto::Powerbasis as i32 } + Representation::Ntt => proto.representation = RepresentationProto::Ntt as i32, + Representation::NttShoup => proto.representation = RepresentationProto::Nttshoup as i32, } let serialization: Vec = izip!(q.coefficients.outer_iter(), p.ctx.q.iter()) .flat_map(|(v, qi)| qi.serialize_vec(v.as_slice().unwrap())) @@ -47,265 +41,262 @@ impl From<&Poly> for Rq { } } -impl TryConvertFrom> for Poly { - fn try_convert_from( - mut v: Vec, - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - let repr = representation.into(); - match repr { - Some(Representation::Ntt) => { - if let Ok(coefficients) = Array2::from_shape_vec((ctx.q.len(), ctx.degree), v) { - Ok(Self { - ctx: ctx.clone(), - representation: repr.unwrap(), - allow_variable_time_computations: variable_time, - coefficients, - coefficients_shoup: None, - has_lazy_coefficients: false, - }) - } else { - Err(Error::Default( - "In Ntt representation, all coefficients must be specified".to_string(), - )) - } - } - Some(Representation::NttShoup) => { - if let Ok(coefficients) = Array2::from_shape_vec((ctx.q.len(), ctx.degree), v) { - let mut p = Self { - ctx: ctx.clone(), - representation: repr.unwrap(), - allow_variable_time_computations: variable_time, - coefficients, - coefficients_shoup: None, - has_lazy_coefficients: false, - }; - p.compute_coefficients_shoup(); - Ok(p) - } else { - Err(Error::Default( - "In NttShoup representation, all coefficients must be specified" - .to_string(), - )) - } - } - Some(Representation::PowerBasis) => { - if v.len() == ctx.q.len() * ctx.degree { - let coefficients = - Array2::from_shape_vec((ctx.q.len(), ctx.degree), v).unwrap(); - Ok(Self { - ctx: ctx.clone(), - representation: repr.unwrap(), - allow_variable_time_computations: variable_time, - coefficients, - coefficients_shoup: None, - has_lazy_coefficients: false, - }) - } else if v.len() <= ctx.degree { - let mut out = Self::zero(ctx, repr.unwrap()); - if variable_time { - unsafe { - izip!(out.coefficients.outer_iter_mut(), ctx.q.iter()).for_each( - |(mut w, qi)| { - let wi = w.as_slice_mut().unwrap(); - wi[..v.len()].copy_from_slice(&v); - qi.reduce_vec_vt(wi); - }, - ); - out.allow_variable_time_computations(); - } - } else { - izip!(out.coefficients.outer_iter_mut(), ctx.q.iter()).for_each( - |(mut w, qi)| { - let wi = w.as_slice_mut().unwrap(); - wi[..v.len()].copy_from_slice(&v); - qi.reduce_vec(wi); - }, - ); - v.zeroize(); - } - Ok(out) - } else { - Err(Error::Default("In PowerBasis representation, either all coefficients must be specified, or only coefficients up to the degree".to_string())) - } - } - None => Err(Error::Default( - "When converting from a vector, the representation needs to be specified" - .to_string(), - )), +fn parse_proto( + value: &Rq, + ctx: &Arc, + variable_time: bool, +) -> Result<(Representation, Vec, bool)> { + let repr = value + .representation + .try_into() + .map_err(|_| Error::Default("Invalid representation".to_string()))?; + let representation_from_proto = match repr { + RepresentationProto::Powerbasis => Representation::PowerBasis, + RepresentationProto::Ntt => Representation::Ntt, + RepresentationProto::Nttshoup => Representation::NttShoup, + RepresentationProto::Unknown => { + return Err(Error::Default("Unknown representation".to_string())); } + }; + + let variable_time = variable_time || value.allow_variable_time; + + let degree = value.degree as usize; + if !degree.is_multiple_of(8) || degree < 8 { + return Err(Error::Default("Invalid degree".to_string())); + } + + let mut expected_nbytes = 0; + ctx.q + .iter() + .for_each(|qi| expected_nbytes += qi.serialization_length(degree)); + if value.coefficients.len() != expected_nbytes { + return Err(Error::Default("Invalid coefficients".to_string())); } + + let mut index = 0; + let power_basis_coefficients: Vec = ctx + .q + .iter() + .flat_map(|qi| { + let size = qi.serialization_length(degree); + let v = qi.deserialize_vec(&value.coefficients[index..index + size]); + index += size; + v + }) + .collect(); + + Ok(( + representation_from_proto, + power_basis_coefficients, + variable_time, + )) } -impl TryConvertFrom<&Rq> for Poly { - fn try_convert_from( - value: &Rq, - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - let repr = value - .representation - .try_into() - .map_err(|_| Error::Default("Invalid representation".to_string()))?; - let representation_from_proto = match repr { - RepresentationProto::Powerbasis => Representation::PowerBasis, - RepresentationProto::Ntt => Representation::Ntt, - RepresentationProto::Nttshoup => Representation::NttShoup, - RepresentationProto::Unknown => { - return Err(Error::Default("Unknown representation".to_string())); - } - }; +impl TryConvertFrom<&Rq> for Poly { + fn try_convert_from(value: &Rq, ctx: &Arc, variable_time: bool) -> Result { + let (representation_from_proto, coefficients, variable_time) = + parse_proto(value, ctx, variable_time)?; + if representation_from_proto != Representation::PowerBasis { + return Err(Error::Default( + "The representation asked for does not match the representation in the serialization".to_string(), + )); + } + Poly::::try_convert_from(coefficients, ctx, variable_time) + } +} - let variable_time = variable_time || value.allow_variable_time; +impl TryConvertFrom<&Rq> for Poly { + fn try_convert_from(value: &Rq, ctx: &Arc, variable_time: bool) -> Result { + let (representation_from_proto, coefficients, variable_time) = + parse_proto(value, ctx, variable_time)?; + if representation_from_proto != Representation::Ntt { + return Err(Error::Default( + "The representation asked for does not match the representation in the serialization".to_string(), + )); + } + let p = Poly::::try_convert_from(coefficients, ctx, variable_time)?; + Ok(p.into_ntt()) + } +} - if let Some(r) = representation.into() as Option - && r != representation_from_proto - { - return Err(Error::Default("The representation asked for does not match the representation in the serialization".to_string())); +impl TryConvertFrom<&Rq> for Poly { + fn try_convert_from(value: &Rq, ctx: &Arc, variable_time: bool) -> Result { + let (representation_from_proto, coefficients, variable_time) = + parse_proto(value, ctx, variable_time)?; + if representation_from_proto != Representation::NttShoup { + return Err(Error::Default( + "The representation asked for does not match the representation in the serialization".to_string(), + )); } + let p = Poly::::try_convert_from(coefficients, ctx, variable_time)?; + Ok(p.into_ntt_shoup()) + } +} - let degree = value.degree as usize; - if !degree.is_multiple_of(8) || degree < 8 { - return Err(Error::Default("Invalid degree".to_string())); +impl TryConvertFrom> for Poly { + fn try_convert_from(mut v: Vec, ctx: &Arc, variable_time: bool) -> Result { + if v.len() == ctx.q.len() * ctx.degree { + let coefficients = Array2::from_shape_vec((ctx.q.len(), ctx.degree), v).unwrap(); + Ok(Self { + ctx: ctx.clone(), + allow_variable_time_computations: variable_time, + coefficients, + coefficients_shoup: None, + has_lazy_coefficients: false, + _repr: std::marker::PhantomData, + }) + } else if v.len() <= ctx.degree { + let mut out = Self::zero(ctx); + if variable_time { + unsafe { + izip!(out.coefficients.outer_iter_mut(), ctx.q.iter()).for_each( + |(mut w, qi)| { + let wi = w.as_slice_mut().unwrap(); + wi[..v.len()].copy_from_slice(&v); + qi.reduce_vec_vt(wi); + }, + ); + out.allow_variable_time_computations(); + } + } else { + izip!(out.coefficients.outer_iter_mut(), ctx.q.iter()).for_each(|(mut w, qi)| { + let wi = w.as_slice_mut().unwrap(); + wi[..v.len()].copy_from_slice(&v); + qi.reduce_vec(wi); + }); + v.zeroize(); + } + Ok(out) + } else { + Err(Error::Default( + "In PowerBasis representation, either all coefficients must be specified, or only coefficients up to the degree".to_string(), + )) } + } +} - let mut expected_nbytes = 0; - ctx.q - .iter() - .for_each(|qi| expected_nbytes += qi.serialization_length(degree)); - if value.coefficients.len() != expected_nbytes { - return Err(Error::Default("Invalid coefficients".to_string())); +impl TryConvertFrom> for Poly { + fn try_convert_from(v: Vec, ctx: &Arc, variable_time: bool) -> Result { + if let Ok(coefficients) = Array2::from_shape_vec((ctx.q.len(), ctx.degree), v) { + Ok(Self { + ctx: ctx.clone(), + allow_variable_time_computations: variable_time, + coefficients, + coefficients_shoup: None, + has_lazy_coefficients: false, + _repr: std::marker::PhantomData, + }) + } else { + Err(Error::Default( + "In Ntt representation, all coefficients must be specified".to_string(), + )) } + } +} - let mut index = 0; - let power_basis_coefficients: Vec = ctx - .q - .iter() - .flat_map(|qi| { - let size = qi.serialization_length(degree); - let v = qi.deserialize_vec(&value.coefficients[index..index + size]); - index += size; - v +impl TryConvertFrom> for Poly { + fn try_convert_from(v: Vec, ctx: &Arc, variable_time: bool) -> Result { + if let Ok(coefficients) = Array2::from_shape_vec((ctx.q.len(), ctx.degree), v) { + let mut p = Self { + ctx: ctx.clone(), + allow_variable_time_computations: variable_time, + coefficients, + coefficients_shoup: None, + has_lazy_coefficients: false, + _repr: std::marker::PhantomData, + }; + p.compute_coefficients_shoup(); + Ok(p) + } else { + Err(Error::Default( + "In NttShoup representation, all coefficients must be specified".to_string(), + )) + } + } +} + +impl TryConvertFrom> for Poly { + fn try_convert_from(a: Array2, ctx: &Arc, variable_time: bool) -> Result { + if a.shape() != [ctx.q.len(), ctx.degree] { + Err(Error::Default( + "The array of coefficient does not have the correct shape".to_string(), + )) + } else { + Ok(Self { + ctx: ctx.clone(), + allow_variable_time_computations: variable_time, + coefficients: a, + coefficients_shoup: None, + has_lazy_coefficients: false, + _repr: std::marker::PhantomData, }) - .collect(); + } + } +} - let mut p = Poly::try_convert_from( - power_basis_coefficients, - ctx, - variable_time, - Representation::PowerBasis, - )?; - p.change_representation(representation_from_proto); - Ok(p) +impl TryConvertFrom> for Poly { + fn try_convert_from(a: Array2, ctx: &Arc, variable_time: bool) -> Result { + if a.shape() != [ctx.q.len(), ctx.degree] { + Err(Error::Default( + "The array of coefficient does not have the correct shape".to_string(), + )) + } else { + Ok(Self { + ctx: ctx.clone(), + allow_variable_time_computations: variable_time, + coefficients: a, + coefficients_shoup: None, + has_lazy_coefficients: false, + _repr: std::marker::PhantomData, + }) + } } } -impl TryConvertFrom> for Poly { - fn try_convert_from( - a: Array2, - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { +impl TryConvertFrom> for Poly { + fn try_convert_from(a: Array2, ctx: &Arc, variable_time: bool) -> Result { if a.shape() != [ctx.q.len(), ctx.degree] { Err(Error::Default( "The array of coefficient does not have the correct shape".to_string(), )) - } else if let Some(repr) = representation.into() { + } else { let mut p = Self { ctx: ctx.clone(), - representation: repr, allow_variable_time_computations: variable_time, coefficients: a, coefficients_shoup: None, has_lazy_coefficients: false, + _repr: std::marker::PhantomData, }; - if p.representation == Representation::NttShoup { - p.compute_coefficients_shoup() - } + p.compute_coefficients_shoup(); Ok(p) - } else { - Err(Error::Default("When converting from a 2-dimensional array, the representation needs to be specified".to_string())) } } } -impl<'a> TryConvertFrom<&'a [u64]> for Poly { - fn try_convert_from( - v: &'a [u64], - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - let repr = representation.into(); - match repr { - Some(Representation::PowerBasis) => { - if v.len() == ctx.q.len() * ctx.degree { - Poly::try_convert_from(v.to_vec(), ctx, variable_time, repr) - } else if v.len() <= ctx.degree { - let mut out = Self::zero(ctx, Representation::PowerBasis); - if variable_time { - unsafe { - izip!(out.coefficients.outer_iter_mut(), ctx.q.iter()).for_each( - |(mut w, qi)| { - let wi = w.as_slice_mut().unwrap(); - wi[..v.len()].copy_from_slice(v); - qi.reduce_vec_vt(wi); - }, - ); - out.allow_variable_time_computations(); - } - } else { - izip!(out.coefficients.outer_iter_mut(), ctx.q.iter()).for_each( - |(mut w, qi)| { - let wi = w.as_slice_mut().unwrap(); - wi[..v.len()].copy_from_slice(v); - qi.reduce_vec(wi); - }, - ); - } - Ok(out) - } else { - Err(Error::Default("In PowerBasis representation, either all coefficients must be specified, or only coefficients up to the degree".to_string())) - } - } - _ => Poly::try_convert_from(v.to_vec(), ctx, variable_time, repr), - } +impl<'a> TryConvertFrom<&'a [u64]> for Poly { + fn try_convert_from(v: &'a [u64], ctx: &Arc, variable_time: bool) -> Result { + Poly::::try_convert_from(v.to_vec(), ctx, variable_time) } } -impl<'a> TryConvertFrom<&'a [i64]> for Poly { - fn try_convert_from( - v: &'a [i64], - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - if representation.into() != Some(Representation::PowerBasis) { - Err(Error::Default( - "Converting signed integer require to import in PowerBasis representation" - .to_string(), - )) - } else if v.len() <= ctx.degree { - let mut out = Self::zero(ctx, Representation::PowerBasis); +impl<'a> TryConvertFrom<&'a [u64]> for Poly { + fn try_convert_from(v: &'a [u64], ctx: &Arc, variable_time: bool) -> Result { + Poly::::try_convert_from(v.to_vec(), ctx, variable_time) + } +} + +impl<'a> TryConvertFrom<&'a [u64]> for Poly { + fn try_convert_from(v: &'a [u64], ctx: &Arc, variable_time: bool) -> Result { + Poly::::try_convert_from(v.to_vec(), ctx, variable_time) + } +} + +impl<'a> TryConvertFrom<&'a [i64]> for Poly { + fn try_convert_from(v: &'a [i64], ctx: &Arc, variable_time: bool) -> Result { + if v.len() <= ctx.degree { + let mut out = Self::zero(ctx); if variable_time { unsafe { out.allow_variable_time_computations() } } @@ -319,135 +310,161 @@ impl<'a> TryConvertFrom<&'a [i64]> for Poly { }); Ok(out) } else { - Err(Error::Default("In PowerBasis representation with signed integers, only `degree` coefficients can be specified".to_string())) + Err(Error::Default( + "In PowerBasis representation with signed integers, only `degree` coefficients can be specified".to_string(), + )) } } } -impl<'a> TryConvertFrom<&'a Vec> for Poly { - fn try_convert_from( - v: &'a Vec, - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - Poly::try_convert_from(v.as_ref() as &[i64], ctx, variable_time, representation) +impl<'a> TryConvertFrom<&'a Vec> for Poly { + fn try_convert_from(v: &'a Vec, ctx: &Arc, variable_time: bool) -> Result { + Poly::try_convert_from(v.as_ref() as &[i64], ctx, variable_time) } } -impl<'a> TryConvertFrom<&'a [BigUint]> for Poly { - fn try_convert_from( - v: &'a [BigUint], - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - let repr = representation.into(); - +impl<'a> TryConvertFrom<&'a [BigUint]> for Poly { + fn try_convert_from(v: &'a [BigUint], ctx: &Arc, variable_time: bool) -> Result { if v.len() > ctx.degree { Err(Error::Default( "The slice contains too many big integers compared to the polynomial degree" .to_string(), )) - } else if repr.is_some() { + } else { let mut coefficients = Array2::zeros((ctx.q.len(), ctx.degree)); izip!(coefficients.axis_iter_mut(Axis(1)), v).for_each(|(mut c, vi)| { c.assign(&ArrayView::from(&ctx.rns.project(vi))); }); - let mut p = Self { + Ok(Self { ctx: ctx.clone(), - representation: repr.unwrap(), allow_variable_time_computations: variable_time, coefficients, coefficients_shoup: None, has_lazy_coefficients: false, - }; - - match p.representation { - Representation::PowerBasis => Ok(p), - Representation::Ntt => Ok(p), - Representation::NttShoup => { - p.compute_coefficients_shoup(); - Ok(p) - } - } - } else { - Err(Error::Default( - "When converting from a vector, the representation needs to be specified" - .to_string(), - )) + _repr: std::marker::PhantomData, + }) } } } -impl<'a> TryConvertFrom<&'a Vec> for Poly { - fn try_convert_from( - v: &'a Vec, - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - Poly::try_convert_from(v.to_vec(), ctx, variable_time, representation) +impl<'a> TryConvertFrom<&'a [BigUint]> for Poly { + fn try_convert_from(v: &'a [BigUint], ctx: &Arc, variable_time: bool) -> Result { + let p = Poly::::try_convert_from(v, ctx, variable_time)?; + Ok(p.into_ntt()) + } +} + +impl<'a> TryConvertFrom<&'a [BigUint]> for Poly { + fn try_convert_from(v: &'a [BigUint], ctx: &Arc, variable_time: bool) -> Result { + let p = Poly::::try_convert_from(v, ctx, variable_time)?; + Ok(p.into_ntt_shoup()) + } +} + +impl<'a> TryConvertFrom<&'a Vec> for Poly { + fn try_convert_from(v: &'a Vec, ctx: &Arc, variable_time: bool) -> Result { + Poly::try_convert_from(v.to_vec(), ctx, variable_time) + } +} + +impl<'a> TryConvertFrom<&'a Vec> for Poly { + fn try_convert_from(v: &'a Vec, ctx: &Arc, variable_time: bool) -> Result { + Poly::try_convert_from(v.to_vec(), ctx, variable_time) + } +} + +impl<'a> TryConvertFrom<&'a Vec> for Poly { + fn try_convert_from(v: &'a Vec, ctx: &Arc, variable_time: bool) -> Result { + Poly::try_convert_from(v.to_vec(), ctx, variable_time) + } +} + +impl<'a, const N: usize> TryConvertFrom<&'a [u64; N]> for Poly { + fn try_convert_from(v: &'a [u64; N], ctx: &Arc, variable_time: bool) -> Result { + Poly::try_convert_from(v.as_ref(), ctx, variable_time) + } +} + +impl<'a, const N: usize> TryConvertFrom<&'a [u64; N]> for Poly { + fn try_convert_from(v: &'a [u64; N], ctx: &Arc, variable_time: bool) -> Result { + Poly::try_convert_from(v.as_ref(), ctx, variable_time) } } -impl<'a, const N: usize> TryConvertFrom<&'a [u64; N]> for Poly { - fn try_convert_from( - v: &'a [u64; N], +impl<'a, const N: usize> TryConvertFrom<&'a [u64; N]> for Poly { + fn try_convert_from(v: &'a [u64; N], ctx: &Arc, variable_time: bool) -> Result { + Poly::try_convert_from(v.as_ref(), ctx, variable_time) + } +} + +impl<'a, const N: usize> TryConvertFrom<&'a [BigUint; N]> for Poly { + fn try_convert_from( + v: &'a [BigUint; N], ctx: &Arc, variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - Poly::try_convert_from(v.as_ref(), ctx, variable_time, representation) + ) -> Result { + Poly::try_convert_from(v.as_ref(), ctx, variable_time) } } -impl<'a, const N: usize> TryConvertFrom<&'a [BigUint; N]> for Poly { - fn try_convert_from( +impl<'a, const N: usize> TryConvertFrom<&'a [BigUint; N]> for Poly { + fn try_convert_from( v: &'a [BigUint; N], ctx: &Arc, variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - Poly::try_convert_from(v.as_ref(), ctx, variable_time, representation) + ) -> Result { + Poly::try_convert_from(v.as_ref(), ctx, variable_time) } } -impl<'a, const N: usize> TryConvertFrom<&'a [i64; N]> for Poly { - fn try_convert_from( - v: &'a [i64; N], +impl<'a, const N: usize> TryConvertFrom<&'a [BigUint; N]> for Poly { + fn try_convert_from( + v: &'a [BigUint; N], ctx: &Arc, variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - Poly::try_convert_from(v.as_ref(), ctx, variable_time, representation) + ) -> Result { + Poly::try_convert_from(v.as_ref(), ctx, variable_time) + } +} + +impl<'a, const N: usize> TryConvertFrom<&'a [i64; N]> for Poly { + fn try_convert_from(v: &'a [i64; N], ctx: &Arc, variable_time: bool) -> Result { + Poly::try_convert_from(v.as_ref(), ctx, variable_time) + } +} + +impl TryFrom<&Poly> for Vec { + type Error = Error; + + fn try_from(p: &Poly) -> Result { + p.coefficients + .as_slice() + .ok_or_else(|| { + Error::Default("Polynomial coefficients are not contiguous in memory".to_string()) + }) + .map(|slice| slice.to_vec()) + } +} + +impl TryFrom<&Poly> for Vec { + type Error = Error; + + fn try_from(p: &Poly) -> Result { + p.coefficients + .as_slice() + .ok_or_else(|| { + Error::Default("Polynomial coefficients are not contiguous in memory".to_string()) + }) + .map(|slice| slice.to_vec()) } } -impl TryFrom<&Poly> for Vec { +impl TryFrom<&Poly> for Vec { type Error = Error; - fn try_from(p: &Poly) -> Result { + fn try_from(p: &Poly) -> Result { p.coefficients .as_slice() .ok_or_else(|| { @@ -457,8 +474,24 @@ impl TryFrom<&Poly> for Vec { } } -impl From<&Poly> for Vec { - fn from(p: &Poly) -> Self { +impl From<&Poly> for Vec { + fn from(p: &Poly) -> Self { + izip!(p.coefficients.axis_iter(Axis(1))) + .map(|c| p.ctx.rns.lift(c)) + .collect_vec() + } +} + +impl From<&Poly> for Vec { + fn from(p: &Poly) -> Self { + izip!(p.coefficients.axis_iter(Axis(1))) + .map(|c| p.ctx.rns.lift(c)) + .collect_vec() + } +} + +impl From<&Poly> for Vec { + fn from(p: &Poly) -> Self { izip!(p.coefficients.axis_iter(Axis(1))) .map(|c| p.ctx.rns.lift(c)) .collect_vec() @@ -467,12 +500,10 @@ impl From<&Poly> for Vec { #[cfg(test)] mod tests { - #![expect(clippy::expect_used, reason = "bounds are validated before use")] - use crate::{ Error as CrateError, proto::rq::Rq, - rq::{Context, Poly, Representation, traits::TryConvertFrom}, + rq::{Context, Ntt, NttShoup, Poly, PowerBasis, traits::TryConvertFrom}, }; use num_bigint::BigUint; use rand::rng; @@ -485,50 +516,34 @@ mod tests { let mut rng = rng(); for modulus in MODULI { let ctx = Arc::new(Context::new(&[*modulus], 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let proto = Rq::from(&p); - assert_eq!(Poly::try_convert_from(&proto, &ctx, false, None)?, p); assert_eq!( - Poly::try_convert_from(&proto, &ctx, false, Representation::PowerBasis)?, + Poly::::try_convert_from(&proto, &ctx, false)?, p ); assert_eq!( - Poly::try_convert_from(&proto, &ctx, false, Representation::Ntt) - .expect_err("Should fail because of mismatched representations"), - CrateError::Default("The representation asked for does not match the representation in the serialization".to_string()) - ); + Poly::::try_convert_from(&proto, &ctx, false).unwrap_err(), + CrateError::Default( + "The representation asked for does not match the representation in the serialization".to_string() + ) + ); assert_eq!( - Poly::try_convert_from(&proto, &ctx, false, Representation::NttShoup) - .expect_err("Should fail because of mismatched representations"), - CrateError::Default("The representation asked for does not match the representation in the serialization".to_string()) - ); + Poly::::try_convert_from(&proto, &ctx, false).unwrap_err(), + CrateError::Default( + "The representation asked for does not match the representation in the serialization".to_string() + ) + ); } let ctx = Arc::new(Context::new(MODULI, 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let proto = Rq::from(&p); - assert_eq!(Poly::try_convert_from(&proto, &ctx, false, None)?, p); - assert_eq!( - Poly::try_convert_from(&proto, &ctx, false, Representation::PowerBasis)?, - p - ); - assert_eq!( - Poly::try_convert_from(&proto, &ctx, false, Representation::Ntt) - .expect_err("Should fail because of mismatched representations"), - CrateError::Default("The representation asked for does not match the representation in the serialization".to_string()) - ); - assert_eq!( - Poly::try_convert_from(&proto, &ctx, false, Representation::NttShoup) - .expect_err("Should fail because of mismatched representations"), - CrateError::Default("The representation asked for does not match the representation in the serialization".to_string()) - ); + assert_eq!(Poly::::try_convert_from(&proto, &ctx, false)?, p); - let ctx = Arc::new(Context::new(&MODULI[0..1], 16)?); - assert_eq!( - Poly::try_convert_from(&proto, &ctx, false, None) - .expect_err("Should fail because of incorrect context"), - CrateError::Default("Invalid coefficients".to_string()) - ); + let p = Poly::::random(&ctx, &mut rng); + let proto = Rq::from(&p); + assert_eq!(Poly::::try_convert_from(&proto, &ctx, false)?, p); Ok(()) } @@ -540,90 +555,35 @@ mod tests { // Power Basis assert_eq!( - Poly::try_convert_from(&[0u64], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) + Poly::::try_convert_from(&[0u64], &ctx, false)?, + Poly::::zero(&ctx) ); assert_eq!( - Poly::try_convert_from(&[0i64], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) + Poly::::try_convert_from(&[0i64], &ctx, false)?, + Poly::::zero(&ctx) ); assert_eq!( - Poly::try_convert_from(&[0u64; 16], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) + Poly::::try_convert_from(&[0u64; 16], &ctx, false)?, + Poly::::zero(&ctx) ); assert_eq!( - Poly::try_convert_from(&[0i64; 16], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!( - Poly::try_convert_from( - &[0u64; 17], // One too many - &ctx, - false, - Representation::PowerBasis, - ) - .is_err() + Poly::::try_convert_from(&[0i64; 16], &ctx, false)?, + Poly::::zero(&ctx) ); + assert!(Poly::::try_convert_from(&[0u64; 17], &ctx, false).is_err()); // Ntt - assert!(Poly::try_convert_from(&[0u64], &ctx, false, Representation::Ntt).is_err()); - assert!(Poly::try_convert_from(&[0i64], &ctx, false, Representation::Ntt).is_err()); - assert_eq!( - Poly::try_convert_from(&[0u64; 16], &ctx, false, Representation::Ntt)?, - Poly::zero(&ctx, Representation::Ntt) - ); - assert!(Poly::try_convert_from(&[0i64; 16], &ctx, false, Representation::Ntt).is_err()); - assert!( - Poly::try_convert_from( - &[0u64; 17], // One too many - &ctx, - false, - Representation::Ntt, - ) - .is_err() - ); + assert!(Poly::::try_convert_from(&[0u64], &ctx, false).is_err()); + assert!(Poly::::try_convert_from(&[0u64; 16], &ctx, false).is_ok()); + assert!(Poly::::try_convert_from(&[0u64; 17], &ctx, false).is_err()); } let ctx = Arc::new(Context::new(MODULI, 16)?); assert_eq!( - Poly::try_convert_from( - Vec::::default(), - &ctx, - false, - Representation::PowerBasis, - )?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!( - Poly::try_convert_from(Vec::::default(), &ctx, false, Representation::Ntt) - .is_err() - ); - - assert_eq!( - Poly::try_convert_from(&[0u64], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!(Poly::try_convert_from(&[0u64], &ctx, false, Representation::Ntt).is_err()); - - assert_eq!( - Poly::try_convert_from(&[0u64; 16], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!(Poly::try_convert_from(&[0u64; 16], &ctx, false, Representation::Ntt).is_err()); - - assert!( - Poly::try_convert_from(&[0u64; 17], &ctx, false, Representation::PowerBasis).is_err() - ); - assert!(Poly::try_convert_from(&[0u64; 17], &ctx, false, Representation::Ntt).is_err()); - - assert_eq!( - Poly::try_convert_from(&[0u64; 16], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert_eq!( - Poly::try_convert_from(&[0u64; 48], &ctx, false, Representation::Ntt)?, - Poly::zero(&ctx, Representation::Ntt) + Poly::::try_convert_from(Vec::::default(), &ctx, false)?, + Poly::::zero(&ctx) ); + assert!(Poly::::try_convert_from(Vec::::default(), &ctx, false).is_err()); Ok(()) } @@ -633,98 +593,38 @@ mod tests { for modulus in MODULI { let ctx = Arc::new(Context::new(&[*modulus], 16)?); assert_eq!( - Poly::try_convert_from(vec![], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) + Poly::::try_convert_from(vec![], &ctx, false)?, + Poly::::zero(&ctx) ); - assert!(Poly::try_convert_from(vec![], &ctx, false, Representation::Ntt).is_err()); + assert!(Poly::::try_convert_from(vec![], &ctx, false).is_err()); assert_eq!( - Poly::try_convert_from(vec![0], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) + Poly::::try_convert_from(vec![0], &ctx, false)?, + Poly::::zero(&ctx) ); - assert!(Poly::try_convert_from(vec![0], &ctx, false, Representation::Ntt).is_err()); + assert!(Poly::::try_convert_from(vec![0], &ctx, false).is_err()); assert_eq!( - Poly::try_convert_from(vec![0; 16], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) + Poly::::try_convert_from(vec![0; 16], &ctx, false)?, + Poly::::zero(&ctx) ); assert_eq!( - Poly::try_convert_from(vec![0; 16], &ctx, false, Representation::Ntt)?, - Poly::zero(&ctx, Representation::Ntt) - ); - - assert!( - Poly::try_convert_from(vec![0; 17], &ctx, false, Representation::PowerBasis) - .is_err() + Poly::::try_convert_from(vec![0; 16], &ctx, false)?, + Poly::::zero(&ctx) ); - assert!(Poly::try_convert_from(vec![0; 17], &ctx, false, Representation::Ntt).is_err()); } - let ctx = Arc::new(Context::new(MODULI, 16)?); - assert_eq!( - Poly::try_convert_from(vec![], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!(Poly::try_convert_from(vec![], &ctx, false, Representation::Ntt).is_err()); - - assert_eq!( - Poly::try_convert_from(vec![0], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!(Poly::try_convert_from(vec![0], &ctx, false, Representation::Ntt).is_err()); - - assert_eq!( - Poly::try_convert_from(vec![0; 16], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!(Poly::try_convert_from(vec![0; 16], &ctx, false, Representation::Ntt).is_err()); - - assert!( - Poly::try_convert_from(vec![0; 17], &ctx, false, Representation::PowerBasis).is_err() - ); - assert!(Poly::try_convert_from(vec![0; 17], &ctx, false, Representation::Ntt).is_err()); - - assert_eq!( - Poly::try_convert_from(vec![0; 48], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert_eq!( - Poly::try_convert_from(vec![0; 48], &ctx, false, Representation::Ntt)?, - Poly::zero(&ctx, Representation::Ntt) - ); - Ok(()) } #[test] fn biguint() -> Result<(), Box> { let mut rng = rng(); - for _ in 0..100 { - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let p_coeffs = Vec::::from(&p); - let q = Poly::try_convert_from( - p_coeffs.as_slice(), - &ctx, - false, - Representation::PowerBasis, - )?; - assert_eq!(p, q); - } - - let ctx = Arc::new(Context::new(MODULI, 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let p_coeffs = Vec::::from(&p); - assert_eq!(p_coeffs.len(), ctx.degree); - let q = Poly::try_convert_from( - p_coeffs.as_slice(), - &ctx, - false, - Representation::PowerBasis, - )?; - assert_eq!(p, q); - } + let ctx = Arc::new(Context::new(MODULI, 16)?); + let p = Poly::::random(&ctx, &mut rng); + let values = Vec::::from(&p); + let p2 = Poly::::try_convert_from(values.as_slice(), &ctx, false)?; + assert_eq!(p, p2); Ok(()) } } diff --git a/crates/fhe-math/src/rq/mod.rs b/crates/fhe-math/src/rq/mod.rs index 294ff3a7..9c043ecf 100644 --- a/crates/fhe-math/src/rq/mod.rs +++ b/crates/fhe-math/src/rq/mod.rs @@ -27,6 +27,7 @@ pub use ops::dot_product; use rand::{CryptoRng, RngCore, SeedableRng}; use rand_chacha::ChaCha8Rng; use sha2::{Digest, Sha256}; +use std::marker::PhantomData; use std::sync::Arc; use zeroize::{Zeroize, Zeroizing}; @@ -45,6 +46,43 @@ pub enum Representation { NttShoup, } +/// Marker type for PowerBasis representation. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct PowerBasis; + +/// Marker type for Ntt representation. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct Ntt; + +/// Marker type for NttShoup representation. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct NttShoup; + +/// Trait implemented by representation marker types. +pub trait RepresentationTag: Default + Copy + 'static { + /// Associated runtime representation. + const REPRESENTATION: Representation; +} + +impl RepresentationTag for PowerBasis { + const REPRESENTATION: Representation = Representation::PowerBasis; +} + +impl RepresentationTag for Ntt { + const REPRESENTATION: Representation = Representation::Ntt; +} + +impl RepresentationTag for NttShoup { + const REPRESENTATION: Representation = Representation::NttShoup; +} + +/// Marker trait for representations that can be scaled/switched without +/// requiring Shoup coefficients. +pub trait ScaleRepresentation: RepresentationTag {} + +impl ScaleRepresentation for PowerBasis {} +impl ScaleRepresentation for Ntt {} + /// An exponent for a substitution. #[derive(Debug, PartialEq, Eq)] pub struct SubstitutionExponent { @@ -83,18 +121,18 @@ impl SubstitutionExponent { } /// Struct that holds a polynomial for a specific context. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct Poly { +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Poly { ctx: Arc, - representation: Representation, has_lazy_coefficients: bool, allow_variable_time_computations: bool, coefficients: Array2, coefficients_shoup: Option>, + _repr: PhantomData, } // Implements zeroization of polynomials -impl Zeroize for Poly { +impl Zeroize for Poly { fn zeroize(&mut self) { if let Some(coeffs) = self.coefficients.as_slice_mut() { coeffs.zeroize() @@ -103,33 +141,35 @@ impl Zeroize for Poly { } } -impl AsRef for Poly { - fn as_ref(&self) -> &Poly { +impl AsRef> for Poly { + fn as_ref(&self) -> &Poly { self } } -impl AsMut for Poly { - fn as_mut(&mut self) -> &mut Poly { +impl AsMut> for Poly { + fn as_mut(&mut self) -> &mut Poly { self } } -impl Poly { +impl Poly { /// Creates a polynomial holding the constant 0. #[must_use] - pub fn zero(ctx: &Arc, representation: Representation) -> Self { + pub fn zero(ctx: &Arc) -> Self { + let representation = R::REPRESENTATION; + let coefficients_shoup = if representation == Representation::NttShoup { + Some(Array2::zeros((ctx.q.len(), ctx.degree))) + } else { + None + }; Self { ctx: ctx.clone(), - representation, allow_variable_time_computations: false, has_lazy_coefficients: false, coefficients: Array2::zeros((ctx.q.len(), ctx.degree)), - coefficients_shoup: if representation == Representation::NttShoup { - Some(Array2::zeros((ctx.q.len(), ctx.degree))) - } else { - None - }, + coefficients_shoup, + _repr: PhantomData, } } @@ -150,8 +190,8 @@ impl Poly { /// Current representation of the polynomial. #[must_use] - pub const fn representation(&self) -> &Representation { - &self.representation + pub const fn representation(&self) -> Representation { + R::REPRESENTATION } /// Zeroize the shoup coefficients @@ -165,35 +205,6 @@ impl Poly { } } - /// Change the representation of the underlying polynomial. - pub fn change_representation(&mut self, to: Representation) { - if self.representation == to { - return; - } - - match (&self.representation, &to) { - (Representation::PowerBasis, Representation::Ntt) => self.ntt_forward(), - (Representation::PowerBasis, Representation::NttShoup) => { - self.ntt_forward(); - self.compute_coefficients_shoup() - } - (Representation::Ntt, Representation::PowerBasis) => self.ntt_backward(), - (Representation::Ntt, Representation::NttShoup) => self.compute_coefficients_shoup(), - (Representation::NttShoup, Representation::PowerBasis) => { - self.zeroize_shoup(); - self.coefficients_shoup = None; - self.ntt_backward() - } - (Representation::NttShoup, Representation::Ntt) => { - self.zeroize_shoup(); - self.coefficients_shoup = None; - } - _ => unreachable!(), - } - - self.representation = to; - } - /// Compute the Shoup representation of the coefficients. fn compute_coefficients_shoup(&mut self) { let mut coefficients_shoup = Array2::zeros((self.ctx.q.len(), self.ctx.degree)); @@ -211,64 +222,36 @@ impl Poly { self.coefficients_shoup = Some(coefficients_shoup) } - /// Override the internal representation to a given representation. - /// - /// # Safety - /// - /// Prefer the `change_representation` function to safely modify the - /// polynomial representation. If the `to` representation is NttShoup, the - /// coefficients are still computed correctly to avoid being in an unstable - /// state. If we override a polynomial with Shoup coefficients, we zeroize - /// them. - pub unsafe fn override_representation(&mut self, to: Representation) { - if self.coefficients_shoup.is_some() { - self.zeroize_shoup(); - self.coefficients_shoup = None - } - if to == Representation::NttShoup { - self.compute_coefficients_shoup() - } - self.representation = to; - } - /// Generate a random polynomial. - pub fn random( - ctx: &Arc, - representation: Representation, - rng: &mut R, - ) -> Self { - let mut p = Poly::zero(ctx, representation); + pub fn random(ctx: &Arc, rng: &mut T) -> Self { + let mut p = Poly::zero(ctx); izip!(p.coefficients.outer_iter_mut(), ctx.q.iter()).for_each(|(mut v, qi)| { v.as_slice_mut() .unwrap() .copy_from_slice(&qi.random_vec(ctx.degree, rng)) }); - if p.representation == Representation::NttShoup { - p.compute_coefficients_shoup() + if R::REPRESENTATION == Representation::NttShoup { + p.compute_coefficients_shoup(); } p } /// Generate a random polynomial deterministically from a seed. #[must_use] - pub fn random_from_seed( - ctx: &Arc, - representation: Representation, - seed: ::Seed, - ) -> Self { + pub fn random_from_seed(ctx: &Arc, seed: ::Seed) -> Self { // Let's hash the seed into a ChaCha8Rng seed. let mut hasher = Sha256::new(); hasher.update(seed); let mut prng = ChaCha8Rng::from_seed(::Seed::from(hasher.finalize())); - let mut p = Poly::zero(ctx, representation); + let mut p = Poly::zero(ctx); izip!(p.coefficients.outer_iter_mut(), ctx.q.iter()).for_each(|(mut v, qi)| { v.as_slice_mut() .unwrap() .copy_from_slice(&qi.random_vec(ctx.degree, &mut prng)) }); - if p.representation == Representation::NttShoup { - p.compute_coefficients_shoup() + if R::REPRESENTATION == Representation::NttShoup { + p.compute_coefficients_shoup(); } p } @@ -279,7 +262,6 @@ impl Poly { /// Returns an error if the variance does not belong to [1, ..., 16]. pub fn small( ctx: &Arc, - representation: Representation, variance: usize, rng: &mut T, ) -> Result { @@ -292,16 +274,14 @@ impl Poly { let coeffs = Zeroizing::new( sample_vec_cbd(ctx.degree, variance, rng).map_err(|e| Error::Default(e.to_string()))?, ); - let mut p = Poly::try_convert_from( - coeffs.as_ref() as &[i64], - ctx, - false, - Representation::PowerBasis, - )?; - if representation != Representation::PowerBasis { - p.change_representation(representation); + let p = Poly::::try_convert_from(coeffs.as_ref() as &[i64], ctx, false)?; + if R::REPRESENTATION == Representation::PowerBasis { + Ok(Poly::from_parts(p)) + } else if R::REPRESENTATION == Representation::Ntt { + Ok(Poly::from_parts(p.into_ntt())) + } else { + Ok(Poly::from_parts(p.into_ntt_shoup())) } - Ok(p) } /// Access the polynomial coefficients in RNS representation. @@ -336,12 +316,12 @@ impl Poly { /// In PowerBasis representation, i can be any integer that is not a /// multiple of 2 * degree. In Ntt and NttShoup representation, i can be any /// odd integer that is not a multiple of 2 * degree. - pub fn substitute(&self, i: &SubstitutionExponent) -> Result { - let mut q = Poly::zero(&self.ctx, self.representation); + pub fn substitute(&self, i: &SubstitutionExponent) -> Result> { + let mut q = Poly::::zero(&self.ctx); if self.allow_variable_time_computations { unsafe { q.allow_variable_time_computations() } } - match self.representation { + match R::REPRESENTATION { Representation::Ntt | Representation::NttShoup => { izip!( q.coefficients.outer_iter_mut(), @@ -352,7 +332,7 @@ impl Poly { q_row[*j] = p_row[*k] } }); - if self.representation == Representation::NttShoup { + if R::REPRESENTATION == Representation::NttShoup { izip!( q.coefficients_shoup.as_mut().unwrap().outer_iter_mut(), self.coefficients_shoup.as_ref().unwrap().outer_iter() @@ -388,55 +368,30 @@ impl Poly { Ok(q) } - /// Create a polynomial which can only be multiplied by a polynomial in - /// NttShoup representation. All other operations may panic. - /// - /// # Safety - /// This operation also creates a polynomial that allows variable time - /// operations. + /// Returns the context of the underlying polynomial #[must_use] - pub unsafe fn create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( - power_basis_coefficients: &[u64], - ctx: &Arc, - ) -> Self { - let mut coefficients = Array2::zeros((ctx.q.len(), ctx.degree)); - izip!(coefficients.outer_iter_mut(), ctx.q.iter(), ctx.ops.iter()).for_each( - |(mut p, qi, op)| { - p.as_slice_mut() - .unwrap() - .clone_from_slice(power_basis_coefficients); - qi.lazy_reduce_vec(p.as_slice_mut().unwrap()); - unsafe { op.forward_vt_lazy(p.as_mut_ptr()) }; - }, - ); - Self { - ctx: ctx.clone(), - representation: Representation::Ntt, - allow_variable_time_computations: true, - coefficients, - coefficients_shoup: None, - has_lazy_coefficients: true, - } + pub fn ctx(&self) -> &Arc { + &self.ctx + } +} + +impl Poly { + /// Borrowed conversion to PowerBasis (clone). + #[must_use] + pub fn to_power_basis(&self) -> Poly { + self.clone() } /// Modulus switch down the polynomial by dividing and rounding each /// coefficient by the last modulus in the chain, then drops the last /// modulus, as described in Algorithm 2 of . /// - /// Returns an error if there is no next context or if the representation - /// is not PowerBasis. + /// Returns an error if there is no next context. pub fn switch_down(&mut self) -> Result<()> { if self.ctx.next_context.is_none() { return Err(Error::NoMoreContext); } - if self.representation != Representation::PowerBasis { - return Err(Error::IncorrectRepresentation( - self.representation, - Representation::PowerBasis, - )); - } - // Unwrap the next_context. let next_context = self.ctx.next_context.as_ref().unwrap(); @@ -496,8 +451,7 @@ impl Poly { /// Modulo switch down to a smaller context. /// /// Returns an error if there is the provided context is not a child of the - /// current context, or if the polynomial is not in PowerBasis - /// representation. + /// current context. pub fn switch_down_to(&mut self, context: &Arc) -> Result<()> { let niterations = self.ctx.niterations_to(context)?; for _ in 0..niterations { @@ -507,32 +461,8 @@ impl Poly { Ok(()) } - /// Modulo switch to another context. The target context needs not to be - /// related to the current context. - pub fn switch(&self, switcher: &Switcher) -> Result { - switcher.switch(self) - } - - /// Scale a polynomial using a scaler. - pub fn scale(&self, scaler: &Scaler) -> Result { - scaler.scale(self) - } - - /// Returns the context of the underlying polynomial - #[must_use] - pub fn ctx(&self) -> &Arc { - &self.ctx - } - /// Multiplies a polynomial in PowerBasis representation by x^(-power). pub fn multiply_inverse_power_of_x(&mut self, power: usize) -> Result<()> { - if self.representation != Representation::PowerBasis { - return Err(Error::IncorrectRepresentation( - self.representation, - Representation::PowerBasis, - )); - } - let shift = ((self.ctx.degree << 1) - power) % (self.ctx.degree << 1); let mask = self.ctx.degree - 1; let mut new_coefficients = Array2::zeros((self.ctx.q.len(), self.ctx.degree)); @@ -554,11 +484,150 @@ impl Poly { self.coefficients = new_coefficients; Ok(()) } + + /// Convert into Ntt representation. + #[must_use] + pub fn into_ntt(mut self) -> Poly { + self.ntt_forward(); + Poly::from_parts(self) + } + + /// Convert into NttShoup representation. + #[must_use] + pub fn into_ntt_shoup(mut self) -> Poly { + self.ntt_forward(); + self.compute_coefficients_shoup(); + Poly::from_parts(self) + } +} + +impl Poly { + /// Borrowed conversion to PowerBasis. + #[must_use] + pub fn to_power_basis(&self) -> Poly { + self.clone().into_power_basis() + } + + /// Create a polynomial which can only be multiplied by a polynomial in + /// NttShoup representation. All other operations may panic. + /// + /// # Safety + /// This operation also creates a polynomial that allows variable time + /// operations. + #[must_use] + pub unsafe fn create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( + power_basis_coefficients: &[u64], + ctx: &Arc, + ) -> Poly { + let mut coefficients = Array2::zeros((ctx.q.len(), ctx.degree)); + izip!(coefficients.outer_iter_mut(), ctx.q.iter(), ctx.ops.iter()).for_each( + |(mut p, qi, op)| { + p.as_slice_mut() + .unwrap() + .clone_from_slice(power_basis_coefficients); + qi.lazy_reduce_vec(p.as_slice_mut().unwrap()); + unsafe { op.forward_vt_lazy(p.as_mut_ptr()) }; + }, + ); + Poly { + ctx: ctx.clone(), + allow_variable_time_computations: true, + coefficients, + coefficients_shoup: None, + has_lazy_coefficients: true, + _repr: PhantomData, + } + } + + /// Convert into PowerBasis representation. + #[must_use] + pub fn into_power_basis(mut self) -> Poly { + self.ntt_backward(); + Poly::from_parts(self) + } + + /// Convert into NttShoup representation. + #[must_use] + pub fn into_ntt_shoup(mut self) -> Poly { + self.compute_coefficients_shoup(); + Poly::from_parts(self) + } +} + +impl Poly { + /// Borrowed conversion to PowerBasis. + #[must_use] + pub fn to_power_basis(&self) -> Poly { + self.clone().into_power_basis() + } + + /// Convert into Ntt representation. + #[must_use] + pub fn into_ntt(mut self) -> Poly { + self.zeroize_shoup(); + self.coefficients_shoup = None; + Poly::from_parts(self) + } + + /// Convert into PowerBasis representation. + #[must_use] + pub fn into_power_basis(mut self) -> Poly { + self.zeroize_shoup(); + self.coefficients_shoup = None; + self.ntt_backward(); + Poly::from_parts(self) + } +} + +impl Poly { + #[must_use] + fn from_parts(mut other: Poly) -> Poly { + let coefficients_shoup = if R::REPRESENTATION == Representation::NttShoup { + if other.coefficients_shoup.is_none() { + other.compute_coefficients_shoup(); + } + other.coefficients_shoup + } else { + if other.coefficients_shoup.is_some() { + other.zeroize_shoup(); + } + None + }; + let Poly { + ctx, + has_lazy_coefficients, + allow_variable_time_computations, + coefficients, + coefficients_shoup: _, + _repr: _, + } = other; + Poly { + ctx, + has_lazy_coefficients, + allow_variable_time_computations, + coefficients, + coefficients_shoup, + _repr: PhantomData, + } + } +} + +impl Poly { + /// Modulo switch to another context. The target context needs not to be + /// related to the current context. + pub fn switch(&self, switcher: &Switcher) -> Result> { + switcher.switch(self) + } + + /// Scale a polynomial using a scaler. + pub fn scale(&self, scaler: &Scaler) -> Result> { + scaler.scale(self) + } } #[cfg(test)] mod tests { - use super::{Context, Poly, Representation, switcher::Switcher}; + use super::{Context, Ntt, Poly, PowerBasis, Representation, switcher::Switcher}; use crate::{rq::SubstitutionExponent, zq::Modulus}; use fhe_util::variance; use itertools::Itertools; @@ -600,17 +669,17 @@ mod tests { for modulus in MODULI { let ctx = Arc::new(Context::new(&[*modulus], 16)?); - let p = Poly::zero(&ctx, Representation::PowerBasis); - let q = Poly::zero(&ctx, Representation::Ntt); - assert_ne!(p, q); + let p = Poly::::zero(&ctx); + let q = Poly::::zero(&ctx); + assert_eq!(p, q.to_power_basis()); assert_eq!(Vec::::try_from(&p).unwrap(), &[0; 16]); assert_eq!(Vec::::try_from(&q).unwrap(), &[0; 16]); } let ctx = Arc::new(Context::new(MODULI, 16)?); - let p = Poly::zero(&ctx, Representation::PowerBasis); - let q = Poly::zero(&ctx, Representation::Ntt); - assert_ne!(p, q); + let p = Poly::::zero(&ctx); + let q = Poly::::zero(&ctx); + assert_eq!(p, q.to_power_basis()); assert_eq!(Vec::::try_from(&p).unwrap(), [0; 16 * MODULI.len()]); assert_eq!(Vec::::try_from(&q).unwrap(), [0; 16 * MODULI.len()]); assert_eq!(Vec::::from(&p), reference); @@ -623,12 +692,12 @@ mod tests { fn ctx() -> Result<(), Box> { for modulus in MODULI { let ctx = Arc::new(Context::new(&[*modulus], 16)?); - let p = Poly::zero(&ctx, Representation::PowerBasis); + let p = Poly::::zero(&ctx); assert_eq!(p.ctx(), &ctx); } let ctx = Arc::new(Context::new(MODULI, 16)?); - let p = Poly::zero(&ctx, Representation::PowerBasis); + let p = Poly::::zero(&ctx); assert_eq!(p.ctx(), &ctx); Ok(()) @@ -643,21 +712,21 @@ mod tests { for modulus in MODULI { let ctx = Arc::new(Context::new(&[*modulus], 16)?); - let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed); - let q = Poly::random_from_seed(&ctx, Representation::Ntt, seed); + let p = Poly::::random_from_seed(&ctx, seed); + let q = Poly::::random_from_seed(&ctx, seed); assert_eq!(p, q); } let ctx = Arc::new(Context::new(MODULI, 16)?); - let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed); - let q = Poly::random_from_seed(&ctx, Representation::Ntt, seed); + let p = Poly::::random_from_seed(&ctx, seed); + let q = Poly::::random_from_seed(&ctx, seed); assert_eq!(p, q); rand::rng().fill(&mut seed); - let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed); + let p = Poly::::random_from_seed(&ctx, seed); assert_ne!(p, q); - let r = Poly::random(&ctx, Representation::Ntt, &mut rng); + let r = Poly::::random(&ctx, &mut rng); assert_ne!(p, r); assert_ne!(q, r); } @@ -670,13 +739,13 @@ mod tests { for _ in 0..50 { for modulus in MODULI { let ctx = Arc::new(Context::new(&[*modulus], 16)?); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let p_coefficients = Vec::::try_from(&p).unwrap(); assert_eq!(p_coefficients, p.coefficients().as_slice().unwrap()) } let ctx = Arc::new(Context::new(MODULI, 16)?); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let p_coefficients = Vec::::try_from(&p).unwrap(); assert_eq!(p_coefficients, p.coefficients().as_slice().unwrap()) } @@ -704,7 +773,7 @@ mod tests { let mut rng = rand::rng(); for modulus in MODULI { let ctx = Arc::new(Context::new(&[*modulus], 16)?); - let mut p = Poly::random(&ctx, Representation::default(), &mut rng); + let mut p = Poly::::random(&ctx, &mut rng); assert!(!p.allow_variable_time_computations); unsafe { p.allow_variable_time_computations() } @@ -718,7 +787,7 @@ mod tests { } let ctx = Arc::new(Context::new(MODULI, 16)?); - let mut p = Poly::random(&ctx, Representation::default(), &mut rng); + let mut p = Poly::::random(&ctx, &mut rng); assert!(!p.allow_variable_time_computations); unsafe { p.allow_variable_time_computations() } @@ -728,9 +797,9 @@ mod tests { assert!(q.allow_variable_time_computations); // Allowing variable time propagates. - let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let mut p = Poly::::random(&ctx, &mut rng); unsafe { p.allow_variable_time_computations() } - let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); + let mut q = Poly::::random(&ctx, &mut rng); assert!(!q.allow_variable_time_computations); q *= &p; @@ -760,12 +829,12 @@ mod tests { .collect(); let poly = unsafe { - Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( + Poly::::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( &coeffs, &ctx, ) }; - assert_eq!(poly.representation(), &Representation::Ntt); + assert_eq!(poly.representation(), Representation::Ntt); assert!(poly.allow_variable_time_computations); assert!(poly.has_lazy_coefficients); @@ -777,83 +846,6 @@ mod tests { Ok(()) } - #[test] - fn change_representation() -> Result<(), Box> { - let mut rng = rand::rng(); - let ctx = Arc::new(Context::new(MODULI, 16)?); - - let mut p = Poly::random(&ctx, Representation::default(), &mut rng); - assert_eq!(p.representation, Representation::default()); - assert_eq!(p.representation(), &Representation::default()); - - p.change_representation(Representation::PowerBasis); - assert_eq!(p.representation, Representation::PowerBasis); - assert_eq!(p.representation(), &Representation::PowerBasis); - assert!(p.coefficients_shoup.is_none()); - let q = p.clone(); - - p.change_representation(Representation::Ntt); - assert_eq!(p.representation, Representation::Ntt); - assert_eq!(p.representation(), &Representation::Ntt); - assert_ne!(p.coefficients, q.coefficients); - assert!(p.coefficients_shoup.is_none()); - let q_ntt = p.clone(); - - p.change_representation(Representation::NttShoup); - assert_eq!(p.representation, Representation::NttShoup); - assert_eq!(p.representation(), &Representation::NttShoup); - assert_ne!(p.coefficients, q.coefficients); - assert!(p.coefficients_shoup.is_some()); - let q_ntt_shoup = p.clone(); - - p.change_representation(Representation::PowerBasis); - assert_eq!(p, q); - - p.change_representation(Representation::NttShoup); - assert_eq!(p, q_ntt_shoup); - - p.change_representation(Representation::Ntt); - assert_eq!(p, q_ntt); - - p.change_representation(Representation::PowerBasis); - assert_eq!(p, q); - - Ok(()) - } - - #[test] - fn override_representation() -> Result<(), Box> { - let mut rng = rand::rng(); - let ctx = Arc::new(Context::new(MODULI, 16)?); - - let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - assert_eq!(p.representation(), &p.representation); - let q = p.clone(); - - unsafe { p.override_representation(Representation::Ntt) } - assert_eq!(p.representation, Representation::Ntt); - assert_eq!(p.representation(), &p.representation); - assert_eq!(p.coefficients, q.coefficients); - assert!(p.coefficients_shoup.is_none()); - - unsafe { p.override_representation(Representation::NttShoup) } - assert_eq!(p.representation, Representation::NttShoup); - assert_eq!(p.representation(), &p.representation); - assert_eq!(p.coefficients, q.coefficients); - assert!(p.coefficients_shoup.is_some()); - - unsafe { p.override_representation(Representation::PowerBasis) } - assert_eq!(p, q); - - unsafe { p.override_representation(Representation::NttShoup) } - assert!(p.coefficients_shoup.is_some()); - - unsafe { p.override_representation(Representation::Ntt) } - assert!(p.coefficients_shoup.is_none()); - - Ok(()) - } - #[test] fn small() -> Result<(), Box> { let mut rng = rand::rng(); @@ -861,13 +853,13 @@ mod tests { let ctx = Arc::new(Context::new(&[*modulus], 16)?); let q = Modulus::new(*modulus).unwrap(); - let e = Poly::small(&ctx, Representation::PowerBasis, 0, &mut rng); + let e = Poly::::small(&ctx, 0, &mut rng); assert!(e.is_err()); assert_eq!( e.unwrap_err().to_string(), "The variance should be an integer between 1 and 16" ); - let e = Poly::small(&ctx, Representation::PowerBasis, 17, &mut rng); + let e = Poly::::small(&ctx, 17, &mut rng); assert!(e.is_err()); assert_eq!( e.unwrap_err().to_string(), @@ -875,7 +867,7 @@ mod tests { ); for i in 1..=16 { - let p = Poly::small(&ctx, Representation::PowerBasis, i, &mut rng)?; + let p = Poly::::small(&ctx, i, &mut rng)?; let coefficients = p.coefficients().to_slice().unwrap(); let v = q.center_vec(coefficients); @@ -887,7 +879,7 @@ mod tests { let ctx = Arc::new(Context::new(&[4611686018326724609], 1 << 18)?); let q = Modulus::new(4611686018326724609).unwrap(); let mut rng = rand::rng(); - let p = Poly::small(&ctx, Representation::PowerBasis, 16, &mut rng)?; + let p = Poly::::small(&ctx, 16, &mut rng)?; let coefficients = p.coefficients().to_slice().unwrap(); let v = q.center_vec(coefficients); assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 32); @@ -901,11 +893,9 @@ mod tests { let mut rng = rand::rng(); for modulus in MODULI { let ctx = Arc::new(Context::new(&[*modulus], 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let mut p_ntt = p.clone(); - p_ntt.change_representation(Representation::Ntt); - let mut p_ntt_shoup = p.clone(); - p_ntt_shoup.change_representation(Representation::NttShoup); + let p = Poly::::random(&ctx, &mut rng); + let p_ntt = p.clone().into_ntt(); + let p_ntt_shoup = p.clone().into_ntt_shoup(); let p_coeffs = Vec::::try_from(&p).unwrap(); // Substitution by a multiple of 2 * degree, or even numbers, should fail @@ -925,7 +915,7 @@ mod tests { ); // Substitution by 3 - let mut q = p.substitute(&SubstitutionExponent::new(&ctx, 3)?)?; + let q = p.substitute(&SubstitutionExponent::new(&ctx, 3)?)?; let mut v = vec![0u64; 16]; for i in 0..16 { v[(3 * i) % 16] = if ((3 * i) / 16) & 1 == 1 && p_coeffs[i] > 0 { @@ -937,12 +927,12 @@ mod tests { assert_eq!(&Vec::::try_from(&q).unwrap(), &v); let q_ntt = p_ntt.substitute(&SubstitutionExponent::new(&ctx, 3)?)?; - q.change_representation(Representation::Ntt); - assert_eq!(q, q_ntt); + let q_as_ntt = q.clone().into_ntt(); + assert_eq!(q_as_ntt, q_ntt); let q_ntt_shoup = p_ntt_shoup.substitute(&SubstitutionExponent::new(&ctx, 3)?)?; - q.change_representation(Representation::NttShoup); - assert_eq!(q, q_ntt_shoup); + let q_as_ntt_shoup = q.clone().into_ntt_shoup(); + assert_eq!(q_as_ntt_shoup, q_ntt_shoup); // 11 = 3^(-1) % 16 assert_eq!( @@ -965,11 +955,9 @@ mod tests { } let ctx = Arc::new(Context::new(MODULI, 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let mut p_ntt = p.clone(); - p_ntt.change_representation(Representation::Ntt); - let mut p_ntt_shoup = p.clone(); - p_ntt_shoup.change_representation(Representation::NttShoup); + let p = Poly::::random(&ctx, &mut rng); + let p_ntt = p.clone().into_ntt(); + let p_ntt_shoup = p.clone().into_ntt_shoup(); assert_eq!( p, @@ -999,19 +987,8 @@ mod tests { let ctx = Arc::new(Context::new(MODULI, 16)?); for _ in 0..ntests { - // If the polynomial has incorrect representation, an error is returned - let e = Poly::random(&ctx, Representation::Ntt, &mut rng).switch_down(); - assert!(e.is_err()); - assert_eq!( - e.unwrap_err(), - crate::Error::IncorrectRepresentation( - Representation::Ntt, - Representation::PowerBasis - ) - ); - // Otherwise, no error happens and the coefficients evolve as expected. - let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let mut p = Poly::::random(&ctx, &mut rng); let mut reference = Vec::::from(&p); let mut current_ctx = ctx.clone(); assert_eq!(p.ctx, current_ctx); @@ -1046,7 +1023,7 @@ mod tests { let ctx2 = Arc::new(Context::new(&MODULI[..2], 16)?); for _ in 0..ntests { - let mut p = Poly::random(&ctx1, Representation::PowerBasis, &mut rng); + let mut p = Poly::::random(&ctx1, &mut rng); let reference = Vec::::from(&p); p.switch_down_to(&ctx2)?; @@ -1072,7 +1049,7 @@ mod tests { let ctx2 = Arc::new(Context::new(&MODULI[3..], 16)?); let switcher = Switcher::new(&ctx1, &ctx2)?; for _ in 0..ntests { - let p = Poly::random(&ctx1, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx1, &mut rng); let reference = Vec::::from(&p); let q = p.switch(&switcher)?; @@ -1093,14 +1070,7 @@ mod tests { fn mul_x_power() -> Result<(), Box> { let mut rng = rand::rng(); let ctx = Arc::new(Context::new(MODULI, 16)?); - let e = Poly::random(&ctx, Representation::Ntt, &mut rng).multiply_inverse_power_of_x(1); - assert!(e.is_err()); - assert_eq!( - e.unwrap_err(), - crate::Error::IncorrectRepresentation(Representation::Ntt, Representation::PowerBasis) - ); - - let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let mut p = Poly::::random(&ctx, &mut rng); let q = p.clone(); p.multiply_inverse_power_of_x(0)?; diff --git a/crates/fhe-math/src/rq/ops.rs b/crates/fhe-math/src/rq/ops.rs index d1c6f3a6..aacd7ade 100644 --- a/crates/fhe-math/src/rq/ops.rs +++ b/crates/fhe-math/src/rq/ops.rs @@ -1,49 +1,17 @@ //! Implementation of operations over polynomials. -use super::{Poly, Representation}; +use super::{Ntt, NttShoup, Poly, PowerBasis}; use crate::{Error, Result}; use itertools::{Itertools, izip}; use ndarray::Array2; use num_bigint::BigUint; -use std::{ - cmp::min, - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, -}; -use zeroize::Zeroize; - -impl AddAssign<&Poly> for Poly { - fn add_assign(&mut self, p: &Poly) { - assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients); +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; - // p and self must have the same context. +impl AddAssign<&Poly> for Poly { + fn add_assign(&mut self, p: &Poly) { + assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients); debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts"); - // p and q must have comptatible representations. - match self.representation { - Representation::PowerBasis => assert_eq!( - p.representation, - Representation::PowerBasis, - "Incompatible representations" - ), - Representation::Ntt | Representation::NttShoup => assert!( - p.representation == Representation::Ntt - || p.representation == Representation::NttShoup, - "Incompatible representations" - ), - } - - // If the representation is NttShoup, drop the Shoup coefficients - // and switch to Ntt representation. - if self.representation == Representation::NttShoup { - self.coefficients_shoup - .as_mut() - .unwrap() - .as_slice_mut() - .unwrap() - .zeroize(); - unsafe { self.override_representation(Representation::Ntt) } - } - self.allow_variable_time_computations |= p.allow_variable_time_computations; if self.allow_variable_time_computations { izip!( @@ -67,62 +35,109 @@ impl AddAssign<&Poly> for Poly { } } -impl Add<&Poly> for &Poly { - type Output = Poly; - fn add(self, p: &Poly) -> Poly { - // if self is in NttShoup representation, let's copy `p` instead - if self.representation == Representation::NttShoup { - let mut q = p.clone(); - q += self; - q - } else { - let mut q = self.clone(); - q += p; - q - } +impl Add<&Poly> for &Poly { + type Output = Poly; + fn add(self, p: &Poly) -> Poly { + let mut q = self.clone(); + q += p; + q } } -impl Add for Poly { - type Output = Poly; - fn add(self, mut p: Poly) -> Poly { +impl Add for Poly { + type Output = Poly; + fn add(self, mut p: Poly) -> Poly { p += &self; p } } -impl SubAssign<&Poly> for Poly { - fn sub_assign(&mut self, p: &Poly) { +impl SubAssign<&Poly> for Poly { + fn sub_assign(&mut self, p: &Poly) { assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients); - - // p and self must have the same context. debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts"); - // p and q must have comptatible representations. - match self.representation { - Representation::PowerBasis => assert_eq!( - p.representation, - Representation::PowerBasis, - "Incompatible representations" - ), - Representation::Ntt | Representation::NttShoup => assert!( - p.representation == Representation::Ntt - || p.representation == Representation::NttShoup, - "Incompatible representations" - ), + self.allow_variable_time_computations |= p.allow_variable_time_computations; + if self.allow_variable_time_computations { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| unsafe { + qi.sub_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) + }); + } else { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| { + qi.sub_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) + }); } + } +} + +impl Sub<&Poly> for &Poly { + type Output = Poly; + fn sub(self, p: &Poly) -> Poly { + let mut q = self.clone(); + q -= p; + q + } +} + +impl AddAssign<&Poly> for Poly { + fn add_assign(&mut self, p: &Poly) { + assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients); + debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts"); - // If the representation is NttShoup, drop the Shoup coefficients - // and switch to Ntt representation. - if self.representation == Representation::NttShoup { - self.coefficients_shoup - .as_mut() - .unwrap() - .as_slice_mut() - .unwrap() - .zeroize(); - unsafe { self.override_representation(Representation::Ntt) } + self.allow_variable_time_computations |= p.allow_variable_time_computations; + if self.allow_variable_time_computations { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| unsafe { + qi.add_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) + }); + } else { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| { + qi.add_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) + }); } + } +} + +impl Add<&Poly> for &Poly { + type Output = Poly; + fn add(self, p: &Poly) -> Poly { + let mut q = self.clone(); + q += p; + q + } +} + +impl Add for Poly { + type Output = Poly; + fn add(self, mut p: Poly) -> Poly { + p += &self; + p + } +} + +impl SubAssign<&Poly> for Poly { + fn sub_assign(&mut self, p: &Poly) { + assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients); + debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts"); self.allow_variable_time_computations |= p.allow_variable_time_computations; if self.allow_variable_time_computations { @@ -147,116 +162,140 @@ impl SubAssign<&Poly> for Poly { } } -impl Sub<&Poly> for &Poly { - type Output = Poly; - fn sub(self, p: &Poly) -> Poly { +impl Sub<&Poly> for &Poly { + type Output = Poly; + fn sub(self, p: &Poly) -> Poly { let mut q = self.clone(); q -= p; q } } -impl MulAssign<&Poly> for Poly { - #[expect(clippy::panic, reason = "panic indicates violated internal invariant")] - fn mul_assign(&mut self, p: &Poly) { +impl MulAssign<&Poly> for Poly { + fn mul_assign(&mut self, p: &Poly) { assert!(!p.has_lazy_coefficients); - assert_ne!( - self.representation, - Representation::NttShoup, - "Cannot multiply to a polynomial in NttShoup representation" + assert!( + !self.has_lazy_coefficients, + "Cannot multiply lazy coefficients by an Ntt polynomial" ); - if self.has_lazy_coefficients && self.representation == Representation::Ntt { - assert!( - p.representation == Representation::NttShoup, - "Can only multiply a polynomial with lazy coefficients by an NttShoup representation." - ); - } else { - assert_eq!( - self.representation, - Representation::Ntt, - "Multiplication requires an Ntt representation." - ); - } debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts"); self.allow_variable_time_computations |= p.allow_variable_time_computations; - match p.representation { - Representation::Ntt => { - if self.allow_variable_time_computations { - unsafe { - izip!( - self.coefficients.outer_iter_mut(), - p.coefficients.outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, qi)| { - qi.mul_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()); - }); - } - } else { - izip!( - self.coefficients.outer_iter_mut(), - p.coefficients.outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, qi)| { - qi.mul_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) - }); - } - } - Representation::NttShoup => { - if self.allow_variable_time_computations { - izip!( - self.coefficients.outer_iter_mut(), - p.coefficients.outer_iter(), - p.coefficients_shoup.as_ref().unwrap().outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, v2_shoup, qi)| unsafe { - qi.mul_shoup_vec_vt( - v1.as_slice_mut().unwrap(), - v2.as_slice().unwrap(), - v2_shoup.as_slice().unwrap(), - ) - }); - } else { - izip!( - self.coefficients.outer_iter_mut(), - p.coefficients.outer_iter(), - p.coefficients_shoup.as_ref().unwrap().outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, v2_shoup, qi)| { - qi.mul_shoup_vec( - v1.as_slice_mut().unwrap(), - v2.as_slice().unwrap(), - v2_shoup.as_slice().unwrap(), - ) - }); - } - self.has_lazy_coefficients = false - } - Representation::PowerBasis => { - panic!("Multiplication requires a multipliand in Ntt or NttShoup representation.") + if self.allow_variable_time_computations { + unsafe { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| { + qi.mul_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()); + }); } + } else { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| { + qi.mul_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) + }); } } } -impl MulAssign<&BigUint> for Poly { - fn mul_assign(&mut self, p: &BigUint) { - // If the representation is NttShoup, drop the Shoup coefficients - // and switch to Ntt representation. - if self.representation == Representation::NttShoup { - self.coefficients_shoup - .as_mut() - .unwrap() - .as_slice_mut() - .unwrap() - .zeroize(); - unsafe { self.override_representation(Representation::Ntt) } +impl MulAssign<&Poly> for Poly { + fn mul_assign(&mut self, p: &Poly) { + assert!(!p.has_lazy_coefficients); + debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts"); + self.allow_variable_time_computations |= p.allow_variable_time_computations; + + if self.allow_variable_time_computations { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + p.coefficients_shoup.as_ref().unwrap().outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, v2_shoup, qi)| unsafe { + qi.mul_shoup_vec_vt( + v1.as_slice_mut().unwrap(), + v2.as_slice().unwrap(), + v2_shoup.as_slice().unwrap(), + ) + }); + } else { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + p.coefficients_shoup.as_ref().unwrap().outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, v2_shoup, qi)| { + qi.mul_shoup_vec( + v1.as_slice_mut().unwrap(), + v2.as_slice().unwrap(), + v2_shoup.as_slice().unwrap(), + ) + }); } + self.has_lazy_coefficients = false; + } +} +impl Mul<&Poly> for &Poly { + type Output = Poly; + fn mul(self, p: &Poly) -> Poly { + let mut q = self.clone(); + q *= p; + q + } +} + +impl Mul<&Poly> for &Poly { + type Output = Poly; + fn mul(self, p: &Poly) -> Poly { + let mut q = self.clone(); + q *= p; + q + } +} + +impl Mul<&BigUint> for &Poly { + type Output = Poly; + fn mul(self, p: &BigUint) -> Poly { + let mut q = self.clone(); + q *= p; + q + } +} + +impl Mul<&BigUint> for &Poly { + type Output = Poly; + fn mul(self, p: &BigUint) -> Poly { + let mut q = self.clone(); + q *= p; + q + } +} + +impl Mul<&Poly> for &BigUint { + type Output = Poly; + fn mul(self, p: &Poly) -> Poly { + p * self + } +} + +impl Mul<&Poly> for &BigUint { + type Output = Poly; + fn mul(self, p: &Poly) -> Poly { + p * self + } +} + +impl MulAssign<&BigUint> for Poly { + fn mul_assign(&mut self, p: &BigUint) { // Project the scalar into its CRT representation (reduced modulo each prime) let scalar_crt = self.ctx.rns.project(p); @@ -284,53 +323,57 @@ impl MulAssign<&BigUint> for Poly { } } -impl Mul<&Poly> for &Poly { - type Output = Poly; - fn mul(self, p: &Poly) -> Poly { - // if self is in NttShoup representation, let's copy `p` instead - if self.representation == Representation::NttShoup { - let mut q = p.clone(); - q *= self; - q +impl MulAssign<&BigUint> for Poly { + fn mul_assign(&mut self, p: &BigUint) { + let scalar_crt = self.ctx.rns.project(p); + + if self.allow_variable_time_computations { + unsafe { + izip!( + self.coefficients.outer_iter_mut(), + scalar_crt.iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, scalar_qi, qi)| { + qi.scalar_mul_vec_vt(v1.as_slice_mut().unwrap(), *scalar_qi) + }); + } } else { - let mut q = self.clone(); - q *= p; - q + izip!( + self.coefficients.outer_iter_mut(), + scalar_crt.iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, scalar_qi, qi)| { + qi.scalar_mul_vec(v1.as_slice_mut().unwrap(), *scalar_qi) + }); } } } -impl Mul<&BigUint> for &Poly { - type Output = Poly; - fn mul(self, p: &BigUint) -> Poly { - let mut q = self.clone(); - q *= p; - q - } -} +impl Neg for &Poly { + type Output = Poly; -impl Mul<&Poly> for &BigUint { - type Output = Poly; - fn mul(self, p: &Poly) -> Poly { - p * self + fn neg(self) -> Poly { + assert!(!self.has_lazy_coefficients); + let mut out = self.clone(); + if self.allow_variable_time_computations { + izip!(out.coefficients.outer_iter_mut(), out.ctx.q.iter()) + .for_each(|(mut v1, qi)| unsafe { qi.neg_vec_vt(v1.as_slice_mut().unwrap()) }); + } else { + izip!(out.coefficients.outer_iter_mut(), out.ctx.q.iter()) + .for_each(|(mut v1, qi)| qi.neg_vec(v1.as_slice_mut().unwrap())); + } + out } } -impl Neg for &Poly { - type Output = Poly; +impl Neg for &Poly { + type Output = Poly; - fn neg(self) -> Poly { + fn neg(self) -> Poly { assert!(!self.has_lazy_coefficients); let mut out = self.clone(); - if out.representation == Representation::NttShoup { - out.coefficients_shoup - .as_mut() - .unwrap() - .as_slice_mut() - .unwrap() - .zeroize(); - unsafe { out.override_representation(Representation::Ntt) } - } if self.allow_variable_time_computations { izip!(out.coefficients.outer_iter_mut(), out.ctx.q.iter()) .for_each(|(mut v1, qi)| unsafe { qi.neg_vec_vt(v1.as_slice_mut().unwrap()) }); @@ -342,20 +385,27 @@ impl Neg for &Poly { } } -impl Neg for Poly { - type Output = Poly; +impl Neg for Poly { + type Output = Poly; - fn neg(mut self) -> Poly { + fn neg(mut self) -> Poly { assert!(!self.has_lazy_coefficients); - if self.representation == Representation::NttShoup { - self.coefficients_shoup - .as_mut() - .unwrap() - .as_slice_mut() - .unwrap() - .zeroize(); - unsafe { self.override_representation(Representation::Ntt) } + if self.allow_variable_time_computations { + izip!(self.coefficients.outer_iter_mut(), self.ctx.q.iter()) + .for_each(|(mut v1, qi)| unsafe { qi.neg_vec_vt(v1.as_slice_mut().unwrap()) }); + } else { + izip!(self.coefficients.outer_iter_mut(), self.ctx.q.iter()) + .for_each(|(mut v1, qi)| qi.neg_vec(v1.as_slice_mut().unwrap())); } + self + } +} + +impl Neg for Poly { + type Output = Poly; + + fn neg(mut self) -> Poly { + assert!(!self.has_lazy_coefficients); if self.allow_variable_time_computations { izip!(self.coefficients.outer_iter_mut(), self.ctx.q.iter()) .for_each(|(mut v1, qi)| unsafe { qi.neg_vec_vt(v1.as_slice_mut().unwrap()) }); @@ -393,24 +443,14 @@ fn fma(out: &mut [u128], x: &[u64], y: &[u64]) { } } -/// Compute the dot product between two iterators of polynomials. -/// Returna an error if the iterator counts are 0, or if any of the polynomial -/// is not in Ntt or NttShoup representation. -pub fn dot_product<'a, 'b, I, J>(p: I, q: J) -> Result +/// Compute the dot product between two iterators of polynomials in Ntt +/// representation. Returns an error if either iterator is empty. +pub fn dot_product<'a, 'b, I, J>(p: I, q: J) -> Result> where - I: Iterator + Clone, - J: Iterator + Clone, + I: Iterator> + Clone, + J: Iterator> + Clone, { - debug_assert!( - !p.clone() - .any(|pi| pi.representation == Representation::PowerBasis) - ); - debug_assert!( - !q.clone() - .any(|qi| qi.representation == Representation::PowerBasis) - ); - - let count = min(p.clone().count(), q.clone().count()); + let count = std::cmp::min(p.clone().count(), q.clone().count()); if count == 0 { return Err(Error::Default("At least one iterator is empty".to_string())); } @@ -501,11 +541,11 @@ where Ok(Poly { ctx: p_first.ctx.clone(), - representation: Representation::Ntt, allow_variable_time_computations: p_first.allow_variable_time_computations, coefficients: coeffs, coefficients_shoup: None, has_lazy_coefficients: false, + _repr: std::marker::PhantomData, }) } @@ -517,7 +557,7 @@ mod tests { use super::dot_product; use crate::{ - rq::{Context, Poly, Representation}, + rq::{Context, Ntt, NttShoup, Poly, PowerBasis}, zq::Modulus, }; use std::{error::Error, sync::Arc}; @@ -533,26 +573,24 @@ mod tests { let ctx = Arc::new(Context::new(&[*modulus], n)?); let m = Modulus::new(*modulus).unwrap(); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); + let q = Poly::::random(&ctx, &mut rng); let r = &p + &q; - assert_eq!(r.representation, Representation::PowerBasis); let mut a = Vec::::try_from(&p).unwrap(); m.add_vec(&mut a, &Vec::::try_from(&q).unwrap()); assert_eq!(Vec::::try_from(&r).unwrap(), a); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let q = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p = Poly::::random(&ctx, &mut rng); + let q = Poly::::random(&ctx, &mut rng); let r = &p + &q; - assert_eq!(r.representation, Representation::Ntt); let mut a = Vec::::try_from(&p).unwrap(); m.add_vec(&mut a, &Vec::::try_from(&q).unwrap()); assert_eq!(Vec::::try_from(&r).unwrap(), a); } let ctx = Arc::new(Context::new(MODULI, 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); + let q = Poly::::random(&ctx, &mut rng); let mut a = Vec::::try_from(&p).unwrap(); let b = Vec::::try_from(&q).unwrap(); for i in 0..MODULI.len() { @@ -560,7 +598,6 @@ mod tests { m.add_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16]) } let r = &p + &q; - assert_eq!(r.representation, Representation::PowerBasis); assert_eq!(Vec::::try_from(&r).unwrap(), a); } Ok(()) @@ -574,26 +611,24 @@ mod tests { let ctx = Arc::new(Context::new(&[*modulus], 16)?); let m = Modulus::new(*modulus).unwrap(); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); + let q = Poly::::random(&ctx, &mut rng); let r = &p - &q; - assert_eq!(r.representation, Representation::PowerBasis); let mut a = Vec::::try_from(&p).unwrap(); m.sub_vec(&mut a, &Vec::::try_from(&q).unwrap()); assert_eq!(Vec::::try_from(&r).unwrap(), a); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let q = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p = Poly::::random(&ctx, &mut rng); + let q = Poly::::random(&ctx, &mut rng); let r = &p - &q; - assert_eq!(r.representation, Representation::Ntt); let mut a = Vec::::try_from(&p).unwrap(); m.sub_vec(&mut a, &Vec::::try_from(&q).unwrap()); assert_eq!(Vec::::try_from(&r).unwrap(), a); } let ctx = Arc::new(Context::new(MODULI, 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); + let q = Poly::::random(&ctx, &mut rng); let mut a = Vec::::try_from(&p).unwrap(); let b = Vec::::try_from(&q).unwrap(); for i in 0..MODULI.len() { @@ -601,7 +636,6 @@ mod tests { m.sub_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16]) } let r = &p - &q; - assert_eq!(r.representation, Representation::PowerBasis); assert_eq!(Vec::::try_from(&r).unwrap(), a); } Ok(()) @@ -615,18 +649,17 @@ mod tests { let ctx = Arc::new(Context::new(&[*modulus], 16)?); let m = Modulus::new(*modulus).unwrap(); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let q = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p = Poly::::random(&ctx, &mut rng); + let q = Poly::::random(&ctx, &mut rng); let r = &p * &q; - assert_eq!(r.representation, Representation::Ntt); let mut a = Vec::::try_from(&p).unwrap(); m.mul_vec(&mut a, &Vec::::try_from(&q).unwrap()); assert_eq!(Vec::::try_from(&r).unwrap(), a); } let ctx = Arc::new(Context::new(MODULI, 16)?); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let q = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p = Poly::::random(&ctx, &mut rng); + let q = Poly::::random(&ctx, &mut rng); let mut a = Vec::::try_from(&p).unwrap(); let b = Vec::::try_from(&q).unwrap(); for i in 0..MODULI.len() { @@ -634,7 +667,6 @@ mod tests { m.mul_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16]) } let r = &p * &q; - assert_eq!(r.representation, Representation::Ntt); assert_eq!(Vec::::try_from(&r).unwrap(), a); } Ok(()) @@ -648,18 +680,17 @@ mod tests { let ctx = Arc::new(Context::new(&[*modulus], 16)?); let m = Modulus::new(*modulus).unwrap(); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let q = Poly::random(&ctx, Representation::NttShoup, &mut rng); + let p = Poly::::random(&ctx, &mut rng); + let q = Poly::::random(&ctx, &mut rng); let r = &p * &q; - assert_eq!(r.representation, Representation::Ntt); let mut a = Vec::::try_from(&p).unwrap(); m.mul_vec(&mut a, &Vec::::try_from(&q).unwrap()); assert_eq!(Vec::::try_from(&r).unwrap(), a); } let ctx = Arc::new(Context::new(MODULI, 16)?); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let q = Poly::random(&ctx, Representation::NttShoup, &mut rng); + let p = Poly::::random(&ctx, &mut rng); + let q = Poly::::random(&ctx, &mut rng); let mut a = Vec::::try_from(&p).unwrap(); let b = Vec::::try_from(&q).unwrap(); for i in 0..MODULI.len() { @@ -667,7 +698,6 @@ mod tests { m.mul_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16]) } let r = &p * &q; - assert_eq!(r.representation, Representation::Ntt); assert_eq!(Vec::::try_from(&r).unwrap(), a); } Ok(()) @@ -681,34 +711,30 @@ mod tests { let ctx = Arc::new(Context::new(&[*modulus], 16)?); let m = Modulus::new(*modulus).unwrap(); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let r = -&p; - assert_eq!(r.representation, Representation::PowerBasis); let mut a = Vec::::try_from(&p).unwrap(); m.neg_vec(&mut a); assert_eq!(Vec::::try_from(&r).unwrap(), a); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let r = -&p; - assert_eq!(r.representation, Representation::Ntt); let mut a = Vec::::try_from(&p).unwrap(); m.neg_vec(&mut a); assert_eq!(Vec::::try_from(&r).unwrap(), a); } let ctx = Arc::new(Context::new(MODULI, 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let mut a = Vec::::try_from(&p).unwrap(); for i in 0..MODULI.len() { let m = Modulus::new(MODULI[i]).unwrap(); m.neg_vec(&mut a[i * 16..(i + 1) * 16]) } let r = -&p; - assert_eq!(r.representation, Representation::PowerBasis); assert_eq!(Vec::::try_from(&r).unwrap(), a); let r = -p; - assert_eq!(r.representation, Representation::PowerBasis); assert_eq!(Vec::::try_from(&r).unwrap(), a); } Ok(()) @@ -723,14 +749,14 @@ mod tests { for len in 1..50 { let p = (0..len) - .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) + .map(|_| Poly::::random(&ctx, &mut rng)) .collect_vec(); let q = (0..len) - .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) + .map(|_| Poly::::random(&ctx, &mut rng)) .collect_vec(); let r = dot_product(p.iter(), q.iter())?; - let mut expected = Poly::zero(&ctx, Representation::Ntt); + let mut expected = Poly::::zero(&ctx); izip!(&p, &q).for_each(|(pi, qi)| expected += &(pi * qi)); assert_eq!(r, expected); } @@ -739,14 +765,14 @@ mod tests { let ctx = Arc::new(Context::new(MODULI, 16)?); for len in 1..50 { let p = (0..len) - .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) + .map(|_| Poly::::random(&ctx, &mut rng)) .collect_vec(); let q = (0..len) - .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) + .map(|_| Poly::::random(&ctx, &mut rng)) .collect_vec(); let r = dot_product(p.iter(), q.iter())?; - let mut expected = Poly::zero(&ctx, Representation::Ntt); + let mut expected = Poly::::zero(&ctx); izip!(&p, &q).for_each(|(pi, qi)| expected += &(pi * qi)); assert_eq!(r, expected); } @@ -763,19 +789,17 @@ mod tests { let m = Modulus::new(*modulus).unwrap(); // Test with PowerBasis representation - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let scalar = BigUint::from(42u64); let r = &p * &scalar; - assert_eq!(r.representation, Representation::PowerBasis); let mut expected = Vec::::try_from(&p).unwrap(); m.scalar_mul_vec(&mut expected, 42u64); assert_eq!(Vec::::try_from(&r).unwrap(), expected); // Test with NTT representation - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let scalar = BigUint::from(123u64); let r = &p * &scalar; - assert_eq!(r.representation, Representation::Ntt); let mut expected = Vec::::try_from(&p).unwrap(); m.scalar_mul_vec(&mut expected, 123u64); assert_eq!(Vec::::try_from(&r).unwrap(), expected); @@ -784,10 +808,9 @@ mod tests { let ctx = Arc::new(Context::new(MODULI, 16)?); // Test with PowerBasis representation - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let scalar = BigUint::from(99u64); let r = &p * &scalar; - assert_eq!(r.representation, Representation::PowerBasis); let mut expected = Vec::::try_from(&p).unwrap(); for i in 0..MODULI.len() { let m = Modulus::new(MODULI[i]).unwrap(); @@ -796,10 +819,9 @@ mod tests { assert_eq!(Vec::::try_from(&r).unwrap(), expected); // Test with NTT representation - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let scalar = BigUint::from(77u64); let r = &p * &scalar; - assert_eq!(r.representation, Representation::Ntt); let mut expected = Vec::::try_from(&p).unwrap(); for i in 0..MODULI.len() { let m = Modulus::new(MODULI[i]).unwrap(); @@ -818,9 +840,8 @@ mod tests { let q_prod = MODULI.iter().fold(BigUint::from(1u64), |acc, &m| acc * m); let large_scalar = &q_prod + BigUint::from(12345u64); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng()); + let p = Poly::::random(&ctx, &mut rng()); let r = &p * &large_scalar; - assert_eq!(r.representation, Representation::Ntt); // Verify by computing the expected result manually for each modulus let mut expected = Vec::::try_from(&p).unwrap(); @@ -837,17 +858,15 @@ mod tests { #[test] fn mul_scalar_ntt_shoup() { - use num_bigint::BigUint; - let ctx = Arc::new(Context::new(MODULI, 16).unwrap()); - let mut p = Poly::random(&ctx, Representation::NttShoup, &mut rng()); - let mut p_ntt = p.clone(); - p_ntt.change_representation(Representation::Ntt); + let p = Poly::::random(&ctx, &mut rng()); + let mut p_ntt = p.clone().into_ntt(); let scalar = BigUint::from(42u64); - p *= &scalar; + let mut p_ntt_scaled = p_ntt.clone(); + p_ntt_scaled *= &scalar; - assert_eq!(p.representation, Representation::Ntt); - assert_eq!(&p_ntt * &scalar, p); + p_ntt *= &scalar; + assert_eq!(p_ntt_scaled, p_ntt); } } diff --git a/crates/fhe-math/src/rq/scaler.rs b/crates/fhe-math/src/rq/scaler.rs index 86652fcb..2a23bd79 100644 --- a/crates/fhe-math/src/rq/scaler.rs +++ b/crates/fhe-math/src/rq/scaler.rs @@ -2,7 +2,7 @@ //! Polynomial scaler. -use super::{Context, Poly, Representation}; +use super::{Context, Poly, Representation, ScaleRepresentation}; use crate::{ Error, Result, rns::{RnsScaler, ScalingFactor}, @@ -10,6 +10,7 @@ use crate::{ use itertools::izip; use ndarray::{Array2, Axis, s}; use std::borrow::Cow; +use std::marker::PhantomData; use std::sync::Arc; /// Context extender. @@ -48,17 +49,12 @@ impl Scaler { } /// Scale a polynomial - pub(crate) fn scale(&self, p: &Poly) -> Result { + pub(crate) fn scale(&self, p: &Poly) -> Result> { if p.ctx.as_ref() != self.from.as_ref() { Err(Error::Default( "The input polynomial does not have the correct context".to_string(), )) } else { - let mut representation = p.representation; - if representation == Representation::NttShoup { - representation = Representation::Ntt; - } - let mut new_coefficients = Array2::::zeros((self.to.q.len(), self.to.degree)); if self.number_common_moduli > 0 { @@ -68,7 +64,7 @@ impl Scaler { } if self.number_common_moduli < self.to.q.len() { - let needs_transform = p.representation != Representation::PowerBasis; + let needs_transform = R::REPRESENTATION != Representation::PowerBasis; let p_coefficients_powerbasis: Cow<'_, Array2> = if needs_transform { let mut owned = p.coefficients.clone(); // Backward NTT @@ -120,11 +116,11 @@ impl Scaler { Ok(Poly { ctx: self.to.clone(), - representation, allow_variable_time_computations: p.allow_variable_time_computations, coefficients: new_coefficients, coefficients_shoup: None, has_lazy_coefficients: false, + _repr: PhantomData, }) } } @@ -133,7 +129,7 @@ impl Scaler { #[cfg(test)] mod tests { use super::{Scaler, ScalingFactor}; - use crate::rq::{Context, Poly, Representation}; + use crate::rq::{Context, Ntt, Poly, PowerBasis}; use itertools::Itertools; use num_bigint::BigUint; use num_traits::{One, Zero}; @@ -168,7 +164,7 @@ mod tests { let scaler = Scaler::new(&from, &to, ScalingFactor::new(&n, &d))?; for _ in 0..ntests { - let mut poly = Poly::random(&from, Representation::PowerBasis, &mut rng); + let poly = Poly::::random(&from, &mut rng); let poly_biguint = Vec::::from(&poly); let scaled_poly = scaler.scale(&poly)?; @@ -195,10 +191,9 @@ mod tests { .collect_vec(); assert_eq!(expected, scaled_biguint); - poly.change_representation(Representation::Ntt); - let mut scaled_poly = scaler.scale(&poly)?; - scaled_poly.change_representation(Representation::PowerBasis); - let scaled_biguint = Vec::::from(&scaled_poly); + let poly_ntt: Poly = poly.clone().into_ntt(); + let scaled_poly = scaler.scale(&poly_ntt)?; + let scaled_biguint = Vec::::from(&scaled_poly.to_power_basis()); assert_eq!(expected, scaled_biguint); } } diff --git a/crates/fhe-math/src/rq/serialize.rs b/crates/fhe-math/src/rq/serialize.rs index 444110ab..45f14604 100644 --- a/crates/fhe-math/src/rq/serialize.rs +++ b/crates/fhe-math/src/rq/serialize.rs @@ -2,24 +2,27 @@ use std::sync::Arc; -use super::{Context, Poly, traits::TryConvertFrom}; +use super::{Context, Poly, RepresentationTag, traits::TryConvertFrom}; use crate::{Error, proto::rq::Rq}; use fhe_traits::{DeserializeWithContext, Serialize}; use prost::Message; -impl Serialize for Poly { +impl Serialize for Poly { fn to_bytes(&self) -> Vec { Rq::from(self).encode_to_vec() } } -impl DeserializeWithContext for Poly { +impl DeserializeWithContext for Poly +where + Poly: for<'a> TryConvertFrom<&'a Rq>, +{ type Error = Error; type Context = Context; fn from_bytes(bytes: &[u8], ctx: &Arc) -> Result { let rq: Rq = Message::decode(bytes).map_err(|e| Error::Serialization(e.to_string()))?; - Poly::try_convert_from(&rq, ctx, false, None) + Poly::try_convert_from(&rq, ctx, false) } } @@ -31,7 +34,7 @@ mod tests { use rand::rng; use crate::proto::rq::{Representation as RepresentationProto, Rq}; - use crate::rq::{Context, Poly, Representation, traits::TryConvertFrom}; + use crate::rq::{Context, Ntt, NttShoup, Poly, PowerBasis, traits::TryConvertFrom}; use prost::Message; const Q: &[u64; 3] = &[ @@ -46,21 +49,21 @@ mod tests { for qi in Q { let ctx = Arc::new(Context::new(&[*qi], 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); - let p = Poly::random(&ctx, Representation::NttShoup, &mut rng); - assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); + let p = Poly::::random(&ctx, &mut rng); + assert_eq!(p, Poly::::from_bytes(&p.to_bytes(), &ctx)?); + let p = Poly::::random(&ctx, &mut rng); + assert_eq!(p, Poly::::from_bytes(&p.to_bytes(), &ctx)?); + let p = Poly::::random(&ctx, &mut rng); + assert_eq!(p, Poly::::from_bytes(&p.to_bytes(), &ctx)?); } let ctx = Arc::new(Context::new(Q, 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); - let p = Poly::random(&ctx, Representation::NttShoup, &mut rng); - assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); + let p = Poly::::random(&ctx, &mut rng); + assert_eq!(p, Poly::::from_bytes(&p.to_bytes(), &ctx)?); + let p = Poly::::random(&ctx, &mut rng); + assert_eq!(p, Poly::::from_bytes(&p.to_bytes(), &ctx)?); + let p = Poly::::random(&ctx, &mut rng); + assert_eq!(p, Poly::::from_bytes(&p.to_bytes(), &ctx)?); Ok(()) } @@ -69,11 +72,11 @@ mod tests { fn deserialize_unknown_representation_rejected() -> Result<(), Box> { let mut rng = rng(); let ctx = Arc::new(Context::new(Q, 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let mut proto = Rq::from(&p); proto.representation = RepresentationProto::Unknown as i32; let bytes = proto.encode_to_vec(); - let err = Poly::from_bytes(&bytes, &ctx).unwrap_err(); + let err = Poly::::from_bytes(&bytes, &ctx).unwrap_err(); assert!(err.to_string().contains("Unknown representation")); Ok(()) } @@ -82,11 +85,11 @@ mod tests { fn deserialize_invalid_degree_rejected() -> Result<(), Box> { let mut rng = rng(); let ctx = Arc::new(Context::new(Q, 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let mut proto = Rq::from(&p); proto.degree = 6; let bytes = proto.encode_to_vec(); - let err = Poly::from_bytes(&bytes, &ctx).unwrap_err(); + let err = Poly::::from_bytes(&bytes, &ctx).unwrap_err(); assert!(err.to_string().contains("Invalid degree")); Ok(()) } @@ -95,11 +98,11 @@ mod tests { fn deserialize_invalid_coefficients_rejected() -> Result<(), Box> { let mut rng = rng(); let ctx = Arc::new(Context::new(Q, 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let mut proto = Rq::from(&p); proto.coefficients.clear(); let bytes = proto.encode_to_vec(); - let err = Poly::from_bytes(&bytes, &ctx).unwrap_err(); + let err = Poly::::from_bytes(&bytes, &ctx).unwrap_err(); assert!(err.to_string().contains("Invalid coefficients")); Ok(()) } @@ -108,10 +111,9 @@ mod tests { fn deserialize_representation_mismatch_rejected() -> Result<(), Box> { let mut rng = rng(); let ctx = Arc::new(Context::new(Q, 16)?); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let proto = Rq::from(&p); - let err = - Poly::try_convert_from(&proto, &ctx, false, Representation::PowerBasis).unwrap_err(); + let err = Poly::::try_convert_from(&proto, &ctx, false).unwrap_err(); assert!( err.to_string() .contains("representation asked for does not match") @@ -123,11 +125,11 @@ mod tests { fn deserialize_variable_time_flag_propagates() -> Result<(), Box> { let mut rng = rng(); let ctx = Arc::new(Context::new(Q, 16)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p = Poly::::random(&ctx, &mut rng); let mut proto = Rq::from(&p); proto.allow_variable_time = true; let bytes = proto.encode_to_vec(); - let decoded = Poly::from_bytes(&bytes, &ctx)?; + let decoded = Poly::::from_bytes(&bytes, &ctx)?; assert!(decoded.allow_variable_time_computations); Ok(()) } diff --git a/crates/fhe-math/src/rq/switcher.rs b/crates/fhe-math/src/rq/switcher.rs index 10aac049..b77e24ec 100644 --- a/crates/fhe-math/src/rq/switcher.rs +++ b/crates/fhe-math/src/rq/switcher.rs @@ -2,7 +2,7 @@ //! Polynomial modulus switcher. -use super::{Context, Poly, scaler::Scaler}; +use super::{Context, Poly, ScaleRepresentation, scaler::Scaler}; use crate::{Result, rns::ScalingFactor}; use std::sync::Arc; @@ -21,7 +21,7 @@ impl Switcher { } /// Switch a polynomial. - pub(crate) fn switch(&self, p: &Poly) -> Result { + pub(crate) fn switch(&self, p: &Poly) -> Result> { self.scaler.scale(p) } } diff --git a/crates/fhe-math/src/rq/traits.rs b/crates/fhe-math/src/rq/traits.rs index 2ca35ca0..1ad7f3bd 100644 --- a/crates/fhe-math/src/rq/traits.rs +++ b/crates/fhe-math/src/rq/traits.rs @@ -2,7 +2,7 @@ //! Traits associated with polynomials. -use super::{Context, Representation}; +use super::Context; use crate::Result; use std::sync::Arc; @@ -16,17 +16,8 @@ pub trait TryConvertFrom where Self: Sized, { - /// Attempt to convert the `value` into a polynomial with a specific context - /// and under a specific representation. The representation may optional and - /// be specified as `None`; this is useful for example when converting from - /// a value that encodes the representation (e.g., serialization, protobuf, - /// etc.). - fn try_convert_from( - value: T, - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>; + /// Attempt to convert the `value` into a polynomial with a specific + /// context. Callers select the target representation via the `Self` + /// type. + fn try_convert_from(value: T, ctx: &Arc, variable_time: bool) -> Result; } diff --git a/crates/fhe-math/tests/ntt_shoup_ops.rs b/crates/fhe-math/tests/ntt_shoup_ops.rs index f513d44c..79b5c2ca 100644 --- a/crates/fhe-math/tests/ntt_shoup_ops.rs +++ b/crates/fhe-math/tests/ntt_shoup_ops.rs @@ -1,6 +1,6 @@ //! Unit test for polynomial Shoup operations. -use fhe_math::rq::{Context, Poly, Representation}; +use fhe_math::rq::{Context, Ntt, NttShoup, Poly}; use rand::rng; use std::sync::Arc; @@ -10,61 +10,17 @@ fn test_ntt_shoup_add_sub_neg() { let ctx = Arc::new(Context::new(&[modulus], 16).unwrap()); let mut rng = rng(); - let p_ntt = Poly::random(&ctx, Representation::Ntt, &mut rng); - let p_shoup = Poly::random(&ctx, Representation::NttShoup, &mut rng); + let p_ntt = Poly::::random(&ctx, &mut rng); + let p_shoup = Poly::::random(&ctx, &mut rng); + let p_shoup_as_ntt = p_shoup.clone().into_ntt(); - // Helper to get Ntt version of a poly - let to_ntt = |p: &Poly| -> Poly { - let mut q = p.clone(); - if *q.representation() == Representation::NttShoup { - unsafe { q.override_representation(Representation::Ntt) }; - // Note: override_representation handles shoup cleanup if needed or just switch - // enum But strict conversion: - q.change_representation(Representation::Ntt); - } - q - }; + // Add/Sub/Neg on Ntt after explicit conversion. + let sum = &p_ntt + &p_shoup_as_ntt; + assert_eq!(sum, &p_ntt + &p_shoup_as_ntt); - let p_shoup_as_ntt = to_ntt(&p_shoup); + let diff = &p_ntt - &p_shoup_as_ntt; + assert_eq!(diff, &p_ntt - &p_shoup_as_ntt); - // Case 1: Ntt + NttShoup - let sum1 = &p_ntt + &p_shoup; - assert_eq!(sum1.representation(), &Representation::Ntt); - assert_eq!(sum1, &p_ntt + &p_shoup_as_ntt); - - // Case 2: NttShoup + Ntt - let sum2 = &p_shoup + &p_ntt; - assert_eq!(sum2.representation(), &Representation::Ntt); - assert_eq!(sum2, &p_shoup_as_ntt + &p_ntt); - - // Case 3: NttShoup + NttShoup (should work if we relaxed AddAssign correctly) - // Wait, AddAssign on LHS=NttShoup is forbidden. - // But Add(&NttShoup, &NttShoup) -> converts LHS to Ntt, then adds RHS - // (NttShoup). So LHS becomes Ntt. Ntt += NttShoup. This should work now. - let p_shoup2 = Poly::random(&ctx, Representation::NttShoup, &mut rng); - let p_shoup2_as_ntt = to_ntt(&p_shoup2); - - let sum3 = &p_shoup + &p_shoup2; - assert_eq!(sum3.representation(), &Representation::Ntt); - assert_eq!(sum3, &p_shoup_as_ntt + &p_shoup2_as_ntt); - - // Case 4: Neg NttShoup - let neg = -&p_shoup; - assert_eq!(neg.representation(), &Representation::Ntt); + let neg = -&p_shoup_as_ntt; assert_eq!(neg, -&p_shoup_as_ntt); - - // Case 5: Sub Ntt - NttShoup - let diff1 = &p_ntt - &p_shoup; - assert_eq!(diff1.representation(), &Representation::Ntt); - assert_eq!(diff1, &p_ntt - &p_shoup_as_ntt); - - // Case 6: Sub NttShoup - Ntt - let diff2 = &p_shoup - &p_ntt; - assert_eq!(diff2.representation(), &Representation::Ntt); - assert_eq!(diff2, &p_shoup_as_ntt - &p_ntt); - - // Case 7: Sub NttShoup - NttShoup - let diff3 = &p_shoup - &p_shoup2; - assert_eq!(diff3.representation(), &Representation::Ntt); - assert_eq!(diff3, &p_shoup_as_ntt - &p_shoup2_as_ntt); } diff --git a/crates/fhe/examples/sealpir.rs b/crates/fhe/examples/sealpir.rs index 8ec01feb..4b2a72a9 100644 --- a/crates/fhe/examples/sealpir.rs +++ b/crates/fhe/examples/sealpir.rs @@ -18,7 +18,7 @@ mod util; use clap::Parser; use fhe::bfv; -use fhe_math::rq::{Poly, Representation, traits::TryConvertFrom}; +use fhe_math::rq::{Ntt, Poly, traits::TryConvertFrom}; use fhe_traits::{ DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncoderVariableTime, FheEncrypter, Serialize, @@ -255,8 +255,8 @@ fn main() -> Result<(), Box> { let ctx = params.context_at_level(2)?; let ct = bfv::Ciphertext::new( vec![ - Poly::try_convert_from(poly0, ctx, true, Representation::Ntt)?, - Poly::try_convert_from(poly1, ctx, true, Representation::Ntt)?, + Poly::::try_convert_from(poly0, ctx, true)?, + Poly::::try_convert_from(poly1, ctx, true)?, ], ¶ms, )?; diff --git a/crates/fhe/src/bfv/ciphertext.rs b/crates/fhe/src/bfv/ciphertext.rs index 3987f6a8..e7861f21 100644 --- a/crates/fhe/src/bfv/ciphertext.rs +++ b/crates/fhe/src/bfv/ciphertext.rs @@ -3,7 +3,7 @@ use crate::bfv::{parameters::BfvParameters, traits::TryConvertFrom}; use crate::proto::bfv::Ciphertext as CiphertextProto; use crate::{Error, Result, SerializationError}; -use fhe_math::rq::{Poly, Representation}; +use fhe_math::rq::{Ntt, Poly}; use fhe_traits::{ DeserializeParametrized, DeserializeWithContext, FheCiphertext, FheParametrized, Serialize, }; @@ -23,14 +23,14 @@ pub struct Ciphertext { pub(crate) seed: Option<::Seed>, /// The ciphertext elements. - pub(crate) c: Vec, + pub(crate) c: Vec>, /// The ciphertext level pub(crate) level: usize, } impl Deref for Ciphertext { - type Target = [Poly]; + type Target = [Poly]; fn deref(&self) -> &Self::Target { &self.c @@ -48,7 +48,7 @@ impl Ciphertext { /// A ciphertext must contain at least two polynomials, and all polynomials /// must be in Ntt representation and with the same context. #[expect(clippy::expect_used, reason = "bounds are validated before use")] - pub fn new(c: Vec, par: &Arc) -> Result { + pub fn new(c: Vec>, par: &Arc) -> Result { if c.len() < 2 { return Err(Error::TooFewValues { actual: c.len(), @@ -62,14 +62,8 @@ impl Ciphertext { .ctx(); let level = par.level_of_context(ctx)?; - // Check that all polynomials have the expected representation and context. + // Check that all polynomials have the expected context. for ci in c.iter() { - if ci.representation() != &Representation::Ntt { - return Err(Error::MathError(fhe_math::Error::IncorrectRepresentation( - *ci.representation(), - Representation::Ntt, - ))); - } if ci.ctx() != ctx { return Err(Error::MathError(fhe_math::Error::InvalidContext)); } @@ -93,9 +87,9 @@ impl Ciphertext { if self.level < self.max_switchable_level() { self.seed = None; for ci in self.c.iter_mut() { - ci.change_representation(Representation::PowerBasis); - ci.switch_down()?; - ci.change_representation(Representation::Ntt); + let mut pb = ci.clone().into_power_basis(); + pb.switch_down()?; + *ci = pb.into_ntt(); } self.level += 1 } @@ -220,7 +214,7 @@ impl TryConvertFrom<&CiphertextProto> for Ciphertext { let mut c = Vec::with_capacity(value.c.len() + 1); for cip in &value.c { - c.push(Poly::from_bytes(cip, ctx)?) + c.push(Poly::::from_bytes(cip, ctx)?) } let mut seed = None; @@ -233,7 +227,7 @@ impl TryConvertFrom<&CiphertextProto> for Ciphertext { )) })?; seed = Some(try_seed); - let mut c1 = Poly::random_from_seed(ctx, Representation::Ntt, try_seed); + let mut c1 = Poly::::random_from_seed(ctx, try_seed); unsafe { c1.allow_variable_time_computations() } c.push(c1) } diff --git a/crates/fhe/src/bfv/context/cipher_plain_context.rs b/crates/fhe/src/bfv/context/cipher_plain_context.rs index 681ffd59..5ea993fd 100644 --- a/crates/fhe/src/bfv/context/cipher_plain_context.rs +++ b/crates/fhe/src/bfv/context/cipher_plain_context.rs @@ -1,4 +1,4 @@ -use fhe_math::rq::{Context, Poly, scaler::Scaler}; +use fhe_math::rq::{Context, NttShoup, Poly, scaler::Scaler}; use num_bigint::BigUint; use std::sync::Arc; @@ -9,7 +9,7 @@ use std::sync::Arc; #[derive(Debug, Clone, PartialEq, Eq)] pub struct CipherPlainContext { /// Scaling polynomial for the plaintext - pub(crate) delta: Poly, + pub(crate) delta: Poly, /// Q modulo the plaintext modulus pub(crate) q_mod_t: BigUint, @@ -33,7 +33,7 @@ impl CipherPlainContext { pub(crate) fn new_arc( plaintext_context: &Arc, ciphertext_context: &Arc, - delta: Poly, + delta: Poly, q_mod_t: BigUint, plain_threshold: BigUint, scaler: Scaler, diff --git a/crates/fhe/src/bfv/keys/evaluation_key.rs b/crates/fhe/src/bfv/keys/evaluation_key.rs index 25741e33..2ff0dffd 100644 --- a/crates/fhe/src/bfv/keys/evaluation_key.rs +++ b/crates/fhe/src/bfv/keys/evaluation_key.rs @@ -3,7 +3,7 @@ use crate::bfv::{BfvParameters, Ciphertext, SecretKey, keys::GaloisKey, traits::TryConvertFrom}; use crate::proto::bfv::{EvaluationKey as EvaluationKeyProto, GaloisKey as GaloisKeyProto}; use crate::{Error, Result}; -use fhe_math::rq::{Poly, Representation, traits::TryConvertFrom as TryConvertFromPoly}; +use fhe_math::rq::{NttShoup, Poly, PowerBasis, traits::TryConvertFrom as TryConvertFromPoly}; use fhe_math::zq::Modulus; use fhe_traits::{DeserializeParametrized, FheParametrized, Serialize}; use prost::Message; @@ -33,7 +33,7 @@ pub struct EvaluationKey { rot_to_gk_exponent: HashMap, /// Monomials used in expansion - monomials: Vec, + monomials: Vec>, } impl EvaluationKey { @@ -356,15 +356,10 @@ impl EvaluationKeyBuilder { for l in 0..self.sk.par.degree().ilog2() { let mut monomial = vec![0i64; self.sk.par.degree()]; monomial[self.sk.par.degree() - (1 << l)] = -1; - let mut monomial = Poly::try_convert_from( - &monomial, - ciphertext_ctx, - true, - Representation::PowerBasis, - )?; + let mut monomial = + Poly::::try_convert_from(&monomial, ciphertext_ctx, true)?; unsafe { monomial.allow_variable_time_computations() } - monomial.change_representation(Representation::NttShoup); - ek.monomials.push(monomial); + ek.monomials.push(monomial.into_ntt_shoup()); } for index in indices { @@ -419,15 +414,10 @@ impl TryConvertFrom<&EvaluationKeyProto> for EvaluationKey { for l in 0..par.degree().ilog2() { let mut monomial = vec![0i64; par.degree()]; monomial[par.degree() - (1 << l)] = -1; - let mut monomial = Poly::try_convert_from( - &monomial, - ciphertext_ctx, - true, - Representation::PowerBasis, - )?; + let mut monomial = + Poly::::try_convert_from(&monomial, ciphertext_ctx, true)?; unsafe { monomial.allow_variable_time_computations() } - monomial.change_representation(Representation::NttShoup); - monomials.push(monomial); + monomials.push(monomial.into_ntt_shoup()); } Ok(EvaluationKey { diff --git a/crates/fhe/src/bfv/keys/galois_key.rs b/crates/fhe/src/bfv/keys/galois_key.rs index b0e24df7..4bed4390 100644 --- a/crates/fhe/src/bfv/keys/galois_key.rs +++ b/crates/fhe/src/bfv/keys/galois_key.rs @@ -5,7 +5,7 @@ use crate::bfv::{BfvParameters, Ciphertext, SecretKey, traits::TryConvertFrom}; use crate::proto::bfv::{GaloisKey as GaloisKeyProto, KeySwitchingKey as KeySwitchingKeyProto}; use crate::{Error, Result}; use fhe_math::rq::{ - Poly, Representation, SubstitutionExponent, switcher::Switcher, + Ntt, Poly, PowerBasis, SubstitutionExponent, switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, }; use rand::{CryptoRng, RngCore}; @@ -37,15 +37,13 @@ impl GaloisKey { SubstitutionExponent::new(ctx_ciphertext, exponent).map_err(Error::MathError)?; let switcher_up = Switcher::new(ctx_ciphertext, ctx_galois_key)?; - let s = Zeroizing::new(Poly::try_convert_from( + let s = Zeroizing::new(Poly::::try_convert_from( sk.coeffs.as_ref(), ctx_ciphertext, false, - Representation::PowerBasis, )?); let s_sub = Zeroizing::new(s.substitute(&ciphertext_exponent)?); - let mut s_sub_switched_up = Zeroizing::new(s_sub.switch(&switcher_up)?); - s_sub_switched_up.change_representation(Representation::PowerBasis); + let s_sub_switched_up = Zeroizing::new(s_sub.switch(&switcher_up)?); let ksk = KeySwitchingKey::new( sk, @@ -66,17 +64,16 @@ impl GaloisKey { // assert_eq!(ct.par, self.ksk.par); assert_eq!(ct.len(), 2); - let mut c2 = ct[1].substitute(&self.element)?; - c2.change_representation(Representation::PowerBasis); + let c2 = ct[1].substitute(&self.element)?.into_power_basis(); let (mut c0, mut c1) = self.ksk.key_switch(&c2)?; if c0.ctx() != ct[0].ctx() { - c0.change_representation(Representation::PowerBasis); - c1.change_representation(Representation::PowerBasis); - c0.switch_down_to(ct[0].ctx())?; - c1.switch_down_to(ct[1].ctx())?; - c0.change_representation(Representation::Ntt); - c1.change_representation(Representation::Ntt); + let mut c0_pb = c0.into_power_basis(); + let mut c1_pb = c1.into_power_basis(); + c0_pb.switch_down_to(ct[0].ctx())?; + c1_pb.switch_down_to(ct[1].ctx())?; + c0 = c0_pb.into_ntt(); + c1 = c1_pb.into_ntt(); } c0 += &ct[0].substitute(&self.element)?; @@ -95,8 +92,8 @@ impl GaloisKey { if out.len() != 2 || out[0].ctx() != ct[0].ctx() || out[1].ctx() != ct[1].ctx() { out.c = vec![ - Poly::zero(ct[0].ctx(), Representation::Ntt), - Poly::zero(ct[1].ctx(), Representation::Ntt), + Poly::::zero(ct[0].ctx()), + Poly::::zero(ct[1].ctx()), ]; } out.par = ct.par.clone(); @@ -110,17 +107,16 @@ impl GaloisKey { out0.zeroize(); out1.zeroize(); - let mut c2 = ct[1].substitute(&self.element)?; - c2.change_representation(Representation::PowerBasis); + let c2 = ct[1].substitute(&self.element)?.into_power_basis(); self.ksk.key_switch_assign(&c2, out0, out1)?; if out0.ctx() != ct[0].ctx() { - out0.change_representation(Representation::PowerBasis); - out1.change_representation(Representation::PowerBasis); - out0.switch_down_to(ct[0].ctx())?; - out1.switch_down_to(ct[1].ctx())?; - out0.change_representation(Representation::Ntt); - out1.change_representation(Representation::Ntt); + let mut out0_pb = out0.clone().into_power_basis(); + let mut out1_pb = out1.clone().into_power_basis(); + out0_pb.switch_down_to(ct[0].ctx())?; + out1_pb.switch_down_to(ct[1].ctx())?; + *out0 = out0_pb.into_ntt(); + *out1 = out1_pb.into_ntt(); } *out0 += &ct[0].substitute(&self.element)?; diff --git a/crates/fhe/src/bfv/keys/key_switching_key.rs b/crates/fhe/src/bfv/keys/key_switching_key.rs index 83a289c6..c4ef1abc 100644 --- a/crates/fhe/src/bfv/keys/key_switching_key.rs +++ b/crates/fhe/src/bfv/keys/key_switching_key.rs @@ -7,7 +7,7 @@ use fhe_math::rq::Context; use fhe_math::rq::traits::TryConvertFrom; use fhe_math::{ rns::RnsContext, - rq::{Poly, Representation}, + rq::{Ntt, NttShoup, Poly, PowerBasis}, }; use fhe_traits::{DeserializeWithContext, Serialize}; use itertools::{Itertools, izip}; @@ -27,10 +27,10 @@ pub struct KeySwitchingKey { pub(crate) seed: Option<::Seed>, /// The key switching elements c0. - pub(crate) c0: Box<[Poly]>, + pub(crate) c0: Box<[Poly]>, /// The key switching elements c1. - pub(crate) c1: Box<[Poly]>, + pub(crate) c1: Box<[Poly]>, /// The level and context of the polynomials that will be key switched. pub(crate) ciphertext_level: usize, @@ -49,7 +49,7 @@ impl KeySwitchingKey { /// `from`. pub fn new( sk: &SecretKey, - from: &Poly, + from: &Poly, ciphertext_level: usize, ksk_level: usize, rng: &mut R, @@ -109,13 +109,13 @@ impl KeySwitchingKey { ctx: &Arc, seed: ::Seed, size: usize, - ) -> Vec { + ) -> Vec> { let mut c1 = Vec::with_capacity(size); let mut rng = ChaCha8Rng::from_seed(seed); (0..size).for_each(|_| { let mut seed_i = ::Seed::default(); rng.fill(&mut seed_i); - let mut a = Poly::random_from_seed(ctx, Representation::NttShoup, seed_i); + let mut a = Poly::::random_from_seed(ctx, seed_i); unsafe { a.allow_variable_time_computations() } c1.push(a); }); @@ -125,43 +125,35 @@ impl KeySwitchingKey { /// Generate the c0's from the c1's and the secret key fn generate_c0( sk: &SecretKey, - from: &Poly, - c1: &[Poly], + from: &Poly, + c1: &[Poly], rng: &mut R, - ) -> Result> { + ) -> Result>> { if c1.is_empty() { return Err(Error::DefaultError("Empty number of c1's".to_string())); } - if from.representation() != &Representation::PowerBasis { - return Err(Error::DefaultError( - "Unexpected representation for from".to_string(), - )); - } let size = c1.len(); - let mut s = Zeroizing::new(Poly::try_convert_from( - sk.coeffs.as_ref(), - c1[0].ctx(), - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); + let s = Zeroizing::new( + Poly::::try_convert_from(sk.coeffs.as_ref(), c1[0].ctx(), false)? + .into_ntt(), + ); let rns = RnsContext::new(&sk.par.moduli[..size])?; let c0 = c1 .iter() .enumerate() .map(|(i, c1i)| { - let mut a_s = Zeroizing::new(c1i.clone()); + let mut a_s = Zeroizing::new(c1i.clone().into_ntt()); a_s.disallow_variable_time_computations(); - a_s.change_representation(Representation::Ntt); *a_s.as_mut() *= s.as_ref(); - a_s.change_representation(Representation::PowerBasis); + let ctx = a_s.ctx().clone(); + let a_s_inner = std::mem::replace(a_s.as_mut(), Poly::::zero(&ctx)); + let a_s_pb = a_s_inner.into_power_basis(); - let mut b = - Poly::small(a_s.ctx(), Representation::PowerBasis, sk.par.variance, rng)?; - b -= &a_s; + let mut b = Poly::::small(a_s_pb.ctx(), sk.par.variance, rng)?; + b -= &a_s_pb; let gi = rns.get_garner(i).unwrap(); let g_i_from = Zeroizing::new(gi * from); @@ -169,10 +161,9 @@ impl KeySwitchingKey { // It is now safe to enable variable time computations. unsafe { b.allow_variable_time_computations() } - b.change_representation(Representation::NttShoup); - Ok(b) + Ok(b.into_ntt_shoup()) }) - .collect::>>()?; + .collect::>>>()?; Ok(c0) } @@ -180,58 +171,47 @@ impl KeySwitchingKey { /// Generate the c0's from the c1's and the secret key fn generate_c0_decomposition( sk: &SecretKey, - from: &Poly, - c1: &[Poly], + from: &Poly, + c1: &[Poly], rng: &mut R, log_base: usize, - ) -> Result> { + ) -> Result>> { if c1.is_empty() { return Err(Error::DefaultError("Empty number of c1's".to_string())); } - - if from.representation() != &Representation::PowerBasis { - return Err(Error::DefaultError( - "Unexpected representation for from".to_string(), - )); - } - - let mut s = Zeroizing::new(Poly::try_convert_from( - sk.coeffs.as_ref(), - c1[0].ctx(), - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); + let s = Zeroizing::new( + Poly::::try_convert_from(sk.coeffs.as_ref(), c1[0].ctx(), false)? + .into_ntt(), + ); let c0 = c1 .iter() .enumerate() .map(|(i, c1i)| { - let mut a_s = Zeroizing::new(c1i.clone()); + let mut a_s = Zeroizing::new(c1i.clone().into_ntt()); a_s.disallow_variable_time_computations(); - a_s.change_representation(Representation::Ntt); *a_s.as_mut() *= s.as_ref(); - a_s.change_representation(Representation::PowerBasis); + let ctx = a_s.ctx().clone(); + let a_s_inner = std::mem::replace(a_s.as_mut(), Poly::::zero(&ctx)); + let a_s_pb = a_s_inner.into_power_basis(); - let mut b = - Poly::small(a_s.ctx(), Representation::PowerBasis, sk.par.variance, rng)?; - b -= &a_s; + let mut b = Poly::::small(a_s_pb.ctx(), sk.par.variance, rng)?; + b -= &a_s_pb; let power = BigUint::from(1u64 << (i * log_base)); b += &(from * &power); // It is now safe to enable variable time computations. unsafe { b.allow_variable_time_computations() } - b.change_representation(Representation::NttShoup); - Ok(b) + Ok(b.into_ntt_shoup()) }) - .collect::>>()?; + .collect::>>>()?; Ok(c0) } /// Key switch a polynomial. - pub fn key_switch(&self, p: &Poly) -> Result<(Poly, Poly)> { + pub fn key_switch(&self, p: &Poly) -> Result<(Poly, Poly)> { if self.log_base != 0 { return self.key_switch_decomposition(p); } @@ -241,18 +221,14 @@ impl KeySwitchingKey { "The input polynomial does not have the correct context.".to_string(), )); } - if p.representation() != &Representation::PowerBasis { - return Err(Error::DefaultError("Incorrect representation".to_string())); - } - - let mut c0 = Poly::zero(&self.ctx_ksk, Representation::Ntt); - let mut c1 = Poly::zero(&self.ctx_ksk, Representation::Ntt); + let mut c0 = Poly::::zero(&self.ctx_ksk); + let mut c1 = Poly::::zero(&self.ctx_ksk); let p_coefficients = p.coefficients(); for (c2_i_coefficients, c0_i, c1_i) in izip!(p_coefficients.outer_iter(), self.c0.iter(), self.c1.iter()) { let mut c2_i = unsafe { - Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( + Poly::::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( c2_i_coefficients.as_slice().unwrap(), &self.ctx_ksk, ) @@ -265,7 +241,12 @@ impl KeySwitchingKey { } /// Key switch a polynomial, writing the result in-place. - pub fn key_switch_assign(&self, p: &Poly, c0: &mut Poly, c1: &mut Poly) -> Result<()> { + pub fn key_switch_assign( + &self, + p: &Poly, + c0: &mut Poly, + c1: &mut Poly, + ) -> Result<()> { if self.log_base != 0 { let (k0, k1) = self.key_switch_decomposition(p)?; *c0 = k0; @@ -278,20 +259,14 @@ impl KeySwitchingKey { "The input polynomial does not have the correct context.".to_string(), )); } - if p.representation() != &Representation::PowerBasis { - return Err(Error::DefaultError("Incorrect representation".to_string())); - } - - if c0.ctx().as_ref() != self.ctx_ksk.as_ref() || c0.representation() != &Representation::Ntt - { - *c0 = Poly::zero(&self.ctx_ksk, Representation::Ntt); + if c0.ctx().as_ref() != self.ctx_ksk.as_ref() { + *c0 = Poly::::zero(&self.ctx_ksk); } else { c0.zeroize(); } - if c1.ctx().as_ref() != self.ctx_ksk.as_ref() || c1.representation() != &Representation::Ntt - { - *c1 = Poly::zero(&self.ctx_ksk, Representation::Ntt); + if c1.ctx().as_ref() != self.ctx_ksk.as_ref() { + *c1 = Poly::::zero(&self.ctx_ksk); } else { c1.zeroize(); } @@ -301,7 +276,7 @@ impl KeySwitchingKey { izip!(p_coefficients.outer_iter(), self.c0.iter(), self.c1.iter()) { let mut c2_i = unsafe { - Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( + Poly::::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( c2_i_coefficients.as_slice().unwrap(), &self.ctx_ksk, ) @@ -314,15 +289,12 @@ impl KeySwitchingKey { } /// Key switch a polynomial. - fn key_switch_decomposition(&self, p: &Poly) -> Result<(Poly, Poly)> { + fn key_switch_decomposition(&self, p: &Poly) -> Result<(Poly, Poly)> { if p.ctx().as_ref() != self.ctx_ciphertext.as_ref() { return Err(Error::DefaultError( "The input polynomial does not have the correct context.".to_string(), )); } - if p.representation() != &Representation::PowerBasis { - return Err(Error::DefaultError("Incorrect representation".to_string())); - } let log_modulus = p .ctx() @@ -340,11 +312,11 @@ impl KeySwitchingKey { coefficients.iter_mut().for_each(|c| *c >>= self.log_base); }); - let mut c0 = Poly::zero(&self.ctx_ksk, Representation::Ntt); - let mut c1 = Poly::zero(&self.ctx_ksk, Representation::Ntt); + let mut c0 = Poly::::zero(&self.ctx_ksk); + let mut c1 = Poly::::zero(&self.ctx_ksk); for (c2_i_coefficients, c0_i, c1_i) in izip!(c2i.iter(), self.c0.iter(), self.c1.iter()) { let mut c2_i = unsafe { - Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( + Poly::::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( c2_i_coefficients.as_slice(), &self.ctx_ksk, ) @@ -429,15 +401,15 @@ impl BfvTryConvertFrom<&KeySwitchingKeyProto> for KeySwitchingKey { value .c1 .iter() - .map(|c1i| Poly::from_bytes(c1i, &ctx_ksk).map_err(Error::MathError)) - .collect::>>()? + .map(|c1i| Poly::::from_bytes(c1i, &ctx_ksk).map_err(Error::MathError)) + .collect::>>>()? }; let c0 = value .c0 .iter() - .map(|c0i| Poly::from_bytes(c0i, &ctx_ksk).map_err(Error::MathError)) - .collect::>>()?; + .map(|c0i| Poly::::from_bytes(c0i, &ctx_ksk).map_err(Error::MathError)) + .collect::>>>()?; Ok(Self { par: par.clone(), @@ -461,7 +433,7 @@ mod tests { use crate::proto::bfv::KeySwitchingKey as KeySwitchingKeyProto; use fhe_math::{ rns::RnsContext, - rq::{Poly, Representation, traits::TryConvertFrom as TryConvertFromPoly}, + rq::{Ntt, Poly, PowerBasis, traits::TryConvertFrom as TryConvertFromPoly}, }; use num_bigint::BigUint; use rand::rng; @@ -476,7 +448,7 @@ mod tests { ] { let sk = SecretKey::random(¶ms, &mut rng); let ctx = params.context_at_level(0)?; - let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let p = Poly::::small(ctx, 10, &mut rng)?; let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng); assert!(ksk.is_ok()); } @@ -493,7 +465,7 @@ mod tests { let level = params.moduli().len() - 1; let sk = SecretKey::random(¶ms, &mut rng); let ctx = params.context_at_level(level)?; - let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let p = Poly::::small(ctx, 10, &mut rng)?; let ksk = KeySwitchingKey::new(&sk, &p, level, level, &mut rng); assert!(ksk.is_ok()); } @@ -507,27 +479,20 @@ mod tests { for _ in 0..100 { let sk = SecretKey::random(¶ms, &mut rng); let ctx = params.context_at_level(0)?; - let mut p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let p = Poly::::small(ctx, 10, &mut rng)?; let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?; - let mut s = Poly::try_convert_from( - sk.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - ) - .map_err(crate::Error::MathError)?; - s.change_representation(Representation::Ntt); + let s = Poly::::try_convert_from(sk.coeffs.as_ref(), ctx, false) + .map_err(crate::Error::MathError)? + .into_ntt(); - let mut input = Poly::random(ctx, Representation::PowerBasis, &mut rng); + let input = Poly::::random(ctx, &mut rng); let (c0, c1) = ksk.key_switch(&input)?; - let mut c2 = &c0 + &(&c1 * &s); - c2.change_representation(Representation::PowerBasis); + let c2 = (&c0 + &(&c1 * &s)).into_power_basis(); - input.change_representation(Representation::Ntt); - p.change_representation(Representation::Ntt); - let mut c3 = &input * &p; - c3.change_representation(Representation::PowerBasis); + let input_ntt = input.into_ntt(); + let p_ntt = p.into_ntt(); + let c3 = (&input_ntt * &p_ntt).into_power_basis(); let rns = RnsContext::new(¶ms.moduli)?; Vec::::from(&(&c2 - &c3)).iter().for_each(|b| { @@ -545,14 +510,14 @@ mod tests { let params = BfvParameters::default_arc(6, 16); let sk = SecretKey::random(¶ms, &mut rng); let ctx = params.context_at_level(0)?; - let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let p = Poly::::small(ctx, 10, &mut rng)?; let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?; - let input = Poly::random(ctx, Representation::PowerBasis, &mut rng); + let input = Poly::::random(ctx, &mut rng); let (c0, c1) = ksk.key_switch(&input)?; - let mut a0 = Poly::zero(&ksk.ctx_ksk, Representation::Ntt); - let mut a1 = Poly::zero(&ksk.ctx_ksk, Representation::Ntt); + let mut a0 = Poly::::zero(&ksk.ctx_ksk); + let mut a1 = Poly::::zero(&ksk.ctx_ksk); ksk.key_switch_assign(&input, &mut a0, &mut a1)?; assert_eq!(c0, a0); @@ -568,27 +533,20 @@ mod tests { for _ in 0..100 { let sk = SecretKey::random(¶ms, &mut rng); let ctx = params.context_at_level(5)?; - let mut p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let p = Poly::::small(ctx, 10, &mut rng)?; let ksk = KeySwitchingKey::new(&sk, &p, 5, 5, &mut rng)?; - let mut s = Poly::try_convert_from( - sk.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - ) - .map_err(crate::Error::MathError)?; - s.change_representation(Representation::Ntt); + let s = Poly::::try_convert_from(sk.coeffs.as_ref(), ctx, false) + .map_err(crate::Error::MathError)? + .into_ntt(); - let mut input = Poly::random(ctx, Representation::PowerBasis, &mut rng); + let input = Poly::::random(ctx, &mut rng); let (c0, c1) = ksk.key_switch(&input)?; - let mut c2 = &c0 + &(&c1 * &s); - c2.change_representation(Representation::PowerBasis); + let c2 = (&c0 + &(&c1 * &s)).into_power_basis(); - input.change_representation(Representation::Ntt); - p.change_representation(Representation::Ntt); - let mut c3 = &input * &p; - c3.change_representation(Representation::PowerBasis); + let input_ntt = input.into_ntt(); + let p_ntt = p.into_ntt(); + let c3 = (&input_ntt * &p_ntt).into_power_basis(); let rns = RnsContext::new(ctx.moduli())?; Vec::::from(&(&c2 - &c3)).iter().for_each(|b| { @@ -611,7 +569,7 @@ mod tests { ] { let sk = SecretKey::random(¶ms, &mut rng); let ctx = params.context_at_level(0)?; - let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let p = Poly::::small(ctx, 10, &mut rng)?; let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?; let ksk_proto = KeySwitchingKeyProto::from(&ksk); assert_eq!(ksk, KeySwitchingKey::try_convert_from(&ksk_proto, ¶ms)?); diff --git a/crates/fhe/src/bfv/keys/public_key.rs b/crates/fhe/src/bfv/keys/public_key.rs index 893151e3..5d913f8f 100644 --- a/crates/fhe/src/bfv/keys/public_key.rs +++ b/crates/fhe/src/bfv/keys/public_key.rs @@ -4,7 +4,7 @@ use crate::bfv::traits::TryConvertFrom; use crate::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext}; use crate::proto::bfv::{Ciphertext as CiphertextProto, PublicKey as PublicKeyProto}; use crate::{Error, Result, SerializationError}; -use fhe_math::rq::{Poly, Representation}; +use fhe_math::rq::{Ntt, Poly}; use fhe_traits::{DeserializeParametrized, FheEncrypter, FheParametrized, Serialize}; use prost::Message; use rand::{CryptoRng, RngCore}; @@ -61,24 +61,9 @@ impl FheEncrypter for PublicKey { }; let ctx = self.par.context_at_level(ct.level)?; - let u = Zeroizing::new(Poly::small( - ctx, - Representation::Ntt, - self.par.variance, - rng, - )?); - let e1 = Zeroizing::new(Poly::small( - ctx, - Representation::Ntt, - self.par.variance, - rng, - )?); - let e2 = Zeroizing::new(Poly::small( - ctx, - Representation::Ntt, - self.par.variance, - rng, - )?); + let u = Zeroizing::new(Poly::::small(ctx, self.par.variance, rng)?); + let e1 = Zeroizing::new(Poly::::small(ctx, self.par.variance, rng)?); + let e2 = Zeroizing::new(Poly::::small(ctx, self.par.variance, rng)?); let m = Zeroizing::new(pt.to_poly()); let mut c0 = u.as_ref() * &ct[0]; diff --git a/crates/fhe/src/bfv/keys/relinearization_key.rs b/crates/fhe/src/bfv/keys/relinearization_key.rs index 7de36b32..996d6b02 100644 --- a/crates/fhe/src/bfv/keys/relinearization_key.rs +++ b/crates/fhe/src/bfv/keys/relinearization_key.rs @@ -9,7 +9,7 @@ use crate::proto::bfv::{ }; use crate::{Error, Result}; use fhe_math::rq::{ - Poly, Representation, switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, + Ntt, Poly, PowerBasis, switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, }; use fhe_traits::{DeserializeParametrized, FheParametrized, Serialize}; use prost::Message; @@ -55,15 +55,11 @@ impl RelinearizationKey { )); } - let mut s = Zeroizing::new(Poly::try_convert_from( - sk.coeffs.as_ref(), - ctx_ciphertext, - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); - let mut s2 = Zeroizing::new(s.as_ref() * s.as_ref()); - s2.change_representation(Representation::PowerBasis); + let s = Zeroizing::new( + Poly::::try_convert_from(sk.coeffs.as_ref(), ctx_ciphertext, false)? + .into_ntt(), + ); + let s2 = Zeroizing::new((s.as_ref() * s.as_ref()).into_power_basis()); let switcher_up = Switcher::new(ctx_ciphertext, ctx_relin_key)?; let s2_switched_up = Zeroizing::new(s2.switch(&switcher_up)?); let ksk = KeySwitchingKey::new(sk, &s2_switched_up, ciphertext_level, key_level, rng)?; @@ -82,18 +78,16 @@ impl RelinearizationKey { "Ciphertext has incorrect level".to_string(), )) } else { - let mut c2 = ct[2].clone(); - c2.change_representation(Representation::PowerBasis); - + let c2 = ct[2].clone().into_power_basis(); let (mut c0, mut c1) = self.relinearizes_poly(&c2)?; if c0.ctx() != ct[0].ctx() { - c0.change_representation(Representation::PowerBasis); - c1.change_representation(Representation::PowerBasis); - c0.switch_down_to(ct[0].ctx())?; - c1.switch_down_to(ct[1].ctx())?; - c0.change_representation(Representation::Ntt); - c1.change_representation(Representation::Ntt); + let mut c0_pb = c0.into_power_basis(); + let mut c1_pb = c1.into_power_basis(); + c0_pb.switch_down_to(ct[0].ctx())?; + c1_pb.switch_down_to(ct[1].ctx())?; + c0 = c0_pb.into_ntt(); + c1 = c1_pb.into_ntt(); } ct[0] += &c0; @@ -104,7 +98,10 @@ impl RelinearizationKey { } /// Relinearize using polynomials. - pub(crate) fn relinearizes_poly(&self, c2: &Poly) -> Result<(Poly, Poly)> { + pub(crate) fn relinearizes_poly( + &self, + c2: &Poly, + ) -> Result<(Poly, Poly)> { self.ksk.key_switch(c2) } } @@ -157,7 +154,7 @@ mod tests { use super::RelinearizationKey; use crate::bfv::{BfvParameters, Ciphertext, Encoding, SecretKey, traits::TryConvertFrom}; use crate::proto::bfv::RelinearizationKey as RelinearizationKeyProto; - use fhe_math::rq::{Poly, Representation, traits::TryConvertFrom as TryConvertFromPoly}; + use fhe_math::rq::{Ntt, Poly, PowerBasis, traits::TryConvertFrom as TryConvertFromPoly}; use fhe_traits::{FheDecoder, FheDecrypter}; use rand::rng; use std::error::Error; @@ -171,22 +168,16 @@ mod tests { let rk = RelinearizationKey::new(&sk, &mut rng)?; let ctx = params.context_at_level(0)?; - let mut s = Poly::try_convert_from( - sk.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - ) - .map_err(crate::Error::MathError)?; - s.change_representation(Representation::Ntt); + let s = Poly::::try_convert_from(sk.coeffs.as_ref(), ctx, false) + .map_err(crate::Error::MathError)? + .into_ntt(); let s2 = &s * &s; // Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 * s^2, // c1, c2) encrypting 0. - let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng); - let c1 = Poly::random(ctx, Representation::Ntt, &mut rng); - let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?; - c0.change_representation(Representation::Ntt); + let c2 = Poly::::random(ctx, &mut rng); + let c1 = Poly::::random(ctx, &mut rng); + let mut c0 = Poly::::small(ctx, 16, &mut rng)?.into_ntt(); c0 -= &(&c1 * &s); c0 -= &(&c2 * &s2); let mut ct = Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)?; @@ -196,14 +187,14 @@ mod tests { assert_eq!(ct.len(), 2); // Check that the relinearization by polynomials works the same way - c2.change_representation(Representation::PowerBasis); - let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?; - c0r.change_representation(Representation::PowerBasis); - c0r.switch_down_to(c0.ctx())?; - c1r.change_representation(Representation::PowerBasis); - c1r.switch_down_to(c1.ctx())?; - c0r.change_representation(Representation::Ntt); - c1r.change_representation(Representation::Ntt); + let c2_pb = c2.clone().into_power_basis(); + let (c0r, c1r) = rk.relinearizes_poly(&c2_pb)?; + let mut c0r_pb = c0r.into_power_basis(); + c0r_pb.switch_down_to(c0.ctx())?; + let mut c1r_pb = c1r.into_power_basis(); + c1r_pb.switch_down_to(c1.ctx())?; + let c0r = c0r_pb.into_ntt(); + let c1r = c1r_pb.into_ntt(); assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], ¶ms)?); // Print the noise and decrypt @@ -232,21 +223,16 @@ mod tests { )?; let ctx = params.context_at_level(ciphertext_level)?; - let mut s = Poly::try_convert_from( - sk.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - ) - .map_err(crate::Error::MathError)?; - s.change_representation(Representation::Ntt); + let s = + Poly::::try_convert_from(sk.coeffs.as_ref(), ctx, false) + .map_err(crate::Error::MathError)? + .into_ntt(); let s2 = &s * &s; // Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 * // s^2, c1, c2) encrypting 0. - let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng); - let c1 = Poly::random(ctx, Representation::Ntt, &mut rng); - let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?; - c0.change_representation(Representation::Ntt); + let c2 = Poly::::random(ctx, &mut rng); + let c1 = Poly::::random(ctx, &mut rng); + let mut c0 = Poly::::small(ctx, 16, &mut rng)?.into_ntt(); c0 -= &(&c1 * &s); c0 -= &(&c2 * &s2); let mut ct = @@ -257,14 +243,14 @@ mod tests { assert_eq!(ct.len(), 2); // Check that the relinearization by polynomials works the same way - c2.change_representation(Representation::PowerBasis); - let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?; - c0r.change_representation(Representation::PowerBasis); - c0r.switch_down_to(c0.ctx())?; - c1r.change_representation(Representation::PowerBasis); - c1r.switch_down_to(c1.ctx())?; - c0r.change_representation(Representation::Ntt); - c1r.change_representation(Representation::Ntt); + let c2_pb = c2.clone().into_power_basis(); + let (c0r, c1r) = rk.relinearizes_poly(&c2_pb)?; + let mut c0r_pb = c0r.into_power_basis(); + c0r_pb.switch_down_to(c0.ctx())?; + let mut c1r_pb = c1r.into_power_basis(); + c1r_pb.switch_down_to(c1.ctx())?; + let c0r = c0r_pb.into_ntt(); + let c1r = c1r_pb.into_ntt(); assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], ¶ms)?); // Print the noise and decrypt diff --git a/crates/fhe/src/bfv/keys/secret_key.rs b/crates/fhe/src/bfv/keys/secret_key.rs index f491a312..3c7ff088 100644 --- a/crates/fhe/src/bfv/keys/secret_key.rs +++ b/crates/fhe/src/bfv/keys/secret_key.rs @@ -6,7 +6,7 @@ use crate::bfv::{ use crate::proto::bfv::SecretKey as SecretKeyProto; use crate::{Error, Result, SerializationError}; use fhe_math::{ - rq::{Poly, Representation, traits::TryConvertFrom}, + rq::{Ntt, Poly, PowerBasis, traits::TryConvertFrom}, zq::Modulus, }; use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncrypter, FheParametrized, Serialize}; @@ -65,13 +65,10 @@ impl SecretKey { let m = Zeroizing::new(plaintext.to_poly()); // Let's create a secret key with the ciphertext context - let mut s = Zeroizing::new(Poly::try_convert_from( - self.coeffs.as_ref(), - ct[0].ctx(), - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); + let s = Zeroizing::new( + Poly::::try_convert_from(self.coeffs.as_ref(), ct[0].ctx(), false)? + .into_ntt(), + ); let mut si = s.clone(); // Let's disable variable time computations @@ -86,11 +83,13 @@ impl SecretKey { *si.as_mut() *= s.as_ref(); } *c.as_mut() -= &m; - c.change_representation(Representation::PowerBasis); + let ctx = c.ctx().clone(); + let c_inner = std::mem::replace(c.as_mut(), Poly::::zero(&ctx)); + let c = c_inner.into_power_basis(); let ciphertext_modulus = ct[0].ctx().modulus(); let mut noise = 0usize; - for coeff in Vec::::from(c.as_ref()) { + for coeff in Vec::::from(&c) { noise = std::cmp::max( noise, std::cmp::min(coeff.bits(), (ciphertext_modulus - &coeff).bits()) as usize, @@ -102,30 +101,24 @@ impl SecretKey { pub(crate) fn encrypt_poly( &self, - p: &Poly, + p: &Poly, rng: &mut R, ) -> Result { - assert_eq!(p.representation(), &Representation::Ntt); - let level = self.par.level_of_context(p.ctx())?; let mut seed = ::Seed::default(); rand::rng().fill(&mut seed); // Let's create a secret key with the ciphertext context - let mut s = Zeroizing::new(Poly::try_convert_from( - self.coeffs.as_ref(), - p.ctx(), - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); - - let mut a = Poly::random_from_seed(p.ctx(), Representation::Ntt, seed); + let s = Zeroizing::new( + Poly::::try_convert_from(self.coeffs.as_ref(), p.ctx(), false)?.into_ntt(), + ); + + let mut a = Poly::::random_from_seed(p.ctx(), seed); let a_s = Zeroizing::new(&a * s.as_ref()); - let mut b = Poly::small(p.ctx(), Representation::Ntt, self.par.variance, rng) - .map_err(Error::MathError)?; + let mut b = + Poly::::small(p.ctx(), self.par.variance, rng).map_err(Error::MathError)?; b -= &a_s; b += p; @@ -211,13 +204,10 @@ impl FheDecrypter for SecretKey { )) } else { // Let's create a secret key with the ciphertext context - let mut s = Zeroizing::new(Poly::try_convert_from( - self.coeffs.as_ref(), - ct[0].ctx(), - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); + let s = Zeroizing::new( + Poly::::try_convert_from(self.coeffs.as_ref(), ct[0].ctx(), false)? + .into_ntt(), + ); let mut si = s.clone(); let mut c = Zeroizing::new(ct[0].clone()); @@ -234,10 +224,11 @@ impl FheDecrypter for SecretKey { *si.as_mut() *= s.as_ref(); } } - c.change_representation(Representation::PowerBasis); - let ctx_lvl = self.par.context_level_at(ct.level).unwrap(); - let d = Zeroizing::new(c.scale(&ctx_lvl.cipher_plain_context.scaler)?); + let ctx = c.ctx().clone(); + let c_inner = std::mem::replace(c.as_mut(), Poly::::zero(&ctx)); + let c_pb = c_inner.into_power_basis(); + let d = Zeroizing::new(c_pb.scale(&ctx_lvl.cipher_plain_context.scaler)?); let value = match self.par.plaintext { PlaintextModulus::Small { .. } => { @@ -268,22 +259,15 @@ impl FheDecrypter for SecretKey { } }; - let mut poly = match &value { - PlaintextValues::Small(v) => Poly::try_convert_from( - v.as_ref(), - ct[0].ctx(), - false, - Representation::PowerBasis, - )?, - PlaintextValues::Large(v) => Poly::try_convert_from( - v.as_ref(), - ct[0].ctx(), - false, - Representation::PowerBasis, - )?, - }; - - poly.change_representation(Representation::Ntt); + let poly = match &value { + PlaintextValues::Small(v) => { + Poly::::try_convert_from(v.as_ref(), ct[0].ctx(), false)? + } + PlaintextValues::Large(v) => { + Poly::::try_convert_from(v.as_ref(), ct[0].ctx(), false)? + } + } + .into_ntt(); let pt = Plaintext { par: self.par.clone(), diff --git a/crates/fhe/src/bfv/ops/dot_product.rs b/crates/fhe/src/bfv/ops/dot_product.rs index 50f29404..ca4408f1 100644 --- a/crates/fhe/src/bfv/ops/dot_product.rs +++ b/crates/fhe/src/bfv/ops/dot_product.rs @@ -1,6 +1,6 @@ use std::cmp::min; -use fhe_math::rq::{Poly, Representation, dot_product as poly_dot_product, traits::TryConvertFrom}; +use fhe_math::rq::{Ntt, Poly, dot_product as poly_dot_product, traits::TryConvertFrom}; use itertools::{Itertools, izip}; use ndarray::{Array, Array2}; @@ -96,7 +96,7 @@ where ) .map_err(Error::MathError) }) - .collect::>>()?; + .collect::>>>()?; Ok(Ciphertext { par: ct_first.par.clone(), @@ -139,12 +139,7 @@ where unsafe { *outij_coeff = q.reduce_u128_vt(*accij_coeff) } } } - c.push(Poly::try_convert_from( - coeffs, - ctx, - true, - Representation::Ntt, - )?) + c.push(Poly::::try_convert_from(coeffs, ctx, true)?) } Ok(Ciphertext { diff --git a/crates/fhe/src/bfv/ops/mod.rs b/crates/fhe/src/bfv/ops/mod.rs index 95197282..156db7b1 100644 --- a/crates/fhe/src/bfv/ops/mod.rs +++ b/crates/fhe/src/bfv/ops/mod.rs @@ -8,7 +8,7 @@ pub use mul::Multiplicator; use super::{Ciphertext, Plaintext}; use crate::{Error, Result}; -use fhe_math::rq::{Poly, Representation}; +use fhe_math::rq::{Ntt, Poly}; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::sync::Arc; @@ -273,11 +273,11 @@ impl Mul<&Ciphertext> for &Ciphertext { let self_c = self .iter() .map(|ci| ci.scale(&mp.extender).map_err(Error::MathError)) - .collect::>>() + .collect::>>>() .unwrap(); // Multiply - let mut c = vec![Poly::zero(&mp.to, Representation::Ntt); 2 * self_c.len() - 1]; + let mut c = vec![Poly::::zero(&mp.to); 2 * self_c.len() - 1]; for i in 0..self_c.len() { for j in 0..self_c.len() { c[i + j] += &(&self_c[i] * &self_c[j]) @@ -287,13 +287,8 @@ impl Mul<&Ciphertext> for &Ciphertext { // Scale let c = c .iter_mut() - .map(|ci| { - ci.change_representation(Representation::PowerBasis); - let mut ci = ci.scale(&mp.down_scaler).map_err(Error::MathError)?; - ci.change_representation(Representation::Ntt); - Ok(ci) - }) - .collect::>>() + .map(|ci| ci.scale(&mp.down_scaler).map_err(Error::MathError)) + .collect::>>>() .unwrap(); Ciphertext { @@ -313,17 +308,16 @@ impl Mul<&Ciphertext> for &Ciphertext { let self_c = self .iter() .map(|ci| ci.scale(&mp.extender).map_err(Error::MathError)) - .collect::>>() + .collect::>>>() .unwrap(); let other_c = rhs .iter() .map(|ci| ci.scale(&mp.extender).map_err(Error::MathError)) - .collect::>>() + .collect::>>>() .unwrap(); // Multiply - let mut c = - vec![Poly::zero(&mp.to, Representation::Ntt); self_c.len() + other_c.len() - 1]; + let mut c = vec![Poly::::zero(&mp.to); self_c.len() + other_c.len() - 1]; for i in 0..self_c.len() { for j in 0..other_c.len() { c[i + j] += &(&self_c[i] * &other_c[j]) @@ -333,13 +327,8 @@ impl Mul<&Ciphertext> for &Ciphertext { // Scale let c = c .iter_mut() - .map(|ci| { - ci.change_representation(Representation::PowerBasis); - let mut ci = ci.scale(&mp.down_scaler).map_err(Error::MathError)?; - ci.change_representation(Representation::Ntt); - Ok(ci) - }) - .collect::>>() + .map(|ci| ci.scale(&mp.down_scaler).map_err(Error::MathError)) + .collect::>>>() .unwrap(); Ciphertext { diff --git a/crates/fhe/src/bfv/ops/mul.rs b/crates/fhe/src/bfv/ops/mul.rs index bb84dadb..e473e192 100644 --- a/crates/fhe/src/bfv/ops/mul.rs +++ b/crates/fhe/src/bfv/ops/mul.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use fhe_math::{ rns::ScalingFactor, - rq::{Context, Representation, scaler::Scaler}, + rq::{Context, scaler::Scaler}, zq::primes::generate_prime, }; @@ -179,13 +179,10 @@ impl Multiplicator { let c11 = rhs[1].scale(&self.extender_rhs)?; // Multiply - let mut c0 = &c00 * &c10; + let c0 = &c00 * &c10; let mut c1 = &c00 * &c11; c1 += &(&c01 * &c10); - let mut c2 = &c01 * &c11; - c0.change_representation(Representation::PowerBasis); - c1.change_representation(Representation::PowerBasis); - c2.change_representation(Representation::PowerBasis); + let c2 = &c01 * &c11; // Scale let c0 = c0.scale(&self.down_scaler)?; @@ -196,16 +193,16 @@ impl Multiplicator { // Relinearize if let Some(rk) = self.rk.as_ref() { - let (mut c0r, mut c1r) = rk.relinearizes_poly(&c[2])?; + let c2_pb = c[2].clone().into_power_basis(); + let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2_pb)?; if c0r.ctx() != c[0].ctx() { - c0r.change_representation(Representation::PowerBasis); - c1r.change_representation(Representation::PowerBasis); - c0r.switch_down_to(c[0].ctx())?; - c1r.switch_down_to(c[1].ctx())?; - } else { - c[0].change_representation(Representation::Ntt); - c[1].change_representation(Representation::Ntt); + let mut c0r_pb = c0r.into_power_basis(); + let mut c1r_pb = c1r.into_power_basis(); + c0r_pb.switch_down_to(c[0].ctx())?; + c1r_pb.switch_down_to(c[1].ctx())?; + c0r = c0r_pb.into_ntt(); + c1r = c1r_pb.into_ntt(); } c[0] += &c0r; @@ -224,9 +221,6 @@ impl Multiplicator { if self.mod_switch { c.switch_down()?; - } else { - c.iter_mut() - .for_each(|p| p.change_representation(Representation::Ntt)); } Ok(c) diff --git a/crates/fhe/src/bfv/parameters.rs b/crates/fhe/src/bfv/parameters.rs index 41ca62e6..6240d92d 100644 --- a/crates/fhe/src/bfv/parameters.rs +++ b/crates/fhe/src/bfv/parameters.rs @@ -6,7 +6,7 @@ use crate::{Error, ParametersError, Result, SerializationError}; use fhe_math::{ ntt::NttOperator, rns::{RnsContext, ScalingFactor}, - rq::{Context, Poly, Representation, scaler::Scaler, traits::TryConvertFrom}, + rq::{Context, Poly, PowerBasis, scaler::Scaler, traits::TryConvertFrom}, zq::{Modulus, primes::generate_prime}, }; use fhe_traits::{Deserialize, FheParameters, Serialize}; @@ -526,13 +526,12 @@ impl BfvParametersBuilder { // Use RnsContext to lift the delta values and create the scaling polynomial let rns = RnsContext::new(level_moduli)?; - let mut delta = Poly::try_convert_from( + let delta = Poly::::try_convert_from( &[rns.lift((&delta_rests).into())], &cipher_ctx, true, - Representation::PowerBasis, - )?; - delta.change_representation(Representation::NttShoup); + )? + .into_ntt_shoup(); // Compute q_mod_t let q_mod_t = rns.modulus() % plaintext_big; diff --git a/crates/fhe/src/bfv/plaintext.rs b/crates/fhe/src/bfv/plaintext.rs index 036e0202..36341063 100644 --- a/crates/fhe/src/bfv/plaintext.rs +++ b/crates/fhe/src/bfv/plaintext.rs @@ -3,7 +3,7 @@ use crate::{ Error, Result, bfv::{BfvParameters, Encoding, PlaintextVec, parameters::PlaintextModulus}, }; -use fhe_math::rq::{Context, Poly, Representation, traits::TryConvertFrom}; +use fhe_math::rq::{Context, Ntt, Poly, PowerBasis, traits::TryConvertFrom}; use fhe_traits::{FheDecoder, FheEncoder, FheParametrized, FhePlaintext}; use num_bigint::{BigInt, BigUint, Sign}; use num_traits::{ToPrimitive, Zero}; @@ -41,7 +41,7 @@ pub struct Plaintext { /// The encoding of the plaintext, if known pub(crate) encoding: Option, /// The plaintext as a polynomial. - pub(crate) poly_ntt: Poly, + pub(crate) poly_ntt: Poly, /// The level of the plaintext pub(crate) level: usize, } @@ -68,11 +68,11 @@ impl FhePlaintext for Plaintext { } impl Plaintext { - pub(crate) fn to_poly(&self) -> Poly { + pub(crate) fn to_poly(&self) -> Poly { let ctx_lvl = self.par.context_level_at(self.level).unwrap(); let ctx = &ctx_lvl.poly_context; - let mut m = match &self.value { + let m = match &self.value { PlaintextValues::Small(v) => { let mut m_v = Zeroizing::new(v.clone()); if let PlaintextModulus::Small { modulus, .. } = &self.par.plaintext { @@ -81,20 +81,18 @@ impl Plaintext { } else { unreachable!("PlaintextValues::Small but PlaintextModulus::Large"); } - Poly::try_convert_from(m_v.as_ref(), ctx, false, Representation::PowerBasis) - .unwrap() + Poly::::try_convert_from(m_v.as_ref(), ctx, false).unwrap() } PlaintextValues::Large(v) => { let mut m_v = v.clone(); self.par .plaintext .scalar_mul_vec(&mut m_v, &ctx_lvl.cipher_plain_context.q_mod_t); - Poly::try_convert_from(m_v.as_ref(), ctx, false, Representation::PowerBasis) - .unwrap() + Poly::::try_convert_from(m_v.as_ref(), ctx, false).unwrap() } }; - m.change_representation(Representation::Ntt); + let mut m = m.into_ntt(); m *= &ctx_lvl.cipher_plain_context.delta; m } @@ -111,7 +109,7 @@ impl Plaintext { PlaintextValues::Large(vec![BigUint::zero(); par.degree()].into_boxed_slice()) } }; - let poly_ntt = Poly::zero(ctx, Representation::Ntt); + let poly_ntt = Poly::::zero(ctx); Ok(Self { par: par.clone(), value, @@ -159,16 +157,12 @@ impl PartialEq for Plaintext { } // Conversions. -impl TryConvertFrom<&Plaintext> for Poly { - fn try_convert_from( +impl TryConvertFrom<&Plaintext> for Poly { + fn try_convert_from( pt: &Plaintext, ctx: &Arc, variable_time: bool, - _: R, - ) -> fhe_math::Result - where - R: Into>, - { + ) -> fhe_math::Result { if ctx != pt .par @@ -180,18 +174,12 @@ impl TryConvertFrom<&Plaintext> for Poly { )) } else { match &pt.value { - PlaintextValues::Small(v) => Poly::try_convert_from( - v.as_ref(), - ctx, - variable_time, - Representation::PowerBasis, - ), - PlaintextValues::Large(v) => Poly::try_convert_from( - v.as_ref(), - ctx, - variable_time, - Representation::PowerBasis, - ), + PlaintextValues::Small(v) => { + Poly::::try_convert_from(v.as_ref(), ctx, variable_time) + } + PlaintextValues::Large(v) => { + Poly::::try_convert_from(v.as_ref(), ctx, variable_time) + } } } } @@ -463,7 +451,7 @@ mod tests { use super::{Encoding, Plaintext}; use crate::bfv::parameters::{BfvParameters, BfvParametersBuilder}; use crate::bfv::plaintext::PlaintextValues; - use fhe_math::rq::{Poly, Representation}; + use fhe_math::rq::{Ntt, Poly}; use fhe_traits::{FheDecoder, FheEncoder}; use num_bigint::BigUint; use num_traits::Zero; @@ -656,7 +644,7 @@ mod tests { ); assert_eq!( plaintext.poly_ntt, - Poly::zero(params.context_at_level(0)?, Representation::Ntt) + Poly::::zero(params.context_at_level(0)?) ); Ok(()) diff --git a/crates/fhe/src/bfv/plaintext_vec.rs b/crates/fhe/src/bfv/plaintext_vec.rs index e4ebe06b..1deac15a 100644 --- a/crates/fhe/src/bfv/plaintext_vec.rs +++ b/crates/fhe/src/bfv/plaintext_vec.rs @@ -1,6 +1,6 @@ use std::{cmp::min, ops::Deref, sync::Arc}; -use fhe_math::rq::{Poly, Representation, traits::TryConvertFrom}; +use fhe_math::rq::{Poly, PowerBasis, traits::TryConvertFrom}; use fhe_traits::{FheEncoder, FheEncoderVariableTime, FheParametrized, FhePlaintext}; use num_bigint::BigUint; use num_traits::{ToPrimitive, Zero}; @@ -73,9 +73,7 @@ impl FheEncoderVariableTime<&[u64]> for PlaintextVec { } }; - let mut poly = - Poly::try_convert_from(&v, ctx, true, Representation::PowerBasis)?; - poly.change_representation(Representation::Ntt); + let poly = Poly::::try_convert_from(&v, ctx, true)?.into_ntt(); let value_enum = match par.plaintext { crate::bfv::PlaintextModulus::Small { .. } => { @@ -143,13 +141,8 @@ impl FheEncoder<&[BigUint]> for PlaintextVec { } }; - let mut poly = Poly::try_convert_from( - v.as_slice(), - ctx, - false, - Representation::PowerBasis, - )?; - poly.change_representation(Representation::Ntt); + let poly = + Poly::::try_convert_from(v.as_slice(), ctx, false)?.into_ntt(); let value_enum = match &par.plaintext { crate::bfv::PlaintextModulus::Small { modulus_big, .. } => { @@ -213,9 +206,7 @@ impl FheEncoder<&[u64]> for PlaintextVec { } }; - let mut poly = - Poly::try_convert_from(&v, ctx, false, Representation::PowerBasis)?; - poly.change_representation(Representation::Ntt); + let poly = Poly::::try_convert_from(&v, ctx, false)?.into_ntt(); let value_enum = match par.plaintext { crate::bfv::PlaintextModulus::Small { .. } => { diff --git a/crates/fhe/src/bfv/rgsw_ciphertext.rs b/crates/fhe/src/bfv/rgsw_ciphertext.rs index 374fc06d..f25a4c6d 100644 --- a/crates/fhe/src/bfv/rgsw_ciphertext.rs +++ b/crates/fhe/src/bfv/rgsw_ciphertext.rs @@ -4,7 +4,7 @@ use crate::proto::bfv::{ KeySwitchingKey as KeySwitchingKeyProto, RgswCiphertext as RGSWCiphertextProto, }; use crate::{Error, Result, SerializationError}; -use fhe_math::rq::{Poly, Representation, traits::TryConvertFrom as TryConvertFromPoly}; +use fhe_math::rq::{Ntt, Poly, PowerBasis, traits::TryConvertFrom as TryConvertFromPoly}; use fhe_traits::{ DeserializeParametrized, FheCiphertext, FheEncrypter, FheParametrized, Serialize, }; @@ -104,17 +104,14 @@ impl FheEncrypter for SecretKey { let level = pt.level; let ctx = self.par.context_at_level(level)?; - let mut m = Zeroizing::new(pt.poly_ntt.clone()); - let mut m_s = Zeroizing::new(Poly::try_convert_from( - self.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - )?); - m_s.change_representation(Representation::Ntt); - *m_s.as_mut() *= m.as_ref(); - m_s.change_representation(Representation::PowerBasis); - m.change_representation(Representation::PowerBasis); + let m = Zeroizing::new(pt.poly_ntt.clone().into_power_basis()); + let mut m_s = Zeroizing::new( + Poly::::try_convert_from(self.coeffs.as_ref(), ctx, false)?.into_ntt(), + ); + *m_s.as_mut() *= pt.poly_ntt.as_ref(); + let ctx = m_s.ctx().clone(); + let m_s_inner = std::mem::replace(m_s.as_mut(), Poly::::zero(&ctx)); + let m_s = Zeroizing::new(m_s_inner.into_power_basis()); let ksk0 = KeySwitchingKey::new(self, &m, pt.level, pt.level, rng)?; let ksk1 = KeySwitchingKey::new(self, &m_s, pt.level, pt.level, rng)?; @@ -137,17 +134,15 @@ impl Mul<&RGSWCiphertext> for &Ciphertext { ); assert_eq!(self.len(), 2, "Ciphertext must have two parts"); - let mut ct0 = self[0].clone(); - let mut ct1 = self[1].clone(); - ct0.change_representation(Representation::PowerBasis); - ct1.change_representation(Representation::PowerBasis); + let ct0 = self[0].clone().into_power_basis(); + let ct1 = self[1].clone().into_power_basis(); - let mut c0 = Poly::zero(&rhs.ksk0.ctx_ksk, Representation::Ntt); - let mut c1 = Poly::zero(&rhs.ksk0.ctx_ksk, Representation::Ntt); + let mut c0 = Poly::::zero(&rhs.ksk0.ctx_ksk); + let mut c1 = Poly::::zero(&rhs.ksk0.ctx_ksk); rhs.ksk0.key_switch_assign(&ct0, &mut c0, &mut c1).unwrap(); - let mut c0p = Poly::zero(&rhs.ksk1.ctx_ksk, Representation::Ntt); - let mut c1p = Poly::zero(&rhs.ksk1.ctx_ksk, Representation::Ntt); + let mut c0p = Poly::::zero(&rhs.ksk1.ctx_ksk); + let mut c1p = Poly::::zero(&rhs.ksk1.ctx_ksk); rhs.ksk1 .key_switch_assign(&ct1, &mut c0p, &mut c1p) .unwrap(); diff --git a/crates/fhe/src/mbfv/crp.rs b/crates/fhe/src/mbfv/crp.rs index ec3ddf95..11f426ea 100644 --- a/crates/fhe/src/mbfv/crp.rs +++ b/crates/fhe/src/mbfv/crp.rs @@ -2,14 +2,14 @@ use std::sync::Arc; use crate::Result; use crate::bfv::BfvParameters; -use fhe_math::rq::Poly; +use fhe_math::rq::{Ntt, Poly}; use rand::{CryptoRng, RngCore}; /// A polynomial sampled from a random _common reference string_. // TODO CRS->CRP implementation. For now just a random polynomial. #[derive(Debug, PartialEq, Eq, Clone)] pub struct CommonRandomPoly { - pub(crate) poly: Poly, + pub(crate) poly: Poly, } impl CommonRandomPoly { @@ -38,7 +38,7 @@ impl CommonRandomPoly { rng: &mut R, ) -> Result { let ctx = par.context_at_level(level)?; - let poly = Poly::random(ctx, fhe_math::rq::Representation::Ntt, rng); + let poly = Poly::::random(ctx, rng); Ok(Self { poly }) } } diff --git a/crates/fhe/src/mbfv/public_key_gen.rs b/crates/fhe/src/mbfv/public_key_gen.rs index 3915aa80..08db94bc 100644 --- a/crates/fhe/src/mbfv/public_key_gen.rs +++ b/crates/fhe/src/mbfv/public_key_gen.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use crate::Error; use crate::bfv::{BfvParameters, Ciphertext, PublicKey, SecretKey}; use crate::errors::Result; -use fhe_math::rq::{Poly, Representation, traits::TryConvertFrom}; +use fhe_math::rq::{Ntt, Poly, PowerBasis, traits::TryConvertFrom}; use rand::{CryptoRng, RngCore}; use zeroize::Zeroizing; @@ -16,7 +16,7 @@ use super::{Aggregate, CommonRandomPoly}; pub struct PublicKeyShare { pub(crate) par: Arc, pub(crate) crp: CommonRandomPoly, - pub(crate) p0_share: Poly, + pub(crate) p0_share: Poly, } impl PublicKeyShare { @@ -39,20 +39,15 @@ impl PublicKeyShare { let ctx = par.context_at_level(0)?; // Convert secret key to usable polynomial - let mut s = Zeroizing::new(Poly::try_convert_from( - sk_share.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); + let s = Zeroizing::new( + Poly::::try_convert_from(sk_share.coeffs.as_ref(), ctx, false)?.into_ntt(), + ); // Sample error - let e = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?); + let e = Zeroizing::new(Poly::::small(ctx, par.variance, rng)?); // Create p0_i share let mut p0_share = -crp.poly.clone(); p0_share.disallow_variable_time_computations(); - p0_share.change_representation(Representation::Ntt); p0_share *= s.as_ref(); p0_share += e.as_ref(); unsafe { p0_share.allow_variable_time_computations() } diff --git a/crates/fhe/src/mbfv/public_key_switch.rs b/crates/fhe/src/mbfv/public_key_switch.rs index f7c711b4..8de2c71b 100644 --- a/crates/fhe/src/mbfv/public_key_switch.rs +++ b/crates/fhe/src/mbfv/public_key_switch.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use fhe_math::rq::traits::TryConvertFrom; -use fhe_math::rq::{Poly, Representation}; +use fhe_math::rq::{Ntt, Poly, PowerBasis}; use rand::{CryptoRng, RngCore}; use zeroize::Zeroizing; @@ -18,9 +18,9 @@ use super::Aggregate; pub struct PublicKeySwitchShare { pub(crate) par: Arc, /// The first component of the input ciphertext - pub(crate) c0: Poly, - pub(crate) h0_share: Poly, - pub(crate) h1_share: Poly, + pub(crate) c0: Poly, + pub(crate) h0_share: Poly, + pub(crate) h1_share: Poly, } impl PublicKeySwitchShare { @@ -50,19 +50,15 @@ impl PublicKeySwitchShare { } let ctx = par.context_at_level(ct.level)?; - let mut s = Zeroizing::new(Poly::try_convert_from( - sk_share.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); + let mut s = Zeroizing::new( + Poly::::try_convert_from(sk_share.coeffs.as_ref(), ctx, false)?.into_ntt(), + ); s.disallow_variable_time_computations(); - let u = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?); + let u = Zeroizing::new(Poly::::small(ctx, par.variance, rng)?); // TODO this should be exponential in ciphertext noise! - let e0 = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?); - let e1 = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?); + let e0 = Zeroizing::new(Poly::::small(ctx, par.variance, rng)?); + let e1 = Zeroizing::new(Poly::::small(ctx, par.variance, rng)?); let mut h0 = pk_ct[0].clone(); h0.disallow_variable_time_computations(); diff --git a/crates/fhe/src/mbfv/relin_key_gen.rs b/crates/fhe/src/mbfv/relin_key_gen.rs index b5c60455..d56024a7 100644 --- a/crates/fhe/src/mbfv/relin_key_gen.rs +++ b/crates/fhe/src/mbfv/relin_key_gen.rs @@ -5,7 +5,7 @@ use crate::Error; use crate::bfv::{BfvParameters, KeySwitchingKey, RelinearizationKey, SecretKey}; use crate::errors::Result; use fhe_math::rns::RnsContext; -use fhe_math::rq::{Poly, Representation, traits::TryConvertFrom}; +use fhe_math::rq::{Ntt, NttShoup, Poly, PowerBasis, traits::TryConvertFrom}; use itertools::izip; use rand::{CryptoRng, RngCore}; use zeroize::Zeroizing; @@ -18,8 +18,8 @@ use super::{Aggregate, CommonRandomPoly}; #[derive(Debug, PartialEq, Eq, Clone)] pub struct RelinKeyShare { pub(crate) par: Arc, - pub(crate) h0: Box<[Poly]>, - pub(crate) h1: Box<[Poly]>, + pub(crate) h0: Box<[Poly]>, + pub(crate) h1: Box<[Poly]>, last_round: Option>>, _phantom_data: PhantomData, } @@ -66,7 +66,7 @@ pub struct RelinKeyShare { pub struct RelinKeyGenerator<'a, 'b> { sk_share: &'a SecretKey, crp: &'b [CommonRandomPoly], - u: Zeroizing, + u: Zeroizing>, } impl<'a, 'b> RelinKeyGenerator<'a, 'b> { @@ -91,7 +91,7 @@ impl<'a, 'b> RelinKeyGenerator<'a, 'b> { .to_string(), )) } else { - let u = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?); + let u = Zeroizing::new(Poly::::small(ctx, par.variance, rng)?); Ok(Self { sk_share, crp, u }) } } @@ -115,7 +115,7 @@ impl RelinKeyShare { fn new( sk_share: &SecretKey, crp: &[CommonRandomPoly], - u: &Zeroizing, + u: &Zeroizing>, rng: &mut R, ) -> Result { let par = sk_share.par.clone(); @@ -141,32 +141,27 @@ impl RelinKeyShare { fn generate_h0( sk_share: &SecretKey, crp: &[CommonRandomPoly], - u: &Zeroizing, + u: &Zeroizing>, rng: &mut R, - ) -> Result> { + ) -> Result]>> { let par = sk_share.par.clone(); let ctx = par.context_at_level(0)?; - let s = Zeroizing::new(Poly::try_convert_from( - sk_share.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - )?); + let s = Zeroizing::new( + Poly::::try_convert_from(sk_share.coeffs.as_ref(), ctx, false)?.into_ntt(), + ); let rns = RnsContext::new(&sk_share.par.moduli[..crp.len()])?; let h0 = crp .iter() .enumerate() .map(|(i, a)| { let w = rns.get_garner(i).unwrap(); - let mut w_s = Zeroizing::new(w * s.as_ref()); - w_s.change_representation(Representation::Ntt); + let w_s = Zeroizing::new(w * s.as_ref()); - let e = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?); + let e = Zeroizing::new(Poly::::small(ctx, par.variance, rng)?); let mut h = -a.poly.clone(); h.disallow_variable_time_computations(); - h.change_representation(Representation::Ntt); h *= u.as_ref(); h += w_s.as_ref(); h += e.as_ref(); @@ -180,24 +175,19 @@ impl RelinKeyShare { sk_share: &SecretKey, crp: &[CommonRandomPoly], rng: &mut R, - ) -> Result> { + ) -> Result]>> { let par = sk_share.par.clone(); let ctx = par.context_at_level(0)?; - let mut s = Zeroizing::new(Poly::try_convert_from( - sk_share.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); + let s = Zeroizing::new( + Poly::::try_convert_from(sk_share.coeffs.as_ref(), ctx, false)?.into_ntt(), + ); let h1 = crp .iter() .map(|a| { let mut h = a.poly.clone(); h.disallow_variable_time_computations(); - h.change_representation(Representation::Ntt); - let e = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?); + let e = Zeroizing::new(Poly::::small(ctx, par.variance, rng)?); h *= s.as_ref(); h += e.as_ref(); Ok(h) @@ -237,7 +227,7 @@ impl Aggregate> for RelinKeyShare { impl RelinKeyShare { fn new( sk_share: &SecretKey, - u: &Zeroizing, + u: &Zeroizing>, r1: &Arc>, rng: &mut R, ) -> Result { @@ -255,27 +245,22 @@ impl RelinKeyShare { fn generate_h0( sk_share: &SecretKey, - r1_h0: &[Poly], + r1_h0: &[Poly], rng: &mut R, - ) -> Result> { + ) -> Result]>> { let par = sk_share.par.clone(); let ctx = par.context_at_level(0)?; - let mut s = Zeroizing::new(Poly::try_convert_from( - sk_share.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); + let s = Zeroizing::new( + Poly::::try_convert_from(sk_share.coeffs.as_ref(), ctx, false)?.into_ntt(), + ); let h0 = r1_h0 .iter() .map(|h| { - let e = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?); + let e = Zeroizing::new(Poly::::small(ctx, par.variance, rng)?); let mut h_prime = h.clone(); h_prime.disallow_variable_time_computations(); - h_prime.change_representation(Representation::Ntt); h_prime *= s.as_ref(); h_prime += e.as_ref(); @@ -287,19 +272,15 @@ impl RelinKeyShare { fn generate_h1( sk_share: &SecretKey, - u: &Zeroizing, - r1_h1: &[Poly], + u: &Zeroizing>, + r1_h1: &[Poly], rng: &mut R, - ) -> Result> { + ) -> Result]>> { let par = sk_share.par.clone(); let ctx = par.context_at_level(0)?; - let mut s = Zeroizing::new(Poly::try_convert_from( - sk_share.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); + let s = Zeroizing::new( + Poly::::try_convert_from(sk_share.coeffs.as_ref(), ctx, false)?.into_ntt(), + ); let u_s = Zeroizing::new(u.as_ref() - s.as_ref()); @@ -308,8 +289,7 @@ impl RelinKeyShare { .map(|h| { let mut h_prime = h.clone(); h_prime.disallow_variable_time_computations(); - h_prime.change_representation(Representation::Ntt); - let e = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?); + let e = Zeroizing::new(Poly::::small(ctx, par.variance, rng)?); h_prime *= u_s.as_ref(); h_prime += e.as_ref(); Ok(h_prime) @@ -346,16 +326,21 @@ impl Aggregate> for RelinearizationKey { ); } - let mut c0 = h0; - izip!(c0.iter_mut(), h1.iter()).for_each(|(c0, h1)| { - *c0 += h1; - c0.change_representation(Representation::NttShoup); - }); + let mut c0 = Vec::from(h0); + izip!(c0.iter_mut(), h1.iter()).for_each(|(c0, h1)| *c0 += h1); + let c0 = c0 + .into_iter() + .map(Poly::::into_ntt_shoup) + .collect::>>() + .into_boxed_slice(); - let mut c1 = r1.h1.clone(); - c1.iter_mut().for_each(|c1| { - c1.change_representation(Representation::NttShoup); - }); + let c1 = r1 + .h1 + .iter() + .cloned() + .map(Poly::::into_ntt_shoup) + .collect::>>() + .into_boxed_slice(); let ksk = KeySwitchingKey { par, diff --git a/crates/fhe/src/mbfv/secret_key_switch.rs b/crates/fhe/src/mbfv/secret_key_switch.rs index 14c83c83..0c055f01 100644 --- a/crates/fhe/src/mbfv/secret_key_switch.rs +++ b/crates/fhe/src/mbfv/secret_key_switch.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use fhe_math::rq::{Poly, Representation, traits::TryConvertFrom}; +use fhe_math::rq::{Ntt, Poly, PowerBasis, traits::TryConvertFrom}; use itertools::Itertools; use num_bigint::BigUint; use num_traits::ToPrimitive; @@ -26,7 +26,7 @@ pub struct SecretKeySwitchShare { /// The original input ciphertext // Probably doesn't need to be Arc in real usage but w/e pub(crate) ct: Arc, - pub(crate) h_share: Poly, + pub(crate) h_share: Poly, } impl SecretKeySwitchShare { @@ -56,29 +56,26 @@ impl SecretKeySwitchShare { } let par = sk_input_share.par.clone(); - let mut s_in = Zeroizing::new(Poly::try_convert_from( - sk_input_share.coeffs.as_ref(), - ct[0].ctx(), - false, - Representation::PowerBasis, - )?); - s_in.change_representation(Representation::Ntt); - let mut s_out = Zeroizing::new(Poly::try_convert_from( - sk_output_share.coeffs.as_ref(), - ct[0].ctx(), - false, - Representation::PowerBasis, - )?); - s_out.change_representation(Representation::Ntt); + let s_in = Zeroizing::new( + Poly::::try_convert_from( + sk_input_share.coeffs.as_ref(), + ct[0].ctx(), + false, + )? + .into_ntt(), + ); + let s_out = Zeroizing::new( + Poly::::try_convert_from( + sk_output_share.coeffs.as_ref(), + ct[0].ctx(), + false, + )? + .into_ntt(), + ); // Sample error // TODO this should be exponential in ciphertext noise! - let e = Zeroizing::new(Poly::small( - ct[0].ctx(), - Representation::Ntt, - par.variance, - rng, - )?); + let e = Zeroizing::new(Poly::::small(ct[0].ctx(), par.variance, rng)?); // Create h_i share let mut h_share = s_in.as_ref() - s_out.as_ref(); @@ -151,7 +148,9 @@ impl Aggregate for Plaintext { // Note: during SKS, c[1]*sk has already been added to c[0]. let mut c = Zeroizing::new(ct[0].clone()); c.disallow_variable_time_computations(); - c.change_representation(Representation::PowerBasis); + let ctx = c.ctx().clone(); + let c_inner = std::mem::replace(c.as_mut(), Poly::::zero(&ctx)); + let c = c_inner.into_power_basis(); // The true decryption part is done during SKS; all that is left is to scale let ctx_lvl = ct.par.context_level_at(ct.level)?; @@ -168,9 +167,8 @@ impl Aggregate for Plaintext { ct.par.plaintext.reduce_vec(&mut w); - let mut poly = - Poly::try_convert_from(w.as_slice(), ct[0].ctx(), false, Representation::PowerBasis)?; - poly.change_representation(Representation::Ntt); + let poly = + Poly::::try_convert_from(w.as_slice(), ct[0].ctx(), false)?.into_ntt(); let value = match ct.par.plaintext { crate::bfv::PlaintextModulus::Small { .. } => PlaintextValues::Small(