From a6bd540a6c16bcd6a69969a5c0e2c8808ecf7935 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tancr=C3=A8de=20Lepoint?= Date: Tue, 27 Jan 2026 20:33:42 -0500 Subject: [PATCH] Optimize dot product accumulation --- crates/fhe/src/bfv/ops/dot_product.rs | 78 ++++++++++++++++++++------- 1 file changed, 58 insertions(+), 20 deletions(-) diff --git a/crates/fhe/src/bfv/ops/dot_product.rs b/crates/fhe/src/bfv/ops/dot_product.rs index f51aed13..c08cf902 100644 --- a/crates/fhe/src/bfv/ops/dot_product.rs +++ b/crates/fhe/src/bfv/ops/dot_product.rs @@ -2,7 +2,7 @@ use std::cmp::min; use fhe_math::rq::{Poly, Representation, dot_product as poly_dot_product, traits::TryConvertFrom}; use itertools::{Itertools, izip}; -use ndarray::{Array, Array2}; +use ndarray::Array2; use crate::{ Error, Result, @@ -105,21 +105,28 @@ where level: ct_first.level, }) } else { - let mut acc = Array::zeros((ct_first.len(), ctx.moduli().len(), ct_first.par.degree())); + let moduli_len = ctx.moduli().len(); + let degree = ct_first.par.degree(); + let coeffs_len = moduli_len * degree; + let mut acc = vec![0u128; ct_first.len() * coeffs_len]; for (ciphertext, plaintext) in izip!(ct, pt) { let pt_coefficients = plaintext.poly_ntt.coefficients(); - for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.iter()) { + let pt_coefficients = pt_coefficients.as_slice().ok_or_else(|| { + Error::DefaultError("Plaintext coefficients are not contiguous".to_string()) + })?; + for (part_idx, ci) in ciphertext.iter().enumerate() { let ci_coefficients = ci.coefficients(); - for (mut accij, cij, pij) in izip!( - acci.outer_iter_mut(), - ci_coefficients.outer_iter(), - pt_coefficients.outer_iter() - ) { + let ci_coefficients = ci_coefficients.as_slice().ok_or_else(|| { + Error::DefaultError("Ciphertext coefficients are not contiguous".to_string()) + })?; + let acc_offset = part_idx * coeffs_len; + for modulus_idx in 0..moduli_len { + let coeff_offset = modulus_idx * degree; unsafe { fma( - accij.as_slice_mut().unwrap(), - cij.as_slice().unwrap(), - pij.as_slice().unwrap(), + &mut acc[acc_offset + coeff_offset..acc_offset + coeff_offset + degree], + &ci_coefficients[coeff_offset..coeff_offset + degree], + &pt_coefficients[coeff_offset..coeff_offset + degree], ) } } @@ -128,15 +135,18 @@ where // Reduce let mut c = Vec::with_capacity(ct_first.len()); - for acci in acc.outer_iter() { - let mut coeffs = Array2::zeros((ctx.moduli().len(), ct_first.par.degree())); - for (mut outij, accij, q) in izip!( - coeffs.outer_iter_mut(), - acci.outer_iter(), - ctx.moduli_operators() - ) { - for (outij_coeff, accij_coeff) in izip!(outij.iter_mut(), accij.iter()) { - unsafe { *outij_coeff = q.reduce_u128_vt(*accij_coeff) } + for part_idx in 0..ct_first.len() { + let mut coeffs = Array2::zeros((moduli_len, degree)); + let coeffs_slice = coeffs.as_slice_mut().ok_or_else(|| { + Error::DefaultError("Output coefficients are not contiguous".to_string()) + })?; + let acc_slice = &acc[part_idx * coeffs_len..(part_idx + 1) * coeffs_len]; + for (modulus_idx, q) in ctx.moduli_operators().iter().enumerate() { + let coeff_offset = modulus_idx * degree; + let out_slice = &mut coeffs_slice[coeff_offset..coeff_offset + degree]; + let acc_coeffs = &acc_slice[coeff_offset..coeff_offset + degree]; + for (out_coeff, acc_coeff) in out_slice.iter_mut().zip(acc_coeffs.iter()) { + unsafe { *out_coeff = q.reduce_u128_vt(*acc_coeff) } } } c.push(Poly::try_convert_from( @@ -201,4 +211,32 @@ mod tests { } Ok(()) } + + #[test] + fn test_dot_product_scalar_step_by() -> Result<(), Box> { + let mut rng = rng(); + let params = BfvParameters::default_arc(1, 16); + let sk = SecretKey::random(¶ms, &mut rng); + let ct = (0..8) + .map(|_| { + let v = params.plaintext.random_vec(params.degree(), &mut rng); + let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms).unwrap(); + sk.try_encrypt(&pt, &mut rng).unwrap() + }) + .collect_vec(); + let pt = (0..8) + .map(|_| { + let v = params.plaintext.random_vec(params.degree(), &mut rng); + Plaintext::try_encode(&v, Encoding::simd(), ¶ms).unwrap() + }) + .collect_vec(); + + let r = dot_product_scalar(ct.iter().step_by(2), pt.iter().step_by(2))?; + + let mut expected = Ciphertext::zero(¶ms); + izip!(ct.iter().step_by(2), pt.iter().step_by(2)) + .for_each(|(cti, pti)| expected += &(cti * pti)); + assert_eq!(r, expected); + Ok(()) + } }