From 469475fbbcfda63843a4d039e9a11cb842b35837 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 30 Jan 2026 16:12:34 +0800 Subject: [PATCH] optimize and skip steprecord copy --- ceno_emul/src/lib.rs | 6 +- ceno_emul/src/test_utils.rs | 3 +- ceno_emul/src/tracer.rs | 234 ++++++++++++------ ceno_emul/src/vm_state.rs | 4 + ceno_emul/tests/test_vm_trace.rs | 3 +- ceno_host/src/lib.rs | 5 +- ceno_host/tests/test_elf.rs | 3 +- ceno_zkvm/src/e2e.rs | 138 +++++++++-- ceno_zkvm/src/instructions.rs | 53 +++- ceno_zkvm/src/instructions/riscv/arith.rs | 2 +- ceno_zkvm/src/instructions/riscv/arith_imm.rs | 2 +- ceno_zkvm/src/instructions/riscv/auipc.rs | 2 +- .../src/instructions/riscv/branch/test.rs | 12 +- ceno_zkvm/src/instructions/riscv/div.rs | 2 +- .../src/instructions/riscv/dummy/test.rs | 2 +- .../instructions/riscv/ecall/fptower_fp.rs | 32 ++- .../riscv/ecall/fptower_fp2_add.rs | 37 ++- .../riscv/ecall/fptower_fp2_mul.rs | 37 ++- .../src/instructions/riscv/ecall/keccak.rs | 31 ++- .../instructions/riscv/ecall/sha_extend.rs | 23 +- .../src/instructions/riscv/ecall/uint256.rs | 48 ++-- .../riscv/ecall/weierstrass_add.rs | 25 +- .../riscv/ecall/weierstrass_decompress.rs | 27 +- .../riscv/ecall/weierstrass_double.rs | 25 +- ceno_zkvm/src/instructions/riscv/jump/test.rs | 4 +- .../src/instructions/riscv/logic/test.rs | 6 +- .../riscv/logic_imm/logic_imm_circuit.rs | 2 +- .../src/instructions/riscv/logic_imm/test.rs | 2 +- ceno_zkvm/src/instructions/riscv/lui.rs | 2 +- .../src/instructions/riscv/memory/test.rs | 4 +- ceno_zkvm/src/instructions/riscv/mulh.rs | 6 +- ceno_zkvm/src/instructions/riscv/rv32im.rs | 27 +- ceno_zkvm/src/instructions/riscv/shift.rs | 2 +- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 2 +- ceno_zkvm/src/instructions/riscv/slt.rs | 2 +- ceno_zkvm/src/instructions/riscv/slti.rs | 2 +- ceno_zkvm/src/scheme/tests.rs | 18 +- ceno_zkvm/src/structs.rs | 8 +- 38 files changed, 557 insertions(+), 286 deletions(-) diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 1e49ced29..6b16a3587 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -10,9 +10,9 @@ pub use platform::{CENO_PLATFORM, Platform}; mod tracer; pub use tracer::{ - Change, FullTracer, LatestAccesses, MemOp, NextAccessPair, NextCycleAccess, PreflightTracer, - PreflightTracerConfig, ReadOp, ShardPlanBuilder, StepCellExtractor, StepRecord, Tracer, - WriteOp, + Change, FullTracer, FullTracerConfig, LatestAccesses, MemOp, NextAccessPair, NextCycleAccess, + PreflightTracer, PreflightTracerConfig, ReadOp, ShardPlanBuilder, StepCellExtractor, StepIndex, + StepRecord, Tracer, WriteOp, }; mod vm_state; diff --git a/ceno_emul/src/test_utils.rs b/ceno_emul/src/test_utils.rs index 185ec7aad..39577c13c 100644 --- a/ceno_emul/src/test_utils.rs +++ b/ceno_emul/src/test_utils.rs @@ -24,7 +24,8 @@ pub fn keccak_step() -> (StepRecord, Vec) { Default::default(), ); let mut vm = VMState::new(CENO_PLATFORM.clone(), program.into()); - let steps = vm.iter_until_halt().collect::>>().unwrap(); + vm.iter_until_halt().collect::>>().unwrap(); + let steps = vm.tracer().recorded_steps(); (steps[2].clone(), instructions) } diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 64b7aa746..0aedf02dd 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -6,9 +6,10 @@ use crate::{ syscalls::{SyscallEffects, SyscallWitness}, }; use ceno_rt::WORD_SIZE; +use rayon::prelude::*; use rustc_hash::FxHashMap; use smallvec::SmallVec; -use std::{collections::BTreeMap, fmt, mem, sync::Arc}; +use std::{collections::BTreeMap, fmt, sync::Arc}; /// An instruction and its context in an execution trace. That is concrete values of registers and memory. /// @@ -39,6 +40,8 @@ pub struct StepRecord { syscall: Option, } +pub type StepIndex = usize; + pub trait StepCellExtractor { fn cells_for_kind(&self, kind: InsnKind, rs1_value: Option) -> u64; @@ -257,6 +260,8 @@ pub struct ShardPlanBuilder { current_shard_start_cycle: Cycle, cur_cells: u64, cur_cycle_in_shard: Cycle, + cur_step_count: usize, + max_step_shard: usize, shard_id: usize, finalized: bool, } @@ -272,6 +277,8 @@ impl ShardPlanBuilder { current_shard_start_cycle: initial_cycle, cur_cells: 0, cur_cycle_in_shard: 0, + cur_step_count: 0, + max_step_shard: 0, shard_id: 0, finalized: false, } @@ -293,6 +300,10 @@ impl ShardPlanBuilder { .expect("shard boundaries must contain at least one entry") } + pub fn max_step_shard(&self) -> usize { + self.max_step_shard + } + pub fn into_cycle_boundaries(self) -> Vec { assert!(self.finalized, "shard plan not finalized yet"); self.shard_cycle_boundaries @@ -314,6 +325,7 @@ impl ShardPlanBuilder { self.cur_cycle_in_shard = self .cur_cycle_in_shard .saturating_add(FullTracer::SUBCYCLES_PER_INSN); + self.cur_step_count = self.cur_step_count.saturating_add(1); let cycle_limit_hit = self.cur_cycle_in_shard >= self.max_cycle_per_shard; let should_split = self.cur_cells >= target || cycle_limit_hit; @@ -328,6 +340,8 @@ impl ShardPlanBuilder { self.current_shard_start_cycle = next_shard_cycle; self.cur_cells = 0; self.cur_cycle_in_shard = 0; + self.max_step_shard = self.max_step_shard.max(self.cur_step_count); + self.cur_step_count = 0; } } @@ -336,6 +350,8 @@ impl ShardPlanBuilder { !self.finalized, "shard plan cannot be finalized multiple times" ); + self.max_step_shard = self.max_step_shard.max(self.cur_step_count); + self.cur_step_count = 0; self.push_boundary(max_cycle); self.finalized = true; } @@ -654,9 +670,19 @@ impl StepRecord { } } +#[derive(Clone, Copy, Debug, Default)] +pub struct FullTracerConfig { + /// Maximum number of completed steps per shard. Internally, `FullTracer` + /// reserves one extra slot to hold the pending (in-progress) record. + pub max_step_shard: usize, +} + #[derive(Debug)] pub struct FullTracer { - record: StepRecord, + records: Vec, + len: usize, + pending_index: usize, + pending_cycle: Cycle, // record each section max access address // (start_addr -> (start_addr, end_addr, min_access_addr, max_access_addr)) @@ -676,72 +702,133 @@ impl FullTracer { pub const SUBCYCLE_MEM: Cycle = ::SUBCYCLE_MEM; pub const SUBCYCLES_PER_INSN: Cycle = ::SUBCYCLES_PER_INSN; - pub fn new(platform: &Platform) -> FullTracer { + pub fn new(platform: &Platform, config: FullTracerConfig) -> FullTracer { let mmio_max_access = init_mmio_min_max_access(platform); - - FullTracer { + // Always reserve one extra slot for the pending/in-progress record. Without + // this, a shard that executes exactly `max_step_shard` steps would panic + // when `advance()` tries to reset the next (non-existent) slot. + let capacity = config.max_step_shard.saturating_add(1); + let mut records = if capacity > 0 { + (0..capacity) + .into_par_iter() + .map(|_| StepRecord::default()) + .collect::>() + } else { + Vec::new() + }; + if records.is_empty() { + records.push(StepRecord::default()); + } + let mut tracer = FullTracer { + records, + len: 0, + pending_index: 0, + pending_cycle: Self::SUBCYCLES_PER_INSN, mmio_min_max_access: Some(mmio_max_access), - record: StepRecord { - cycle: Self::SUBCYCLES_PER_INSN, - ..StepRecord::default() - }, platform: platform.clone(), latest_accesses: LatestAccesses::new(platform), max_heap_addr_access: ByteAddr::from(platform.heap.start), max_hint_addr_access: ByteAddr::from(platform.hints.start), + }; + tracer.reset_pending_slot(); + tracer + } + + /// Prepare the slot for the next step; panics if the preallocated capacity + /// (from `FullTracerConfig::max_step_shard`) is exceeded. + #[inline(always)] + fn reset_pending_slot(&mut self) { + if self.pending_index >= self.records.len() { + panic!( + "FullTracer step buffer exhausted: recorded {} steps with capacity {}", + self.pending_index, + self.records.len() + ); } + self.records[self.pending_index] = StepRecord { + cycle: self.pending_cycle, + ..StepRecord::default() + }; + } + + pub fn reset_step_buffer(&mut self) { + self.len = 0; + self.pending_index = 0; + self.reset_pending_slot(); + } + + pub fn recorded_steps(&self) -> &[StepRecord] { + &self.records[..self.len] + } + + #[inline(always)] + pub fn step_record(&self, index: StepIndex) -> &StepRecord { + assert!( + index < self.len, + "step index {index} out of bounds {}", + self.len + ); + &self.records[index] } /// Return the completed step and advance to the next cycle. #[inline(always)] - pub fn advance(&mut self) -> StepRecord { - let next_cycle = self.record.cycle + Self::SUBCYCLES_PER_INSN; - mem::replace( - &mut self.record, - StepRecord { - cycle: next_cycle, - ..StepRecord::default() - }, - ) + pub fn advance(&mut self) -> StepIndex { + let idx = self.pending_index; + let next_cycle = self.records[self.pending_index].cycle + Self::SUBCYCLES_PER_INSN; + self.len = idx + 1; + self.pending_cycle = next_cycle; + self.pending_index += 1; + self.reset_pending_slot(); + idx } #[inline(always)] pub fn store_pc(&mut self, pc: ByteAddr) { - self.record.pc.after = pc; + self.records[self.pending_index].pc.after = pc; } #[inline(always)] pub fn fetch(&mut self, pc: WordAddr, value: Instruction) { - self.record.pc.before = pc.baddr(); - self.record.insn = value; + let record = &mut self.records[self.pending_index]; + record.pc.before = pc.baddr(); + record.insn = value; } #[inline(always)] pub fn track_mmu_maxtouch_before(&mut self) { - self.record.heap_maxtouch_addr.before = self.max_heap_addr_access; - self.record.hint_maxtouch_addr.before = self.max_hint_addr_access; + let heap_access = self.max_heap_addr_access; + let hint_access = self.max_hint_addr_access; + let record = &mut self.records[self.pending_index]; + record.heap_maxtouch_addr.before = heap_access; + record.hint_maxtouch_addr.before = hint_access; } #[inline(always)] pub fn track_mmu_maxtouch_after(&mut self) { - self.record.heap_maxtouch_addr.after = self.max_heap_addr_access; - self.record.hint_maxtouch_addr.after = self.max_hint_addr_access; + let heap_access = self.max_heap_addr_access; + let hint_access = self.max_hint_addr_access; + let record = &mut self.records[self.pending_index]; + record.heap_maxtouch_addr.after = heap_access; + record.hint_maxtouch_addr.after = hint_access; } #[inline(always)] pub fn load_register(&mut self, idx: RegIdx, value: Word) { let addr = Platform::register_vma(idx).into(); - - match (&self.record.rs1, &self.record.rs2) { + match ( + self.records[self.pending_index].rs1.as_ref(), + self.records[self.pending_index].rs2.as_ref(), + ) { (None, None) => { - self.record.rs1 = Some(ReadOp { + self.records[self.pending_index].rs1 = Some(ReadOp { addr, value, previous_cycle: self.track_access(addr, Self::SUBCYCLE_RS1), }); } (Some(_), None) => { - self.record.rs2 = Some(ReadOp { + self.records[self.pending_index].rs2 = Some(ReadOp { addr, value, previous_cycle: self.track_access(addr, Self::SUBCYCLE_RS2), @@ -753,15 +840,16 @@ impl FullTracer { #[inline(always)] pub fn store_register(&mut self, idx: RegIdx, value: Change) { - if self.record.rd.is_some() { + if self.records[self.pending_index].rd.is_some() { unimplemented!("Only one register write is supported"); } let addr = Platform::register_vma(idx).into(); - self.record.rd = Some(WriteOp { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RD); + self.records[self.pending_index].rd = Some(WriteOp { addr, value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_RD), + previous_cycle, }); } @@ -772,45 +860,40 @@ impl FullTracer { #[inline(always)] pub fn store_memory(&mut self, addr: WordAddr, value: Change) { - if self.record.memory_op.is_some() { + if self.records[self.pending_index].memory_op.is_some() { unimplemented!("Only one memory access is supported"); } - // update min/max mmio access + // Update the tracked min/max MMIO bounds so later phases only materialize + // the actually touched address range for heap / hint regions. if let Some((start_addr, (_, end_addr, min_addr, max_addr))) = self .mmio_min_max_access .as_mut() - // find the MMIO region whose start address is less than or equal to the target address .and_then(|mmio_max_access| mmio_max_access.range_mut(..=addr).next_back()) + && addr < *end_addr { - // skip if the target address is not within the range tracked by this MMIO region - // this condition ensures the address is within the MMIO region's end address - if addr < *end_addr { - // expand the max bound if the address exceeds the current max - if addr >= *max_addr { - *max_addr = addr + WordAddr::from(WORD_SIZE as u32); // end is exclusive - } - // shrink the min bound if the address is below the current min - if addr < *min_addr { - *min_addr = addr; // start is inclusive + if addr >= *max_addr { + *max_addr = addr + WordAddr::from(WORD_SIZE as u32); + } + if addr < *min_addr { + *min_addr = addr; + } + if start_addr.baddr().0 == self.platform.heap.start { + let access_end = addr + WordAddr::from(WORD_SIZE as u32); + let access_end_baddr = access_end.baddr(); + if access_end_baddr > self.max_heap_addr_access { + self.max_heap_addr_access = access_end_baddr; } - if start_addr.baddr().0 == self.platform.heap.start { - let access_end = addr + WordAddr::from(WORD_SIZE as u32); - let access_end_baddr = access_end.baddr(); - if access_end_baddr > self.max_heap_addr_access { - self.max_heap_addr_access = access_end_baddr; - } - } else if start_addr.baddr().0 == self.platform.hints.start { - let access_end = addr + WordAddr::from(WORD_SIZE as u32); - let access_end_baddr = access_end.baddr(); - if access_end_baddr > self.max_hint_addr_access { - self.max_hint_addr_access = access_end_baddr; - } + } else if start_addr.baddr().0 == self.platform.hints.start { + let access_end = addr + WordAddr::from(WORD_SIZE as u32); + let access_end_baddr = access_end.baddr(); + if access_end_baddr > self.max_hint_addr_access { + self.max_hint_addr_access = access_end_baddr; } } } - self.record.memory_op = Some(WriteOp { + self.records[self.pending_index].memory_op = Some(WriteOp { addr, value, previous_cycle: self.track_access(addr, Self::SUBCYCLE_MEM), @@ -820,18 +903,17 @@ impl FullTracer { #[inline(always)] pub fn track_syscall(&mut self, effects: SyscallEffects) { let witness = effects.finalize(self); - - assert!(self.record.syscall.is_none(), "Only one syscall per step"); - self.record.syscall = Some(witness); + let record = &mut self.records[self.pending_index]; + assert!(record.syscall.is_none(), "Only one syscall per step"); + record.syscall = Some(witness); } - /// - Return the cycle when an address was last accessed. - /// - Return 0 if this is the first access. - /// - Record the current instruction as the origin of the latest access. - /// - Accesses within the same instruction are distinguished by `subcycle ∈ [0, 3]`. #[inline(always)] pub fn track_access(&mut self, addr: WordAddr, subcycle: Cycle) -> Cycle { - let cur_cycle = self.record.cycle + subcycle; + // Returns the previous access cycle. Accesses within the same instruction + // are distinguished via `subcycle ∈ [0, 3]`; the first touch of an address + // yields `0`. + let cur_cycle = self.records[self.pending_index].cycle + subcycle; self.latest_accesses.track(addr, cur_cycle) } @@ -839,21 +921,19 @@ impl FullTracer { &self.latest_accesses } - /// Return the cycle of the pending instruction (after the last completed step). pub fn cycle(&self) -> Cycle { - self.record.cycle + self.pending_cycle } - /// Return the number of instruction executed til this moment - /// minus 1 since cycle start from Self::SUBCYCLES_PER_INSN + /// Number of executed instructions so far (discounting the init slot that + /// starts at `SUBCYCLES_PER_INSN`). pub fn executed_insts(&self) -> usize { - (self.record.cycle / Self::SUBCYCLES_PER_INSN) + (self.pending_cycle / Self::SUBCYCLES_PER_INSN) .saturating_sub(1) .try_into() .unwrap() } - /// giving a start address, return (min, max) accessed address within section pub fn probe_min_max_address_by_start_addr( &self, start_addr: WordAddr, @@ -1195,11 +1275,11 @@ impl Tracer for PreflightTracer { } impl Tracer for FullTracer { - type Record = StepRecord; - type Config = (); + type Record = StepIndex; + type Config = FullTracerConfig; - fn new(platform: &Platform, _config: Self::Config) -> Self { - FullTracer::new(platform) + fn new(platform: &Platform, config: Self::Config) -> Self { + FullTracer::new(platform, config) } #[inline(always)] @@ -1209,7 +1289,7 @@ impl Tracer for FullTracer { #[inline(always)] fn is_busy_loop(&self, record: &Self::Record) -> bool { - record.is_busy_loop() + self.step_record(*record).is_busy_loop() } #[inline(always)] diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 514d6435b..d65844682 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -104,6 +104,10 @@ impl VMState { &self.tracer } + pub fn tracer_mut(&mut self) -> &mut T { + &mut self.tracer + } + pub fn take_tracer(self) -> T { self.tracer } diff --git a/ceno_emul/tests/test_vm_trace.rs b/ceno_emul/tests/test_vm_trace.rs index 4d29d8414..ed0022d3d 100644 --- a/ceno_emul/tests/test_vm_trace.rs +++ b/ceno_emul/tests/test_vm_trace.rs @@ -60,7 +60,8 @@ fn test_empty_program() -> Result<()> { } fn run(state: &mut VMState) -> Result> { - state.iter_until_halt().collect() + state.iter_until_halt().collect::>>()?; + Ok(state.tracer().recorded_steps().to_vec()) } /// Example in RISC-V bytecode and assembly. diff --git a/ceno_host/src/lib.rs b/ceno_host/src/lib.rs index 0de4ee9e0..719288ece 100644 --- a/ceno_host/src/lib.rs +++ b/ceno_host/src/lib.rs @@ -138,11 +138,12 @@ pub fn run( state.init_memory(addr.into(), value); } - let steps = state + state .iter_until_halt() .collect::>>() .expect("Failed to run the program"); - eprintln!("Emulator ran for {} steps.", steps.len()); + let step_count = state.tracer().recorded_steps().len(); + eprintln!("Emulator ran for {} steps.", step_count); read_all_messages(&state) } diff --git a/ceno_host/tests/test_elf.rs b/ceno_host/tests/test_elf.rs index c5c35993d..ff752267d 100644 --- a/ceno_host/tests/test_elf.rs +++ b/ceno_host/tests/test_elf.rs @@ -806,7 +806,8 @@ fn messages_to_strings(messages: &[Vec]) -> Vec { } fn run(state: &mut VMState) -> Result> { - let steps = state.iter_until_halt().collect::>>()?; + state.iter_until_halt().collect::>>()?; + let steps = state.tracer().recorded_steps().to_vec(); eprintln!("Emulator ran for {} steps.", steps.len()); Ok(steps) } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 1625fb106..215cbf7b6 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -23,9 +23,10 @@ use crate::{ }, }; use ceno_emul::{ - Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, FullTracer, IterAddresses, NextCycleAccess, - Platform, PreflightTracer, PreflightTracerConfig, Program, StepCellExtractor, StepRecord, - Tracer, VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, + Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, FullTracer, FullTracerConfig, IterAddresses, + NextCycleAccess, Platform, PreflightTracer, PreflightTracerConfig, Program, StepCellExtractor, + StepIndex, StepRecord, Tracer, VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, + host_utils::read_all_messages, }; use clap::ValueEnum; use either::Either; @@ -119,6 +120,7 @@ pub struct EmulationResult<'a> { pub shard_ctx_builder: ShardContextBuilder, pub shard_cycle_boundaries: Arc>, pub executed_steps: usize, + pub max_step_shard: usize, pub phantom: PhantomData<&'a ()>, // pub shard_ctxs: Vec>, } @@ -655,23 +657,28 @@ impl ShardContextBuilder { self.shard_cycle_boundaries.len().saturating_sub(1) } - pub fn position_next_shard<'a>( + pub fn position_next_shard<'a, S>( &mut self, - steps_iter: &mut impl Iterator, - mut on_step: impl FnMut(StepRecord), - ) -> Option<(ShardContext<'a>, ShardStepSummary)> { + steps_iter: &mut S, + mut on_step: impl FnMut(StepIndex, &StepRecord), + ) -> Option<(ShardContext<'a>, ShardStepSummary)> + where + S: StepSource, + { if self.cur_shard_id >= self.total_shards() { return None; } + steps_iter.start_new_shard(); let expected_end_cycle = self .shard_cycle_boundaries .get(self.cur_shard_id + 1) .copied() .expect("missing shard boundary for shard"); let mut summary = ShardStepSummary::default(); - for step in steps_iter.by_ref() { - summary.update(&step); - on_step(step); + while let Some(step_idx) = steps_iter.next() { + let record = steps_iter.step_record(step_idx); + summary.update(record); + on_step(step_idx, record); if summary.last_cycle + FullTracer::SUBCYCLES_PER_INSN == expected_end_cycle { break; } @@ -739,11 +746,17 @@ impl ShardContextBuilder { } } +pub trait StepSource: Iterator { + fn start_new_shard(&mut self); + fn shard_steps(&self) -> &[StepRecord]; + fn step_record(&self, idx: StepIndex) -> &StepRecord; +} + /// Lazily replays `StepRecord`s by re-running the VM up to the number of steps /// recorded during the preflight execution. This keeps shard generation memory /// usage bounded without storing the entire trace. struct StepReplay { - vm: VMState, + vm: VMState, remaining_steps: usize, } @@ -753,8 +766,10 @@ impl StepReplay { program: Arc, init_mem_state: &InitMemState, remaining_steps: usize, + max_step_shard: usize, ) -> Self { - let mut vm = VMState::new(platform, program); + let mut vm = + VMState::new_with_tracer_config(platform, program, FullTracerConfig { max_step_shard }); for record in chain!(init_mem_state.hints.iter(), init_mem_state.io.iter()) { vm.init_memory(record.addr.into(), record.value); } @@ -763,10 +778,18 @@ impl StepReplay { remaining_steps, } } + + fn reset_current_shard(&mut self) { + self.vm.tracer_mut().reset_step_buffer(); + } + + fn current_shard_steps(&self) -> &[StepRecord] { + self.vm.tracer().recorded_steps() + } } impl Iterator for StepReplay { - type Item = StepRecord; + type Item = StepIndex; fn next(&mut self) -> Option { if self.remaining_steps == 0 { @@ -786,6 +809,21 @@ impl Iterator for StepReplay { } } +impl StepSource for StepReplay { + fn start_new_shard(&mut self) { + self.reset_current_shard(); + } + + fn shard_steps(&self) -> &[StepRecord] { + self.current_shard_steps() + } + + #[inline(always)] + fn step_record(&self, idx: StepIndex) -> &StepRecord { + self.vm.tracer().step_record(idx) + } +} + pub fn emulate_program<'a>( program: Arc, max_steps: usize, @@ -810,7 +848,7 @@ pub fn emulate_program<'a>( ) .with_step_cell_extractor(step_cell_extractor); let mut vm: VMState = info_span!("[ceno] emulator.new-preflight-tracer") - .in_scope(|| { + .in_scope(move || { VMState::new_with_tracer_config(platform.clone(), program.clone(), tracer_config) }); @@ -1002,6 +1040,7 @@ pub fn emulate_program<'a>( let tracer = vm.take_tracer(); let (plan_builder, next_accesses) = tracer.into_shard_plan(); + let max_step_shard = plan_builder.max_step_shard(); let shard_cycle_boundaries = Arc::new(plan_builder.into_cycle_boundaries()); let shard_ctx_builder = ShardContextBuilder::from_plan( multi_prover, @@ -1023,6 +1062,7 @@ pub fn emulate_program<'a>( shard_ctx_builder, shard_cycle_boundaries: shard_cycle_boundaries.clone(), executed_steps: insts, + max_step_shard, final_mem_state: FinalMemState { reg: reg_final, io: io_final, @@ -1222,6 +1262,7 @@ pub fn generate_witness<'a, E: ExtensionField>( program.clone(), init_mem_state, emul_result.executed_steps, + emul_result.max_step_shard, ); std::iter::from_fn(move || { info_span!( @@ -1229,14 +1270,18 @@ pub fn generate_witness<'a, E: ExtensionField>( shard_id = shard_ctx_builder.cur_shard_id ) .in_scope(|| { + let time = std::time::Instant::now(); instrunction_dispatch_ctx.begin_shard(); - let (mut shard_ctx, shard_summary) = match shard_ctx_builder.position_next_shard( - &mut step_iter, - |step| instrunction_dispatch_ctx.ingest_step(step), - ) { - Some(result) => result, - None => return None, - }; + let (mut shard_ctx, shard_summary) = + match shard_ctx_builder.position_next_shard( + &mut step_iter, + |idx, record| instrunction_dispatch_ctx.ingest_step(idx, record), + ) { + Some(result) => result, + None => return None, + }; + tracing::debug!("position_next_shard finish in {:?}", time.elapsed()); + let shard_steps = step_iter.shard_steps(); let mut zkvm_witness = ZKVMWitnesses::default(); let mut pi = pi_template.clone(); @@ -1292,6 +1337,7 @@ pub fn generate_witness<'a, E: ExtensionField>( &system_config.zkvm_cs, &mut shard_ctx, &mut instrunction_dispatch_ctx, + shard_steps, &mut zkvm_witness, ) .unwrap(); @@ -1303,6 +1349,7 @@ pub fn generate_witness<'a, E: ExtensionField>( &system_config.zkvm_cs, &mut shard_ctx, &instrunction_dispatch_ctx, + shard_steps, &mut zkvm_witness, ) .unwrap(); @@ -2075,7 +2122,7 @@ pub fn verify + serde::Ser #[cfg(test)] mod tests { use crate::e2e::{MultiProver, ShardContextBuilder}; - use ceno_emul::{CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepRecord}; + use ceno_emul::{CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepIndex, StepRecord}; use itertools::Itertools; use std::sync::Arc; @@ -2136,10 +2183,53 @@ mod tests { max_cycle, NextCycleAccess::default(), ); - let mut steps_iter = steps.into_iter(); + struct TestReplay { + steps: Vec, + cursor: usize, + shard_start: usize, + } + + impl TestReplay { + fn new(steps: Vec) -> Self { + Self { + steps, + cursor: 0, + shard_start: 0, + } + } + } + + impl Iterator for TestReplay { + type Item = StepIndex; + + fn next(&mut self) -> Option { + if self.cursor >= self.steps.len() { + return None; + } + let local_idx = self.cursor - self.shard_start; + self.cursor += 1; + Some(local_idx) + } + } + + impl super::StepSource for TestReplay { + fn start_new_shard(&mut self) { + self.shard_start = self.cursor; + } + + fn shard_steps(&self) -> &[StepRecord] { + &self.steps[self.shard_start..self.cursor] + } + + fn step_record(&self, idx: StepIndex) -> &StepRecord { + &self.steps[self.shard_start + idx] + } + } + + let mut steps_iter = TestReplay::new(steps); let shard_ctx = std::iter::from_fn(|| { shard_ctx_builder - .position_next_shard(&mut steps_iter, |_| {}) + .position_next_shard(&mut steps_iter, |_, _| {}) .map(|(ctx, _)| ctx) }) .collect_vec(); diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index df2c24ff9..9dd99ef92 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -2,7 +2,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, structs::ProgramParams, tables::RMMCollections, witness::LkMultiplicity, }; -use ceno_emul::StepRecord; +use ceno_emul::{StepIndex, StepRecord}; use ff_ext::ExtensionField; use gkr_iop::{ chip::Chip, @@ -101,7 +101,8 @@ pub trait Instruction { shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: &[StepRecord], + shard_steps: &[StepRecord], + step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { // TODO: selector is the only structural witness // this is workaround, as call `construct_circuit` will not initialized selector @@ -110,17 +111,21 @@ pub trait Instruction { let num_structural_witin = num_structural_witin.max(1); let nthreads = max_usable_threads(); - let num_instance_per_batch = if steps.len() > 256 { - steps.len().div_ceil(nthreads) + let total_instances = step_indices.len(); + let num_instance_per_batch = if total_instances > 256 { + total_instances.div_ceil(nthreads) } else { - steps.len() + total_instances } .max(1); let lk_multiplicity = LkMultiplicity::default(); - let mut raw_witin = - RowMajorMatrix::::new(steps.len(), num_witin, Self::padding_strategy()); + let mut raw_witin = RowMajorMatrix::::new( + total_instances, + num_witin, + Self::padding_strategy(), + ); let mut raw_structual_witin = RowMajorMatrix::::new( - steps.len(), + total_instances, num_structural_witin, Self::padding_strategy(), ); @@ -131,23 +136,23 @@ pub trait Instruction { raw_witin_iter .zip_eq(raw_structual_witin_iter) - .zip_eq(steps.par_chunks(num_instance_per_batch)) + .zip_eq(step_indices.par_chunks(num_instance_per_batch)) .zip(shard_ctx_vec) .flat_map( - |(((instances, structural_instance), steps), mut shard_ctx)| { + |(((instances, structural_instance), indices), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin) .zip_eq(structural_instance.chunks_mut(num_structural_witin)) - .zip_eq(steps) - .map(|((instance, structural_instance), step)| { + .zip_eq(indices.iter().copied()) + .map(|((instance, structural_instance), step_idx)| { *structural_instance.last_mut().unwrap() = E::BaseField::ONE; Self::assign_instance( config, &mut shard_ctx, instance, &mut lk_multiplicity, - step, + &shard_steps[step_idx], ) }) .collect::>() @@ -162,4 +167,26 @@ pub trait Instruction { lk_multiplicity.into_finalize_result(), )) } + + fn assign_instances_from_steps( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + let indices = full_step_indices(steps); + Self::assign_instances( + config, + shard_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + ) + } +} + +pub fn full_step_indices(steps: &[StepRecord]) -> Vec { + (0..steps.len()).collect() } diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 260245931..a5f6e006f 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -190,7 +190,7 @@ mod test { // values assignment let insn_code = encode_rv32(I::INST_KIND, 2, 3, 4, 0); - let (raw_witin, lkm) = ArithInstruction::::assign_instances( + let (raw_witin, lkm) = ArithInstruction::::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index a1c1d4403..f41832719 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -62,7 +62,7 @@ mod test { .unwrap(); let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm); - let (raw_witin, lkm) = AddiInstruction::::assign_instances( + let (raw_witin, lkm) = AddiInstruction::::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index ce7c64a95..6311fc2aa 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -245,7 +245,7 @@ mod tests { .unwrap(); let insn_code = encode_rv32(InsnKind::AUIPC, 0, 0, 4, imm); - let (raw_witin, lkm) = AuipcInstruction::::assign_instances( + let (raw_witin, lkm) = AuipcInstruction::::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 67f098ff0..5ca85da4e 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -36,7 +36,7 @@ fn impl_opcode_beq(take_branch: bool, a: u32, b: u32) { let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, 8); let pc_offset = if take_branch { 8 } else { PC_STEP_SIZE }; - let (raw_witin, lkm) = BeqInstruction::assign_instances( + let (raw_witin, lkm) = BeqInstruction::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, @@ -78,7 +78,7 @@ fn impl_opcode_bne(take_branch: bool, a: u32, b: u32) { let insn_code = encode_rv32(InsnKind::BNE, 2, 3, 0, 8); let pc_offset = if take_branch { 8 } else { PC_STEP_SIZE }; - let (raw_witin, lkm) = BneInstruction::assign_instances( + let (raw_witin, lkm) = BneInstruction::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, @@ -122,7 +122,7 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { }; let insn_code = encode_rv32(InsnKind::BLTU, 2, 3, 0, -8); - let (raw_witin, lkm) = BltuInstruction::assign_instances( + let (raw_witin, lkm) = BltuInstruction::assign_instances_from_steps( &config, &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, @@ -167,7 +167,7 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { }; let insn_code = encode_rv32(InsnKind::BGEU, 2, 3, 0, -8); - let (raw_witin, lkm) = BgeuInstruction::assign_instances( + let (raw_witin, lkm) = BgeuInstruction::assign_instances_from_steps( &config, &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, @@ -219,7 +219,7 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<() }; let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, -8); - let (raw_witin, lkm) = BltInstruction::assign_instances( + let (raw_witin, lkm) = BltInstruction::assign_instances_from_steps( &config, &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, @@ -271,7 +271,7 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<() }; let insn_code = encode_rv32(InsnKind::BGE, 2, 3, 0, -8); - let (raw_witin, lkm) = BgeInstruction::assign_instances( + let (raw_witin, lkm) = BgeInstruction::assign_instances_from_steps( &config, &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 85718ad24..981995452 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -180,7 +180,7 @@ mod test { .expect("instruction must declare at least one InsnKind"); let insn_code = encode_rv32(insn_kind, 2, 3, 4, 0); // values assignment - let ([raw_witin, _], lkm) = Insn::assign_instances( + let ([raw_witin, _], lkm) = Insn::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index 8cccc8d49..b74c8ca39 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -19,7 +19,7 @@ fn test_large_ecall_dummy_keccak() { let config = KeccakDummy::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); let (step, program) = ceno_emul::test_utils::keccak_step(); - let (raw_witin, lkm) = KeccakDummy::assign_instances( + let (raw_witin, lkm) = KeccakDummy::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs index 7eeba1ab0..57d824b01 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs @@ -1,8 +1,8 @@ use std::marker::PhantomData; use ceno_emul::{ - BN254_FP_ADD, BN254_FP_MUL, ByteAddr, Change, InsnKind, Platform, StepRecord, WORD_SIZE, - WriteOp, + BN254_FP_ADD, BN254_FP_MUL, ByteAddr, Change, InsnKind, Platform, StepIndex, StepRecord, + WORD_SIZE, WriteOp, }; use ff_ext::ExtensionField; use generic_array::typenum::Unsigned; @@ -126,6 +126,7 @@ impl Instruction num_witin: usize, num_structural_witin: usize, steps: &[StepRecord], + step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { assign_fp_op_instances::( config, @@ -133,6 +134,7 @@ impl Instruction num_witin, num_structural_witin, steps, + step_indices, P::SYSCALL_CODE, FieldOperation::Add, ) @@ -194,6 +196,7 @@ impl Instruction num_witin: usize, num_structural_witin: usize, steps: &[StepRecord], + step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { assign_fp_op_instances::( config, @@ -201,6 +204,7 @@ impl Instruction num_witin, num_structural_witin, steps, + step_indices, P::SYSCALL_CODE, FieldOperation::Mul, ) @@ -305,17 +309,19 @@ fn build_fp_op_circuit( )) } +#[allow(clippy::too_many_arguments)] fn assign_fp_op_instances( config: &EcallFpOpConfig, shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, steps: &[StepRecord], + step_indices: &[StepIndex], syscall_code: u32, op: FieldOperation, ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let mut lk_multiplicity = LkMultiplicity::default(); - if steps.is_empty() { + if step_indices.is_empty() { return Ok(( [ RowMajorMatrix::new(0, num_witin, InstancePaddingStrategy::Default), @@ -326,15 +332,15 @@ fn assign_fp_op_instances( } let nthreads = max_usable_threads(); - let num_instance_per_batch = steps.len().div_ceil(nthreads).max(1); + let num_instance_per_batch = step_indices.len().div_ceil(nthreads).max(1); let mut raw_witin = RowMajorMatrix::::new( - steps.len(), + step_indices.len(), num_witin, InstancePaddingStrategy::Default, ); let mut raw_structural_witin = RowMajorMatrix::::new( - steps.len(), + step_indices.len(), num_structural_witin, InstancePaddingStrategy::Default, ); @@ -343,14 +349,15 @@ fn assign_fp_op_instances( let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter - .zip_eq(steps.par_chunks(num_instance_per_batch)) + .zip_eq(step_indices.par_chunks(num_instance_per_batch)) .zip(shard_ctx_vec) - .flat_map(|((instances, steps), mut shard_ctx)| { + .flat_map(|((instances, indices), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin) - .zip_eq(steps) - .map(|(instance, step)| { + .zip_eq(indices.iter().copied()) + .map(|(instance, idx)| { + let step = &steps[idx]; let ops = &step.syscall().expect("syscall step"); config .vm_state @@ -407,9 +414,10 @@ fn assign_fp_op_instances( .collect::>()?; let words =

::WordsFieldElement::USIZE; - let instances: Vec> = steps + let instances: Vec> = step_indices .par_iter() - .map(|step| { + .map(|&idx| { + let step = &steps[idx]; let values: Vec = step .syscall() .unwrap() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs index 6552b6241..6715a0b74 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use ceno_emul::{ - BN254_FP2_ADD, ByteAddr, Change, InsnKind, Platform, StepRecord, WORD_SIZE, WriteOp, + BN254_FP2_ADD, ByteAddr, Change, InsnKind, Platform, StepIndex, StepRecord, WORD_SIZE, WriteOp, }; use ff_ext::ExtensionField; use generic_array::typenum::Unsigned; @@ -116,9 +116,17 @@ impl Instruction shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: &[StepRecord], + shard_steps: &[StepRecord], + step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { - assign_fp2_add_instances::(config, shard_ctx, num_witin, num_structural_witin, steps) + assign_fp2_add_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) } } @@ -225,9 +233,10 @@ fn assign_fp2_add_instances Result<(RMMCollections, Multiplicity), ZKVMError> { let mut lk_multiplicity = LkMultiplicity::default(); - if steps.is_empty() { + if step_indices.is_empty() { return Ok(( [ RowMajorMatrix::new(0, num_witin, InstancePaddingStrategy::Default), @@ -238,15 +247,15 @@ fn assign_fp2_add_instances::new( - steps.len(), + step_indices.len(), num_witin, InstancePaddingStrategy::Default, ); let mut raw_structural_witin = RowMajorMatrix::::new( - steps.len(), + step_indices.len(), num_structural_witin, InstancePaddingStrategy::Default, ); @@ -255,14 +264,15 @@ fn assign_fp2_add_instances::WordsFieldElement::USIZE; let words_fp2 =

::WordsCurvePoint::USIZE; - let instances: Vec> = steps + let instances: Vec> = step_indices .par_iter() - .map(|step| { + .map(|&idx| { + let step = &steps[idx]; let values: Vec = step .syscall() .unwrap() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs index 709e9734d..7537d31ca 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use ceno_emul::{ - BN254_FP2_MUL, ByteAddr, Change, InsnKind, Platform, StepRecord, WORD_SIZE, WriteOp, + BN254_FP2_MUL, ByteAddr, Change, InsnKind, Platform, StepIndex, StepRecord, WORD_SIZE, WriteOp, }; use ff_ext::ExtensionField; use generic_array::typenum::Unsigned; @@ -115,9 +115,17 @@ impl Instruction shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: &[StepRecord], + shard_steps: &[StepRecord], + step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { - assign_fp2_mul_instances::(config, shard_ctx, num_witin, num_structural_witin, steps) + assign_fp2_mul_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) } } @@ -223,9 +231,10 @@ fn assign_fp2_mul_instances Result<(RMMCollections, Multiplicity), ZKVMError> { let mut lk_multiplicity = LkMultiplicity::default(); - if steps.is_empty() { + if step_indices.is_empty() { return Ok(( [ RowMajorMatrix::new(0, num_witin, InstancePaddingStrategy::Default), @@ -236,15 +245,15 @@ fn assign_fp2_mul_instances::new( - steps.len(), + step_indices.len(), num_witin, InstancePaddingStrategy::Default, ); let mut raw_structural_witin = RowMajorMatrix::::new( - steps.len(), + step_indices.len(), num_structural_witin, InstancePaddingStrategy::Default, ); @@ -253,14 +262,15 @@ fn assign_fp2_mul_instances::WordsFieldElement::USIZE; let words_fp2 =

::WordsCurvePoint::USIZE; - let instances: Vec> = steps + let instances: Vec> = step_indices .par_iter() - .map(|step| { + .map(|&idx| { + let step = &steps[idx]; let values: Vec = step .syscall() .unwrap() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index 51568b56a..e088cc0cc 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -1,7 +1,8 @@ use std::marker::PhantomData; use ceno_emul::{ - ByteAddr, Change, Cycle, InsnKind, KECCAK_PERMUTE, Platform, StepRecord, WORD_SIZE, WriteOp, + ByteAddr, Change, Cycle, InsnKind, KECCAK_PERMUTE, Platform, StepIndex, StepRecord, WORD_SIZE, + WriteOp, }; use ff_ext::ExtensionField; use gkr_iop::{ @@ -175,9 +176,10 @@ impl Instruction for KeccakInstruction { num_witin: usize, num_structural_witin: usize, steps: &[StepRecord], + step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let mut lk_multiplicity = LkMultiplicity::default(); - if steps.is_empty() { + if step_indices.is_empty() { return Ok(( [ RowMajorMatrix::new(0, num_witin, InstancePaddingStrategy::Default), @@ -187,17 +189,18 @@ impl Instruction for KeccakInstruction { )); } let nthreads = max_usable_threads(); - let num_instance_per_batch = steps.len().div_ceil(nthreads).max(1); + let num_instance_per_batch = step_indices.len().div_ceil(nthreads).max(1); + let rotation = KECCAK_ROUNDS.next_power_of_two().ilog2() as usize; let mut raw_witin = RowMajorMatrix::::new_by_rotation( - steps.len(), - KECCAK_ROUNDS.next_power_of_two().ilog2() as usize, + step_indices.len(), + rotation, num_witin, InstancePaddingStrategy::Default, ); let mut raw_structural_witin = RowMajorMatrix::::new_by_rotation( - steps.len(), - KECCAK_ROUNDS.next_power_of_two().ilog2() as usize, + step_indices.len(), + rotation, num_structural_witin, InstancePaddingStrategy::Default, ); @@ -208,15 +211,16 @@ impl Instruction for KeccakInstruction { // 1st pass: assign witness outside of gkr-iop scope raw_witin_iter - .zip_eq(steps.par_chunks(num_instance_per_batch)) + .zip_eq(step_indices.par_chunks(num_instance_per_batch)) .zip(shard_ctx_vec) - .flat_map(|((instances, steps), mut shard_ctx)| { + .flat_map(|((instances, indices), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin * KECCAK_ROUNDS.next_power_of_two()) - .zip_eq(steps) - .map(|(instance_with_rotation, step)| { + .zip_eq(indices.iter().copied()) + .map(|(instance_with_rotation, idx)| { + let step = &steps[idx]; let ops = &step.syscall().expect("syscall step"); let bh = BooleanHypercube::new(KECCAK_ROUNDS_CEIL_LOG2); @@ -276,9 +280,10 @@ impl Instruction for KeccakInstruction { .collect::>()?; // second pass - let instances: Vec = steps + let instances: Vec = step_indices .iter() - .map(|step| -> KeccakInstance { + .map(|&idx| -> KeccakInstance { + let step = &steps[idx]; let (instance, prev_ts): (Vec, Vec) = step .syscall() .unwrap() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs index 7f8513dc1..b61673e4a 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs @@ -1,6 +1,8 @@ use std::{array, marker::PhantomData}; -use ceno_emul::{Change, InsnKind, Platform, SHA_EXTEND, StepRecord, WORD_SIZE, WriteOp}; +use ceno_emul::{ + Change, InsnKind, Platform, SHA_EXTEND, StepIndex, StepRecord, WORD_SIZE, WriteOp, +}; use ff_ext::{ExtensionField, FieldInto}; use gkr_iop::{ ProtocolBuilder, ProtocolWitnessGenerator, @@ -173,10 +175,11 @@ impl Instruction for ShaExtendInstruction { num_witin: usize, num_structural_witin: usize, steps: &[StepRecord], + step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let mut lk_multiplicity = LkMultiplicity::default(); let num_structural_witin = config.layout.n_structural_witin.max(num_structural_witin); - if steps.is_empty() { + if step_indices.is_empty() { return Ok(( [ RowMajorMatrix::new(0, num_witin, InstancePaddingStrategy::Default), @@ -186,7 +189,7 @@ impl Instruction for ShaExtendInstruction { )); } - let num_instances = steps.len(); + let num_instances = step_indices.len(); let nthreads = max_usable_threads(); let num_instance_per_batch = num_instances.div_ceil(nthreads).max(1); @@ -205,15 +208,16 @@ impl Instruction for ShaExtendInstruction { let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter - .zip_eq(steps.par_chunks(num_instance_per_batch)) + .zip_eq(step_indices.par_chunks(num_instance_per_batch)) .zip(shard_ctx_vec) - .flat_map(|((instances, steps), mut shard_ctx)| { + .flat_map(|((instances, indices), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin) - .zip_eq(steps) - .map(|(instance, step)| { + .zip_eq(indices.iter().copied()) + .map(|(instance, idx)| { + let step = &steps[idx]; let ops = step.syscall().expect("syscall step"); // vm_state @@ -277,9 +281,10 @@ impl Instruction for ShaExtendInstruction { }) .collect::>()?; - let instances = steps + let instances = step_indices .iter() - .map(|step| -> ShaExtendInstance { + .map(|&idx| -> ShaExtendInstance { + let step = &steps[idx]; let ops = step.syscall().expect("syscall step"); let w_i_minus_2 = ops.mem_ops[0].value.before; let w_i_minus_7 = ops.mem_ops[1].value.before; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs index 8b38df7e1..f3a39093f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use ceno_emul::{ ByteAddr, Change, Cycle, InsnKind, Platform, SECP256K1_SCALAR_INVERT, SECP256R1_SCALAR_INVERT, - StepRecord, UINT256_MUL, WORD_SIZE, WriteOp, + StepIndex, StepRecord, UINT256_MUL, WORD_SIZE, WriteOp, }; use ff_ext::ExtensionField; use generic_array::typenum::Unsigned; @@ -227,11 +227,12 @@ impl Instruction for Uint256MulInstruction { num_witin: usize, num_structural_witin: usize, steps: &[StepRecord], + step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = UINT256_MUL; let mut lk_multiplicity = LkMultiplicity::default(); - if steps.is_empty() { + if step_indices.is_empty() { return Ok(( [ RowMajorMatrix::new(0, num_witin, InstancePaddingStrategy::Default), @@ -241,15 +242,15 @@ impl Instruction for Uint256MulInstruction { )); } let nthreads = max_usable_threads(); - let num_instance_per_batch = steps.len().div_ceil(nthreads).max(1); + let num_instance_per_batch = step_indices.len().div_ceil(nthreads).max(1); let mut raw_witin = RowMajorMatrix::::new( - steps.len(), + step_indices.len(), num_witin, InstancePaddingStrategy::Default, ); let mut raw_structural_witin = RowMajorMatrix::::new( - steps.len(), + step_indices.len(), num_structural_witin, InstancePaddingStrategy::Default, ); @@ -259,15 +260,16 @@ impl Instruction for Uint256MulInstruction { // 1st pass: assign witness outside of gkr-iop scope raw_witin_iter - .zip_eq(steps.par_chunks(num_instance_per_batch)) + .zip_eq(step_indices.par_chunks(num_instance_per_batch)) .zip(shard_ctx_vec) - .flat_map(|((instances, steps), mut shard_ctx)| { + .flat_map(|((instances, indices), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin) - .zip_eq(steps) - .map(|(instance, step)| { + .zip_eq(indices.iter().copied()) + .map(|(instance, idx)| { + let step = &steps[idx]; let ops = &step.syscall().expect("syscall step"); // vm_state @@ -329,9 +331,10 @@ impl Instruction for Uint256MulInstruction { .collect::>()?; // second pass - let instances: Vec = steps + let instances: Vec = step_indices .par_iter() - .map(|step| { + .map(|&idx| { + let step = &steps[idx]; let (instance, _prev_ts): (Vec, Vec) = step .syscall() .unwrap() @@ -547,11 +550,12 @@ impl Instruction for Uint256InvInstr num_witin: usize, num_structural_witin: usize, steps: &[StepRecord], + step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = Spec::syscall(); let mut lk_multiplicity = LkMultiplicity::default(); - if steps.is_empty() { + if step_indices.is_empty() { return Ok(( [ RowMajorMatrix::new(0, num_witin, InstancePaddingStrategy::Default), @@ -561,15 +565,15 @@ impl Instruction for Uint256InvInstr )); } let nthreads = max_usable_threads(); - let num_instance_per_batch = steps.len().div_ceil(nthreads).max(1); + let num_instance_per_batch = step_indices.len().div_ceil(nthreads).max(1); let mut raw_witin = RowMajorMatrix::::new( - steps.len(), + step_indices.len(), num_witin, InstancePaddingStrategy::Default, ); let mut raw_structural_witin = RowMajorMatrix::::new( - steps.len(), + step_indices.len(), num_structural_witin, InstancePaddingStrategy::Default, ); @@ -579,15 +583,16 @@ impl Instruction for Uint256InvInstr // 1st pass: assign witness outside of gkr-iop scope raw_witin_iter - .zip_eq(steps.par_chunks(num_instance_per_batch)) + .zip_eq(step_indices.par_chunks(num_instance_per_batch)) .zip(shard_ctx_vec) - .flat_map(|((instances, steps), mut shard_ctx)| { + .flat_map(|((instances, indices), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin) - .zip_eq(steps) - .map(|(instance, step)| { + .zip_eq(indices.iter().copied()) + .map(|(instance, idx)| { + let step = &steps[idx]; let ops = &step.syscall().expect("syscall step"); // vm_state @@ -636,9 +641,10 @@ impl Instruction for Uint256InvInstr .collect::>()?; // second pass - let instances: Vec = steps + let instances: Vec = step_indices .par_iter() - .map(|step| { + .map(|&idx| { + let step = &steps[idx]; let (instance, _): (Vec, Vec) = step .syscall() .unwrap() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 05a91cd97..80c85ef7a 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use ceno_emul::{ BLS12381_ADD, BN254_ADD, ByteAddr, Change, Cycle, InsnKind, Platform, SECP256K1_ADD, - SECP256R1_ADD, StepRecord, WORD_SIZE, WriteOp, + SECP256R1_ADD, StepIndex, StepRecord, WORD_SIZE, WriteOp, }; use ff_ext::ExtensionField; use generic_array::{GenericArray, typenum::Unsigned}; @@ -227,6 +227,7 @@ impl Instruction num_witin: usize, num_structural_witin: usize, steps: &[StepRecord], + step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = match EC::CURVE_TYPE { CurveType::Secp256k1 => SECP256K1_ADD, @@ -239,7 +240,7 @@ impl Instruction }; let mut lk_multiplicity = LkMultiplicity::default(); - if steps.is_empty() { + if step_indices.is_empty() { return Ok(( [ RowMajorMatrix::new(0, num_witin, InstancePaddingStrategy::Default), @@ -249,15 +250,15 @@ impl Instruction )); } let nthreads = max_usable_threads(); - let num_instance_per_batch = steps.len().div_ceil(nthreads).max(1); + let num_instance_per_batch = step_indices.len().div_ceil(nthreads).max(1); let mut raw_witin = RowMajorMatrix::::new( - steps.len(), + step_indices.len(), num_witin, InstancePaddingStrategy::Default, ); let mut raw_structural_witin = RowMajorMatrix::::new( - steps.len(), + step_indices.len(), num_structural_witin, InstancePaddingStrategy::Default, ); @@ -267,15 +268,16 @@ impl Instruction // 1st pass: assign witness outside of gkr-iop scope raw_witin_iter - .zip_eq(steps.par_chunks(num_instance_per_batch)) + .zip_eq(step_indices.par_chunks(num_instance_per_batch)) .zip(shard_ctx_vec) - .flat_map(|((instances, steps), mut shard_ctx)| { + .flat_map(|((instances, indices), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin) - .zip_eq(steps) - .map(|(instance, step)| { + .zip_eq(indices.iter().copied()) + .map(|(instance, idx)| { + let step = &steps[idx]; let ops = &step.syscall().expect("syscall step"); // vm_state @@ -338,9 +340,10 @@ impl Instruction .collect::>()?; // second pass - let instances: Vec> = steps + let instances: Vec> = step_indices .par_iter() - .map(|step| { + .map(|&idx| { + let step = &steps[idx]; let (instance, _prev_ts): (Vec, Vec) = step .syscall() .unwrap() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index 6d9a7470b..a07fc00b2 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -1,8 +1,8 @@ use std::marker::PhantomData; use ceno_emul::{ - Change, Cycle, InsnKind, Platform, SECP256K1_DECOMPRESS, SECP256R1_DECOMPRESS, StepRecord, - WriteOp, + Change, Cycle, InsnKind, Platform, SECP256K1_DECOMPRESS, SECP256R1_DECOMPRESS, StepIndex, + StepRecord, WriteOp, }; use ff_ext::ExtensionField; use generic_array::{GenericArray, typenum::Unsigned}; @@ -228,6 +228,7 @@ impl Instruction Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = match EC::CURVE_TYPE { CurveType::Secp256k1 => SECP256K1_DECOMPRESS, @@ -238,7 +239,7 @@ impl Instruction Instruction::new( - steps.len(), + step_indices.len(), num_witin, InstancePaddingStrategy::Default, ); let mut raw_structural_witin = RowMajorMatrix::::new( - steps.len(), + step_indices.len(), num_structural_witin, InstancePaddingStrategy::Default, ); @@ -267,15 +268,16 @@ impl Instruction::WordsFieldElement::USIZE; // 1st pass: assign witness outside of gkr-iop scope let sign_bit_and_y_words = raw_witin_iter - .zip_eq(steps.par_chunks(num_instance_per_batch)) + .zip_eq(step_indices.par_chunks(num_instance_per_batch)) .zip(shard_ctx_vec) - .flat_map(|((instances, steps), mut shard_ctx)| { + .flat_map(|((instances, indices), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin) - .zip_eq(steps) - .map(|(instance, step)| { + .zip_eq(indices.iter().copied()) + .map(|(instance, idx)| { + let step = &steps[idx]; let ops = &step.syscall().expect("syscall step"); // vm_state @@ -342,10 +344,11 @@ impl Instruction, ZKVMError>>()?; // second pass - let instances = steps + let instances = step_indices .par_iter() .zip(sign_bit_and_y_words.into_par_iter()) - .map(|(step, (sign_bit, old_output32))| { + .map(|(idx, (sign_bit, old_output32))| { + let step = &steps[*idx]; let (instance, _prev_ts): (Vec, Vec) = step .syscall() .unwrap() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 4b9a2aeb6..72f5f71d8 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use ceno_emul::{ BLS12381_DOUBLE, BN254_DOUBLE, ByteAddr, Change, Cycle, InsnKind, Platform, SECP256K1_DOUBLE, - SECP256R1_DOUBLE, StepRecord, WORD_SIZE, WriteOp, + SECP256R1_DOUBLE, StepIndex, StepRecord, WORD_SIZE, WriteOp, }; use ff_ext::ExtensionField; use generic_array::{GenericArray, typenum::Unsigned}; @@ -199,6 +199,7 @@ impl Instruction Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = match EC::CURVE_TYPE { CurveType::Secp256k1 => SECP256K1_DOUBLE, @@ -211,7 +212,7 @@ impl Instruction Instruction::new( - steps.len(), + step_indices.len(), num_witin, InstancePaddingStrategy::Default, ); let mut raw_structural_witin = RowMajorMatrix::::new( - steps.len(), + step_indices.len(), num_structural_witin, InstancePaddingStrategy::Default, ); @@ -239,15 +240,16 @@ impl Instruction Instruction>()?; // second pass - let instances: Vec> = steps + let instances: Vec> = step_indices .par_iter() - .map(|step| { + .map(|&idx| { + let step = &steps[idx]; let (instance, _prev_ts): (Vec, Vec) = step .syscall() .unwrap() diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index 355dad511..fe126466e 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -41,7 +41,7 @@ fn verify_test_opcode_jal(pc_offset: i32) { let new_pc: ByteAddr = ByteAddr(MOCK_PC_START.0.wrapping_add_signed(pc_offset)); let insn_code = encode_rv32(InsnKind::JAL, 0, 0, 4, pc_offset); - let (raw_witin, lkm) = JalInstruction::::assign_instances( + let (raw_witin, lkm) = JalInstruction::::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, @@ -117,7 +117,7 @@ fn verify_test_opcode_jalr(rs1_read: Word, imm: i32) { ) .unwrap(); - let (raw_witin, lkm) = JalrInstruction::::assign_instances( + let (raw_witin, lkm) = JalrInstruction::::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/logic/test.rs b/ceno_zkvm/src/instructions/riscv/logic/test.rs index 6bade9c0f..715eb0b38 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/test.rs @@ -30,7 +30,7 @@ fn test_opcode_and() { .unwrap(); let insn_code = encode_rv32(InsnKind::AND, 2, 3, 4, 0); - let (raw_witin, lkm) = AndInstruction::assign_instances( + let (raw_witin, lkm) = AndInstruction::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, @@ -73,7 +73,7 @@ fn test_opcode_or() { .unwrap(); let insn_code = encode_rv32(InsnKind::OR, 2, 3, 4, 0); - let (raw_witin, lkm) = OrInstruction::assign_instances( + let (raw_witin, lkm) = OrInstruction::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, @@ -116,7 +116,7 @@ fn test_opcode_xor() { .unwrap(); let insn_code = encode_rv32(InsnKind::XOR, 2, 3, 4, 0); - let (raw_witin, lkm) = XorInstruction::assign_instances( + let (raw_witin, lkm) = XorInstruction::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index fea2b03df..20ca51424 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -233,7 +233,7 @@ mod test { .unwrap(); let insn_code = encode_rv32u(I::INST_KIND, 2, 0, 4, imm); - let (raw_witin, lkm) = LogicInstruction::::assign_instances( + let (raw_witin, lkm) = LogicInstruction::::assign_instances_from_steps( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs index 3a003777b..c2437c5e2 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs @@ -69,7 +69,7 @@ fn verify(name: &'static str, rs1_read: u32, imm: u32, expected_rd_w .unwrap(); let insn_code = encode_rv32u(I::INST_KIND, 2, 0, 4, imm); - let (raw_witin, lkm) = LogicInstruction::::assign_instances( + let (raw_witin, lkm) = LogicInstruction::::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 93d24c4ef..deb7b5736 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -159,7 +159,7 @@ mod tests { .unwrap(); let insn_code = encode_rv32(InsnKind::LUI, 0, 0, 4, imm); - let (raw_witin, lkm) = LuiInstruction::::assign_instances( + let (raw_witin, lkm) = LuiInstruction::::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index f6b0fa153..b242a1363 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -101,7 +101,7 @@ fn impl_opcode_store sw(prev_mem_value, rs2_word), x => unreachable!("{:?} is not store instruction", x), }; - let (raw_witin, lkm) = Inst::assign_instances( + let (raw_witin, lkm) = Inst::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, @@ -163,7 +163,7 @@ fn impl_opcode_load::assign_instances( + let (raw_witin, lkm) = MulhInstructionBase::::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, @@ -217,7 +217,7 @@ mod test { // values assignment let insn_code = encode_rv32(InsnKind::MULH, 2, 3, 4, 0); - let (raw_witin, lkm) = MulhInstruction::assign_instances( + let (raw_witin, lkm) = MulhInstruction::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, @@ -301,7 +301,7 @@ mod test { // values assignment let insn_code = encode_rv32(InsnKind::MULHSU, 2, 3, 4, 0); - let (raw_witin, lkm) = MulhsuInstruction::assign_instances( + let (raw_witin, lkm) = MulhsuInstruction::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index ea21c4dcb..091ce3000 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -47,8 +47,8 @@ use ceno_emul::{ InsnKind::{self, *}, KeccakSpec, LogPcCycleSpec, Platform, Secp256k1AddSpec, Secp256k1DecompressSpec, Secp256k1DoubleSpec, Secp256k1ScalarInvertSpec, Secp256r1AddSpec, Secp256r1DoubleSpec, - Secp256r1ScalarInvertSpec, Sha256ExtendSpec, StepCellExtractor, StepRecord, SyscallSpec, - Uint256MulSpec, Word, + Secp256r1ScalarInvertSpec, Sha256ExtendSpec, StepCellExtractor, StepIndex, StepRecord, + SyscallSpec, Uint256MulSpec, Word, }; use dummy::LargeEcallDummy; use ff_ext::ExtensionField; @@ -634,6 +634,7 @@ impl Rv32imConfig { cs: &ZKVMConstraintSystem, shard_ctx: &mut ShardContext, instrunction_dispatch_ctx: &mut InstructionDispatchCtx, + shard_steps: &[StepRecord], witness: &mut ZKVMWitnesses, ) -> Result<(), ZKVMError> { instrunction_dispatch_ctx.trace_opcode_stats(); @@ -684,6 +685,7 @@ impl Rv32imConfig { cs, shard_ctx, &self.$config, + shard_steps, records, )?; }}; @@ -698,6 +700,7 @@ impl Rv32imConfig { cs, shard_ctx, &self.$config, + shard_steps, records, )?; }}; @@ -868,9 +871,9 @@ pub struct InstructionDispatchCtx { insn_to_record_buffer: Vec>, type_to_record_buffer: HashMap, insn_kinds: Vec, - circuit_record_buffers: Vec>, - fallback_record_buffers: Vec>, - ecall_record_buffers: BTreeMap>, + circuit_record_buffers: Vec>, + fallback_record_buffers: Vec>, + ecall_record_buffers: BTreeMap>, } impl InstructionDispatchCtx { @@ -894,7 +897,7 @@ impl InstructionDispatchCtx { } #[inline(always)] - pub fn ingest_step(&mut self, step: StepRecord) { + pub fn ingest_step(&mut self, step_idx: StepIndex, step: &StepRecord) { let kind = step.insn.kind; if kind == InsnKind::ECALL { let code = step @@ -904,11 +907,11 @@ impl InstructionDispatchCtx { self.ecall_record_buffers .entry(code) .or_default() - .push(step); + .push(step_idx); } else if let Some(record_buffer_idx) = self.insn_to_record_buffer[kind as usize] { - self.circuit_record_buffers[record_buffer_idx].push(step); + self.circuit_record_buffers[record_buffer_idx].push(step_idx); } else { - self.fallback_record_buffers[kind as usize].push(step); + self.fallback_record_buffers[kind as usize].push(step_idx); } } @@ -960,7 +963,7 @@ impl InstructionDispatchCtx { fn records_for_kinds + 'static>( &self, - ) -> Option<&[StepRecord]> { + ) -> Option<&[StepIndex]> { let record_buffer_id = self .type_to_record_buffer .get(&TypeId::of::()) @@ -970,7 +973,7 @@ impl InstructionDispatchCtx { .map(|records| records.as_slice()) } - fn records_for_ecall_code(&self, code: u32) -> Option<&[StepRecord]> { + fn records_for_ecall_code(&self, code: u32) -> Option<&[StepIndex]> { self.ecall_record_buffers .get(&code) .map(|records| records.as_slice()) @@ -1007,6 +1010,7 @@ impl DummyExtraConfig { cs: &ZKVMConstraintSystem, shard_ctx: &mut ShardContext, instrunction_dispatch_ctx: &InstructionDispatchCtx, + shard_steps: &[StepRecord], witness: &mut ZKVMWitnesses, ) -> Result<(), ZKVMError> { let phantom_log_pc_cycle_records = instrunction_dispatch_ctx @@ -1016,6 +1020,7 @@ impl DummyExtraConfig { cs, shard_ctx, &self.phantom_log_pc_cycle, + shard_steps, phantom_log_pc_cycle_records, )?; Ok(()) diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index ea082e3c6..6bbfc418d 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -172,7 +172,7 @@ mod tests { ) .unwrap(); - let (raw_witin, lkm) = ShiftLogicalInstruction::::assign_instances( + let (raw_witin, lkm) = ShiftLogicalInstruction::::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index d97a0b09e..c626b07d2 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -169,7 +169,7 @@ mod test { ) .unwrap(); - let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances( + let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index 629354e41..a0dc51bd6 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -71,7 +71,7 @@ mod test { .unwrap(); let insn_code = encode_rv32(I::INST_KIND, 2, 3, 4, 0); - let (raw_witin, lkm) = SetLessThanInstruction::<_, I>::assign_instances( + let (raw_witin, lkm) = SetLessThanInstruction::<_, I>::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 620d6ff3d..90dcb8448 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -184,7 +184,7 @@ mod test { .unwrap() .unwrap(); - let (raw_witin, lkm) = SetLessThanImmInstruction::::assign_instances( + let (raw_witin, lkm) = SetLessThanImmInstruction::::assign_instances_from_steps( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index dfb6c35ef..d38f8fca5 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -18,7 +18,7 @@ use crate::{ use ceno_emul::{ CENO_PLATFORM, InsnKind, InsnKind::{ADD, ECALL}, - Platform, Program, StepRecord, VMState, encode_rv32, + Platform, Program, StepIndex, StepRecord, VMState, encode_rv32, }; use ff_ext::{ExtensionField, FieldInto, FromUniformBytes, GoldilocksExt2}; use gkr_iop::cpu::default_backend_config; @@ -147,12 +147,14 @@ fn test_rw_lk_expression_combination() { let num_instances = 1 << 8; let mut zkvm_witness = ZKVMWitnesses::default(); let steps = vec![StepRecord::default(); num_instances]; + let step_indices: Vec = (0..steps.len()).collect(); zkvm_witness .assign_opcode_circuit::>( &zkvm_cs, &mut shard_ctx, &config, &steps, + &step_indices, ) .unwrap(); @@ -327,12 +329,10 @@ fn test_single_add_instance_e2e() { // single instance let mut vm = VMState::new(CENO_PLATFORM.clone(), program.clone().into()); - let all_records = vm - .iter_until_halt() - .collect::, _>>() - .expect("vm exec failed") - .into_iter() - .collect::>(); + vm.iter_until_halt() + .collect::, _>>() + .expect("vm exec failed"); + let all_records = vm.tracer().recorded_steps().to_vec(); let mut add_records = vec![]; let mut halt_records = vec![]; all_records.into_iter().for_each(|record| { @@ -358,12 +358,15 @@ fn test_single_add_instance_e2e() { let verifier = ZKVMVerifier::new(vk); let mut zkvm_witness = ZKVMWitnesses::default(); // assign opcode circuits + let add_indices: Vec = (0..add_records.len()).collect(); + let halt_indices: Vec = (0..halt_records.len()).collect(); zkvm_witness .assign_opcode_circuit::>( &zkvm_cs, &mut shard_ctx, &add_config, &add_records, + &add_indices, ) .unwrap(); zkvm_witness @@ -372,6 +375,7 @@ fn test_single_add_instance_e2e() { &mut shard_ctx, &halt_config, &halt_records, + &halt_indices, ) .unwrap(); zkvm_witness.finalize_lk_multiplicities(); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 76a2d9334..1f6847140 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -10,7 +10,7 @@ use crate::{ TableCircuit, }, }; -use ceno_emul::{Addr, CENO_PLATFORM, Platform, RegIdx, StepRecord, WordAddr}; +use ceno_emul::{Addr, CENO_PLATFORM, Platform, RegIdx, StepIndex, StepRecord, WordAddr}; use ff_ext::{ExtensionField, PoseidonField}; use gkr_iop::{gkr::GKRCircuit, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use itertools::Itertools; @@ -368,7 +368,8 @@ impl ZKVMWitnesses { cs: &ZKVMConstraintSystem, shard_ctx: &mut ShardContext, config: &OC::InstructionConfig, - records: &[StepRecord], + shard_steps: &[StepRecord], + indices: &[StepIndex], ) -> Result<(), ZKVMError> { assert!(self.combined_lk_mlt.is_none()); @@ -378,7 +379,8 @@ impl ZKVMWitnesses { shard_ctx, cs.zkvm_v1_css.num_witin as usize, cs.zkvm_v1_css.num_structural_witin as usize, - records, + shard_steps, + indices, )?; let num_instances = vec![witness[0].num_instances()]; let input = ChipInput::new(