diff --git a/Cargo.lock b/Cargo.lock index ebb5e43..a09d24d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -526,7 +526,7 @@ dependencies = [ [[package]] name = "ceno-examples" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "glob", ] @@ -543,6 +543,7 @@ dependencies = [ "ceno_emul", "ceno_zkvm", "ff_ext", + "gkr_iop", "itertools 0.13.0", "mpcs", "multilinear_extensions", @@ -577,7 +578,7 @@ dependencies = [ [[package]] name = "ceno_emul" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "anyhow", "ceno_rt", @@ -600,7 +601,7 @@ dependencies = [ [[package]] name = "ceno_host" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "anyhow", "ceno_emul", @@ -613,7 +614,7 @@ dependencies = [ [[package]] name = "ceno_rt" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "getrandom 0.2.16", "rkyv", @@ -622,7 +623,7 @@ dependencies = [ [[package]] name = "ceno_zkvm" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "base64", "bincode", @@ -1141,7 +1142,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "once_cell", "p3", @@ -1231,7 +1232,7 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "gkr_iop" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "ark-std 0.5.0", "bincode", @@ -1718,7 +1719,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "aes", "bincode", @@ -1749,7 +1750,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "either", "ff_ext", @@ -2311,7 +2312,7 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "p3-baby-bear", "p3-challenger", @@ -2769,7 +2770,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "criterion", "ff_ext", @@ -3378,7 +3379,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "crossbeam-channel", "either", @@ -3397,7 +3398,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "itertools 0.13.0", "p3", @@ -3671,7 +3672,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3849,7 +3850,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "bincode", "blake2", @@ -4092,7 +4093,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#6b24e309c2050b2adbe62ed5702147e36bc965a0" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index 5f8298b..647c29d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,14 +38,15 @@ ark-poly = "0.5" ark-serialize = "0.5" # Ceno -ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc", package = "multilinear_extensions" } -ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc", package = "sumcheck" } -ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc", package = "transcript" } -ceno_witness = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc", package = "witness" } -ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc" } -ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc" } -mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc" } -ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc" } +ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "multilinear_extensions" } +ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "sumcheck" } +ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "transcript" } +ceno_witness = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "witness" } +ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } +ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } +gkr_iop = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } +mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } +ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/scripts/e2e_test.sh b/scripts/e2e_test.sh index 6d9c077..934bbde 100755 --- a/scripts/e2e_test.sh +++ b/scripts/e2e_test.sh @@ -9,11 +9,11 @@ if [ ! -d "$REPO_ROOT/build/ceno" ] || [ -z "$(ls -A "$REPO_ROOT/build/ceno" 2>/ fi # Enter the ceno directory -cd $REPO_ROOT/build/ceno && git checkout build/smaller_field_support_plonky3_539bbc +cd $REPO_ROOT/build/ceno && git checkout feat/smaller_field_support # Execute the ceno_zkvm e2e test RUST_LOG=info cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno \ - --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/release/examples/fibonacci \ + examples/target/riscv32im-ceno-zkvm-elf/release/examples/keccak_syscall \ --field=baby-bear mkdir -p $REPO_ROOT/src/e2e/encoded diff --git a/src/arithmetics/mod.rs b/src/arithmetics/mod.rs index c79494d..9457e95 100644 --- a/src/arithmetics/mod.rs +++ b/src/arithmetics/mod.rs @@ -7,30 +7,26 @@ use ff_ext::{BabyBearExt4, SmallField}; use itertools::Either; use openvm_native_compiler::prelude::*; use openvm_native_compiler_derive::iter_zip; -use openvm_native_recursion::challenger::ChallengerVariable; -use openvm_native_recursion::challenger::{ - duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, -}; -use p3_field::{FieldAlgebra, FieldExtensionAlgebra}; +use openvm_native_recursion::challenger::{duplex::DuplexChallengerVariable, FeltChallenger}; +use openvm_stark_backend::p3_field::{FieldAlgebra, FieldExtensionAlgebra}; type E = BabyBearExt4; -const HASH_RATE: usize = 8; const MAX_NUM_VARS: usize = 25; -pub fn print_ext_arr(builder: &mut Builder, arr: &Array>) { +pub fn _print_ext_arr(builder: &mut Builder, arr: &Array>) { iter_zip!(builder, arr).for_each(|ptr_vec, builder| { let e = builder.iter_ptr_get(arr, ptr_vec[0]); builder.print_e(e); }); } -pub fn print_felt_arr(builder: &mut Builder, arr: &Array>) { +pub fn _print_felt_arr(builder: &mut Builder, arr: &Array>) { iter_zip!(builder, arr).for_each(|ptr_vec, builder| { let f = builder.iter_ptr_get(arr, ptr_vec[0]); builder.print_f(f); }); } -pub fn print_usize_arr(builder: &mut Builder, arr: &Array>) { +pub fn _print_usize_arr(builder: &mut Builder, arr: &Array>) { iter_zip!(builder, arr).for_each(|ptr_vec, builder| { let n = builder.iter_ptr_get(arr, ptr_vec[0]); builder.print_v(n.get_var()); @@ -204,24 +200,6 @@ pub fn dot_product( acc } -pub fn dot_product_pt_n_eval( - builder: &mut Builder, - pt_and_eval: &Array>, - b: &Array>, -) -> Ext<::F, ::EF> { - let acc: Ext = builder.eval(C::F::ZERO); - - iter_zip!(builder, pt_and_eval, b).for_each(|idx_vec, builder| { - let ptr_a = idx_vec[0]; - let ptr_b = idx_vec[1]; - let v_a = builder.iter_ptr_get(&pt_and_eval, ptr_a); - let v_b = builder.iter_ptr_get(&b, ptr_b); - builder.assign(&acc, acc + v_a.eval * v_b); - }); - - acc -} - pub fn reverse>( builder: &mut Builder, arr: &Array, @@ -321,20 +299,6 @@ pub fn eq_eval_with_index( acc } -// Multiply all elements in the Array -pub fn product( - builder: &mut Builder, - arr: &Array>, -) -> Ext { - let acc = builder.constant(C::EF::ONE); - iter_zip!(builder, arr).for_each(|idx_vec, builder| { - let el = builder.iter_ptr_get(arr, idx_vec[0]); - builder.assign(&acc, acc * el); - }); - - acc -} - // Multiply all elements in a nested Array pub fn nested_product( builder: &mut Builder, @@ -353,47 +317,6 @@ pub fn nested_product( acc } -// Add all elements in the Array -pub fn sum( - builder: &mut Builder, - arr: &Array>, -) -> Ext { - let acc = builder.constant(C::EF::ZERO); - iter_zip!(builder, arr).for_each(|idx_vec, builder| { - let el = builder.iter_ptr_get(arr, idx_vec[0]); - builder.assign(&acc, acc + el); - }); - - acc -} - -// Join two arrays -pub fn join( - builder: &mut Builder, - a: &Array>, - b: &Array>, -) -> Array> { - let a_len = a.len(); - let b_len = b.len(); - let out_len = builder.eval_expr(a_len.clone() + b_len.clone()); - let out = builder.dyn_array(out_len); - - builder.range(0, a_len.clone()).for_each(|i_vec, builder| { - let i = i_vec[0]; - let a_val = builder.get(a, i); - builder.set(&out, i, a_val); - }); - - builder.range(0, b_len).for_each(|i_vec, builder| { - let b_i = i_vec[0]; - let i = builder.eval_expr(b_i + a_len.clone()); - let b_val = builder.get(b, b_i); - builder.set(&out, i, b_val); - }); - - out -} - // Generate alpha power challenges pub fn gen_alpha_pows( builder: &mut Builder, @@ -421,7 +344,6 @@ pub fn gen_alpha_pows( /// = \sum_{\mathbf{b}=0}^{max_idx} \prod_{i=0}^{n-1} (x_i y_i b_i + (1 - x_i)(1 - y_i)(1 - b_i)) pub fn eq_eval_less_or_equal_than( builder: &mut Builder, - _challenger: &mut DuplexChallengerVariable, opcode_proof: &ZKVMChipProofInputVariable, a: &Array>, b: &Array>, @@ -519,35 +441,6 @@ pub fn build_eq_x_r_vec_sequential( evals } -pub fn build_eq_x_r_vec_sequential_with_offset( - builder: &mut Builder, - r: &Array>, - offset: Usize, -) -> Array> { - // we build eq(x,r) from its evaluations - // we want to evaluate eq(x,r) over x \in {0, 1}^num_vars - // for example, with num_vars = 4, x is a binary vector of 4, then - // 0 0 0 0 -> (1-r0) * (1-r1) * (1-r2) * (1-r3) - // 1 0 0 0 -> r0 * (1-r1) * (1-r2) * (1-r3) - // 0 1 0 0 -> (1-r0) * r1 * (1-r2) * (1-r3) - // 1 1 0 0 -> r0 * r1 * (1-r2) * (1-r3) - // .... - // 1 1 1 1 -> r0 * r1 * r2 * r3 - // we will need 2^num_var evaluations - - let r_len: Var = builder.eval(r.len() - offset); - let evals_len: Felt = builder.constant(C::F::ONE); - let evals_len = builder.exp_power_of_2_v::>(evals_len, r_len); - let evals_len = builder.cast_felt_to_var(evals_len); - - let evals: Array> = builder.dyn_array(evals_len); - // _debug - // build_eq_x_r_helper_sequential_offset(r, &mut evals, E::ONE); - // unsafe { std::mem::transmute(evals) } - // FIXME: this function is not implemented yet - evals -} - pub fn ceil_log2(x: usize) -> usize { assert!(x > 0, "ceil_log2: x must be positive"); // Calculate the number of bits in usize diff --git a/src/basefold_verifier/verifier.rs b/src/basefold_verifier/verifier.rs index 4f643fc..e12184a 100644 --- a/src/basefold_verifier/verifier.rs +++ b/src/basefold_verifier/verifier.rs @@ -398,7 +398,6 @@ pub mod tests { let executor = VmExecutor::::new(config); executor.execute(program.clone(), witness.clone()).unwrap(); - // _debug let results = executor.execute_segments(program, witness).unwrap(); for seg in results { println!("=> cycle count: {:?}", seg.metrics.cycle_count); diff --git a/src/e2e/mod.rs b/src/e2e/mod.rs index cc9f8d0..203f293 100644 --- a/src/e2e/mod.rs +++ b/src/e2e/mod.rs @@ -1,22 +1,52 @@ use crate::basefold_verifier::basefold::BasefoldCommitment; +use crate::basefold_verifier::query_phase::QueryPhaseVerifierInput; use crate::tower_verifier::binding::IOPProverMessage; -use crate::zkvm_verifier::binding::{TowerProofInput, ZKVMChipProofInput, ZKVMProofInput, E, F}; -use crate::zkvm_verifier::verifier::verify_zkvm_proof; - -use ceno_zkvm::scheme::ZKVMProof; -use ceno_zkvm::structs::ZKVMVerifyingKey; +use crate::zkvm_verifier::binding::{ + GKRProofInput, LayerProofInput, SumcheckLayerProofInput, TowerProofInput, ZKVMChipProofInput, + ZKVMProofInput, E, F, +}; +use crate::zkvm_verifier::verifier::{verify_gkr_circuit, verify_zkvm_proof}; +use ceno_mle::util::ceil_log2; +use ff_ext::BabyBearExt4; +use gkr_iop::gkr::{ + layer::sumcheck_layer::{SumcheckLayer, SumcheckLayerProof}, + GKRCircuit, +}; +use itertools::Itertools; use mpcs::{Basefold, BasefoldRSParams}; - -use openvm_circuit::arch::instructions::program::Program; +use openvm_circuit::arch::{ + instructions::program::Program, verify_single, SystemConfig, VirtualMachine, VmExecutor, +}; +use openvm_native_circuit::{Native, NativeConfig}; use openvm_native_compiler::{ asm::AsmBuilder, conversion::{convert_program, CompilerOptions}, prelude::AsmCompiler, }; use openvm_native_recursion::hints::Hintable; +use openvm_stark_backend::config::StarkGenericConfig; +use openvm_stark_sdk::{ + config::{ + baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, + fri_params::standard_fri_params_with_100_bits_conjectured_security, + setup_tracing_with_log_level, FriParameters, + }, + engine::StarkFriEngine, + p3_baby_bear::BabyBear, +}; +use std::fs::File; + +type SC = BabyBearPoseidon2Config; +type EF = ::Challenge; + +use ceno_zkvm::{ + scheme::{verifier::ZKVMVerifier, ZKVMProof}, + structs::{ComposedConstrainSystem, ZKVMVerifyingKey}, +}; pub fn parse_zkvm_proof_import( - zkvm_proof: ZKVMProof>, + zkvm_proof: ZKVMProof>, + verifier: &ZKVMVerifier>, ) -> ZKVMProofInput { let raw_pi = zkvm_proof .raw_pi @@ -166,6 +196,91 @@ pub fn parse_zkvm_proof_import( fixed_in_evals.push(v_e); } + let circuit_name = &verifier.vk.circuit_index_to_name[chip_id]; + let circuit_vk = &verifier.vk.circuit_vks[circuit_name]; + + let composed_cs = circuit_vk.get_cs(); + let num_instances = chip_proof.num_instances; + let next_pow2_instance = num_instances.next_power_of_two().max(2); + let log2_num_instances = ceil_log2(next_pow2_instance); + let num_var_with_rotation = log2_num_instances + composed_cs.rotation_vars().unwrap_or(0); + + let has_gkr_proof = chip_proof.gkr_iop_proof.is_some(); + let mut gkr_iop_proof = GKRProofInput { + num_var_with_rotation, + num_instances, + layer_proofs: vec![], + }; + + if has_gkr_proof { + let gkr_proof = chip_proof.gkr_iop_proof.clone().unwrap(); + + for layer_proof in gkr_proof.0 { + // rotation + let (has_rotation, rotation): (usize, SumcheckLayerProofInput) = if let Some(p) = + layer_proof.rotation + { + let mut iop_messages: Vec = vec![]; + for m in p.proof.proofs { + let mut evaluations: Vec = vec![]; + for e in m.evaluations { + let v_e: E = + serde_json::from_value(serde_json::to_value(e.clone()).unwrap()) + .unwrap(); + evaluations.push(v_e); + } + iop_messages.push(IOPProverMessage { evaluations }); + } + let mut evals: Vec = vec![]; + for e in p.evals { + let v_e: E = + serde_json::from_value(serde_json::to_value(e.clone()).unwrap()) + .unwrap(); + evals.push(v_e); + } + ( + 1, + SumcheckLayerProofInput { + proof: iop_messages, + evals, + }, + ) + } else { + (0, SumcheckLayerProofInput::default()) + }; + + // main sumcheck + let mut iop_messages: Vec = vec![]; + let mut evals: Vec = vec![]; + for m in layer_proof.main.proof.proofs { + let mut evaluations: Vec = vec![]; + for e in m.evaluations { + let v_e: E = + serde_json::from_value(serde_json::to_value(e.clone()).unwrap()) + .unwrap(); + evaluations.push(v_e); + } + iop_messages.push(IOPProverMessage { evaluations }); + } + for e in layer_proof.main.evals { + let v_e: E = + serde_json::from_value(serde_json::to_value(e.clone()).unwrap()).unwrap(); + evals.push(v_e); + } + + let main = SumcheckLayerProofInput { + proof: iop_messages, + evals, + }; + + gkr_iop_proof.layer_proofs.push(LayerProofInput { + has_rotation, + rotation, + main, + }); + } + } + chip_proofs.push(ZKVMChipProofInput { idx: chip_id.clone(), num_instances: chip_proof.num_instances, @@ -179,10 +294,12 @@ pub fn parse_zkvm_proof_import( main_sumcheck_proofs, wits_in_evals, fixed_in_evals, + has_gkr_proof, + gkr_iop_proof, }); } - let witin_commit: mpcs::BasefoldCommitment = + let witin_commit: mpcs::BasefoldCommitment = serde_json::from_value(serde_json::to_value(zkvm_proof.witin_commit).unwrap()).unwrap(); let witin_commit: BasefoldCommitment = witin_commit.into(); @@ -197,132 +314,111 @@ pub fn parse_zkvm_proof_import( } } -/// Build Ceno's zkVM verifier program from vk in OpenVM's eDSL -pub fn build_zkvm_verifier_program( - vk: &ZKVMVerifyingKey>, -) -> Program { - let mut builder = AsmBuilder::::default(); - - let zkvm_proof_input_variables = ZKVMProofInput::read(&mut builder); - verify_zkvm_proof(&mut builder, zkvm_proof_input_variables, vk); - builder.halt(); - - // Compile program - #[cfg(feature = "bench-metrics")] - let options = CompilerOptions::default().with_cycle_tracker(); - #[cfg(not(feature = "bench-metrics"))] - let options = CompilerOptions::default(); - let mut compiler = AsmCompiler::new(options.word_size); - compiler.build(builder.operations); - let asm_code = compiler.code(); - - let program: Program = convert_program(asm_code, options); - program -} +pub fn inner_test_thread() { + setup_tracing_with_log_level(tracing::Level::WARN); -#[cfg(test)] -mod tests { - use crate::e2e::build_zkvm_verifier_program; - use crate::e2e::parse_zkvm_proof_import; - use crate::zkvm_verifier::binding::{E, F}; - use ceno_zkvm::scheme::ZKVMProof; - use ceno_zkvm::structs::ZKVMVerifyingKey; - use mpcs::{Basefold, BasefoldRSParams}; - use openvm_circuit::arch::verify_single; - use openvm_circuit::arch::VirtualMachine; - use openvm_circuit::arch::{SystemConfig, VmExecutor}; - use openvm_native_circuit::{Native, NativeConfig}; - use openvm_native_recursion::hints::Hintable; - use openvm_stark_sdk::config::{ - baby_bear_poseidon2::BabyBearPoseidon2Engine, - fri_params::standard_fri_params_with_100_bits_conjectured_security, - setup_tracing_with_log_level, FriParameters, - }; - use openvm_stark_sdk::engine::StarkFriEngine; - use std::fs::File; + let proof_path = "./src/e2e/encoded/proof.bin"; + let vk_path = "./src/e2e/encoded/vk.bin"; - pub fn inner_test_thread() { - setup_tracing_with_log_level(tracing::Level::WARN); + let zkvm_proof: ZKVMProof> = + bincode::deserialize_from(File::open(proof_path).expect("Failed to open proof file")) + .expect("Failed to deserialize proof file"); - let proof_path = "./src/e2e/encoded/proof.bin"; - let vk_path = "./src/e2e/encoded/vk.bin"; + let vk: ZKVMVerifyingKey> = + bincode::deserialize_from(File::open(vk_path).expect("Failed to open vk file")) + .expect("Failed to deserialize vk file"); - let zkvm_proof: ZKVMProof> = - bincode::deserialize_from(File::open(proof_path).expect("Failed to open proof file")) - .expect("Failed to deserialize proof file"); + let verifier = ZKVMVerifier::new(vk); + let zkvm_proof_input = parse_zkvm_proof_import(zkvm_proof, &verifier); - let vk: ZKVMVerifyingKey> = - bincode::deserialize_from(File::open(vk_path).expect("Failed to open vk file")) - .expect("Failed to deserialize vk file"); + // OpenVM DSL + let mut builder = AsmBuilder::::default(); - let program = build_zkvm_verifier_program(&vk); + // Obtain witness inputs + let zkvm_proof_input_variables = ZKVMProofInput::read(&mut builder); + verify_zkvm_proof(&mut builder, zkvm_proof_input_variables, &verifier); + builder.halt(); - // Construct zkvm proof input - let zkvm_proof_input = parse_zkvm_proof_import(zkvm_proof); + // Pass in witness stream + let mut witness_stream: Vec< + Vec>, + > = Vec::new(); - // Pass in witness stream - let mut witness_stream: Vec> = Vec::new(); - witness_stream.extend(zkvm_proof_input.write()); + witness_stream.extend(zkvm_proof_input.write()); - let mut system_config = SystemConfig::default() - .with_public_values(4) - .with_max_segment_len((1 << 25) - 100); - system_config.profiling = true; - let config = NativeConfig::new(system_config, Native); + // Compile program + let options = CompilerOptions::default().with_cycle_tracker(); + let mut compiler = AsmCompiler::new(options.word_size); + compiler.build(builder.operations); + let asm_code = compiler.code(); - let executor = VmExecutor::::new(config); + // _debug: print out assembly + /* + println!("=> AssemblyCode:"); + println!("{asm_code}"); + return (); + */ + + let program: Program< + p3_monty_31::MontyField31, + > = convert_program(asm_code, options); + let mut system_config = SystemConfig::default() + .with_public_values(4) + .with_max_segment_len((1 << 25) - 100); + system_config.profiling = true; + let config = NativeConfig::new(system_config, Native); + + let executor = VmExecutor::::new(config); + + let res = executor + .execute_and_then( + program.clone(), + witness_stream.clone(), + |_, seg| Ok(seg), + |err| err, + ) + .unwrap(); + + for (i, seg) in res.iter().enumerate() { + println!("=> segment {:?} metrics: {:?}", i, seg.metrics); + } - let res = executor - .execute_and_then( - program.clone(), - witness_stream.clone(), - |_, seg| Ok(seg), - |err| err, - ) - .unwrap(); + let poseidon2_max_constraint_degree = 3; + let log_blowup = 1; - for (i, seg) in res.iter().enumerate() { - println!("=> segment {:?} metrics: {:?}", i, seg.metrics); + let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { + FriParameters { + log_blowup, + log_final_poly_len: 0, + num_queries: 10, + proof_of_work_bits: 0, } + } else { + standard_fri_params_with_100_bits_conjectured_security(log_blowup) + }; - let poseidon2_max_constraint_degree = 3; - // TODO: use log_blowup = 1 when native multi_observe chip reduces max constraint degree to 3 - let log_blowup = 2; - - let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { - FriParameters { - log_blowup, - log_final_poly_len: 0, - num_queries: 10, - proof_of_work_bits: 0, - } - } else { - standard_fri_params_with_100_bits_conjectured_security(log_blowup) - }; - - let engine = BabyBearPoseidon2Engine::new(fri_params); - let mut config = NativeConfig::aggregation(0, poseidon2_max_constraint_degree); - config.system.memory_config.max_access_adapter_n = 16; + let engine = BabyBearPoseidon2Engine::new(fri_params); + let mut config = NativeConfig::aggregation(0, poseidon2_max_constraint_degree); + config.system.memory_config.max_access_adapter_n = 16; - let vm = VirtualMachine::new(engine, config); + let vm = VirtualMachine::new(engine, config); - let pk = vm.keygen(); - let result = vm.execute_and_generate(program, witness_stream).unwrap(); - let proofs = vm.prove(&pk, result); - for proof in proofs { - verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); - } + let pk = vm.keygen(); + let result = vm.execute_and_generate(program, witness_stream).unwrap(); + let proofs = vm.prove(&pk, result); + for proof in proofs { + verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); } +} - #[test] - pub fn test_zkvm_verifier() { - let stack_size = 64 * 1024 * 1024; // 64 MB +#[test] +pub fn test_zkvm_verifier() { + let stack_size = 64 * 1024 * 1024; // 64 MB - let handler = std::thread::Builder::new() - .stack_size(stack_size) - .spawn(inner_test_thread) - .expect("Failed to spawn thread"); + let handler = std::thread::Builder::new() + .stack_size(stack_size) + .spawn(inner_test_thread) + .expect("Failed to spawn thread"); - handler.join().expect("Thread panicked"); - } + handler.join().expect("Thread panicked"); } diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 347f556..60ec658 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -1,17 +1,7 @@ #[cfg(test)] mod tests { - use crate::arithmetics::{challenger_multi_observe, exts_to_felts}; - - use crate::zkvm_verifier::binding::{E, F}; - use ceno_mle::expression::StructuralWitIn; - use ceno_zkvm::{circuit_builder::SetTableSpec, scheme::verifier::ZKVMVerifier}; - use ff_ext::BabyBearExt4; - use itertools::interleave; - use itertools::max; - use itertools::Itertools; - use mpcs::BasefoldCommitment; - use mpcs::{Basefold, BasefoldRSParams}; + use crate::zkvm_verifier::binding::F; use openvm_circuit::arch::SystemConfig; use openvm_circuit::arch::VmExecutor; use openvm_native_circuit::Native; @@ -19,22 +9,15 @@ mod tests { use openvm_native_compiler::conversion::convert_program; use openvm_native_compiler::prelude::*; use openvm_native_compiler::{asm::AsmBuilder, conversion::CompilerOptions}; - use openvm_native_compiler_derive::iter_zip; - use openvm_native_recursion::challenger::{self, CanSampleVariable}; + use openvm_native_recursion::challenger::CanSampleVariable; use openvm_native_recursion::challenger::{ - duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, + duplex::DuplexChallengerVariable, CanObserveVariable, }; - use openvm_native_recursion::hints::Hintable; use openvm_stark_backend::config::StarkGenericConfig; + use openvm_stark_backend::p3_field::{Field, FieldAlgebra}; use openvm_stark_sdk::{ config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, }; - use p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra}; - - type Pcs = Basefold; - const NUM_FANIN: usize = 2; - const MAINCONSTRAIN_SUMCHECK_BATCH_SIZE: usize = 3; // read/write/lookup - const SEL_DEGREE: usize = 2; type SC = BabyBearPoseidon2Config; type EF = ::Challenge; diff --git a/src/tower_verifier/binding.rs b/src/tower_verifier/binding.rs index 26c888a..2e4749f 100644 --- a/src/tower_verifier/binding.rs +++ b/src/tower_verifier/binding.rs @@ -3,15 +3,14 @@ use openvm_native_compiler::{ ir::{Array, Builder, Config}, prelude::*, }; -use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; pub type F = BabyBear; pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; +use openvm_stark_backend::p3_field::extension::BinomialExtensionField; use openvm_stark_sdk::p3_baby_bear::BabyBear; -use p3_field::extension::BinomialExtensionField; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; #[derive(DslVariable, Clone)] pub struct PointVariable { @@ -102,20 +101,3 @@ impl Hintable for IOPProverMessage { } } impl VecAutoHintable for IOPProverMessage {} - -pub struct TowerVerifierInput { - pub prod_out_evals: Vec>, - pub logup_out_evals: Vec>, - pub num_variables: Vec, - pub num_fanin: usize, - - // TowerProof - pub num_proofs: usize, - pub num_prod_specs: usize, - pub num_logup_specs: usize, - pub _max_num_variables: usize, - - pub proofs: Vec>, - pub prod_specs_eval: Vec>>, - pub logup_specs_eval: Vec>>, -} diff --git a/src/tower_verifier/program.rs b/src/tower_verifier/program.rs index 85d097b..c5f08a8 100644 --- a/src/tower_verifier/program.rs +++ b/src/tower_verifier/program.rs @@ -1,19 +1,17 @@ use super::binding::{IOPProverMessageVariable, PointAndEvalVariable, PointVariable}; use crate::arithmetics::{ - challenger_multi_observe, dot_product, eq_eval, evaluate_at_point_degree_1, extend, - exts_to_felts, fixed_dot_product, gen_alpha_pows, is_smaller_than, print_ext_arr, reverse, - UniPolyExtrapolator, + challenger_multi_observe, eq_eval, evaluate_at_point_degree_1, extend, exts_to_felts, + fixed_dot_product, reverse, UniPolyExtrapolator, }; use crate::transcript::transcript_observe_label; use crate::zkvm_verifier::binding::TowerProofInputVariable; use ceno_zkvm::scheme::constants::NUM_FANIN; use openvm_native_compiler::prelude::*; use openvm_native_compiler_derive::iter_zip; -use openvm_native_recursion::challenger::ChallengerVariable; use openvm_native_recursion::challenger::{ duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, }; -use p3_field::FieldAlgebra; +use openvm_stark_backend::p3_field::FieldAlgebra; pub(crate) fn interpolate_uni_poly( builder: &mut Builder, @@ -86,51 +84,6 @@ pub(crate) fn interpolate_uni_poly( res } -// Interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this -// polynomial at `eval_at`: -// -// \sum_{i=0}^len p_i * (\prod_{j!=i} (eval_at - j)/(i-j) ) -// -pub(crate) fn interpolate_uni_poly_with_weights( - builder: &mut Builder, - p_i: &Array>, - eval_at: Ext, - interpolation_weights: &Array>>, -) -> Ext { - // \prod_i (eval_at - i) - let weights_idx: Usize = builder.eval(p_i.len() - Usize::from(2)); - let weights = builder.get(interpolation_weights, weights_idx); - let num_points = p_i.len().get_var(); - - let one: Ext = builder.constant(C::EF::ONE); - let zero: Ext = builder.constant(C::EF::ZERO); - let mut iter_i: Ext = builder.eval(zero + zero); // 0 + 0 to take advantage of AddE - let prod: Ext = builder.eval(one + zero); // 1 + 0 to take advantage of AddE - builder.range(0, num_points).for_each(|_, builder| { - builder.assign(&prod, prod * (eval_at - iter_i)); - builder.assign(&iter_i, iter_i + one); - }); - - iter_i = builder.eval(zero + zero); // reset to 0 - let result = zero; // take ownership - iter_zip!(builder, p_i, weights).for_each(|ptr_vec, builder| { - let pi_ptr = ptr_vec[0]; - let w_ptr = ptr_vec[1]; - - let p_i_val = builder.iter_ptr_get(p_i, pi_ptr); - let weight = builder.iter_ptr_get(&weights, w_ptr); - - // weight_i = \prod_{j!=i} 1/(i-j) - // \sum_{i=0}^len p_i * weight_i * prod / (eval_at-i) - let e: Ext = builder.eval(eval_at - iter_i); - let term = p_i_val * weight * prod / e; // TODO: how to handle e = 0 - builder.assign(&iter_i, iter_i + one); - builder.assign(&result, result + term); - }); - - result -} - pub fn iop_verifier_state_verify( builder: &mut Builder, challenger: &mut DuplexChallengerVariable, @@ -164,6 +117,7 @@ pub fn iop_verifier_state_verify( .range(0, max_num_variables_usize.clone()) .for_each(|i_vec, builder| { let i = i_vec[0]; + // TODO: this takes 7 cycles, can we optimize it? let prover_msg = builder.get(&prover_messages, i); @@ -178,6 +132,7 @@ pub fn iop_verifier_state_verify( let e1 = builder.get(&prover_msg.evaluations, 0); let e2 = builder.get(&prover_msg.evaluations, 1); let target: Ext<::F, ::EF> = builder.eval(e1 + e2); + builder.assert_ext_eq(expected, target); let p_r = unipoly_extrapolator.extrapolate_uni_poly( @@ -349,7 +304,7 @@ pub fn verify_tower_proof( builder.set(&interleaved_point_n_eval, q_i, q); }); - let mut initial_claim: Ext = builder.eval(zero + zero); + let initial_claim: Ext = builder.eval(zero + zero); iter_zip!(builder, prod_spec_point_n_eval).for_each(|ptr_vec, builder| { let ptr = ptr_vec[0]; @@ -371,11 +326,9 @@ pub fn verify_tower_proof( let op_range: RVar = builder.eval_expr(max_num_variables - Usize::from(1)); let round: Felt = builder.constant(C::F::ZERO); - let mut next_rt = PointAndEvalVariable { - point: PointVariable { - fs: builder.dyn_array(1), - }, - eval: builder.constant(C::EF::ZERO), + let next_rt = PointAndEvalVariable { + point: PointVariable { fs: initial_rt }, + eval: initial_claim, }; builder diff --git a/src/transcript/mod.rs b/src/transcript/mod.rs index 45af2e2..0d8792c 100644 --- a/src/transcript/mod.rs +++ b/src/transcript/mod.rs @@ -1,10 +1,8 @@ use ff_ext::{BabyBearExt4, ExtensionField as CenoExtensionField, SmallField}; use openvm_native_compiler::prelude::*; -use openvm_native_recursion::challenger::{ - duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, -}; -use openvm_native_recursion::challenger::{CanSampleBitsVariable, ChallengerVariable}; -use p3_field::FieldAlgebra; +use openvm_native_recursion::challenger::CanSampleBitsVariable; +use openvm_native_recursion::challenger::{duplex::DuplexChallengerVariable, CanObserveVariable}; +use openvm_stark_backend::p3_field::FieldAlgebra; pub fn transcript_observe_label( builder: &mut Builder, diff --git a/src/zkvm_verifier/binding.rs b/src/zkvm_verifier/binding.rs index 3645f86..fdf7b72 100644 --- a/src/zkvm_verifier/binding.rs +++ b/src/zkvm_verifier/binding.rs @@ -2,15 +2,10 @@ use crate::arithmetics::next_pow2_instance_padding; use crate::basefold_verifier::basefold::{ BasefoldCommitment, BasefoldCommitmentVariable, BasefoldProof, BasefoldProofVariable, }; -use crate::basefold_verifier::query_phase::{ - QueryPhaseVerifierInput, QueryPhaseVerifierInputVariable, -}; use crate::{ arithmetics::ceil_log2, - tower_verifier::binding::{IOPProverMessage, IOPProverMessageVariable}, + tower_verifier::binding::{IOPProverMessage, IOPProverMessageVariable, PointVariable}, }; -use ark_std::iterable::Iterable; -use ff_ext::BabyBearExt4; use itertools::Itertools; use openvm_native_compiler::{ asm::AsmConfig, @@ -19,8 +14,8 @@ use openvm_native_compiler::{ }; use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; +use openvm_stark_backend::p3_field::{extension::BinomialExtensionField, FieldAlgebra}; use openvm_stark_sdk::p3_baby_bear::BabyBear; -use p3_field::{extension::BinomialExtensionField, FieldAlgebra}; pub type F = BabyBear; pub type E = BinomialExtensionField; @@ -71,6 +66,9 @@ pub struct ZKVMChipProofInputVariable { pub main_sel_sumcheck_proofs: Array>, pub wits_in_evals: Array>, pub fixed_in_evals: Array>, + + pub has_gkr_proof: Usize, + pub gkr_iop_proof: GKRProofVariable, } pub(crate) struct ZKVMProofInput { @@ -269,6 +267,10 @@ pub struct ZKVMChipProofInput { pub main_sumcheck_proofs: Vec, pub wits_in_evals: Vec, pub fixed_in_evals: Vec, + + // gkr proof + pub has_gkr_proof: bool, + pub gkr_iop_proof: GKRProofInput, } impl VecAutoHintable for ZKVMChipProofInput {} @@ -296,6 +298,9 @@ impl Hintable for ZKVMChipProofInput { let wits_in_evals = Vec::::read(builder); let fixed_in_evals = Vec::::read(builder); + let has_gkr_proof = Usize::Var(usize::read(builder)); + let gkr_iop_proof = GKRProofInput::read(builder); + ZKVMChipProofInputVariable { idx, idx_felt, @@ -312,6 +317,8 @@ impl Hintable for ZKVMChipProofInput { main_sel_sumcheck_proofs, wits_in_evals, fixed_in_evals, + has_gkr_proof, + gkr_iop_proof, } } @@ -353,7 +360,148 @@ impl Hintable for ZKVMChipProofInput { stream.extend(self.main_sumcheck_proofs.write()); stream.extend(self.wits_in_evals.write()); stream.extend(self.fixed_in_evals.write()); + if self.has_gkr_proof { + stream.extend(>::write(&1)); + } else { + stream.extend(>::write(&0)); + } + stream.extend(self.gkr_iop_proof.write()); + + stream + } +} + +#[derive(Default)] +pub struct SumcheckLayerProofInput { + pub proof: Vec, + pub evals: Vec, +} +#[derive(DslVariable, Clone)] +pub struct SumcheckLayerProofVariable { + pub proof: Array>, + pub evals: Array>, + pub evals_len_div_3: Var, +} +impl VecAutoHintable for SumcheckLayerProofInput {} +impl Hintable for SumcheckLayerProofInput { + type HintVariable = SumcheckLayerProofVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let proof = Vec::::read(builder); + let evals = Vec::::read(builder); + let evals_len_div_3 = usize::read(builder); + + Self::HintVariable { + proof, + evals, + evals_len_div_3, + } + } + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.proof.write()); + stream.extend(self.evals.write()); + let evals_len_div_3 = self.evals.len() / 3; + stream.extend(>::write(&evals_len_div_3)); + stream + } +} +pub struct LayerProofInput { + pub has_rotation: usize, + pub rotation: SumcheckLayerProofInput, + pub main: SumcheckLayerProofInput, +} +#[derive(DslVariable, Clone)] +pub struct LayerProofVariable { + pub has_rotation: Usize, + pub rotation: SumcheckLayerProofVariable, + pub main: SumcheckLayerProofVariable, +} +impl VecAutoHintable for LayerProofInput {} +impl Hintable for LayerProofInput { + type HintVariable = LayerProofVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let has_rotation = Usize::Var(usize::read(builder)); + let rotation = SumcheckLayerProofInput::read(builder); + let main = SumcheckLayerProofInput::read(builder); + + Self::HintVariable { + has_rotation, + rotation, + main, + } + } + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(>::write(&self.has_rotation)); + stream.extend(self.rotation.write()); + stream.extend(self.main.write()); + stream + } +} +#[derive(Default)] +pub struct GKRProofInput { + pub num_var_with_rotation: usize, + pub num_instances: usize, + pub layer_proofs: Vec, +} +#[derive(DslVariable, Clone)] +pub struct GKRProofVariable { + pub num_var_with_rotation: Usize, + pub num_instances_minus_one_bit_decomposition: Array>, + pub layer_proofs: Array>, +} +impl Hintable for GKRProofInput { + type HintVariable = GKRProofVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let num_var_with_rotation = Usize::Var(usize::read(builder)); + let num_instances_minus_one_bit_decomposition = Vec::::read(builder); + let layer_proofs = Vec::::read(builder); + Self::HintVariable { + num_var_with_rotation, + num_instances_minus_one_bit_decomposition, + layer_proofs, + } + } + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(>::write( + &self.num_var_with_rotation, + )); + let eq_instance = self.num_instances - 1; + let mut bit_decomp: Vec = vec![]; + for i in 0..32usize { + bit_decomp.push(F::from_canonical_usize((eq_instance >> i) & 1)); + } + stream.extend(bit_decomp.write()); + stream.extend(self.layer_proofs.write()); stream } } + +#[derive(DslVariable, Clone)] +pub struct ClaimAndPoint { + pub evals: Array>, + pub has_point: Usize, + pub point: PointVariable, +} + +#[derive(DslVariable, Clone)] +pub struct RotationClaim { + pub left_evals: Array>, + pub right_evals: Array>, + pub target_evals: Array>, + pub left_point: Array>, + pub right_point: Array>, + pub origin_point: Array>, +} + +#[derive(DslVariable, Clone)] +pub struct GKRClaimEvaluation { + pub value: Ext, + pub point: PointVariable, + pub poly: Usize, +} diff --git a/src/zkvm_verifier/verifier.rs b/src/zkvm_verifier/verifier.rs index c5e5776..2a9286b 100644 --- a/src/zkvm_verifier/verifier.rs +++ b/src/zkvm_verifier/verifier.rs @@ -1,7 +1,10 @@ -use super::binding::{ZKVMChipProofInputVariable, ZKVMProofInputVariable}; +use super::binding::{ + ClaimAndPoint, GKRClaimEvaluation, RotationClaim, ZKVMChipProofInputVariable, + ZKVMProofInputVariable, +}; use crate::arithmetics::{ - challenger_multi_observe, eval_ceno_expr_with_instance, print_ext_arr, print_felt_arr, - PolyEvaluator, UniPolyExtrapolator, + challenger_multi_observe, eq_eval, eval_ceno_expr_with_instance, PolyEvaluator, + UniPolyExtrapolator, }; use crate::basefold_verifier::basefold::{ BasefoldCommitmentVariable, RoundOpeningVariable, RoundVariable, @@ -9,40 +12,50 @@ use crate::basefold_verifier::basefold::{ use crate::basefold_verifier::mmcs::MmcsCommitmentVariable; use crate::basefold_verifier::query_phase::PointAndEvalsVariable; use crate::basefold_verifier::utils::pow_2; -use crate::basefold_verifier::verifier::batch_verify; +// use crate::basefold_verifier::verifier::batch_verify; use crate::tower_verifier::program::verify_tower_proof; use crate::transcript::transcript_observe_label; +use crate::zkvm_verifier::binding::{ + GKRProofVariable, LayerProofVariable, SumcheckLayerProofVariable, +}; use crate::{ arithmetics::{ build_eq_x_r_vec_sequential, ceil_log2, concat, dot_product as ext_dot_product, eq_eval_less_or_equal_than, eval_wellform_address_vec, gen_alpha_pows, max_usize_arr, - max_usize_vec, nested_product, next_pow2_instance_padding, product, sum as ext_sum, + max_usize_vec, nested_product, + }, + tower_verifier::{ + binding::{PointAndEvalVariable, PointVariable}, + program::iop_verifier_state_verify, }, - tower_verifier::{binding::PointVariable, program::iop_verifier_state_verify}, }; -use ceno_mle::expression::{Instance, StructuralWitIn}; -use ceno_zkvm::e2e::B; -use ceno_zkvm::structs::{VerifyingKey, ZKVMVerifyingKey}; -use ceno_zkvm::{circuit_builder::SetTableSpec, scheme::verifier::ZKVMVerifier}; +use ceno_mle::expression::{Expression, Instance, StructuralWitIn}; +use ceno_zkvm::structs::VerifyingKey; +use ceno_zkvm::{ + circuit_builder::SetTableSpec, scheme::verifier::ZKVMVerifier, structs::ComposedConstrainSystem, +}; use ff_ext::BabyBearExt4; -use itertools::max; -use itertools::{interleave, Itertools}; +use gkr_iop::gkr::layer::ROTATION_OPENING_COUNT; +use gkr_iop::{ + evaluation::EvalExpression, + gkr::{booleanhypercube::BooleanHypercube, layer::Layer, GKRCircuit}, + selector::SelectorType, +}; +use itertools::{interleave, izip, Itertools}; use mpcs::{Basefold, BasefoldRSParams}; use openvm_native_compiler::prelude::*; use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::challenger::{ duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, }; +use openvm_stark_backend::p3_field::FieldAlgebra; use p3_baby_bear::BabyBear; -use p3_field::{Field, FieldAlgebra}; type F = BabyBear; type E = BabyBearExt4; type Pcs = Basefold; const NUM_FANIN: usize = 2; -const MAINCONSTRAIN_SUMCHECK_BATCH_SIZE: usize = 3; // read/write/lookup -const SEL_DEGREE: usize = 2; pub fn transcript_group_observe_label( builder: &mut Builder, @@ -80,7 +93,7 @@ pub fn transcript_group_sample_ext( pub fn verify_zkvm_proof>( builder: &mut Builder, zkvm_proof_input: ZKVMProofInputVariable, - vk: &ZKVMVerifyingKey, + vk: &ZKVMVerifier, ) { let mut challenger = DuplexChallengerVariable::new(builder); transcript_observe_label(builder, &mut challenger, b"riscv"); @@ -108,7 +121,7 @@ pub fn verify_zkvm_proof>( }, ); - let fixed_commit = if let Some(fixed_commit) = vk.fixed_commit.as_ref() { + let fixed_commit = if let Some(fixed_commit) = vk.vk.fixed_commit.as_ref() { let commit: crate::basefold_verifier::hash::Hash = fixed_commit.commit().into(); let commit_array: Array> = builder.dyn_array(commit.value.len()); commit.value.into_iter().enumerate().for_each(|(i, v)| { @@ -180,14 +193,17 @@ pub fn verify_zkvm_proof>( let dummy_table_item_multiplicity: Var = builder.constant(C::N::ZERO); let num_fixed_opening = vk + .vk .circuit_vks .values() .filter(|c| c.get_cs().num_fixed() > 0) .count(); + let witin_openings: Array> = builder.dyn_array(zkvm_proof_input.chip_proofs.len()); let fixed_openings: Array> = builder.dyn_array(Usize::from(num_fixed_opening)); + let num_chips_verified: Usize = builder.eval(C::N::ZERO); let num_chips_have_fixed: Usize = builder.eval(C::N::ZERO); @@ -200,16 +216,65 @@ pub fn verify_zkvm_proof>( builder.set(&chip_indices, i, chip_proof.idx); }); - // iterate over all chips - for (i, chip_vk) in vk.circuit_vks.values().enumerate() { + for (i, (circuit_name, chip_vk)) in vk.vk.circuit_vks.iter().enumerate() { let chip_id: Var = builder.get(&chip_indices, num_chips_verified.get_var()); + builder.if_eq(chip_id, RVar::from(i)).then(|builder| { let chip_proof = builder.get(&zkvm_proof_input.chip_proofs, num_chips_verified.get_var()); + let circuit_vk = &vk.vk.circuit_vks[circuit_name]; + + builder.assert_usize_eq( + chip_proof.wits_in_evals.len(), + Usize::from(circuit_vk.get_cs().num_witin()), + ); + builder.assert_usize_eq( + chip_proof.fixed_in_evals.len(), + Usize::from(circuit_vk.get_cs().num_fixed()), + ); + builder.assert_usize_eq( + chip_proof.record_r_out_evals.len(), + Usize::from(circuit_vk.get_cs().num_reads()), + ); + builder.assert_usize_eq( + chip_proof.record_w_out_evals.len(), + Usize::from(circuit_vk.get_cs().num_writes()), + ); + builder.assert_usize_eq( + chip_proof.record_lk_out_evals.len(), + Usize::from(circuit_vk.get_cs().num_lks()), + ); + + let chip_logup_sum: Ext = builder.constant(C::EF::ZERO); + iter_zip!(builder, chip_proof.record_lk_out_evals).for_each(|ptr_vec, builder| { + let evals = builder.iter_ptr_get(&chip_proof.record_lk_out_evals, ptr_vec[0]); + let p1 = builder.get(&evals, 0); + let p2 = builder.get(&evals, 1); + let q1 = builder.get(&evals, 2); + let q2 = builder.get(&evals, 3); + + builder.assign(&chip_logup_sum, chip_logup_sum + p1 * q1.inverse()); + builder.assign(&chip_logup_sum, chip_logup_sum + p2 * q2.inverse()); + }); challenger.observe(builder, chip_proof.idx_felt); builder.cycle_tracker_start("Verify chip proof"); let input_opening_point = if chip_vk.get_cs().is_opcode_circuit() { + // getting the number of dummy padding item that we used in this opcode circuit + let num_lks = chip_vk.get_cs().num_lks(); + // FIXME: use builder to compute this + let num_instances = pow_2(builder, chip_proof.log2_num_instances.get_var()); + let num_padded_instance: Var = + builder.eval(num_instances - chip_proof.num_instances.clone()); + + let new_multiplicity: Usize = + builder.eval(Usize::from(num_lks) * Usize::from(num_padded_instance)); + builder.assign( + &dummy_table_item_multiplicity, + dummy_table_item_multiplicity + new_multiplicity, + ); + + builder.assign(&logup_sum, logup_sum + chip_logup_sum); verify_opcode_proof( builder, &mut challenger, @@ -220,62 +285,22 @@ pub fn verify_zkvm_proof>( &mut unipoly_extrapolator, ) } else { + builder.assign(&logup_sum, logup_sum - chip_logup_sum); verify_table_proof( builder, &mut challenger, &chip_proof, + &zkvm_proof_input.raw_pi, + &zkvm_proof_input.raw_pi_num_variables, &zkvm_proof_input.pi_evals, &challenges, &chip_vk, &mut unipoly_extrapolator, + &mut poly_evaluator, ) }; builder.cycle_tracker_end("Verify chip proof"); - // getting the number of dummy padding item that we used in this opcode circuit - if chip_vk.get_cs().is_opcode_circuit() { - let num_lks = chip_vk.get_cs().num_lks(); - // FIXME: use builder to compute this - let num_instances = pow_2(builder, chip_proof.log2_num_instances.get_var()); - let num_padded_instance: Var = - builder.eval(num_instances - chip_proof.num_instances); - - let new_multiplicity: Usize = - builder.eval(Usize::from(num_lks) * Usize::from(num_padded_instance)); - builder.assign( - &dummy_table_item_multiplicity, - dummy_table_item_multiplicity + new_multiplicity, - ); - } - - let record_r_out_evals_prod = nested_product(builder, &chip_proof.record_r_out_evals); - builder.assign(&prod_r, prod_r * record_r_out_evals_prod); - - let record_w_out_evals_prod = nested_product(builder, &chip_proof.record_w_out_evals); - builder.assign(&prod_w, prod_w * record_w_out_evals_prod); - - let sign: Ext = if chip_vk.get_cs().is_opcode_circuit() { - builder.constant(C::EF::ONE) - } else { - builder.constant(-C::EF::ONE) - }; - - iter_zip!(builder, chip_proof.record_lk_out_evals).for_each(|ptr_vec, builder| { - let evals = builder.iter_ptr_get(&chip_proof.record_lk_out_evals, ptr_vec[0]); - let p1 = builder.get(&evals, 0); - let p2 = builder.get(&evals, 1); - let q1 = builder.get(&evals, 2); - let q2 = builder.get(&evals, 3); - - builder.assign(&logup_sum, logup_sum + sign * p1 * q1.inverse()); - builder.assign(&logup_sum, logup_sum + sign * p2 * q2.inverse()); - }); - - builder.assert_usize_eq( - chip_proof.log2_num_instances.clone(), - input_opening_point.len(), - ); - let witin_round: RoundOpeningVariable = builder.eval(RoundOpeningVariable { num_var: chip_proof.log2_num_instances.get_var(), point_and_evals: PointAndEvalsVariable { @@ -303,9 +328,16 @@ pub fn verify_zkvm_proof>( builder.inc(&num_chips_have_fixed); } + let record_r_out_evals_prod = nested_product(builder, &chip_proof.record_r_out_evals); + builder.assign(&prod_r, prod_r * record_r_out_evals_prod); + + let record_w_out_evals_prod = nested_product(builder, &chip_proof.record_w_out_evals); + builder.assign(&prod_w, prod_w * record_w_out_evals_prod); + builder.inc(&num_chips_verified); }); } + builder.assert_usize_eq(num_chips_have_fixed, Usize::from(num_fixed_opening)); builder.assert_eq::>(num_chips_verified, chip_indices.len()); @@ -316,7 +348,7 @@ pub fn verify_zkvm_proof>( logup_sum - dummy_table_item_multiplicity * dummy_table_item.inverse(), ); - let rounds = if num_fixed_opening > 0 { + let rounds: Array> = if num_fixed_opening > 0 { builder.dyn_array(2) } else { builder.dyn_array(1) @@ -330,6 +362,7 @@ pub fn verify_zkvm_proof>( perm: zkvm_proof_input.witin_perm.clone(), }, ); + if num_fixed_opening > 0 { builder.set( &rounds, @@ -342,16 +375,15 @@ pub fn verify_zkvm_proof>( ); } - builder.cycle_tracker_start("Basefold verify"); + /* _debug batch_verify( builder, zkvm_proof_input.max_num_var, - zkvm_proof_input.max_width, rounds, zkvm_proof_input.pcs_proof, &mut challenger, ); - builder.cycle_tracker_end("Basefold verify"); + */ let empty_arr: Array> = builder.dyn_array(0); let initial_global_state = eval_ceno_expr_with_instance( @@ -361,7 +393,7 @@ pub fn verify_zkvm_proof>( &empty_arr, &zkvm_proof_input.pi_evals, &challenges, - &vk.initial_global_state_expr, + &vk.vk.initial_global_state_expr, ); builder.assign(&prod_w, prod_w * initial_global_state); @@ -372,7 +404,7 @@ pub fn verify_zkvm_proof>( &empty_arr, &zkvm_proof_input.pi_evals, &challenges, - &vk.finalize_global_state_expr, + &vk.vk.finalize_global_state_expr, ); builder.assign(&prod_r, prod_r * finalize_global_state); @@ -392,33 +424,42 @@ pub fn verify_opcode_proof( ) -> Array> { let cs = vk.get_cs(); let one: Ext = builder.constant(C::EF::ONE); - let zero: Ext = builder.constant(C::EF::ZERO); let r_len = cs.zkvm_v1_css.r_expressions.len(); let w_len = cs.zkvm_v1_css.w_expressions.len(); let lk_len = cs.zkvm_v1_css.lk_expressions.len(); let num_batched = r_len + w_len + lk_len; - let chip_record_alpha: Ext = builder.get(challenges, 0); let r_counts_per_instance: Usize = Usize::from(r_len); let w_counts_per_instance: Usize = Usize::from(w_len); let lk_counts_per_instance: Usize = Usize::from(lk_len); let num_batched: Usize = Usize::from(num_batched); - let log2_r_count: Usize = Usize::from(ceil_log2(r_len)); - let log2_w_count: Usize = Usize::from(ceil_log2(w_len)); - let log2_lk_count: Usize = Usize::from(ceil_log2(lk_len)); - let log2_num_instances = opcode_proof.log2_num_instances.clone(); + let num_var_with_rotation: Usize = Usize::Var(Var::uninit(builder)); + builder + .if_eq(opcode_proof.has_gkr_proof.clone(), Usize::from(1)) + .then_or_else( + |builder| { + builder.assign( + &num_var_with_rotation, + opcode_proof.gkr_iop_proof.num_var_with_rotation.clone(), + ); + }, + |builder| { + builder.assign(&num_var_with_rotation, log2_num_instances.clone()); + }, + ); + let tower_proof = &opcode_proof.tower_proof; let num_variables: Array> = builder.dyn_array(num_batched); builder .range(0, num_variables.len()) .for_each(|idx_vec, builder| { - builder.set(&num_variables, idx_vec[0], log2_num_instances.clone()); + builder.set(&num_variables, idx_vec[0], num_var_with_rotation.clone()); }); let prod_out_evals: Array>> = concat( @@ -428,191 +469,707 @@ pub fn verify_opcode_proof( ); let num_fanin: Usize = Usize::from(NUM_FANIN); - let max_expr_len = *max([r_len, w_len, lk_len].iter()).unwrap(); builder.cycle_tracker_start("verify tower proof for opcode"); - let (rt_tower, record_evals, logup_p_evals, logup_q_evals) = verify_tower_proof( + let (_, record_evals, logup_p_evals, logup_q_evals) = verify_tower_proof( builder, challenger, prod_out_evals, &opcode_proof.record_lk_out_evals, num_variables, num_fanin, - log2_num_instances.clone(), + num_var_with_rotation.clone(), tower_proof, unipoly_extrapolator, ); builder.cycle_tracker_end("verify tower proof for opcode"); - // verify LogUp witness nominator p(x) ?= constant vector 1 - iter_zip!(builder, logup_p_evals).for_each(|ptr_vec, builder| { - let logup_p_eval = builder.iter_ptr_get(&logup_p_evals, ptr_vec[0]).eval; - builder.assert_ext_eq(logup_p_eval, one); - }); + let logup_p_eval = builder.get(&logup_p_evals, 0).eval; + builder.assert_ext_eq(logup_p_eval, one); // verify zero statement (degree > 1) + sel sumcheck - let rt = builder.get(&record_evals, 0); + let _rt = builder.get(&record_evals, 0); let num_rw_records: Usize = builder.eval(r_counts_per_instance + w_counts_per_instance); builder.assert_usize_eq(record_evals.len(), num_rw_records.clone()); + builder.assert_usize_eq(logup_p_evals.len(), lk_counts_per_instance.clone()); + builder.assert_usize_eq(logup_q_evals.len(), lk_counts_per_instance.clone()); + + let composed_cs = vk.get_cs(); + let ComposedConstrainSystem { + zkvm_v1_css: _, + gkr_circuit, + } = &composed_cs; + let gkr_circuit = gkr_circuit.clone().unwrap(); + + let out_evals_len: Usize = builder.eval(record_evals.len() + logup_q_evals.len()); + let out_evals: Array> = builder.dyn_array(out_evals_len.clone()); + builder + .range(0, record_evals.len()) + .for_each(|idx_vec, builder| { + let cpt = builder.get(&record_evals, idx_vec[0]); + builder.set(&out_evals, idx_vec[0], cpt); + }); + let q_slice = out_evals.slice(builder, record_evals.len(), out_evals_len); + builder + .range(0, logup_q_evals.len()) + .for_each(|idx_vec, builder| { + let cpt = builder.get(&logup_q_evals, idx_vec[0]); + builder.set(&q_slice, idx_vec[0], cpt); + }); - let alpha_len = builder.eval( - num_rw_records.clone() - + lk_counts_per_instance - + Usize::from(cs.zkvm_v1_css.assert_zero_sumcheck_expressions.len()), - ); - transcript_observe_label(builder, challenger, b"combine subset evals"); - let alpha_pow = gen_alpha_pows(builder, challenger, alpha_len); - - // alpha_read * (out_r[rt] - 1) + alpha_write * (out_w[rt] - 1) + alpha_lk * (out_lk_q - chip_record_alpha) - // + 0 // 0 come from zero check - let claim_sum: Ext = builder.constant(C::EF::ZERO); - let rw_logup_len: Usize = builder.eval(num_rw_records.clone() + logup_q_evals.len()); - let alpha_rw_slice = alpha_pow.slice(builder, 0, num_rw_records.clone()); - iter_zip!(builder, alpha_rw_slice, record_evals).for_each(|ptr_vec, builder| { - let alpha = builder.iter_ptr_get(&alpha_rw_slice, ptr_vec[0]); - let eval = builder.iter_ptr_get(&record_evals, ptr_vec[1]); - - builder.assign(&claim_sum, claim_sum + alpha * (eval.eval - one)); - }); - let alpha_logup_slice = alpha_pow.slice(builder, num_rw_records.clone(), rw_logup_len); - iter_zip!(builder, alpha_logup_slice, logup_q_evals).for_each(|ptr_vec, builder| { - let alpha = builder.iter_ptr_get(&alpha_logup_slice, ptr_vec[0]); - let eval = builder.iter_ptr_get(&logup_q_evals, ptr_vec[1]); - builder.assign( - &claim_sum, - claim_sum + alpha * (eval.eval - chip_record_alpha), - ); - }); - - let log2_num_instances_var: Var = RVar::from(log2_num_instances.clone()).variable(); - let log2_num_instances_f: Felt = builder.unsafe_cast_var_to_felt(log2_num_instances_var); - let max_non_lc_degree: usize = cs.zkvm_v1_css.max_non_lc_degree; - let main_sel_subclaim_max_degree: Felt = builder.constant(C::F::from_canonical_u32( - SEL_DEGREE.max(max_non_lc_degree + 1) as u32, - )); - builder.cycle_tracker_start("main sumcheck"); - let (input_opening_point, expected_evaluation) = iop_verifier_state_verify( + let opening_evaluations = verify_gkr_circuit( builder, challenger, - &claim_sum, - &opcode_proof.main_sel_sumcheck_proofs, - log2_num_instances_f, - main_sel_subclaim_max_degree, + gkr_circuit, + &opcode_proof.gkr_iop_proof, + challenges, + pi_evals, + &out_evals, + opcode_proof, unipoly_extrapolator, ); - builder.cycle_tracker_end("main sumcheck"); - // sel(rt, t) - let sel = eq_eval_less_or_equal_than( - builder, - challenger, - opcode_proof, - &input_opening_point, - &rt.point.fs, - ); + opening_evaluations[0].point.fs.clone() +} - // derive r_records, w_records, lk_records from witness's evaluations - let alpha_idx: Var = builder.uninit(); - builder.assign(&alpha_idx, Usize::from(0)); - let empty_arr: Array> = builder.dyn_array(0); +pub fn verify_gkr_circuit( + builder: &mut Builder, + challenger: &mut DuplexChallengerVariable, + gkr_circuit: GKRCircuit, + gkr_proof: &GKRProofVariable, + challenges: &Array>, + pub_io_evals: &Array>, + claims: &Array>, + opcode_proof: &ZKVMChipProofInputVariable, + unipoly_extrapolator: &mut UniPolyExtrapolator, +) -> Vec> { + for (i, layer) in gkr_circuit.layers.iter().enumerate() { + let layer_proof = builder.get(&gkr_proof.layer_proofs, i); + let layer_challenges: Array> = + generate_layer_challenges(builder, challenger, &challenges, layer.n_challenges); + let eval_and_dedup_points: Array> = extract_claim_and_point( + builder, + layer, + &claims, + &layer_challenges, + &layer_proof.has_rotation, + ); - let rw_expressions_sum: Ext = builder.constant(C::EF::ZERO); - cs.zkvm_v1_css - .r_expressions - .iter() - .chain(cs.zkvm_v1_css.w_expressions.iter()) - .for_each(|expr| { - let e = eval_ceno_expr_with_instance( + // ZeroCheckLayer verification (might include other layer types in the future) + let LayerProofVariable { + main: + SumcheckLayerProofVariable { + proof, + evals: main_evals, + evals_len_div_3: _main_evals_len_div_3, + }, + rotation: rotation_proof, + has_rotation, + } = layer_proof; + + builder.if_eq(has_rotation, Usize::from(1)).then(|builder| { + let first = builder.get(&eval_and_dedup_points, 0); + builder.assert_usize_eq(first.has_point, Usize::from(1)); // Rotation proof should have at least one point + let rt = first.point.fs.clone(); + + let RotationClaim { + left_evals, + right_evals, + target_evals, + left_point, + right_point, + origin_point, + } = verify_rotation( builder, - &empty_arr, - &opcode_proof.wits_in_evals, - &empty_arr, - pi_evals, - challenges, - expr, + gkr_proof.num_var_with_rotation.clone(), + &rotation_proof, + layer.rotation_cyclic_subgroup_size, + layer.rotation_cyclic_group_log2, + rt, + challenger, + unipoly_extrapolator, + ); + + let last_idx: Usize = builder.eval(eval_and_dedup_points.len() - Usize::from(1)); + builder.set( + &eval_and_dedup_points, + last_idx.clone(), + ClaimAndPoint { + evals: target_evals, + has_point: Usize::from(1), + point: PointVariable { fs: origin_point }, + }, + ); + + builder.assign(&last_idx, last_idx.clone() - Usize::from(1)); + builder.set( + &eval_and_dedup_points, + last_idx.clone(), + ClaimAndPoint { + evals: right_evals, + has_point: Usize::from(1), + point: PointVariable { fs: right_point }, + }, + ); + + builder.assign(&last_idx, last_idx.clone() - Usize::from(1)); + builder.set( + &eval_and_dedup_points, + last_idx.clone(), + ClaimAndPoint { + evals: left_evals, + has_point: Usize::from(1), + point: PointVariable { fs: left_point }, + }, ); - let alpha = builder.get(&alpha_pow, alpha_idx); - builder.assign(&alpha_idx, alpha_idx + Usize::from(1)); - builder.assign(&rw_expressions_sum, rw_expressions_sum + alpha * (e - one)) }); - builder.assign(&rw_expressions_sum, rw_expressions_sum * sel); - let lk_expressions_sum: Ext = builder.constant(C::EF::ZERO); - cs.zkvm_v1_css.lk_expressions.iter().for_each(|expr| { - let e = eval_ceno_expr_with_instance( + let rotation_exprs_len = layer.rotation_exprs.1.len(); + transcript_observe_label(builder, challenger, b"combine subset evals"); + let alpha_pows = gen_alpha_pows( + builder, + challenger, + Usize::from(layer.exprs.len() + rotation_exprs_len * ROTATION_OPENING_COUNT), + ); + + let sigma: Ext = builder.constant(C::EF::ZERO); + let alpha_idx: Usize = Usize::Var(Var::uninit(builder)); + builder.assign(&alpha_idx, C::N::from_canonical_usize(0)); + + builder + .range(0, eval_and_dedup_points.len()) + .for_each(|idx_vec, builder| { + let ClaimAndPoint { + evals, + has_point: _, + point: _, + } = builder.get(&eval_and_dedup_points, idx_vec[0]); + let end_idx: Usize = builder.eval(alpha_idx.clone() + evals.len()); + let alpha_slice: Array::F, ::EF>> = + alpha_pows.slice(builder, alpha_idx.clone(), end_idx.clone()); + + let sub_sum = ext_dot_product(builder, &evals, &alpha_slice); + builder.assign(&sigma, sigma.clone() + sub_sum); + builder.assign(&alpha_idx, end_idx); + }); + let max_degree = builder.constant(C::F::from_canonical_usize(layer.max_expr_degree + 1)); + let max_num_variables = + builder.unsafe_cast_var_to_felt(gkr_proof.num_var_with_rotation.get_var()); + + let (in_point, expected_evaluation) = iop_verifier_state_verify( + builder, + challenger, + &sigma, + &proof, + max_num_variables, + max_degree, + unipoly_extrapolator, + ); + + layer + .out_sel_and_eval_exprs + .iter() + .enumerate() + .for_each(|(idx, (sel_type, _))| { + let out_point = builder.get(&eval_and_dedup_points, idx).point.fs; + evaluate_selector( + builder, + sel_type, + &main_evals, + &out_point, + &in_point, + opcode_proof, + layer.n_witin, + ); + }); + + let main_sumcheck_challenges_len: Usize = + builder.eval(alpha_pows.len() + Usize::from(2)); + let main_sumcheck_challenges: Array> = + builder.dyn_array(main_sumcheck_challenges_len.clone()); + let alpha = builder.get(&challenges, 0); + let beta = builder.get(&challenges, 1); + builder.set(&main_sumcheck_challenges, 0, alpha); + builder.set(&main_sumcheck_challenges, 1, beta); + let challenge_slice = + main_sumcheck_challenges.slice(builder, 2, main_sumcheck_challenges_len); + builder + .range(0, alpha_pows.len()) + .for_each(|idx_vec, builder| { + let alpha = builder.get(&alpha_pows, idx_vec[0]); + builder.set(&challenge_slice, idx_vec[0], alpha); + }); + + let empty_arr: Array> = builder.dyn_array(0); + let got_claim = eval_ceno_expr_with_instance( builder, &empty_arr, - &opcode_proof.wits_in_evals, + &main_evals, &empty_arr, - pi_evals, - challenges, - expr, + &pub_io_evals, + &main_sumcheck_challenges, + layer.main_sumcheck_expression.as_ref().unwrap(), ); - let alpha = builder.get(&alpha_pow, alpha_idx); - builder.assign(&alpha_idx, alpha_idx + Usize::from(1)); - builder.assign( - &lk_expressions_sum, - lk_expressions_sum + alpha * (e - chip_record_alpha), - ) - }); - builder.assign(&lk_expressions_sum, lk_expressions_sum * sel); - let zero_expressions_sum: Ext = builder.constant(C::EF::ZERO); - cs.zkvm_v1_css - .assert_zero_sumcheck_expressions + builder.assert_ext_eq(got_claim, expected_evaluation); + + // Update claim + layer + .in_eval_expr + .iter() + .enumerate() + .for_each(|(idx, pos)| { + let val = builder.get(&main_evals, idx); + builder.set( + &claims, + *pos, + PointAndEvalVariable { + point: PointVariable { + fs: in_point.clone(), + }, + eval: val, + }, + ); + }); + } + + // GKR Claim + let input_layer = gkr_circuit.layers.last().unwrap(); + input_layer + .in_eval_expr .iter() - .for_each(|expr| { - // evaluate zero expression by all wits_in_evals because they share the unique input_opening_point opening - let e = eval_ceno_expr_with_instance( - builder, - &empty_arr, - &opcode_proof.wits_in_evals, - &empty_arr, - pi_evals, - challenges, - expr, + .enumerate() + .map(|(poly, eval)| { + let PointAndEvalVariable { point, eval } = builder.get(&claims, *eval); + + GKRClaimEvaluation { + value: eval, + point, + poly: Usize::from(poly), + } + }) + .collect_vec() +} + +pub fn verify_rotation( + builder: &mut Builder, + max_num_variables: Usize, + rotation_proof: &SumcheckLayerProofVariable, + rotation_cyclic_subgroup_size: usize, + rotation_cyclic_group_log2: usize, + rt: Array>, + challenger: &mut DuplexChallengerVariable, + unipoly_extrapolator: &mut UniPolyExtrapolator, +) -> RotationClaim { + let SumcheckLayerProofVariable { + proof, + evals, + evals_len_div_3: rotation_expr_len, + } = rotation_proof; + + let rotation_expr_len = Usize::Var(rotation_expr_len.clone()); + transcript_observe_label(builder, challenger, b"combine subset evals"); + let rotation_alpha_pows = gen_alpha_pows(builder, challenger, rotation_expr_len.clone()); + let sigma: Ext = builder.constant(C::EF::ZERO); + + let max_num_variables = builder.unsafe_cast_var_to_felt(max_num_variables.get_var()); + let max_degree: Felt = builder.constant(C::F::TWO); + + let (origin_point, expected_evaluation) = iop_verifier_state_verify( + builder, + challenger, + &sigma, + proof, + max_num_variables, + max_degree, + unipoly_extrapolator, + ); + + // compute the selector evaluation + let selector_eval = rotation_selector_eval( + builder, + &rt, + &origin_point, + rotation_cyclic_subgroup_size, + rotation_cyclic_group_log2, + ); + + // check the final evaluations. + let left_evals: Array> = builder.dyn_array(rotation_expr_len.clone()); + let right_evals: Array> = builder.dyn_array(rotation_expr_len.clone()); + let target_evals: Array> = builder.dyn_array(rotation_expr_len); + + let got_claim: Ext = builder.constant(C::EF::ZERO); + let one: Ext = builder.constant(C::EF::ONE); + let last_origin = if rotation_cyclic_group_log2 > 0 { + builder.get(&origin_point, rotation_cyclic_group_log2 - 1) + } else { + one.clone() + }; + + builder + .range(0, rotation_alpha_pows.len()) + .for_each(|idx_vec, builder| { + let alpha = builder.get(&rotation_alpha_pows, idx_vec[0]); + + let rvar3 = RVar::from(3); + let left_idx: Var = builder.eval(idx_vec[0] * rvar3); + let right_idx: Var = builder.eval(idx_vec[0] * rvar3 + RVar::from(1)); + let target_idx: Var = builder.eval(idx_vec[0] * rvar3 + RVar::from(2)); + + let left = builder.get(&evals, left_idx); + let right = builder.get(&evals, right_idx); + let target = builder.get(&evals, target_idx); + + builder.set(&left_evals, idx_vec[0], left); + builder.set(&right_evals, idx_vec[0], right); + builder.set(&target_evals, idx_vec[0], target); + + builder.assign( + &got_claim, + got_claim + alpha * ((one - last_origin) * left + last_origin * right - target), ); - let alpha = builder.get(&alpha_pow, alpha_idx); - builder.assign(&alpha_idx, alpha_idx + Usize::from(1)); - builder.assign(&zero_expressions_sum, zero_expressions_sum + alpha * e); }); - builder.assign(&zero_expressions_sum, zero_expressions_sum * sel); + builder.assign(&got_claim, got_claim * selector_eval); + builder.assert_ext_eq(got_claim, expected_evaluation); + + let (left_point, right_point) = + get_rotation_points(builder, rotation_cyclic_group_log2, &origin_point); + + RotationClaim { + left_evals, + right_evals, + target_evals, + left_point, + right_point, + origin_point, + } +} - let computed_eval: Ext = - builder.eval(rw_expressions_sum + lk_expressions_sum + zero_expressions_sum); - builder.assert_ext_eq(computed_eval, expected_evaluation); +/// sel(rx) +/// = (\sum_{b = 0}^{cyclic_subgroup_size - 1} eq(out_point[..cyclic_group_log2_size], b) * eq(in_point[..cyclic_group_log2_size], b)) +/// * \prod_{k = cyclic_group_log2_size}^{n - 1} eq(out_point[k], in_point[k]) +pub fn rotation_selector_eval( + builder: &mut Builder, + out_point: &Array>, + in_point: &Array>, + rotation_cyclic_subgroup_size: usize, + cyclic_group_log2_size: usize, +) -> Ext { + let bh = BooleanHypercube::new(5); + let eval: Ext = builder.constant(C::EF::ZERO); + let rotation_index = bh + .into_iter() + .take(rotation_cyclic_subgroup_size) + .collect_vec(); + + let out_subgroup = out_point.slice(builder, 0, cyclic_group_log2_size); + let in_subgroup = in_point.slice(builder, 0, cyclic_group_log2_size); + let out_subgroup_eq = build_eq_x_r_vec_sequential(builder, &out_subgroup); + let in_subgroup_eq = build_eq_x_r_vec_sequential(builder, &in_subgroup); + + for b in rotation_index { + let out_v = builder.get(&out_subgroup_eq, b as usize); + let in_v = builder.get(&in_subgroup_eq, b as usize); + builder.assign(&eval, eval + in_v * out_v); + } - // verify zero expression (degree = 1) statement, thus no sumcheck - cs.zkvm_v1_css - .assert_zero_expressions - .iter() - .for_each(|expr| { - let e = eval_ceno_expr_with_instance( + let out_subgroup = out_point.slice(builder, cyclic_group_log2_size, out_point.len()); + let in_subgroup = in_point.slice(builder, cyclic_group_log2_size, in_point.len()); + + let one: Ext = builder.constant(C::EF::ONE); + let zero: Ext = builder.constant(C::EF::ZERO); + + let eq_eval = eq_eval(builder, &out_subgroup, &in_subgroup, one, zero); + builder.assign(&eval, eval * eq_eval); + + eval +} + +pub fn evaluate_selector( + builder: &mut Builder, + sel_type: &SelectorType, + evals: &Array>, + out_point: &Array>, + in_point: &Array>, + opcode_proof: &ZKVMChipProofInputVariable, + offset_eq_id: usize, +) { + let (expr, eval) = match sel_type { + SelectorType::None => return, + SelectorType::Whole(expr) => { + let one = builder.constant(C::EF::ONE); + let zero = builder.constant(C::EF::ZERO); + (expr, eq_eval(builder, out_point, in_point, one, zero)) + } + SelectorType::Prefix(_, expr) => ( + expr, + eq_eval_less_or_equal_than(builder, opcode_proof, out_point, in_point), + ), + SelectorType::OrderedSparse32 { + indices, + expression, + } => { + let out_point_slice = out_point.slice(builder, 0, 5); + let in_point_slice = in_point.slice(builder, 0, 5); + let out_subgroup_eq = build_eq_x_r_vec_sequential(builder, &out_point_slice); + let in_subgroup_eq = build_eq_x_r_vec_sequential(builder, &in_point_slice); + + let eval: Ext = builder.constant(C::EF::ZERO); + for idx in indices { + let out_val = builder.get(&out_subgroup_eq, *idx); + let in_val = builder.get(&in_subgroup_eq, *idx); + builder.assign(&eval, eval + out_val * in_val); + } + + let out_point_slice = out_point.slice(builder, 5, out_point.len()); + let in_point_slice = in_point.slice(builder, 5, in_point.len()); + + let sel = eq_eval_less_or_equal_than( builder, - &empty_arr, - &opcode_proof.wits_in_evals, - &empty_arr, - pi_evals, - challenges, - expr, + opcode_proof, + &out_point_slice, + &in_point_slice, + ); + builder.assign(&eval, eval * sel); + + (expression, eval) + } + }; + + let Expression::StructuralWitIn(wit_id, _, _, _) = expr else { + panic!("Wrong selector expression format"); + }; + let wit_id = *wit_id as usize + offset_eq_id; + builder.set(evals, wit_id, eval); +} + +pub fn get_rotation_points( + builder: &mut Builder, + _num_vars: usize, + point: &Array>, +) -> (Array>, Array>) { + let left: Array> = builder.dyn_array(point.len()); + let right: Array> = builder.dyn_array(point.len()); + builder.range(0, 4).for_each(|idx_vec, builder| { + let e = builder.get(point, idx_vec[0]); + let dest_idx: Var = builder.eval(idx_vec[0] + RVar::from(1)); + builder.set(&left, dest_idx, e); + builder.set(&right, dest_idx, e); + }); + + let one: Ext = builder.constant(C::EF::ONE); + builder.set(&right, 0, one); + let r1 = builder.get(&right, 2); + builder.set(&right, 2, one - r1); + + builder.range(5, point.len()).for_each(|idx_vec, builder| { + let e = builder.get(point, idx_vec[0]); + builder.set(&left, idx_vec[0], e); + builder.set(&right, idx_vec[0], e); + }); + + (left, right) +} + +pub fn evaluate_gkr_expression( + builder: &mut Builder, + expr: &EvalExpression, + claims: &Array>, + challenges: &Array>, +) -> PointAndEvalVariable { + match expr { + EvalExpression::Zero => { + let point = builder.get(claims, 0).point.clone(); + let eval: Ext = builder.constant(C::EF::ZERO); + PointAndEvalVariable { point, eval } + } + EvalExpression::Single(i) => builder.get(claims, *i).clone(), + EvalExpression::Linear(i, c0, c1) => { + let point = builder.get(claims, *i); + + let eval = point.eval.clone(); + let point = point.point.clone(); + + let empty_arr: Array> = builder.dyn_array(0); + let c0_eval = eval_ceno_expr_with_instance( + builder, &empty_arr, &empty_arr, &empty_arr, &empty_arr, challenges, c0, + ); + let c1_eval = eval_ceno_expr_with_instance( + builder, &empty_arr, &empty_arr, &empty_arr, &empty_arr, challenges, c1, ); - builder.assert_ext_eq(e, zero); + + builder.assign(&eval, eval * c0_eval + c1_eval); + + PointAndEvalVariable { point, eval } + } + EvalExpression::Partition(parts, indices) => { + assert!(izip!(indices.iter(), indices.iter().skip(1)).all(|(a, b)| a.0 < b.0)); + let empty_arr: Array> = builder.dyn_array(0); + let vars = indices + .iter() + .map(|(_, c)| { + eval_ceno_expr_with_instance( + builder, &empty_arr, &empty_arr, &empty_arr, &empty_arr, challenges, c, + ) + }) + .collect_vec(); + let vars_arr: Array> = builder.dyn_array(vars.len()); + for (i, e) in vars.iter().enumerate() { + builder.set(&vars_arr, i, *e); + } + let parts = parts + .iter() + .map(|part| evaluate_gkr_expression(builder, part, claims, challenges)) + .collect_vec(); + + assert_eq!(parts.len(), 1 << indices.len()); + + // _debug + // assert!(parts.iter().all(|part| part.point == parts[0].point)); + + let mut new_point: Vec> = vec![]; + builder + .range(0, parts[0].point.fs.len()) + .for_each(|idx_vec, builder| { + let e = builder.get(&parts[0].point.fs, idx_vec[0]); + new_point.push(e); + }); + for (index_in_point, c) in indices { + let eval = eval_ceno_expr_with_instance( + builder, &empty_arr, &empty_arr, &empty_arr, &empty_arr, challenges, c, + ); + new_point.insert(*index_in_point, eval); + } + + let new_point_arr: Array> = builder.dyn_array(new_point.len()); + for (i, e) in new_point.iter().enumerate() { + builder.set(&new_point_arr, i, *e); + } + let eq = build_eq_x_r_vec_sequential(builder, &vars_arr); + + let parts_arr: Array> = builder.dyn_array(parts.len()); + for (i, pt) in parts.iter().enumerate() { + builder.set(&parts_arr, i, pt.clone()); + } + + let acc: Ext = builder.constant(C::EF::ZERO); + iter_zip!(builder, parts_arr, eq).for_each(|ptr_vec, builder| { + let prt = builder.iter_ptr_get(&parts_arr, ptr_vec[0]); + let eq_v = builder.iter_ptr_get(&eq, ptr_vec[1]); + builder.assign(&acc, acc + prt.eval * eq_v); + }); + + PointAndEvalVariable { + point: PointVariable { fs: new_point_arr }, + eval: acc, + } + } + } +} + +pub fn extract_claim_and_point( + builder: &mut Builder, + layer: &Layer, + claims: &Array>, + challenges: &Array>, + has_rotation: &Usize, +) -> Array> { + let r_len: Usize = Usize::Var(Var::uninit(builder)); + builder.assign( + &r_len, + has_rotation.clone() * Usize::from(3) + Usize::from(layer.out_sel_and_eval_exprs.len()), + ); + + let r = builder.dyn_array(r_len); + + layer + .out_sel_and_eval_exprs + .iter() + .enumerate() + .for_each(|(i, (_, out_evals))| { + let evals = out_evals + .iter() + .map(|out_eval| { + let r = evaluate_gkr_expression(builder, out_eval, claims, challenges); + r.eval + }) + .collect_vec(); + let evals_arr: Array> = builder.dyn_array(evals.len()); + for (j, e) in evals.iter().enumerate() { + builder.set(&evals_arr, j, *e); + } + let point = out_evals.first().map(|out_eval| { + let r = evaluate_gkr_expression(builder, out_eval, claims, challenges); + r.point + }); + + if point.is_some() { + builder.set( + &r, + i, + ClaimAndPoint { + evals: evals_arr, + has_point: Usize::from(1), + point: point.unwrap(), + }, + ); + } else { + let pt = PointVariable { + fs: builder.dyn_array(0), + }; + builder.set( + &r, + i, + ClaimAndPoint { + evals: evals_arr, + has_point: Usize::from(0), + point: pt, + }, + ); + } }); - input_opening_point + r +} + +pub fn generate_layer_challenges( + builder: &mut Builder, + challenger: &mut DuplexChallengerVariable, + challenges: &Array>, + n_challenges: usize, +) -> Array> { + let r = builder.dyn_array(n_challenges + 2); + + let alpha = builder.get(challenges, 0); + let beta = builder.get(challenges, 1); + + builder.set(&r, 0, alpha); + builder.set(&r, 1, beta); + + transcript_observe_label(builder, challenger, b"layer challenge"); + let c = gen_alpha_pows(builder, challenger, Usize::from(n_challenges)); + + for i in 0..n_challenges { + let idx = i + 2; + let e = builder.get(&c, i); + builder.set(&r, idx, e); + } + + r } pub fn verify_table_proof( builder: &mut Builder, challenger: &mut DuplexChallengerVariable, table_proof: &ZKVMChipProofInputVariable, - // raw_pi: &Array>>, - // raw_pi_num_variables: &Array>, + raw_pi: &Array>>, + raw_pi_num_variables: &Array>, pi_evals: &Array>, challenges: &Array>, vk: &VerifyingKey, unipoly_extrapolator: &mut UniPolyExtrapolator, - // poly_evaluator: &mut PolyEvaluator, + poly_evaluator: &mut PolyEvaluator, ) -> Array> { let cs = vk.get_cs(); let tower_proof: &super::binding::TowerProofInputVariable = &table_proof.tower_proof; @@ -669,6 +1226,7 @@ pub fn verify_table_proof( }); let expected_rounds = concat(builder, &r_expected_rounds, &lk_expected_rounds); let max_expected_rounds = max_usize_arr(builder, &expected_rounds); + let num_fanin: Usize = Usize::from(NUM_FANIN); let max_num_variables: Usize = Usize::from(max_expected_rounds); let prod_out_evals: Array>> = concat( @@ -799,7 +1357,7 @@ pub fn verify_table_proof( }); // verify records (degree = 1) statement, thus no sumcheck - interleave( + let expected_evals_vec: Vec> = interleave( &cs.zkvm_v1_css.r_table_expressions, // r &cs.zkvm_v1_css.w_table_expressions, // w ) @@ -811,8 +1369,8 @@ pub fn verify_table_proof( .flat_map(|lk| vec![&lk.multiplicity, &lk.values]), // p, q ) .enumerate() - .for_each(|(idx, expr)| { - let e = eval_ceno_expr_with_instance( + .map(|(_, expr)| { + eval_ceno_expr_with_instance( builder, &table_proof.fixed_in_evals, &table_proof.wits_in_evals, @@ -820,13 +1378,24 @@ pub fn verify_table_proof( pi_evals, challenges, expr, - ); + ) + }) + .collect_vec(); - let expected_evals = builder.get(&in_evals, idx); - builder.assert_ext_eq(e, expected_evals); + let expected_evals: Array> = builder.dyn_array(expected_evals_vec.len()); + expected_evals_vec + .into_iter() + .enumerate() + .for_each(|(idx, e)| { + builder.set(&expected_evals, idx, e); + }); + + iter_zip!(builder, in_evals, expected_evals).for_each(|ptr_vec, builder| { + let eval = builder.iter_ptr_get(&in_evals, ptr_vec[0]); + let expected = builder.iter_ptr_get(&expected_evals, ptr_vec[1]); + builder.assert_ext_eq(eval, expected); }); - /* TODO: enable this // assume public io is tiny vector, so we evaluate it directly without PCS for &Instance(idx) in cs.instance_name_map().keys() { let poly = builder.get(raw_pi, idx); @@ -834,9 +1403,9 @@ pub fn verify_table_proof( let eval_point = rt_tower.fs.slice(builder, 0, poly_num_vars); let expected_eval = poly_evaluator.evaluate_base_poly_at_point(builder, &poly, &eval_point); let eval = builder.get(&pi_evals, idx); + builder.assert_ext_eq(eval, expected_eval); } - */ rt_tower.fs }