diff --git a/Cargo.lock b/Cargo.lock index a13a33c7d..10c1d11f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2235,7 +2235,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.20#4f5a2b99af8edc9e8d808a6ab7c73e4088d61e39" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=8ce1577d61af241606d30aefe9ac896e7502809a#8ce1577d61af241606d30aefe9ac896e7502809a" dependencies = [ "once_cell", "p3", @@ -3240,7 +3240,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.20#4f5a2b99af8edc9e8d808a6ab7c73e4088d61e39" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=8ce1577d61af241606d30aefe9ac896e7502809a#8ce1577d61af241606d30aefe9ac896e7502809a" dependencies = [ "bincode 1.3.3", "clap", @@ -3264,7 +3264,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.20#4f5a2b99af8edc9e8d808a6ab7c73e4088d61e39" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=8ce1577d61af241606d30aefe9ac896e7502809a#8ce1577d61af241606d30aefe9ac896e7502809a" dependencies = [ "either", "ff_ext", @@ -4552,7 +4552,7 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.20#4f5a2b99af8edc9e8d808a6ab7c73e4088d61e39" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=8ce1577d61af241606d30aefe9ac896e7502809a#8ce1577d61af241606d30aefe9ac896e7502809a" dependencies = [ "p3-air", "p3-baby-bear", @@ -5120,7 +5120,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.20#4f5a2b99af8edc9e8d808a6ab7c73e4088d61e39" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=8ce1577d61af241606d30aefe9ac896e7502809a#8ce1577d61af241606d30aefe9ac896e7502809a" dependencies = [ "ff_ext", "p3", @@ -6077,7 +6077,7 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.20#4f5a2b99af8edc9e8d808a6ab7c73e4088d61e39" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=8ce1577d61af241606d30aefe9ac896e7502809a#8ce1577d61af241606d30aefe9ac896e7502809a" dependencies = [ "cfg-if", "dashu", @@ -6202,7 +6202,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.20#4f5a2b99af8edc9e8d808a6ab7c73e4088d61e39" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=8ce1577d61af241606d30aefe9ac896e7502809a#8ce1577d61af241606d30aefe9ac896e7502809a" dependencies = [ "either", "ff_ext", @@ -6220,7 +6220,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.20#4f5a2b99af8edc9e8d808a6ab7c73e4088d61e39" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=8ce1577d61af241606d30aefe9ac896e7502809a#8ce1577d61af241606d30aefe9ac896e7502809a" dependencies = [ "itertools 0.13.0", "p3", @@ -6627,7 +6627,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.20#4f5a2b99af8edc9e8d808a6ab7c73e4088d61e39" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=8ce1577d61af241606d30aefe9ac896e7502809a#8ce1577d61af241606d30aefe9ac896e7502809a" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -6921,7 +6921,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.20#4f5a2b99af8edc9e8d808a6ab7c73e4088d61e39" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=8ce1577d61af241606d30aefe9ac896e7502809a#8ce1577d61af241606d30aefe9ac896e7502809a" dependencies = [ "bincode 1.3.3", "clap", @@ -7208,7 +7208,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.20#4f5a2b99af8edc9e8d808a6ab7c73e4088d61e39" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=8ce1577d61af241606d30aefe9ac896e7502809a#8ce1577d61af241606d30aefe9ac896e7502809a" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index 0a75e8c40..7e7049d89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,16 +27,16 @@ version = "0.1.0" ceno_crypto_primitives = { git = "https://github.com/scroll-tech/ceno-patch.git", package = "ceno_crypto_primitives", branch = "main" } ceno_syscall = { git = "https://github.com/scroll-tech/ceno-patch.git", package = "ceno_syscall", branch = "main" } -ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.20" } -mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.20" } -multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.20" } -p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.20" } -poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.20" } -sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.20" } -sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.20" } -transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.20" } -whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.20" } -witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.20" } +ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", rev = "8ce1577d61af241606d30aefe9ac896e7502809a" } +mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", rev = "8ce1577d61af241606d30aefe9ac896e7502809a" } +multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", rev = "8ce1577d61af241606d30aefe9ac896e7502809a" } +p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", rev = "8ce1577d61af241606d30aefe9ac896e7502809a" } +poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", rev = "8ce1577d61af241606d30aefe9ac896e7502809a" } +sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", rev = "8ce1577d61af241606d30aefe9ac896e7502809a" } +sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", rev = "8ce1577d61af241606d30aefe9ac896e7502809a" } +transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", rev = "8ce1577d61af241606d30aefe9ac896e7502809a" } +whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", rev = "8ce1577d61af241606d30aefe9ac896e7502809a" } +witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", rev = "8ce1577d61af241606d30aefe9ac896e7502809a" } anyhow = { version = "1.0", default-features = false } bincode = "1" diff --git a/ceno_recursion/src/arithmetics/mod.rs b/ceno_recursion/src/arithmetics/mod.rs index 4d78d0315..7c5d97e81 100644 --- a/ceno_recursion/src/arithmetics/mod.rs +++ b/ceno_recursion/src/arithmetics/mod.rs @@ -866,35 +866,51 @@ impl UniPolyExtrapolator { } pub fn extrapolate_uni_poly( - &mut self, + &self, builder: &mut Builder, + p_0: Ext, p_i: &Array>, eval_at: Ext, ) -> Ext { let res: Ext = builder.constant(C::EF::ZERO); + let length: RVar<_> = builder.eval_expr(p_i.len() + Usize::from(1)); - builder.if_eq(p_i.len(), Usize::from(4)).then_or_else( + builder.if_eq(length, Usize::from(4)).then_or_else( |builder| { - let ext = self.extrapolate_uni_poly_deg_3(builder, p_i, eval_at); + let p_i_1: Ext = builder.get(p_i, 0); + let p_i_2: Ext = builder.get(p_i, 1); + let p_i_3: Ext = builder.get(p_i, 2); + let ext = + self.extrapolate_uni_poly_deg_3(builder, p_0, p_i_1, p_i_2, p_i_3, eval_at); builder.assign(&res, ext); }, |builder| { - builder.if_eq(p_i.len(), Usize::from(3)).then_or_else( + builder.if_eq(length, Usize::from(3)).then_or_else( |builder| { - let ext = self.extrapolate_uni_poly_deg_2(builder, p_i, eval_at); + let p_i_1: Ext = builder.get(p_i, 0); + let p_i_2: Ext = builder.get(p_i, 1); + let ext = + self.extrapolate_uni_poly_deg_2(builder, p_0, p_i_1, p_i_2, eval_at); builder.assign(&res, ext); }, |builder| { - builder.if_eq(p_i.len(), Usize::from(2)).then_or_else( + builder.if_eq(length, Usize::from(2)).then_or_else( |builder| { - let ext = self.extrapolate_uni_poly_deg_1(builder, p_i, eval_at); + let p_i_1: Ext = builder.get(p_i, 0); + let ext = + self.extrapolate_uni_poly_deg_1(builder, p_0, p_i_1, eval_at); builder.assign(&res, ext); }, |builder| { - builder.if_eq(p_i.len(), Usize::from(5)).then_or_else( + builder.if_eq(length, Usize::from(5)).then_or_else( |builder| { - let ext = - self.extrapolate_uni_poly_deg_4(builder, p_i, eval_at); + let p_i_1: Ext = builder.get(p_i, 0); + let p_i_2: Ext = builder.get(p_i, 1); + let p_i_3: Ext = builder.get(p_i, 2); + let p_i_4: Ext = builder.get(p_i, 3); + let ext = self.extrapolate_uni_poly_deg_4( + builder, p_0, p_i_1, p_i_2, p_i_3, p_i_4, eval_at, + ); builder.assign(&res, ext); }, |builder| { @@ -914,7 +930,8 @@ impl UniPolyExtrapolator { fn extrapolate_uni_poly_deg_1( &self, builder: &mut Builder, - p_i: &Array>, + p_i_0: Ext, + p_i_1: Ext, eval_at: Ext, ) -> Ext { // w0 = 1 / (0−1) = -1 @@ -923,9 +940,6 @@ impl UniPolyExtrapolator { let d1: Ext = builder.eval(eval_at - self.constants[1]); let l: Ext = builder.eval(d0 * d1); - let p_i_0 = builder.get(p_i, 0); - let p_i_1 = builder.get(p_i, 1); - let t0: Ext = builder.eval(self.constants[5] * p_i_0 * d0.inverse()); let t1: Ext = builder.eval(self.constants[1] * p_i_1 * d1.inverse()); @@ -935,7 +949,9 @@ impl UniPolyExtrapolator { fn extrapolate_uni_poly_deg_2( &self, builder: &mut Builder, - p_i: &Array>, + p_i_0: Ext, + p_i_1: Ext, + p_i_2: Ext, eval_at: Ext, ) -> Ext { // w0 = 1 / ((0−1)(0−2)) = 1/2 @@ -947,10 +963,6 @@ impl UniPolyExtrapolator { let l: Ext = builder.eval(d0 * d1 * d2); - let p_i_0: Ext = builder.get(p_i, 0); - let p_i_1: Ext = builder.get(p_i, 1); - let p_i_2: Ext = builder.get(p_i, 2); - let t0: Ext = builder.eval(self.constants[6] * p_i_0 * d0.inverse()); let t1: Ext = builder.eval(self.constants[5] * p_i_1 * d1.inverse()); let t2: Ext = builder.eval(self.constants[6] * p_i_2 * d2.inverse()); @@ -961,7 +973,10 @@ impl UniPolyExtrapolator { fn extrapolate_uni_poly_deg_3( &self, builder: &mut Builder, - p_i: &Array>, + p_i_0: Ext, + p_i_1: Ext, + p_i_2: Ext, + p_i_3: Ext, eval_at: Ext, ) -> Ext { // w0 = 1 / ((0−1)(0−2)(0−3)) = -1/6 @@ -975,11 +990,6 @@ impl UniPolyExtrapolator { let l: Ext = builder.eval(d0 * d1 * d2 * d3); - let p_i_0: Ext = builder.get(p_i, 0); - let p_i_1: Ext = builder.get(p_i, 1); - let p_i_2: Ext = builder.get(p_i, 2); - let p_i_3: Ext = builder.get(p_i, 3); - let t0: Ext = builder.eval(self.constants[9] * p_i_0 * d0.inverse()); let t1: Ext = builder.eval(self.constants[6] * p_i_1 * d1.inverse()); let t2: Ext = builder.eval(self.constants[7] * p_i_2 * d2.inverse()); @@ -991,7 +1001,11 @@ impl UniPolyExtrapolator { fn extrapolate_uni_poly_deg_4( &self, builder: &mut Builder, - p_i: &Array>, + p_i_0: Ext, + p_i_1: Ext, + p_i_2: Ext, + p_i_3: Ext, + p_i_4: Ext, eval_at: Ext, ) -> Ext { // w0 = 1 / ((0−1)(0−2)(0−3)(0−4)) = 1/24 @@ -1007,12 +1021,6 @@ impl UniPolyExtrapolator { let l: Ext = builder.eval(d0 * d1 * d2 * d3 * d4); - let p_i_0: Ext = builder.get(p_i, 0); - let p_i_1: Ext = builder.get(p_i, 1); - let p_i_2: Ext = builder.get(p_i, 2); - let p_i_3: Ext = builder.get(p_i, 3); - let p_i_4: Ext = builder.get(p_i, 4); - let t0: Ext = builder.eval(self.constants[11] * p_i_0 * d0.inverse()); let t1: Ext = builder.eval(self.constants[9] * p_i_1 * d1.inverse()); let t2: Ext = builder.eval(self.constants[10] * p_i_2 * d2.inverse()); diff --git a/ceno_recursion/src/basefold_verifier/query_phase.rs b/ceno_recursion/src/basefold_verifier/query_phase.rs index 6bf40113f..7499207d0 100644 --- a/ceno_recursion/src/basefold_verifier/query_phase.rs +++ b/ceno_recursion/src/basefold_verifier/query_phase.rs @@ -14,15 +14,13 @@ use p3::{ use serde::Deserialize; use super::{basefold::*, extension_mmcs::*, mmcs::*, rs::*, utils::*}; -use crate::{ - arithmetics::eq_eval_with_index, - tower_verifier::{binding::*, program::interpolate_uni_poly}, -}; +use crate::{arithmetics::eq_eval_with_index, tower_verifier::binding::*}; pub type F = BabyBear; pub type E = BabyBearExt4; pub type InnerConfig = AsmConfig; +use crate::arithmetics::UniPolyExtrapolator; use p3::fri::{ BatchOpening as InnerBatchOpening, CommitPhaseProofStep as InnerCommitPhaseProofStep, }; @@ -318,6 +316,7 @@ pub struct RoundContextVariable { pub(crate) fn batch_verifier_query_phase( builder: &mut Builder, input: QueryPhaseVerifierInputVariable, + unipoly_extrapolator: &UniPolyExtrapolator, ) { let inv_2 = builder.constant(C::F::from_canonical_u32(0x3c000001)); let two_adic_generators_inverses: Array> = builder.dyn_array(28); @@ -667,7 +666,7 @@ pub(crate) fn batch_verifier_query_phase( ); // 1. check initial claim match with first round sumcheck value let batch_coeffs_offset: Var = builder.constant(C::N::ZERO); - let expected_sum: Ext = builder.constant(C::EF::ZERO); + let expected_claim: Ext = builder.constant(C::EF::ZERO); iter_zip!(builder, input.rounds).for_each(|ptr_vec, builder| { let round = builder.iter_ptr_get(&input.rounds, ptr_vec[0]); iter_zip!(builder, round.openings).for_each(|ptr_vec, builder| { @@ -680,45 +679,27 @@ pub(crate) fn batch_verifier_query_phase( let eval = builder.iter_ptr_get(&opening.point_and_evals.evals, ptr_vec[0]); let coeff = builder.get(&input.batch_coeffs, batch_coeffs_offset); let val: Ext = builder.eval(eval * coeff * scalar); - builder.assign(&expected_sum, expected_sum + val); + builder.assign(&expected_claim, expected_claim + val); builder.assign(&batch_coeffs_offset, batch_coeffs_offset + Usize::from(1)); }); }); }); - let sum: Ext = { - let sumcheck_evals = builder.get(&input.proof.sumcheck_proof, 0).evaluations; - let eval0 = builder.get(&sumcheck_evals, 0); - let eval1 = builder.get(&sumcheck_evals, 1); - builder.eval(eval0 + eval1) - }; - builder.assert_eq::>(expected_sum, sum); - - // 2. check every round of sumcheck match with prev claims - let fold_len_minus_one: Var = builder.eval(input.fold_challenges.len() - Usize::from(1)); + // check every round of sumcheck match with prev claims builder - .range(0, fold_len_minus_one) + .range(0, input.fold_challenges.len()) .for_each(|i_vec, builder| { let i = i_vec[0]; let evals = builder.get(&input.proof.sumcheck_proof, i).evaluations; + let eval1 = builder.get(&evals, 0); + let eval0 = builder.eval(expected_claim - eval1); let challenge = builder.get(&input.fold_challenges, i); - let left = interpolate_uni_poly(builder, &evals, challenge); - let i_plus_one = builder.eval_expr(i + Usize::from(1)); - let next_evals = builder - .get(&input.proof.sumcheck_proof, i_plus_one) - .evaluations; - let eval0 = builder.get(&next_evals, 0); - let eval1 = builder.get(&next_evals, 1); - let right: Ext = builder.eval(eval0 + eval1); - builder.assert_eq::>(left, right); + let next_claim = + unipoly_extrapolator.extrapolate_uni_poly(builder, eval0, &evals, challenge); + builder.assign(&expected_claim, next_claim); }); - // 3. check final evaluation are correct - let final_evals = builder - .get(&input.proof.sumcheck_proof, fold_len_minus_one) - .evaluations; - let final_challenge = builder.get(&input.fold_challenges, fold_len_minus_one); - let left = interpolate_uni_poly(builder, &final_evals, final_challenge); - let right: Ext = builder.constant(C::EF::ZERO); + // check final evaluation are correct + let eval_claims: Ext = builder.constant(C::EF::ZERO); let one: Var = builder.constant(C::N::ONE); let j: Var = builder.constant(C::N::ZERO); // \sum_i eq(p, [r,i]) * f(r,i) @@ -752,13 +733,13 @@ pub(crate) fn batch_verifier_query_phase( builder.assert_eq::>(final_message.len(), one); let final_message = builder.get(&final_message, 0); let dot_prod: Ext = builder.eval(final_message * coeff); - builder.assign(&right, right + dot_prod); + builder.assign(&eval_claims, eval_claims + dot_prod); builder.assign(&j, j + Usize::from(1)); }); }); builder.assert_eq::>(j, input.proof.final_message.len()); - builder.assert_eq::>(left, right); + builder.assert_eq::>(expected_claim, eval_claims); } #[cfg(test)] @@ -792,7 +773,9 @@ pub mod tests { type E = BabyBearExt4; type Pcs = BasefoldDefault; + use super::{QueryPhaseVerifierInput, batch_verifier_query_phase}; use crate::{ + arithmetics::UniPolyExtrapolator, basefold_verifier::{ basefold::{Round, RoundOpening}, mmcs::MmcsCommitment, @@ -801,15 +784,14 @@ pub mod tests { tower_verifier::binding::Point, }; - use super::{QueryPhaseVerifierInput, batch_verifier_query_phase}; - pub fn build_batch_verifier_query_phase( input: QueryPhaseVerifierInput, ) -> (Program, Vec>) { // build test program let mut builder = AsmBuilder::::default(); + let unipoly_extrapolator = UniPolyExtrapolator::new(&mut builder); let query_phase_input = QueryPhaseVerifierInput::read(&mut builder); - batch_verifier_query_phase(&mut builder, query_phase_input); + batch_verifier_query_phase(&mut builder, query_phase_input, &unipoly_extrapolator); builder.halt(); let program = builder.compile_isa(); diff --git a/ceno_recursion/src/basefold_verifier/verifier.rs b/ceno_recursion/src/basefold_verifier/verifier.rs index c343660d0..7b6ad8d2f 100644 --- a/ceno_recursion/src/basefold_verifier/verifier.rs +++ b/ceno_recursion/src/basefold_verifier/verifier.rs @@ -4,6 +4,7 @@ use crate::{ }; use super::{basefold::*, rs::*, utils::*}; +use crate::arithmetics::UniPolyExtrapolator; use ff_ext::BabyBearExt4; use openvm_native_compiler::{asm::AsmConfig, ir::FromConstant, prelude::*}; use openvm_native_compiler_derive::iter_zip; @@ -24,6 +25,7 @@ pub fn batch_verify( max_width: Var, rounds: Array>, proof: BasefoldProofVariable, + unipoly_extrapolator: &UniPolyExtrapolator, challenger: &mut DuplexChallengerVariable, ) { builder.cycle_tracker_start("prior query phase"); @@ -153,7 +155,7 @@ pub fn batch_verify( }; builder.cycle_tracker_end("prior query phase"); builder.cycle_tracker_start("query phase"); - batch_verifier_query_phase(builder, input); + batch_verifier_query_phase(builder, input, unipoly_extrapolator); builder.cycle_tracker_end("query phase"); } @@ -185,6 +187,7 @@ pub mod tests { use super::{BasefoldProof, BasefoldProofVariable, InnerConfig, RoundVariable, batch_verify}; use crate::{ + arithmetics::UniPolyExtrapolator, basefold_verifier::{ basefold::{BasefoldCommitment, Round, RoundOpening}, query_phase::{BatchOpening, CommitPhaseProofStep, PointAndEvals, QueryOpeningProof}, @@ -244,12 +247,14 @@ pub mod tests { let mut challenger = DuplexChallengerVariable::new(&mut builder); let verifier_input = VerifierInput::read(&mut builder); builder.cycle_tracker_end("Prepare data"); + let unipoly_extrapolator = UniPolyExtrapolator::new(&mut builder); batch_verify( &mut builder, verifier_input.max_num_var, verifier_input.max_width, verifier_input.rounds, verifier_input.proof, + &unipoly_extrapolator, &mut challenger, ); builder.halt(); diff --git a/ceno_recursion/src/tower_verifier/program.rs b/ceno_recursion/src/tower_verifier/program.rs index 56a400008..2bf1e412d 100644 --- a/ceno_recursion/src/tower_verifier/program.rs +++ b/ceno_recursion/src/tower_verifier/program.rs @@ -2,7 +2,7 @@ use super::binding::{PointAndEvalVariable, PointVariable}; use crate::{ arithmetics::{ UniPolyExtrapolator, challenger_multi_observe, eq_eval, evaluate_at_point_degree_1, extend, - exts_to_felts, reverse, + exts_to_felts, }, tower_verifier::binding::IOPProverMessageVecVariable, transcript::transcript_observe_label, @@ -16,77 +16,6 @@ use openvm_native_recursion::challenger::{ use openvm_stark_backend::p3_field::FieldAlgebra; const NATIVE_SUMCHECK_CTX_LEN: usize = 9; -pub(crate) fn interpolate_uni_poly( - builder: &mut Builder, - p_i: &Array>, - eval_at: Ext, -) -> Ext { - let len = p_i.len(); - let evals: Array> = builder.dyn_array(len.clone()); - let prod: Ext = builder.eval(eval_at); - - builder.set(&evals, 0, eval_at); - - // `prod = \prod_{j} (eval_at - j)` - let e: Ext = builder.constant(C::EF::ONE); - let one: Ext = builder.constant(C::EF::ONE); - builder.range(1, len.clone()).for_each(|i_vec, builder| { - let i = i_vec[0]; - let tmp: Ext = builder.constant(C::EF::ONE); - builder.assign(&tmp, eval_at - e); - builder.set(&evals, i, tmp); - builder.assign(&prod, prod * tmp); - builder.assign(&e, e + one); - }); - - let denom_up: Ext = builder.constant(C::EF::ONE); - let i: Ext = builder.constant(C::EF::ONE); - builder.assign(&i, i + one); - builder.range(2, len.clone()).for_each(|_i_vec, builder| { - builder.assign(&denom_up, denom_up * i); - builder.assign(&i, i + one); - }); - let denom_down: Ext = builder.constant(C::EF::ONE); - - let idx_vec_len: RVar = builder.eval_expr(len.clone() - RVar::from(1)); - let idx_vec: Array> = builder.dyn_array(idx_vec_len); - let idx_val: Ext = builder.constant(C::EF::ONE); - builder.range(0, idx_vec.len()).for_each(|i_vec, builder| { - builder.set(&idx_vec, i_vec[0], idx_val); - builder.assign(&idx_val, idx_val + one); - }); - let idx_rev = reverse(builder, &idx_vec); - let res = builder.constant(C::EF::ZERO); - - let len_f = idx_val; - let neg_one: Ext = builder.constant(C::EF::NEG_ONE); - let evals_rev = reverse(builder, &evals); - let p_i_rev = reverse(builder, p_i); - - let mut idx_pos: RVar = builder.eval_expr(len.clone() - RVar::from(1)); - iter_zip!(builder, idx_rev, evals_rev, p_i_rev).for_each(|ptr_vec, builder| { - let idx = builder.iter_ptr_get(&idx_rev, ptr_vec[0]); - let eval = builder.iter_ptr_get(&evals_rev, ptr_vec[1]); - let up_eval_inv: Ext = builder.eval(denom_up * eval); - builder.assign(&up_eval_inv, up_eval_inv.inverse()); - let p = builder.iter_ptr_get(&p_i_rev, ptr_vec[2]); - - builder.assign(&res, res + p * prod * denom_down * up_eval_inv); - builder.assign(&denom_up, denom_up * (len_f - idx) * neg_one); - builder.assign(&denom_down, denom_down * idx); - - idx_pos = builder.eval_expr(idx_pos - RVar::from(1)); - }); - - let p_i_0 = builder.get(p_i, 0); - let eval_0 = builder.get(&evals, 0); - let up_eval_inv: Ext = builder.eval(denom_up * eval_0); - builder.assign(&up_eval_inv, up_eval_inv.inverse()); - builder.assign(&res, res + p_i_0 * prod * denom_down * up_eval_inv); - - res -} - pub fn iop_verifier_state_verify( builder: &mut Builder, challenger: &mut DuplexChallengerVariable, @@ -94,7 +23,7 @@ pub fn iop_verifier_state_verify( prover_messages: &IOPProverMessageVecVariable, max_num_variables: Felt, max_degree: Felt, - unipoly_extrapolator: &mut UniPolyExtrapolator, + unipoly_extrapolator: &UniPolyExtrapolator, ) -> ( Array::F, ::EF>>, Ext<::F, ::EF>, @@ -133,12 +62,9 @@ pub fn iop_verifier_state_verify( let challenge = challenger.sample_ext(builder); let e1 = builder.get(&prover_msg, 0); - let e2 = builder.get(&prover_msg, 1); - let target: Ext<::F, ::EF> = builder.eval(e1 + e2); - - builder.assert_ext_eq(expected, target); - - let p_r = unipoly_extrapolator.extrapolate_uni_poly(builder, &prover_msg, challenge); + let e0 = builder.eval(expected - e1); + let p_r = + unipoly_extrapolator.extrapolate_uni_poly(builder, e0, &prover_msg, challenge); builder.assign(&expected, p_r + zero); builder.set_value(&challenges, i, challenge); @@ -160,7 +86,7 @@ pub fn verify_tower_proof( max_num_variables: Usize, proof: &TowerProofInputVariable, - unipoly_extrapolator: &mut UniPolyExtrapolator, + unipoly_extrapolator: &UniPolyExtrapolator, ) -> ( PointVariable, Array>, diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index 6883cdd45..7fe8d7eec 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -217,7 +217,7 @@ pub fn verify_zkvm_proof>( .filter(|c| c.get_cs().num_fixed() > 0) .count(); - let mut unipoly_extrapolator = UniPolyExtrapolator::new(builder); + let unipoly_extrapolator = UniPolyExtrapolator::new(builder); let mut poly_evaluator = PolyEvaluator::new(builder); let dummy_table_item = alpha; @@ -351,7 +351,7 @@ pub fn verify_zkvm_proof>( &zkvm_proof_input.raw_pi_num_variables, &challenges, chip_vk, - &mut unipoly_extrapolator, + &unipoly_extrapolator, &mut poly_evaluator, ); builder.cycle_tracker_end("Verify chip proof"); @@ -483,6 +483,7 @@ pub fn verify_zkvm_proof>( zkvm_proof_input.max_width, rounds, zkvm_proof_input.pcs_proof, + &unipoly_extrapolator, &mut challenger, ); @@ -529,7 +530,7 @@ pub fn verify_chip_proof( raw_pi_num_variables: &Array>, challenges: &Array>, vk: &VerifyingKey, - unipoly_extrapolator: &mut UniPolyExtrapolator, + unipoly_extrapolator: &UniPolyExtrapolator, poly_evaluator: &mut PolyEvaluator, ) -> (Array>, SepticPointVariable) { let composed_cs = vk.get_cs(); @@ -785,7 +786,7 @@ pub fn verify_gkr_circuit( claims: &Array>, _chip_proof: &ZKVMChipProofInputVariable, selector_ctxs: Vec>, - unipoly_extrapolator: &mut UniPolyExtrapolator, + unipoly_extrapolator: &UniPolyExtrapolator, poly_evaluator: &mut PolyEvaluator, ) -> PointVariable { let rt = PointVariable { @@ -1168,7 +1169,7 @@ pub fn verify_rotation( rotation_cyclic_group_log2: usize, rt: Array>, challenges: &Array>, - unipoly_extrapolator: &mut UniPolyExtrapolator, + unipoly_extrapolator: &UniPolyExtrapolator, ) -> RotationClaim { builder.cycle_tracker_start("Verify rotation"); let SumcheckLayerProofVariable { @@ -1712,7 +1713,7 @@ pub fn verify_ecc_proof( builder: &mut Builder, challenger: &mut DuplexChallengerVariable, proof: &EccQuarkProofVariable, - unipoly_extrapolator: &mut UniPolyExtrapolator, + unipoly_extrapolator: &UniPolyExtrapolator, ) { let num_vars = proof.num_vars.clone(); diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 4007f7fd7..8f16bbe9a 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -295,7 +295,6 @@ impl CpuEccProver { } } let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y); - assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); EccQuarkProof { zerocheck_proof, diff --git a/ceno_zkvm/src/scheme/gpu/mod.rs b/ceno_zkvm/src/scheme/gpu/mod.rs index a8731a14c..5fb3380fc 100644 --- a/ceno_zkvm/src/scheme/gpu/mod.rs +++ b/ceno_zkvm/src/scheme/gpu/mod.rs @@ -958,7 +958,6 @@ impl> EccQuarkProver