Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 58 additions & 20 deletions crates/fhe/src/bfv/ops/dot_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::cmp::min;

use fhe_math::rq::{Poly, Representation, dot_product as poly_dot_product, traits::TryConvertFrom};
use itertools::{Itertools, izip};
use ndarray::{Array, Array2};
use ndarray::Array2;

use crate::{
Error, Result,
Expand Down Expand Up @@ -105,21 +105,28 @@ where
level: ct_first.level,
})
} else {
let mut acc = Array::zeros((ct_first.len(), ctx.moduli().len(), ct_first.par.degree()));
let moduli_len = ctx.moduli().len();
let degree = ct_first.par.degree();
let coeffs_len = moduli_len * degree;
let mut acc = vec![0u128; ct_first.len() * coeffs_len];
for (ciphertext, plaintext) in izip!(ct, pt) {
let pt_coefficients = plaintext.poly_ntt.coefficients();
for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.iter()) {
let pt_coefficients = pt_coefficients.as_slice().ok_or_else(|| {
Error::DefaultError("Plaintext coefficients are not contiguous".to_string())
})?;
for (part_idx, ci) in ciphertext.iter().enumerate() {
let ci_coefficients = ci.coefficients();
for (mut accij, cij, pij) in izip!(
acci.outer_iter_mut(),
ci_coefficients.outer_iter(),
pt_coefficients.outer_iter()
) {
let ci_coefficients = ci_coefficients.as_slice().ok_or_else(|| {
Error::DefaultError("Ciphertext coefficients are not contiguous".to_string())
})?;
let acc_offset = part_idx * coeffs_len;
for modulus_idx in 0..moduli_len {
let coeff_offset = modulus_idx * degree;
unsafe {
fma(
accij.as_slice_mut().unwrap(),
cij.as_slice().unwrap(),
pij.as_slice().unwrap(),
&mut acc[acc_offset + coeff_offset..acc_offset + coeff_offset + degree],
&ci_coefficients[coeff_offset..coeff_offset + degree],
&pt_coefficients[coeff_offset..coeff_offset + degree],
)
}
}
Expand All @@ -128,15 +135,18 @@ where

// Reduce
let mut c = Vec::with_capacity(ct_first.len());
for acci in acc.outer_iter() {
let mut coeffs = Array2::zeros((ctx.moduli().len(), ct_first.par.degree()));
for (mut outij, accij, q) in izip!(
coeffs.outer_iter_mut(),
acci.outer_iter(),
ctx.moduli_operators()
) {
for (outij_coeff, accij_coeff) in izip!(outij.iter_mut(), accij.iter()) {
unsafe { *outij_coeff = q.reduce_u128_vt(*accij_coeff) }
for part_idx in 0..ct_first.len() {
let mut coeffs = Array2::zeros((moduli_len, degree));
let coeffs_slice = coeffs.as_slice_mut().ok_or_else(|| {
Error::DefaultError("Output coefficients are not contiguous".to_string())
})?;
let acc_slice = &acc[part_idx * coeffs_len..(part_idx + 1) * coeffs_len];
for (modulus_idx, q) in ctx.moduli_operators().iter().enumerate() {
let coeff_offset = modulus_idx * degree;
let out_slice = &mut coeffs_slice[coeff_offset..coeff_offset + degree];
let acc_coeffs = &acc_slice[coeff_offset..coeff_offset + degree];
for (out_coeff, acc_coeff) in out_slice.iter_mut().zip(acc_coeffs.iter()) {
unsafe { *out_coeff = q.reduce_u128_vt(*acc_coeff) }
}
}
c.push(Poly::try_convert_from(
Expand Down Expand Up @@ -201,4 +211,32 @@ mod tests {
}
Ok(())
}

#[test]
fn test_dot_product_scalar_step_by() -> Result<(), Box<dyn Error>> {
let mut rng = rng();
let params = BfvParameters::default_arc(1, 16);
let sk = SecretKey::random(&params, &mut rng);
let ct = (0..8)
.map(|_| {
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params).unwrap();
sk.try_encrypt(&pt, &mut rng).unwrap()
})
.collect_vec();
let pt = (0..8)
.map(|_| {
let v = params.plaintext.random_vec(params.degree(), &mut rng);
Plaintext::try_encode(&v, Encoding::simd(), &params).unwrap()
})
.collect_vec();

let r = dot_product_scalar(ct.iter().step_by(2), pt.iter().step_by(2))?;

let mut expected = Ciphertext::zero(&params);
izip!(ct.iter().step_by(2), pt.iter().step_by(2))
.for_each(|(cti, pti)| expected += &(cti * pti));
assert_eq!(r, expected);
Ok(())
}
}
Loading