Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions jolt-core/src/poly/split_eq_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,65 @@ impl<F: JoltField> GruenSplitEqPolynomial<F> {
])
}

/// Compute the quartic polynomial s(X) = l(X) * q(X), where l(X) is the
/// current (linear) eq polynomial and q(X) = c + dX + eX^2 + fX^3, given:
/// - c, the constant term of q (i.e. q(0))
/// - q(2), the evaluation of q at 2
/// - f, the leading (cubic) coefficient of q
/// - the previous round claim, s(0) + s(1)
pub fn gruen_poly_deg_4(
&self,
q_constant: F,
q_at_2: F,
q_cubic_coeff: F,
s_0_plus_s_1: F,
) -> UniPoly<F> {
// l(X) = a + bX (linear eq polynomial)
// q(X) = c + dX + eX^2 + fX^3 (cubic inner polynomial)
// s(X) = l(X)*q(X) is degree 4, needs evaluations at {0, 1, 2, 3, 4}.

let eq_eval_1 = self.current_scalar
* match self.binding_order {
BindingOrder::LowToHigh => self.w[self.current_index - 1],
BindingOrder::HighToLow => self.w[self.current_index],
};
let eq_eval_0 = self.current_scalar - eq_eval_1;
let eq_m = eq_eval_1 - eq_eval_0;
let eq_eval_2 = eq_eval_1 + eq_m;
let eq_eval_3 = eq_eval_2 + eq_m;
let eq_eval_4 = eq_eval_3 + eq_m;

let c = q_constant;
let f = q_cubic_coeff;

// Recover q(1) from the sumcheck identity: s(0) + s(1) = claim
let quartic_eval_0 = eq_eval_0 * c;
let quartic_eval_1 = s_0_plus_s_1 - quartic_eval_0;
let q_1 = quartic_eval_1 / eq_eval_1;

// q(0) = c, q(1) = c+d+e+f, q(2) = c+2d+4e+8f
// Forward differences: Δ0 = q(1)-q(0) = d+e+f, Δ1 = q(2)-q(1)
// Second differences: Δ²0 = Δ1-Δ0 = 2e+6f
// Third differences: Δ³ = 6f (constant for a cubic)
let delta_0 = q_1 - c;
let delta_1 = q_at_2 - q_1;
let delta2_0 = delta_1 - delta_0;
let f6 = f + f + f + f + f + f; // 6f
// q(3) = q(2) + Δ1 + Δ²0 + Δ³ = q(2) + (Δ1 + Δ²0 + 6f)
let delta_2 = delta_1 + delta2_0 + f6;
let q_3 = q_at_2 + delta_2;
// q(4) = q(3) + Δ2 + Δ²0 + 2*Δ³ = q(3) + (delta_2 + delta2_0 + 6f + 6f)
let q_4 = q_3 + delta_2 + delta2_0 + f6 + f6;

UniPoly::from_evals(&[
quartic_eval_0,
quartic_eval_1,
eq_eval_2 * q_at_2,
eq_eval_3 * q_3,
eq_eval_4 * q_4,
])
}

/// Compute the quadratic polynomial s(X) = l(X) * q(X), where l(X) is the
/// current (linear) Dao-Thaler eq polynomial and q(X) = c + dx
/// - c, the constant term of q
Expand Down Expand Up @@ -795,8 +854,6 @@ mod tests {
/// Verify that evals_cached returns [1] at index 0 (eq over 0 vars).
#[test]
fn evals_cached_starts_with_one() {
use crate::poly::eq_poly::EqPolynomial;

let mut rng = test_rng();
for num_vars in 1..=10 {
let w: Vec<<Fr as JoltField>::Challenge> =
Expand Down
8 changes: 4 additions & 4 deletions jolt-core/src/subprotocols/streaming_sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub trait StreamingSumcheckWindow<F: JoltField>: Sized + MaybeAllocative + Send
fn initialize(shared: &mut Self::Shared, window_size: usize) -> Self;

fn compute_message(
&self,
&mut self,
shared: &Self::Shared,
window_size: usize,
previous_claim: F,
Expand All @@ -39,7 +39,7 @@ pub trait LinearSumcheckStage<F: JoltField>: Sized + MaybeAllocative + Send + Sy
fn next_window(&mut self, shared: &mut Self::Shared, window_size: usize);

fn compute_message(
&self,
&mut self,
shared: &Self::Shared,
window_size: usize,
previous_claim: F,
Expand Down Expand Up @@ -170,9 +170,9 @@ where
}
}

if let Some(streaming) = &mut self.streaming {
if let Some(streaming) = self.streaming.as_mut() {
streaming.compute_message(&self.shared, num_unbound_vars, previous_claim)
} else if let Some(linear) = &mut self.linear {
} else if let Some(linear) = self.linear.as_mut() {
linear.compute_message(&self.shared, num_unbound_vars, previous_claim)
} else {
unreachable!()
Expand Down
22 changes: 22 additions & 0 deletions jolt-core/src/utils/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,16 @@ impl<F: JoltField> FMAdd<F, S192> for WideAccumS<F> {
}
}

impl<F: JoltField> FMAdd<F, u64> for WideAccumS<F> {
#[inline(always)]
fn fmadd(&mut self, field: &F, other: &u64) {
if *other == 0 {
return;
}
self.pos += (*field).mul_u64_unreduced(*other);
}
}

impl<F: JoltField> BarrettReduce<F> for WideAccumS<F> {
#[inline(always)]
fn barrett_reduce(&self) -> F {
Expand Down Expand Up @@ -817,6 +827,18 @@ pub struct S192Sum {
pub sum: S192,
}

impl FMAdd<i32, u64> for S192Sum {
#[inline(always)]
fn fmadd(&mut self, c: &i32, term: &u64) {
if *term == 0 {
return;
}
let c_s64 = S64::from(*c as i64);
let v_s64 = S64::from(*term);
self.sum += c_s64.mul_trunc::<1, 3>(&v_s64);
}
}

impl FMAdd<i32, S64> for S192Sum {
#[inline(always)]
fn fmadd(&mut self, c: &i32, term: &S64) {
Expand Down
Loading