diff --git a/Cargo.lock b/Cargo.lock index 3f3766977f..60ffefb562 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -219,9 +219,9 @@ dependencies = [ [[package]] name = "alloy-eips" -version = "1.8.2" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b97433ffdb356d11b6c89b08c69a787b9f55d787cdeee733c12fdf85d465ef1" +checksum = "b9f7ef09f21bd1e9cb8a686f168cb4a206646804567f0889eadb8dcc4c9288c8" dependencies = [ "alloy-eip2124", "alloy-eip2930", @@ -238,6 +238,7 @@ dependencies = [ "serde", "serde_with", "sha2", + "thiserror 2.0.18", ] [[package]] @@ -352,9 +353,9 @@ dependencies = [ [[package]] name = "alloy-serde" -version = "1.8.2" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1c2c0b5f024814f1c04ae76ff71862d06c836e7d67102daf8a557e5056be68" +checksum = "e2ce1e0dbf7720eee747700e300c99aac01b1a95bb93f493a01e78ee28bb1a37" dependencies = [ "alloy-primitives", "serde", @@ -482,22 +483,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" dependencies = [ "anstyle", - "anstyle-parse 0.2.7", - "anstyle-query", - "anstyle-wincon", - "colorchoice", - "is_terminal_polyfill", - "utf8parse", -] - -[[package]] -name = "anstream" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" -dependencies = [ - "anstyle", - "anstyle-parse 1.0.0", + "anstyle-parse", "anstyle-query", "anstyle-wincon", "colorchoice", @@ -520,22 +506,13 @@ dependencies = [ "utf8parse", ] -[[package]] -name = "anstyle-parse" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" -dependencies = [ - "utf8parse", -] - [[package]] name = "anstyle-query" version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -546,7 +523,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -956,7 +933,7 @@ dependencies = [ "bitflags 2.11.0", "cexpr", "clang-sys", - "itertools 0.12.1", + "itertools 0.13.0", "proc-macro2", "quote", "regex", @@ -1288,7 +1265,7 @@ version = "4.5.60" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" dependencies = [ - "anstream 0.6.21", + "anstream", "anstyle", "clap_lex", "strsim", @@ -1887,11 +1864,11 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.10" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" +checksum = "b2daee4ea451f429a58296525ddf28b45a3b64f1acf6587e2067437bb11e218d" dependencies = [ - "anstream 1.0.0", + "anstream", "anstyle", "env_filter", "jiff", @@ -1931,7 +1908,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -2724,7 +2701,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -2886,7 +2863,6 @@ dependencies = [ "ark-ff 0.5.0", "jolt-inlines-sdk", "rand 0.8.5", - "tracer", ] [[package]] @@ -2897,7 +2873,6 @@ dependencies = [ "hex-literal", "jolt-inlines-sdk", "rand 0.8.5", - "tracer", ] [[package]] @@ -2907,7 +2882,6 @@ dependencies = [ "blake3", "jolt-inlines-sdk", "rand 0.8.5", - "tracer", ] [[package]] @@ -2938,6 +2912,7 @@ version = "0.1.0" dependencies = [ "inventory", "jolt-platform", + "rand 0.8.5", "tracer", ] @@ -3397,7 +3372,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -4738,7 +4713,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -5520,7 +5495,7 @@ dependencies = [ "getrandom 0.3.4", "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -6145,7 +6120,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/jolt-inlines/bigint/Cargo.toml b/jolt-inlines/bigint/Cargo.toml index 47ebcfd419..e458c768bd 100644 --- a/jolt-inlines/bigint/Cargo.toml +++ b/jolt-inlines/bigint/Cargo.toml @@ -17,4 +17,3 @@ jolt-inlines-sdk = { workspace = true, optional = true } [dev-dependencies] ark-ff.workspace = true rand = { workspace = true, features = ["std", "std_rng"] } -tracer = { workspace = true, features = ["std", "test-utils"] } diff --git a/jolt-inlines/bigint/src/multiplication/exec.rs b/jolt-inlines/bigint/src/multiplication/exec.rs deleted file mode 100644 index d8af57fb30..0000000000 --- a/jolt-inlines/bigint/src/multiplication/exec.rs +++ /dev/null @@ -1,46 +0,0 @@ -use super::{INPUT_LIMBS, OUTPUT_LIMBS}; - -/// Execute n-bit × n-bit = 2n-bit multiplication -/// Input: u64 limbs for each operand (little-endian) -/// Output: u64 limbs for the product (little-endian) -pub fn bigint_mul(lhs: [u64; INPUT_LIMBS], rhs: [u64; INPUT_LIMBS]) -> [u64; OUTPUT_LIMBS] { - let mut result = [0u64; OUTPUT_LIMBS]; - - // Schoolbook multiplication: compute all partial products - // For each a[i] * b[j], add the 128-bit product to result[i+j] - for (i, &lhs_limb) in lhs.iter().enumerate() { - for (j, &rhs_limb) in rhs.iter().enumerate() { - // Compute 64×64 = 128-bit product - let product = (lhs_limb as u128) * (rhs_limb as u128); - let low = product as u64; - let high = (product >> 64) as u64; - - // Add to result[i+j] with carry propagation - let result_position = i + j; - - // Add low part - let (sum, carry1) = result[result_position].overflowing_add(low); - result[result_position] = sum; - - // Propagate carry through high part and beyond - let mut carry = carry1 as u64; - if high != 0 || carry != 0 { - // Add high part plus carry from low part - let (sum_with_hi, carry_hi) = result[result_position + 1].overflowing_add(high); - let (sum_with_carry, carry_carry) = sum_with_hi.overflowing_add(carry); - result[result_position + 1] = sum_with_carry; - carry = (carry_hi as u64) + (carry_carry as u64); - - // Continue propagating carry if needed - let mut carry_position = result_position + 2; - while carry != 0 && carry_position < OUTPUT_LIMBS { - let (sum, c) = result[carry_position].overflowing_add(carry); - result[carry_position] = sum; - carry = c as u64; - carry_position += 1; - } - } - } - } - result -} diff --git a/jolt-inlines/bigint/src/multiplication/mod.rs b/jolt-inlines/bigint/src/multiplication/mod.rs index ec327f0fad..ebd29465d5 100644 --- a/jolt-inlines/bigint/src/multiplication/mod.rs +++ b/jolt-inlines/bigint/src/multiplication/mod.rs @@ -10,12 +10,10 @@ const OUTPUT_LIMBS: usize = 2 * INPUT_LIMBS; pub mod sdk; pub use sdk::*; -#[cfg(feature = "host")] -pub mod exec; #[cfg(feature = "host")] pub mod sequence_builder; +#[cfg(feature = "host")] +pub mod spec; -#[cfg(all(test, feature = "host"))] -pub mod test_utils; #[cfg(all(test, feature = "host"))] pub mod tests; diff --git a/jolt-inlines/bigint/src/multiplication/sdk.rs b/jolt-inlines/bigint/src/multiplication/sdk.rs index f8bc8351b8..ad344d06b9 100644 --- a/jolt-inlines/bigint/src/multiplication/sdk.rs +++ b/jolt-inlines/bigint/src/multiplication/sdk.rs @@ -69,10 +69,11 @@ pub unsafe fn bigint256_mul_inline(_a: *const u64, _b: *const u64, _result: *mut /// - `result` must point to at least 64 bytes of writable memory #[cfg(feature = "host")] pub unsafe fn bigint256_mul_inline(a: *const u64, b: *const u64, result: *mut u64) { - use crate::multiplication::exec; + use crate::multiplication::sequence_builder::BigintMul256; + use jolt_inlines_sdk::spec::InlineSpec; let a_array = *(a as *const [u64; INPUT_LIMBS]); let b_array = *(b as *const [u64; INPUT_LIMBS]); - let result_array = exec::bigint_mul(a_array, b_array); + let result_array = BigintMul256::reference(&(a_array, b_array)); core::ptr::copy_nonoverlapping(result_array.as_ptr(), result, OUTPUT_LIMBS); } diff --git a/jolt-inlines/bigint/src/multiplication/sequence_builder.rs b/jolt-inlines/bigint/src/multiplication/sequence_builder.rs index b9917b548a..ccae605801 100644 --- a/jolt-inlines/bigint/src/multiplication/sequence_builder.rs +++ b/jolt-inlines/bigint/src/multiplication/sequence_builder.rs @@ -150,6 +150,7 @@ impl InlineOp for BigintMul256 { const FUNCT3: u32 = crate::BIGINT256_MUL_FUNCT3; const FUNCT7: u32 = crate::BIGINT256_MUL_FUNCT7; const NAME: &'static str = crate::BIGINT256_MUL_NAME; + type Advice = (); fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { BigIntMulSequenceBuilder::new(asm, operands).build() diff --git a/jolt-inlines/bigint/src/multiplication/spec.rs b/jolt-inlines/bigint/src/multiplication/spec.rs new file mode 100644 index 0000000000..70c3057898 --- /dev/null +++ b/jolt-inlines/bigint/src/multiplication/spec.rs @@ -0,0 +1,69 @@ +use super::sequence_builder::BigintMul256; +use super::{INPUT_LIMBS, OUTPUT_LIMBS}; +use jolt_inlines_sdk::host::Xlen; +use jolt_inlines_sdk::spec::rand::rngs::StdRng; +use jolt_inlines_sdk::spec::rand::Rng; +use jolt_inlines_sdk::spec::{InlineMemoryLayout, InlineSpec, InlineTestHarness}; + +impl InlineSpec for BigintMul256 { + type Input = ([u64; INPUT_LIMBS], [u64; INPUT_LIMBS]); + type Output = [u64; OUTPUT_LIMBS]; + + fn random_input(rng: &mut StdRng) -> Self::Input { + ( + core::array::from_fn(|_| rng.gen()), + core::array::from_fn(|_| rng.gen()), + ) + } + + fn reference(input: &Self::Input) -> Self::Output { + let (lhs, rhs) = input; + let mut result = [0u64; OUTPUT_LIMBS]; + + for (i, &lhs_limb) in lhs.iter().enumerate() { + for (j, &rhs_limb) in rhs.iter().enumerate() { + let product = (lhs_limb as u128) * (rhs_limb as u128); + let low = product as u64; + let high = (product >> 64) as u64; + + let result_position = i + j; + + let (sum, carry1) = result[result_position].overflowing_add(low); + result[result_position] = sum; + + let mut carry = carry1 as u64; + if high != 0 || carry != 0 { + let (sum_with_hi, carry_hi) = result[result_position + 1].overflowing_add(high); + let (sum_with_carry, carry_carry) = sum_with_hi.overflowing_add(carry); + result[result_position + 1] = sum_with_carry; + carry = (carry_hi as u64) + (carry_carry as u64); + + let mut carry_position = result_position + 2; + while carry != 0 && carry_position < OUTPUT_LIMBS { + let (sum, c) = result[carry_position].overflowing_add(carry); + result[carry_position] = sum; + carry = c as u64; + carry_position += 1; + } + } + } + } + result + } + + fn create_harness() -> InlineTestHarness { + let layout = InlineMemoryLayout::two_inputs(32, 32, 64); + InlineTestHarness::new(layout, Xlen::Bit64) + } + + fn load(harness: &mut InlineTestHarness, input: &Self::Input) { + harness.setup_registers(); + harness.load_input64(&input.0); + harness.load_input2_64(&input.1); + } + + fn read(harness: &mut InlineTestHarness) -> Self::Output { + let vec = harness.read_output64(OUTPUT_LIMBS); + vec.try_into().unwrap() + } +} diff --git a/jolt-inlines/bigint/src/multiplication/test_utils.rs b/jolt-inlines/bigint/src/multiplication/test_utils.rs deleted file mode 100644 index 90a723d0b4..0000000000 --- a/jolt-inlines/bigint/src/multiplication/test_utils.rs +++ /dev/null @@ -1,41 +0,0 @@ -use super::{BIGINT256_MUL_FUNCT3, BIGINT256_MUL_FUNCT7, INLINE_OPCODE, INPUT_LIMBS, OUTPUT_LIMBS}; -use tracer::emulator::cpu::Xlen; -use tracer::utils::inline_test_harness::{InlineMemoryLayout, InlineTestHarness}; - -pub type BigIntInput = ([u64; INPUT_LIMBS], [u64; INPUT_LIMBS]); -pub type BigIntOutput = [u64; OUTPUT_LIMBS]; - -pub fn create_bigint_harness() -> InlineTestHarness { - let layout = InlineMemoryLayout::two_inputs(32, 32, 64); - InlineTestHarness::new(layout, Xlen::Bit64) -} - -pub fn instruction() -> tracer::instruction::inline::INLINE { - InlineTestHarness::create_default_instruction( - INLINE_OPCODE, - BIGINT256_MUL_FUNCT3, - BIGINT256_MUL_FUNCT7, - ) -} - -pub mod bigint_verify { - use super::*; - - pub fn assert_exec_trace_equiv( - lhs: &[u64; INPUT_LIMBS], - rhs: &[u64; INPUT_LIMBS], - expected: &[u64; OUTPUT_LIMBS], - ) { - let mut harness = create_bigint_harness(); - harness.setup_registers(); - harness.load_input64(lhs); - harness.load_input2_64(rhs); - harness.execute_inline(instruction()); - - let result_vec = harness.read_output64(OUTPUT_LIMBS); - let mut result = [0u64; OUTPUT_LIMBS]; - result.copy_from_slice(&result_vec); - - assert_eq!(&result, expected, "BigInt multiplication result mismatch"); - } -} diff --git a/jolt-inlines/bigint/src/multiplication/tests.rs b/jolt-inlines/bigint/src/multiplication/tests.rs index b4bd9cbcab..a94b35ca51 100644 --- a/jolt-inlines/bigint/src/multiplication/tests.rs +++ b/jolt-inlines/bigint/src/multiplication/tests.rs @@ -4,36 +4,29 @@ use super::{INPUT_LIMBS, OUTPUT_LIMBS}; mod bigint256_multiplication { use super::TestVectors; - use crate::test_utils::bigint_verify; + use crate::multiplication::sequence_builder::BigintMul256; + use jolt_inlines_sdk::spec; #[test] fn test_bigint256_mul_default() { - // Test with the default test vector - let (lhs, rhs, expected) = TestVectors::get_default_test(); - bigint_verify::assert_exec_trace_equiv(&lhs, &rhs, &expected); + let (lhs, rhs, _expected) = TestVectors::get_default_test(); + spec::verify::(&(lhs, rhs)); } #[test] fn test_bigint256_mul_random() { - // Test with 100 random inputs for _ in 0..100 { - let (lhs, rhs, expected) = TestVectors::generate_random_test(); - bigint_verify::assert_exec_trace_equiv(&lhs, &rhs, &expected); + let (lhs, rhs, _expected) = TestVectors::generate_random_test(); + spec::verify::(&(lhs, rhs)); } } #[test] fn test_bigint256_mul_edge_cases() { - println!("\n=== Testing BigInt256 multiplication edge cases ==="); - let edge_cases = TestVectors::get_edge_cases(); - - for (i, (lhs, rhs, expected, description)) in edge_cases.iter().enumerate() { - println!("\nEdge case #{}: {}", i + 1, description); - bigint_verify::assert_exec_trace_equiv(lhs, rhs, expected); + for (lhs, rhs, _expected, _description) in &edge_cases { + spec::verify::(&(*lhs, *rhs)); } - - println!("\n✅ All {} edge cases passed!\n", edge_cases.len()); } } diff --git a/jolt-inlines/blake2/Cargo.toml b/jolt-inlines/blake2/Cargo.toml index f43becd2e2..49f0215672 100644 --- a/jolt-inlines/blake2/Cargo.toml +++ b/jolt-inlines/blake2/Cargo.toml @@ -18,4 +18,3 @@ jolt-inlines-sdk = { workspace = true, optional = true } hex-literal.workspace = true blake2.workspace = true rand = { workspace = true, features = ["std", "std_rng"] } -tracer = { workspace = true, features = ["std", "test-utils"] } diff --git a/jolt-inlines/blake2/src/exec.rs b/jolt-inlines/blake2/src/exec.rs deleted file mode 100644 index 23e51b2387..0000000000 --- a/jolt-inlines/blake2/src/exec.rs +++ /dev/null @@ -1,108 +0,0 @@ -use crate::{IV, SIGMA}; - -/// Rust implementation of BLAKE2 compression on the host. -pub fn execute_blake2b_compression(state: &mut [u64; 8], message_words: &[u64; 18]) { - let mut v = [0u64; 16]; - v[0..8].copy_from_slice(state); - v[8..16].copy_from_slice(&IV); - - v[12] ^= message_words[16]; - // v[13] ^= counter.shr(64) as u64; // not used for 64-bit counter - - if message_words[17] != 0 { - v[14] = !v[14]; - } - - for s in SIGMA { - // Column step - g( - &mut v, - 0, - 4, - 8, - 12, - message_words[s[0]], - message_words[s[1]], - ); - g( - &mut v, - 1, - 5, - 9, - 13, - message_words[s[2]], - message_words[s[3]], - ); - g( - &mut v, - 2, - 6, - 10, - 14, - message_words[s[4]], - message_words[s[5]], - ); - g( - &mut v, - 3, - 7, - 11, - 15, - message_words[s[6]], - message_words[s[7]], - ); - - // Diagonal step - g( - &mut v, - 0, - 5, - 10, - 15, - message_words[s[8]], - message_words[s[9]], - ); - g( - &mut v, - 1, - 6, - 11, - 12, - message_words[s[10]], - message_words[s[11]], - ); - g( - &mut v, - 2, - 7, - 8, - 13, - message_words[s[12]], - message_words[s[13]], - ); - g( - &mut v, - 3, - 4, - 9, - 14, - message_words[s[14]], - message_words[s[15]], - ); - } - - for i in 0..8 { - state[i] ^= v[i] ^ v[i + 8]; - } -} - -fn g(v: &mut [u64; 16], a: usize, b: usize, c: usize, d: usize, x: u64, y: u64) { - v[a] = v[a].wrapping_add(v[b]).wrapping_add(x); - v[d] = (v[d] ^ v[a]).rotate_right(32); - v[c] = v[c].wrapping_add(v[d]); - v[b] = (v[b] ^ v[c]).rotate_right(24); - v[a] = v[a].wrapping_add(v[b]).wrapping_add(y); - v[d] = (v[d] ^ v[a]).rotate_right(16); - v[c] = v[c].wrapping_add(v[d]); - v[b] = (v[b] ^ v[c]).rotate_right(63); -} diff --git a/jolt-inlines/blake2/src/lib.rs b/jolt-inlines/blake2/src/lib.rs index 12dd544c7e..ad1d7ef6d1 100644 --- a/jolt-inlines/blake2/src/lib.rs +++ b/jolt-inlines/blake2/src/lib.rs @@ -10,19 +10,16 @@ pub const BLAKE2_NAME: &str = "BLAKE2_INLINE"; pub mod sdk; pub use sdk::*; -#[cfg(feature = "host")] -pub mod exec; #[cfg(feature = "host")] pub mod sequence_builder; +#[cfg(feature = "host")] +pub mod spec; #[cfg(feature = "host")] mod host; #[cfg(feature = "host")] pub use host::*; -#[cfg(all(test, feature = "host"))] -pub mod test_utils; - /// Blake2b initialization vector (IV). pub const IV: [u64; 8] = [ 0x6a09e667f3bcc908, diff --git a/jolt-inlines/blake2/src/sdk.rs b/jolt-inlines/blake2/src/sdk.rs index 3b0bd2861a..48a1df7269 100644 --- a/jolt-inlines/blake2/src/sdk.rs +++ b/jolt-inlines/blake2/src/sdk.rs @@ -401,18 +401,13 @@ pub(crate) unsafe fn blake2b_compress(state: *mut u64, message: *const u64) { /// - `message` must point to a valid array of 18 u64 values #[cfg(feature = "host")] pub(crate) unsafe fn blake2b_compress(state: *mut u64, message: *const u64) { - let state_slice = core::slice::from_raw_parts_mut(state, 8); - let message_slice = core::slice::from_raw_parts(message, 18); - - // Convert to arrays for type safety - let state_array: &mut [u64; 8] = state_slice - .try_into() - .expect("State pointer must reference exactly 8 u64 values"); - let message_array: [u64; 18] = message_slice - .try_into() - .expect("Message pointer must reference exactly 18 u64 values"); - - crate::exec::execute_blake2b_compression(state_array, &message_array); + use crate::sequence_builder::Blake2bCompression; + use jolt_inlines_sdk::spec::InlineSpec; + + let state_array = *(state as *const [u64; 8]); + let message_array = *(message as *const [u64; 18]); + let result = Blake2bCompression::reference(&(state_array, message_array)); + core::ptr::copy_nonoverlapping(result.as_ptr(), state, 8); } #[cfg(all( diff --git a/jolt-inlines/blake2/src/sequence_builder.rs b/jolt-inlines/blake2/src/sequence_builder.rs index d933210d64..5cba2ae618 100644 --- a/jolt-inlines/blake2/src/sequence_builder.rs +++ b/jolt-inlines/blake2/src/sequence_builder.rs @@ -14,11 +14,10 @@ use jolt_inlines_sdk::host::{ instruction::{ ld::LD, lui::LUI, - sd::SD, sub::SUB, virtual_xor_rot::{VirtualXORROT16, VirtualXORROT24, VirtualXORROT32, VirtualXORROT63}, }, - FormatInline, InlineOp, InstrAssembler, Instruction, + FormatInline, InlineOp, InstrAssembler, InstrAssemblerExt, Instruction, Value::{Imm, Reg}, VirtualRegisterGuard, }; @@ -79,21 +78,17 @@ impl Blake2SequenceBuilder { } fn load_hash_state(&mut self) { - self.load_data_range( - self.operands.rs1, - 0, - VR_HASH_STATE_START, - crate::STATE_VECTOR_LEN, - ); + let regs: Vec = (VR_HASH_STATE_START..VR_HASH_STATE_START + crate::STATE_VECTOR_LEN) + .map(|i| *self.vr[i]) + .collect(); + self.asm.load_u64_range(self.operands.rs1, 0, ®s); } fn load_message_blocks(&mut self) { - self.load_data_range( - self.operands.rs2, - 0, - VR_MESSAGE_BLOCK_START, - crate::MSG_BLOCK_LEN, - ); + let regs: Vec = (VR_MESSAGE_BLOCK_START..VR_MESSAGE_BLOCK_START + crate::MSG_BLOCK_LEN) + .map(|i| *self.vr[i]) + .collect(); + self.asm.load_u64_range(self.operands.rs2, 0, ®s); } fn load_counter_and_is_final(&mut self) { @@ -224,32 +219,11 @@ impl Blake2SequenceBuilder { } } - /// Store the final hash state fn store_state(&mut self) { - for i in 0..crate::STATE_VECTOR_LEN { - self.asm.emit_s::( - self.operands.rs1, - *self.vr[VR_HASH_STATE_START + i], - (i * 8) as i64, - ); - } - } - - /// Load data from memory into virtual registers starting at a given offset - fn load_data_range( - &mut self, - base_register: u8, - memory_offset_start: usize, - vr_start: usize, - count: usize, - ) { - (0..count).for_each(|i| { - self.asm.emit_ld::( - *self.vr[vr_start + i], - base_register, - ((memory_offset_start + i) * 8) as i64, - ); - }); + let regs: Vec = (VR_HASH_STATE_START..VR_HASH_STATE_START + crate::STATE_VECTOR_LEN) + .map(|i| *self.vr[i]) + .collect(); + self.asm.store_u64_range(self.operands.rs1, 0, ®s); } } @@ -260,6 +234,7 @@ impl InlineOp for Blake2bCompression { const FUNCT3: u32 = crate::BLAKE2_FUNCT3; const FUNCT7: u32 = crate::BLAKE2_FUNCT7; const NAME: &'static str = crate::BLAKE2_NAME; + type Advice = (); fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { Blake2SequenceBuilder::new(asm, operands).build() @@ -268,13 +243,10 @@ impl InlineOp for Blake2bCompression { #[cfg(test)] mod tests { - use crate::{ - test_utils::{create_blake2_harness, instruction, load_blake2_data, read_state}, - IV, - }; + use crate::IV; + use jolt_inlines_sdk::spec::InlineSpec; fn generate_default_input() -> ([u64; crate::MSG_BLOCK_LEN], u64) { - // Message block with "abc" in little-endian let mut message = [0u64; crate::MSG_BLOCK_LEN]; message[0] = 0x0000000000636261u64; (message, 3) @@ -317,13 +289,18 @@ mod tests { counter: u64, is_final: bool, ) -> [u64; crate::STATE_VECTOR_LEN] { - let mut harness = create_blake2_harness(); - load_blake2_data(&mut harness, state, message, counter, is_final); - harness.execute_inline(instruction()); - read_state(&mut harness) + let mut combined = [0u64; 18]; + combined[..16].copy_from_slice(message); + combined[16] = counter; + combined[17] = is_final as u64; + + let input = (*state, combined); + let mut harness = super::Blake2bCompression::create_harness(); + super::Blake2bCompression::load(&mut harness, &input); + harness.execute_inline(super::Blake2bCompression::instruction()); + super::Blake2bCompression::read(&mut harness) } - /// Helper function to test blake2b compression with given input fn verify_blake2b_compression(message_words: [u64; crate::MSG_BLOCK_LEN], message_len: u64) { let mut initial_state = IV; initial_state[0] ^= 0x01010000 ^ 64u64; @@ -333,7 +310,7 @@ mod tests { assert_eq!( &expected_state, &trace_result, - "\n❌ BLAKE2b Trace Verification Failed!\n\ + "\nBLAKE2b Trace Verification Failed!\n\ Message: {message_words:016x?}" ); } @@ -346,7 +323,6 @@ mod tests { #[test] fn test_trace_result_with_random_inputs() { - // Test with multiple random inputs for _ in 0..10 { let input = generate_random_input(); verify_blake2b_compression(input.0, input.1); diff --git a/jolt-inlines/blake2/src/spec.rs b/jolt-inlines/blake2/src/spec.rs new file mode 100644 index 0000000000..043dabc81a --- /dev/null +++ b/jolt-inlines/blake2/src/spec.rs @@ -0,0 +1,76 @@ +use crate::sequence_builder::Blake2bCompression; +use crate::{IV, SIGMA}; +use jolt_inlines_sdk::host::Xlen; +use jolt_inlines_sdk::spec::rand::rngs::StdRng; +use jolt_inlines_sdk::spec::rand::Rng; +use jolt_inlines_sdk::spec::{InlineMemoryLayout, InlineSpec, InlineTestHarness}; + +fn g(v: &mut [u64; 16], a: usize, b: usize, c: usize, d: usize, x: u64, y: u64) { + v[a] = v[a].wrapping_add(v[b]).wrapping_add(x); + v[d] = (v[d] ^ v[a]).rotate_right(32); + v[c] = v[c].wrapping_add(v[d]); + v[b] = (v[b] ^ v[c]).rotate_right(24); + v[a] = v[a].wrapping_add(v[b]).wrapping_add(y); + v[d] = (v[d] ^ v[a]).rotate_right(16); + v[c] = v[c].wrapping_add(v[d]); + v[b] = (v[b] ^ v[c]).rotate_right(63); +} + +impl InlineSpec for Blake2bCompression { + type Input = ([u64; 8], [u64; 18]); + type Output = [u64; 8]; + + fn random_input(rng: &mut StdRng) -> Self::Input { + let state: [u64; 8] = core::array::from_fn(|_| rng.gen()); + let mut message: [u64; 18] = core::array::from_fn(|_| rng.gen()); + message[17] = rng.gen_range(0..=1); + (state, message) + } + + fn reference(input: &Self::Input) -> Self::Output { + let (state, w) = input; + let mut v = [0u64; 16]; + v[0..8].copy_from_slice(state); + v[8..16].copy_from_slice(&IV); + + v[12] ^= w[16]; + + if w[17] != 0 { + v[14] = !v[14]; + } + + for s in SIGMA { + g(&mut v, 0, 4, 8, 12, w[s[0]], w[s[1]]); + g(&mut v, 1, 5, 9, 13, w[s[2]], w[s[3]]); + g(&mut v, 2, 6, 10, 14, w[s[4]], w[s[5]]); + g(&mut v, 3, 7, 11, 15, w[s[6]], w[s[7]]); + + g(&mut v, 0, 5, 10, 15, w[s[8]], w[s[9]]); + g(&mut v, 1, 6, 11, 12, w[s[10]], w[s[11]]); + g(&mut v, 2, 7, 8, 13, w[s[12]], w[s[13]]); + g(&mut v, 3, 4, 9, 14, w[s[14]], w[s[15]]); + } + + let mut result = *state; + for i in 0..8 { + result[i] ^= v[i] ^ v[i + 8]; + } + result + } + + fn create_harness() -> InlineTestHarness { + let layout = InlineMemoryLayout::single_input(144, 64); + InlineTestHarness::new(layout, Xlen::Bit64) + } + + fn load(harness: &mut InlineTestHarness, input: &Self::Input) { + harness.setup_registers(); + harness.load_state64(&input.0); + harness.load_input64(&input.1); + } + + fn read(harness: &mut InlineTestHarness) -> Self::Output { + let vec = harness.read_output64(8); + vec.try_into().unwrap() + } +} diff --git a/jolt-inlines/blake2/src/test_utils.rs b/jolt-inlines/blake2/src/test_utils.rs deleted file mode 100644 index cec52c4154..0000000000 --- a/jolt-inlines/blake2/src/test_utils.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::{BLAKE2_FUNCT3, BLAKE2_FUNCT7, INLINE_OPCODE}; -use tracer::emulator::cpu::Xlen; -use tracer::utils::inline_test_harness::{InlineMemoryLayout, InlineTestHarness}; - -pub fn create_blake2_harness() -> InlineTestHarness { - let layout = InlineMemoryLayout::single_input(144, 64); - InlineTestHarness::new(layout, Xlen::Bit64) -} - -pub fn load_blake2_data( - harness: &mut InlineTestHarness, - state: &[u64; crate::STATE_VECTOR_LEN], - message: &[u64; crate::MSG_BLOCK_LEN], - counter: u64, - is_final: bool, -) { - harness.setup_registers(); // RS1=state, RS2=message+params - harness.load_state64(state); - - // Blake2 expects message + counter + flag contiguously at rs2 - // Create combined input: message (16 u64s) + counter (1 u64) + flag (1 u64) - let mut combined_input = Vec::with_capacity(18); - combined_input.extend_from_slice(message); - combined_input.push(counter); - let flag_value = if is_final { 1u64 } else { 0u64 }; - combined_input.push(flag_value); - - // Load the combined input - harness.load_input64(&combined_input); -} - -pub fn read_state(harness: &mut InlineTestHarness) -> [u64; crate::STATE_VECTOR_LEN] { - let vec = harness.read_output64(crate::STATE_VECTOR_LEN); - let mut state = [0u64; crate::STATE_VECTOR_LEN]; - state.copy_from_slice(&vec); - state -} - -pub fn instruction() -> tracer::instruction::inline::INLINE { - InlineTestHarness::create_default_instruction(INLINE_OPCODE, BLAKE2_FUNCT3, BLAKE2_FUNCT7) -} diff --git a/jolt-inlines/blake3/Cargo.toml b/jolt-inlines/blake3/Cargo.toml index ac7977ecd6..e6bc8e4b17 100644 --- a/jolt-inlines/blake3/Cargo.toml +++ b/jolt-inlines/blake3/Cargo.toml @@ -17,4 +17,3 @@ jolt-inlines-sdk = { workspace = true, optional = true } [dev-dependencies] blake3.workspace = true rand.workspace = true -tracer = { workspace = true, features = ["std", "test-utils"] } diff --git a/jolt-inlines/blake3/src/exec.rs b/jolt-inlines/blake3/src/exec.rs deleted file mode 100644 index 492fc888ce..0000000000 --- a/jolt-inlines/blake3/src/exec.rs +++ /dev/null @@ -1,75 +0,0 @@ -use crate::IV; - -/// Rust implementation of BLAKE3 compression on the host. -/// The following code is obtained from reference BLAKE3 implementation (https://github.com/BLAKE3-team/BLAKE3/blob/master/reference_impl/reference_impl.rs) -pub fn execute_blake3_compression( - chaining_value: &mut [u32; 8], - block_words: &[u32; 16], - counter: &[u32; 2], - block_len: u32, - flags: u32, -) { - #[rustfmt::skip] - let mut state = [ - chaining_value[0], chaining_value[1], chaining_value[2], chaining_value[3], - chaining_value[4], chaining_value[5], chaining_value[6], chaining_value[7], - IV[0], IV[1], IV[2], IV[3], - counter[0], counter[1], block_len, flags, - ]; - let mut block = *block_words; - - round(&mut state, &block); // round 1 - permute(&mut block); - round(&mut state, &block); // round 2 - permute(&mut block); - round(&mut state, &block); // round 3 - permute(&mut block); - round(&mut state, &block); // round 4 - permute(&mut block); - round(&mut state, &block); // round 5 - permute(&mut block); - round(&mut state, &block); // round 6 - permute(&mut block); - round(&mut state, &block); // round 7 - - for i in 0..8 { - state[i] ^= state[i + 8]; - } - chaining_value.copy_from_slice(&state[..8]); -} - -/// The mixing function G, which mixes either a column or a diagonal in the state matrix. -/// This is the core operation of the BLAKE3 compression function. -fn g(state: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize, mx: u32, my: u32) { - state[a] = state[a].wrapping_add(state[b]).wrapping_add(mx); - state[d] = (state[d] ^ state[a]).rotate_right(16); - state[c] = state[c].wrapping_add(state[d]); - state[b] = (state[b] ^ state[c]).rotate_right(12); - state[a] = state[a].wrapping_add(state[b]).wrapping_add(my); - state[d] = (state[d] ^ state[a]).rotate_right(8); - state[c] = state[c].wrapping_add(state[d]); - state[b] = (state[b] ^ state[c]).rotate_right(7); -} - -fn round(state: &mut [u32; 16], m: &[u32; 16]) { - // Mix the columns. - g(state, 0, 4, 8, 12, m[0], m[1]); - g(state, 1, 5, 9, 13, m[2], m[3]); - g(state, 2, 6, 10, 14, m[4], m[5]); - g(state, 3, 7, 11, 15, m[6], m[7]); - // Mix the diagonals. - g(state, 0, 5, 10, 15, m[8], m[9]); - g(state, 1, 6, 11, 12, m[10], m[11]); - g(state, 2, 7, 8, 13, m[12], m[13]); - g(state, 3, 4, 9, 14, m[14], m[15]); -} - -const MSG_PERMUTATION: [usize; 16] = [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8]; - -fn permute(m: &mut [u32; 16]) { - let mut permuted = [0; 16]; - for i in 0..16 { - permuted[i] = m[MSG_PERMUTATION[i]]; - } - *m = permuted; -} diff --git a/jolt-inlines/blake3/src/lib.rs b/jolt-inlines/blake3/src/lib.rs index 005e16948f..9582d9f5a6 100644 --- a/jolt-inlines/blake3/src/lib.rs +++ b/jolt-inlines/blake3/src/lib.rs @@ -12,19 +12,16 @@ pub const BLAKE3_KEYED64_NAME: &str = "BLAKE3_KEYED64_INLINE"; pub mod sdk; pub use sdk::*; -#[cfg(feature = "host")] -pub mod exec; #[cfg(feature = "host")] pub mod sequence_builder; +#[cfg(feature = "host")] +pub mod spec; #[cfg(feature = "host")] mod host; #[cfg(feature = "host")] pub use host::*; -#[cfg(all(test, feature = "host"))] -pub mod test_utils; - /// BLAKE3 initialization vector (IV) #[rustfmt::skip] pub const IV: [u32; 8] = [ diff --git a/jolt-inlines/blake3/src/sdk.rs b/jolt-inlines/blake3/src/sdk.rs index 59e55a25c8..6683c20256 100644 --- a/jolt-inlines/blake3/src/sdk.rs +++ b/jolt-inlines/blake3/src/sdk.rs @@ -1,6 +1,4 @@ //! This file provides high-level API to use BLAKE3 compression, both in host and guest mode. -#[cfg(feature = "host")] -use crate::FLAG_KEYED_HASH; use crate::{ BLOCK_INPUT_SIZE_IN_BYTES, CHAINING_VALUE_LEN, COUNTER_LEN, FLAG_CHUNK_END, FLAG_CHUNK_START, FLAG_ROOT, IV, MSG_BLOCK_LEN, OUTPUT_SIZE_IN_BYTES, @@ -366,14 +364,18 @@ pub(crate) unsafe fn blake3_compress(chaining_value: *mut u32, message: *const u // Extract flags (next u32 at offset 19) let flags = *message.add(19); - // On the host, we call our reference implementation from the exec module. - crate::exec::execute_blake3_compression( - &mut *(chaining_value as *mut [u32; 8]), - message_block, - &counter_array, + use crate::sequence_builder::Blake3Compression; + use jolt_inlines_sdk::spec::InlineSpec; + + let input = ( + *(chaining_value as *const [u32; 8]), + *message_block, + counter_array, block_len, flags, ); + let result = Blake3Compression::reference(&input); + core::ptr::copy_nonoverlapping(result.as_ptr(), chaining_value, 8); } #[cfg(all( @@ -412,21 +414,16 @@ unsafe fn blake3_keyed64_compress(left: *const u32, right: *const u32, iv: *mut #[cfg(feature = "host")] #[inline(always)] unsafe fn blake3_keyed64_compress(left: *const u32, right: *const u32, key: *mut u32) { - // Concatenate left || right as message - let mut message = [0u32; 16]; - core::ptr::copy_nonoverlapping(left, message.as_mut_ptr(), 8); - core::ptr::copy_nonoverlapping(right, message.as_mut_ptr().add(8), 8); - - let key_arr = &mut *(key as *mut [u32; 8]); - - // flags = CHUNK_START | CHUNK_END | ROOT | KEYED_HASH - crate::exec::execute_blake3_compression( - key_arr, - &message, - &[0, 0], - 64, - FLAG_CHUNK_START | FLAG_CHUNK_END | FLAG_ROOT | FLAG_KEYED_HASH, + use crate::sequence_builder::Blake3Keyed64Compression; + use jolt_inlines_sdk::spec::InlineSpec; + + let input = ( + *(left as *const [u32; 8]), + *(right as *const [u32; 8]), + *(key as *const [u32; 8]), ); + let result = Blake3Keyed64Compression::reference(&input); + core::ptr::copy_nonoverlapping(result.as_ptr(), key, 8); } #[cfg(all( @@ -483,7 +480,22 @@ pub const BLAKE3_IV: AlignedHash32 = AlignedHash32([ #[cfg(feature = "host")] mod tests { use super::Blake3; - use crate::{test_utils::helpers::*, BLOCK_INPUT_SIZE_IN_BYTES}; + use crate::BLOCK_INPUT_SIZE_IN_BYTES; + + fn generate_random_bytes(len: usize) -> Vec { + use rand::rngs::StdRng; + use rand::{RngCore, SeedableRng}; + let mut buf = vec![0u8; len]; + let mut rng = StdRng::seed_from_u64(12345); + rng.fill_bytes(&mut buf); + buf + } + + fn compute_expected_result(input: &[u8]) -> [u8; crate::OUTPUT_SIZE_IN_BYTES] { + blake3::hash(input).as_bytes()[0..crate::OUTPUT_SIZE_IN_BYTES] + .try_into() + .unwrap() + } fn random_partition(data: &[u8]) -> Vec<&[u8]> { use rand::rngs::StdRng; diff --git a/jolt-inlines/blake3/src/sequence_builder.rs b/jolt-inlines/blake3/src/sequence_builder.rs index 9fa600f4a3..38663f8e73 100644 --- a/jolt-inlines/blake3/src/sequence_builder.rs +++ b/jolt-inlines/blake3/src/sequence_builder.rs @@ -15,12 +15,9 @@ use crate::{ }; use jolt_inlines_sdk::host::{ instruction::{ - ld::LD, lui::LUI, lw::LW, - srli::SRLI, virtual_xor_rotw::{VirtualXORROTW12, VirtualXORROTW16, VirtualXORROTW7, VirtualXORROTW8}, - virtual_zero_extend_word::VirtualZeroExtendWord, }, FormatInline, InlineOp, InstrAssembler, InstrAssemblerExt, Instruction, Value::{Imm, Reg}, @@ -57,6 +54,51 @@ where g(3, 4, 9, 14, msg_schedule_round[14], msg_schedule_round[15]); } +/// BLAKE3 quarter-round G function. +/// When `temp` is Some, uses a scratch register for the first add to avoid clobbering `va`. +/// When `temp` is None, performs in-place adds (saves one register). +#[allow(clippy::too_many_arguments)] +fn blake3_g( + asm: &mut InstrAssembler, + va: u8, + vb: u8, + vc: u8, + vd: u8, + mx: u8, + my: u8, + temp: Option, +) { + match temp { + Some(t) => { + asm.add(Reg(va), Reg(vb), t); + asm.add(Reg(t), Reg(mx), va); + } + None => { + asm.add(Reg(va), Reg(vb), va); + asm.add(Reg(va), Reg(mx), va); + } + } + + asm.emit_r::(vd, vd, va); + asm.add(Reg(vc), Reg(vd), vc); + asm.emit_r::(vb, vb, vc); + + match temp { + Some(t) => { + asm.add(Reg(va), Reg(vb), t); + asm.add(Reg(t), Reg(my), va); + } + None => { + asm.add(Reg(va), Reg(vb), va); + asm.add(Reg(va), Reg(my), va); + } + } + + asm.emit_r::(vd, vd, va); + asm.add(Reg(vc), Reg(vd), vc); + asm.emit_r::(vb, vb, vc); +} + /// Virtual register layout: /// - vr[0..15]: Internal state `v` /// - vr[16..31]: Message block `m` @@ -182,39 +224,16 @@ impl Blake3SequenceBuilder { } fn g_function(&mut self, a: usize, b: usize, c: usize, d: usize, x: usize, y: usize) { - let va = *self.vr[a]; - let vb = *self.vr[b]; - let vc = *self.vr[c]; - let vd = *self.vr[d]; - let mx = *self.vr[MSG_BLOCK_START_VR + x]; - let my = *self.vr[MSG_BLOCK_START_VR + y]; - let temp1 = *self.vr[TEMP_VR]; - - // v[a] = v[a] + v[b] + m[x] - self.asm.add(Reg(va), Reg(vb), temp1); - self.asm.add(Reg(temp1), Reg(mx), va); - - // v[d] = rotr32(v[d] ^ v[a], 16) - self.asm.emit_r::(vd, vd, va); - - // v[c] = v[c] + v[d] - self.asm.add(Reg(vc), Reg(vd), vc); - - // v[b] = rotr32(v[b] ^ v[c], 12) - self.asm.emit_r::(vb, vb, vc); - - // v[a] = v[a] + v[b] + m[y] - self.asm.add(Reg(va), Reg(vb), temp1); - self.asm.add(Reg(temp1), Reg(my), va); - - // v[d] = rotr32(v[d] ^ v[a], 8) - self.asm.emit_r::(vd, vd, va); - - // v[c] = v[c] + v[d] - self.asm.add(Reg(vc), Reg(vd), vc); - - // v[b] = rotr32(v[b] ^ v[c], 7) - self.asm.emit_r::(vb, vb, vc); + blake3_g( + &mut self.asm, + *self.vr[a], + *self.vr[b], + *self.vr[c], + *self.vr[d], + *self.vr[MSG_BLOCK_START_VR + x], + *self.vr[MSG_BLOCK_START_VR + y], + Some(*self.vr[TEMP_VR]), + ); } fn finalize_state(&mut self) { @@ -227,73 +246,34 @@ impl Blake3SequenceBuilder { } fn store_state(&mut self) { - for i in 0..CHAINING_VALUE_LEN / 2 { - self.asm.store_paired_u32( - self.operands.rs1, - (i * 2) as i64 * 4, - *self.vr[CV_START_VR + i * 2], - *self.vr[CV_START_VR + i * 2 + 1], - ); - } - } - - fn load_data_range_paired( - &mut self, - base_register: u8, - memory_offset_start: usize, - vr_start: usize, - count: usize, - ) { - debug_assert!( - count.is_multiple_of(2), - "count must be even for paired loading" - ); - let temp = *self.vr[TEMP_VR]; - for i in 0..count / 2 { - self.asm.load_paired_u32( - temp, - base_register, - (memory_offset_start + i * 2) as i64 * 4, - *self.vr[vr_start + i * 2], - *self.vr[vr_start + i * 2 + 1], - ); - } + let regs: Vec = (CV_START_VR..CV_START_VR + CHAINING_VALUE_LEN) + .map(|i| *self.vr[i]) + .collect(); + self.asm.store_u32_range_paired(self.operands.rs1, 0, ®s); } - /// Load data from memory into virtual registers (non-paired, used for counter) - fn load_data_range( - &mut self, - base_register: u8, - memory_offset_start: usize, - vr_start: usize, - count: usize, - ) { - (0..count).for_each(|i| { - self.asm.emit_ld::( - *self.vr[vr_start + i], - base_register, - (memory_offset_start + i) as i64 * 4, - ); - }); + fn vr_slice(&self, start: usize, count: usize) -> Vec { + (start..start + count).map(|i| *self.vr[i]).collect() } fn load_chaining_value(&mut self) { - // Use paired loading for chaining value (8 u32 = 4 pairs) - self.load_data_range_paired(self.operands.rs1, 0, CV_START_VR, CHAINING_VALUE_LEN); + let temp = *self.vr[TEMP_VR]; + let regs = self.vr_slice(CV_START_VR, CHAINING_VALUE_LEN); + self.asm + .load_u32_range_paired(temp, self.operands.rs1, 0, ®s); } fn load_message_blocks(&mut self) { - // Use paired loading for message blocks (16 u32 = 8 pairs) - self.load_data_range_paired(self.operands.rs2, 0, MSG_BLOCK_START_VR, MSG_BLOCK_LEN); + let temp = *self.vr[TEMP_VR]; + let regs = self.vr_slice(MSG_BLOCK_START_VR, MSG_BLOCK_LEN); + self.asm + .load_u32_range_paired(temp, self.operands.rs2, 0, ®s); } fn load_counter(&mut self) { - self.load_data_range( - self.operands.rs2, - MSG_BLOCK_LEN, - COUNTER_START_VR, - COUNTER_LEN, - ); + let regs = self.vr_slice(COUNTER_START_VR, COUNTER_LEN); + self.asm + .load_u32_range(self.operands.rs2, MSG_BLOCK_LEN * 4, ®s); } fn load_input_len_and_flags(&mut self) { @@ -324,12 +304,12 @@ impl Blake3Keyed64SequenceBuilder { } fn build(mut self) -> Vec { - // Load key from rs3/rd directly into v[0..7] - self.load_data_range_paired(self.operands.rs3, 0, INTERNAL_STATE_VR_START, 8); - // Load left (32 bytes) from rs1 as message[0..7] - self.load_data_range_paired(self.operands.rs1, 0, MSG_BLOCK_START_VR, 8); - // Load right (32 bytes) from rs2 as message[8..15] - self.load_data_range_paired(self.operands.rs2, 0, MSG_BLOCK_START_VR + 8, 8); + let key_regs = self.vr_slice(INTERNAL_STATE_VR_START, 8); + self.load_u32_range_paired_no_temp(self.operands.rs3, 0, &key_regs); + let msg_lo_regs = self.vr_slice(MSG_BLOCK_START_VR, 8); + self.load_u32_range_paired_no_temp(self.operands.rs1, 0, &msg_lo_regs); + let msg_hi_regs = self.vr_slice(MSG_BLOCK_START_VR + 8, 8); + self.load_u32_range_paired_no_temp(self.operands.rs2, 0, &msg_hi_regs); self.initialize_internal_state(); @@ -345,14 +325,9 @@ impl Blake3Keyed64SequenceBuilder { self.asm.xor(Reg(vi), Reg(vi8), vi); } - for i in 0..CHAINING_VALUE_LEN / 2 { - self.asm.store_paired_u32( - self.operands.rs3, - (i * 2) as i64 * 4, - *self.vr[INTERNAL_STATE_VR_START + i * 2], - *self.vr[INTERNAL_STATE_VR_START + i * 2 + 1], - ); - } + let out_regs = self.vr_slice(INTERNAL_STATE_VR_START, CHAINING_VALUE_LEN); + self.asm + .store_u32_range_paired(self.operands.rs3, 0, &out_regs); drop(self.vr); self.asm.finalize_inline() @@ -386,72 +361,35 @@ impl Blake3Keyed64SequenceBuilder { blake3_apply_round_schedule(round, |a, b, c, d, x, y| self.g_function(a, b, c, d, x, y)); } - #[inline] fn g_function(&mut self, a: usize, b: usize, c: usize, d: usize, x: usize, y: usize) { - let va = *self.vr[a]; - let vb = *self.vr[b]; - let vc = *self.vr[c]; - let vd = *self.vr[d]; - let mx = *self.vr[MSG_BLOCK_START_VR + x]; - let my = *self.vr[MSG_BLOCK_START_VR + y]; - - // v[a] = v[a] + v[b] + m[x] - self.asm.add(Reg(va), Reg(vb), va); - self.asm.add(Reg(va), Reg(mx), va); - - // v[d] = rotr32(v[d] ^ v[a], 16) - self.asm.emit_r::(vd, vd, va); - - // v[c] = v[c] + v[d] - self.asm.add(Reg(vc), Reg(vd), vc); - - // v[b] = rotr32(v[b] ^ v[c], 12) - self.asm.emit_r::(vb, vb, vc); - - // v[a] = v[a] + v[b] + m[y] - self.asm.add(Reg(va), Reg(vb), va); - self.asm.add(Reg(va), Reg(my), va); - - // v[d] = rotr32(v[d] ^ v[a], 8) - self.asm.emit_r::(vd, vd, va); - - // v[c] = v[c] + v[d] - self.asm.add(Reg(vc), Reg(vd), vc); - - // v[b] = rotr32(v[b] ^ v[c], 7) - self.asm.emit_r::(vb, vb, vc); + blake3_g( + &mut self.asm, + *self.vr[a], + *self.vr[b], + *self.vr[c], + *self.vr[d], + *self.vr[MSG_BLOCK_START_VR + x], + *self.vr[MSG_BLOCK_START_VR + y], + None, + ); } - /// Load two u32 values from an 8-byte aligned address using a single LD. - /// Uses `vr_hi` as the temporary 64-bit container (no extra scratch register). - fn load_paired_u32(&mut self, base: u8, offset: i64, vr_lo: u8, vr_hi: u8) { - // Load 64 bits (2 x u32) into vr_hi temporarily. - self.asm.emit_ld::(vr_hi, base, offset); - - // Extract low 32 bits: zero-extend word. - self.asm.emit_i::(vr_lo, vr_hi, 0); - - // Extract high 32 bits: shift right by 32 (in place). - self.asm.emit_i::(vr_hi, vr_hi, 32); + fn vr_slice(&self, start: usize, count: usize) -> Vec { + (start..start + count).map(|i| *self.vr[i]).collect() } - fn load_data_range_paired( - &mut self, - base_register: u8, - memory_offset_start: usize, - vr_start: usize, - count: usize, - ) { - debug_assert!( - count.is_multiple_of(2), - "count must be even for paired loading" - ); - for i in 0..count / 2 { - self.load_paired_u32( - base_register, - (memory_offset_start + i * 2) as i64 * 4, - *self.vr[vr_start + i * 2], - *self.vr[vr_start + i * 2 + 1], + /// Load paired u32 values without requiring a separate temp register. + /// Uses each pair's hi register as the 64-bit load target. + fn load_u32_range_paired_no_temp(&mut self, base: u8, byte_offset: usize, regs: &[u8]) { + debug_assert!(regs.len().is_multiple_of(2)); + for (i, pair) in regs.chunks_exact(2).enumerate() { + // Use vr_hi as temp: load_paired_u32(temp=hi, base, offset, lo, hi) + self.asm.load_paired_u32( + pair[1], + base, + (byte_offset + i * 8) as i64, + pair[0], + pair[1], ); } } @@ -464,6 +402,7 @@ impl InlineOp for Blake3Compression { const FUNCT3: u32 = crate::BLAKE3_FUNCT3; const FUNCT7: u32 = crate::BLAKE3_FUNCT7; const NAME: &'static str = crate::BLAKE3_NAME; + type Advice = (); fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { Blake3SequenceBuilder::new(asm, operands).build() @@ -477,6 +416,7 @@ impl InlineOp for Blake3Keyed64Compression { const FUNCT3: u32 = crate::BLAKE3_KEYED64_FUNCT3; const FUNCT7: u32 = crate::BLAKE3_FUNCT7; const NAME: &'static str = crate::BLAKE3_KEYED64_NAME; + type Advice = (); fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { Blake3Keyed64SequenceBuilder::new(asm, operands).build() @@ -485,35 +425,59 @@ impl InlineOp for Blake3Keyed64Compression { #[cfg(test)] mod tests { - use crate::test_utils::{ - create_blake3_harness, create_blake3_keyed64_harness, helpers::*, instruction, - keyed64_instruction, load_blake3_data, load_blake3_keyed64_data, read_output, - ChainingValue, MessageBlock, - }; + use jolt_inlines_sdk::spec::InlineSpec; + use rand::rngs::StdRng; + use rand::{RngCore, SeedableRng}; + + fn generate_random_bytes(len: usize) -> Vec { + let mut buf = vec![0u8; len]; + let mut rng = StdRng::seed_from_u64(12345); + rng.fill_bytes(&mut buf); + buf + } + + fn bytes_to_u32_vec(bytes: &[u8]) -> Vec { + bytes + .chunks_exact(4) + .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect() + } + + fn compute_expected_result(input: &[u8]) -> [u8; crate::OUTPUT_SIZE_IN_BYTES] { + blake3::hash(input).as_bytes()[0..crate::OUTPUT_SIZE_IN_BYTES] + .try_into() + .unwrap() + } + + fn compute_keyed_expected_result( + input: &[u8], + key: [u32; crate::CHAINING_VALUE_LEN], + ) -> [u8; crate::OUTPUT_SIZE_IN_BYTES] { + let mut key_bytes = [0u8; 32]; + for (i, word) in key.iter().enumerate() { + key_bytes[i * 4..(i + 1) * 4].copy_from_slice(&word.to_le_bytes()); + } + blake3::keyed_hash(&key_bytes, input).as_bytes()[0..crate::OUTPUT_SIZE_IN_BYTES] + .try_into() + .unwrap() + } fn generate_trace_result( - chaining_value: &ChainingValue, - message: &MessageBlock, + chaining_value: &[u32; crate::CHAINING_VALUE_LEN], + message: &[u32; crate::MSG_BLOCK_LEN], counter: &[u32; 2], block_len: u32, flags: u32, ) -> [u8; crate::OUTPUT_SIZE_IN_BYTES] { - let mut harness = create_blake3_harness(); - load_blake3_data( - &mut harness, - chaining_value, - message, - counter, - block_len, - flags, - ); - harness.execute_inline(instruction()); - let words = read_output(&mut harness); + let input = (*chaining_value, *message, *counter, block_len, flags); + let mut harness = super::Blake3Compression::create_harness(); + super::Blake3Compression::load(&mut harness, &input); + harness.execute_inline(super::Blake3Compression::instruction()); + let words = super::Blake3Compression::read(&mut harness); let mut bytes = [0u8; crate::OUTPUT_SIZE_IN_BYTES]; for (i, w) in words.iter().enumerate() { - let le = w.to_le_bytes(); - bytes[i * 4..(i + 1) * 4].copy_from_slice(&le); + bytes[i * 4..(i + 1) * 4].copy_from_slice(&w.to_le_bytes()); } bytes } @@ -522,7 +486,6 @@ mod tests { fn test_trace_result_equals_blake3_compress_reference() { for _ in 0..1000 { let message_bytes = generate_random_bytes(crate::MSG_BLOCK_LEN * 4); - // Convert bytes to message block (u32 words) assert_eq!( message_bytes.len(), crate::MSG_BLOCK_LEN * 4, @@ -549,21 +512,17 @@ mod tests { #[test] fn test_trace_result_equals_blake3_keyed_compress_reference() { for _ in 0..1000 { - // Generate random key let key_bytes = generate_random_bytes(crate::CHAINING_VALUE_LEN * 4); let mut key = [0u32; crate::CHAINING_VALUE_LEN]; key.copy_from_slice(&bytes_to_u32_vec(&key_bytes)); - // Generate random message let message_bytes = generate_random_bytes(crate::MSG_BLOCK_LEN * 4); let words_vec = bytes_to_u32_vec(&message_bytes); let mut message_words = [0u32; crate::MSG_BLOCK_LEN]; message_words.copy_from_slice(&words_vec); - // Compute expected result using keyed hash let expected_hash_bytes = compute_keyed_expected_result(&message_bytes, key); - // Generate trace result with keyed hash flag let counter = [0u32, 0u32]; let block_len = 64u32; let flags = crate::FLAG_CHUNK_START @@ -582,14 +541,12 @@ mod tests { #[test] fn test_trace_keyed64_matches_blake3_keyed_hash() { - // Test that sequence builder's Keyed64 mode matches blake3::keyed_hash for 64-byte input use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; let mut rng = StdRng::seed_from_u64(88888); for _ in 0..100 { - // Generate random left, right, and key let mut left = [0u32; crate::CHAINING_VALUE_LEN]; let mut right = [0u32; crate::CHAINING_VALUE_LEN]; let mut key = [0u32; crate::CHAINING_VALUE_LEN]; @@ -599,20 +556,17 @@ mod tests { key[i] = rng.gen(); } - // Execute sequence builder with key as IV - let mut harness = create_blake3_keyed64_harness(); - load_blake3_keyed64_data(&mut harness, &left, &right, &key); - harness.execute_inline(keyed64_instruction()); - let result_words = read_output(&mut harness); + let input = (left, right, key); + let mut harness = super::Blake3Keyed64Compression::create_harness(); + super::Blake3Keyed64Compression::load(&mut harness, &input); + harness.execute_inline(super::Blake3Keyed64Compression::instruction()); + let result_words = super::Blake3Keyed64Compression::read(&mut harness); - // Convert result to bytes let mut result_bytes = [0u8; 32]; for (i, w) in result_words.iter().enumerate() { - let le = w.to_le_bytes(); - result_bytes[i * 4..(i + 1) * 4].copy_from_slice(&le); + result_bytes[i * 4..(i + 1) * 4].copy_from_slice(&w.to_le_bytes()); } - // Convert left/right/key to bytes for blake3 reference let mut left_bytes = [0u8; 32]; let mut right_bytes = [0u8; 32]; let mut key_bytes = [0u8; 32]; @@ -626,13 +580,11 @@ mod tests { key_bytes[i * 4..(i + 1) * 4].copy_from_slice(&w.to_le_bytes()); } - // Concatenate left || right as 64-byte input - let mut input = [0u8; 64]; - input[..32].copy_from_slice(&left_bytes); - input[32..].copy_from_slice(&right_bytes); + let mut input_bytes = [0u8; 64]; + input_bytes[..32].copy_from_slice(&left_bytes); + input_bytes[32..].copy_from_slice(&right_bytes); - // Compute expected using official blake3::keyed_hash - let expected = blake3::keyed_hash(&key_bytes, &input); + let expected = blake3::keyed_hash(&key_bytes, &input_bytes); assert_eq!( result_bytes, diff --git a/jolt-inlines/blake3/src/spec.rs b/jolt-inlines/blake3/src/spec.rs new file mode 100644 index 0000000000..6c38ece689 --- /dev/null +++ b/jolt-inlines/blake3/src/spec.rs @@ -0,0 +1,166 @@ +use crate::sequence_builder::{Blake3Compression, Blake3Keyed64Compression}; +use crate::{CHAINING_VALUE_LEN, FLAG_CHUNK_END, FLAG_CHUNK_START, FLAG_KEYED_HASH, FLAG_ROOT, IV}; +use jolt_inlines_sdk::host::Xlen; +use jolt_inlines_sdk::spec::rand::rngs::StdRng; +use jolt_inlines_sdk::spec::rand::Rng; +use jolt_inlines_sdk::spec::{InlineMemoryLayout, InlineSpec, InlineTestHarness}; + +const MSG_PERMUTATION: [usize; 16] = [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8]; + +fn g(state: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize, mx: u32, my: u32) { + state[a] = state[a].wrapping_add(state[b]).wrapping_add(mx); + state[d] = (state[d] ^ state[a]).rotate_right(16); + state[c] = state[c].wrapping_add(state[d]); + state[b] = (state[b] ^ state[c]).rotate_right(12); + state[a] = state[a].wrapping_add(state[b]).wrapping_add(my); + state[d] = (state[d] ^ state[a]).rotate_right(8); + state[c] = state[c].wrapping_add(state[d]); + state[b] = (state[b] ^ state[c]).rotate_right(7); +} + +fn round(state: &mut [u32; 16], m: &[u32; 16]) { + g(state, 0, 4, 8, 12, m[0], m[1]); + g(state, 1, 5, 9, 13, m[2], m[3]); + g(state, 2, 6, 10, 14, m[4], m[5]); + g(state, 3, 7, 11, 15, m[6], m[7]); + g(state, 0, 5, 10, 15, m[8], m[9]); + g(state, 1, 6, 11, 12, m[10], m[11]); + g(state, 2, 7, 8, 13, m[12], m[13]); + g(state, 3, 4, 9, 14, m[14], m[15]); +} + +fn permute(m: &mut [u32; 16]) { + let mut permuted = [0; 16]; + for i in 0..16 { + permuted[i] = m[MSG_PERMUTATION[i]]; + } + *m = permuted; +} + +/// Reference BLAKE3 compression function. +/// See https://github.com/BLAKE3-team/BLAKE3/blob/master/reference_impl/reference_impl.rs +fn compress( + chaining_value: &[u32; 8], + block_words: &[u32; 16], + counter: &[u32; 2], + block_len: u32, + flags: u32, +) -> [u32; 8] { + #[rustfmt::skip] + let mut state = [ + chaining_value[0], chaining_value[1], chaining_value[2], chaining_value[3], + chaining_value[4], chaining_value[5], chaining_value[6], chaining_value[7], + IV[0], IV[1], IV[2], IV[3], + counter[0], counter[1], block_len, flags, + ]; + let mut block = *block_words; + + round(&mut state, &block); // round 1 + permute(&mut block); + round(&mut state, &block); // round 2 + permute(&mut block); + round(&mut state, &block); // round 3 + permute(&mut block); + round(&mut state, &block); // round 4 + permute(&mut block); + round(&mut state, &block); // round 5 + permute(&mut block); + round(&mut state, &block); // round 6 + permute(&mut block); + round(&mut state, &block); // round 7 + + let mut result = [0u32; 8]; + for i in 0..8 { + result[i] = state[i] ^ state[i + 8]; + } + result +} + +impl InlineSpec for Blake3Compression { + type Input = ([u32; 8], [u32; 16], [u32; 2], u32, u32); + type Output = [u32; CHAINING_VALUE_LEN]; + + fn random_input(rng: &mut StdRng) -> Self::Input { + ( + core::array::from_fn(|_| rng.gen()), + core::array::from_fn(|_| rng.gen()), + core::array::from_fn(|_| rng.gen()), + rng.gen(), + rng.gen(), + ) + } + + fn reference(input: &Self::Input) -> Self::Output { + compress(&input.0, &input.1, &input.2, input.3, input.4) + } + + fn create_harness() -> InlineTestHarness { + let layout = InlineMemoryLayout::single_input(80, 32); + InlineTestHarness::new(layout, Xlen::Bit64) + } + + fn load(harness: &mut InlineTestHarness, input: &Self::Input) { + harness.setup_registers(); + harness.load_state32(&input.0); + + let mut combined = Vec::with_capacity(20); + combined.extend_from_slice(&input.1); + combined.extend_from_slice(&input.2); + combined.push(input.3); + combined.push(input.4); + harness.load_input32(&combined); + } + + fn read(harness: &mut InlineTestHarness) -> Self::Output { + harness + .read_output32(CHAINING_VALUE_LEN) + .try_into() + .unwrap() + } +} + +impl InlineSpec for Blake3Keyed64Compression { + type Input = ([u32; 8], [u32; 8], [u32; 8]); + type Output = [u32; CHAINING_VALUE_LEN]; + + fn random_input(rng: &mut StdRng) -> Self::Input { + ( + core::array::from_fn(|_| rng.gen()), + core::array::from_fn(|_| rng.gen()), + core::array::from_fn(|_| rng.gen()), + ) + } + + fn reference(input: &Self::Input) -> Self::Output { + let mut message = [0u32; 16]; + message[..8].copy_from_slice(&input.0); + message[8..].copy_from_slice(&input.1); + + compress( + &input.2, + &message, + &[0, 0], + 64, + FLAG_CHUNK_START | FLAG_CHUNK_END | FLAG_ROOT | FLAG_KEYED_HASH, + ) + } + + fn create_harness() -> InlineTestHarness { + let layout = InlineMemoryLayout::two_inputs(32, 32, 32); + InlineTestHarness::new(layout, Xlen::Bit64) + } + + fn load(harness: &mut InlineTestHarness, input: &Self::Input) { + harness.setup_registers(); + harness.load_input32(&input.0); + harness.load_input2_32(&input.1); + harness.load_state32(&input.2); + } + + fn read(harness: &mut InlineTestHarness) -> Self::Output { + harness + .read_output32(CHAINING_VALUE_LEN) + .try_into() + .unwrap() + } +} diff --git a/jolt-inlines/blake3/src/test_utils.rs b/jolt-inlines/blake3/src/test_utils.rs deleted file mode 100644 index 8d7de9b7da..0000000000 --- a/jolt-inlines/blake3/src/test_utils.rs +++ /dev/null @@ -1,122 +0,0 @@ -use crate::{BLAKE3_FUNCT3, BLAKE3_FUNCT7, BLAKE3_KEYED64_FUNCT3, INLINE_OPCODE}; -use tracer::emulator::cpu::Xlen; -use tracer::utils::inline_test_harness::{InlineMemoryLayout, InlineTestHarness}; - -pub type ChainingValue = [u32; crate::CHAINING_VALUE_LEN]; -pub type MessageBlock = [u32; crate::MSG_BLOCK_LEN]; - -pub fn create_blake3_harness() -> InlineTestHarness { - // Blake3 needs message block (64 bytes) + params (16 bytes) contiguous at rs2 - // and state (32 bytes) at rs1 - let layout = InlineMemoryLayout::single_input(80, 32); // 80 bytes for message+params, 32-byte state - InlineTestHarness::new(layout, Xlen::Bit64) -} - -/// Create harness for Keyed64 instruction (Merkle tree merge) -/// ABI: rs1 = left, rs2 = right, rd = iv (in/out) -pub fn create_blake3_keyed64_harness() -> InlineTestHarness { - // Keyed64 needs: - // - rs1: left CV (32 bytes) -> input - // - rs2: right CV (32 bytes) -> input2 - // - rd: IV (32 bytes, in/out) -> output - let layout = InlineMemoryLayout::two_inputs(32, 32, 32); // left, right, iv - InlineTestHarness::new(layout, Xlen::Bit64) -} - -pub fn load_blake3_keyed64_data( - harness: &mut InlineTestHarness, - left: &ChainingValue, - right: &ChainingValue, - iv: &ChainingValue, -) { - harness.setup_registers(); - // Load left to rs1 location (input) - harness.load_input32(left); - // Load right to rs2 location (input2) - harness.load_input2_32(right); - // Load IV to rd/rs3 location (output) - harness.load_state32(iv); -} - -pub fn keyed64_instruction() -> tracer::instruction::inline::INLINE { - InlineTestHarness::create_default_instruction( - INLINE_OPCODE, - BLAKE3_KEYED64_FUNCT3, - BLAKE3_FUNCT7, - ) -} - -pub fn load_blake3_data( - harness: &mut InlineTestHarness, - chaining_value: &ChainingValue, - message: &MessageBlock, - counter: &[u32; 2], - block_len: u32, - flags: u32, -) { - harness.setup_registers(); - // Load chaining value to output location (rs1 points here) - harness.load_state32(chaining_value); - - // Blake3 expects message + parameters contiguously at rs2 - // Create combined input: message (16 u32s) + counter (2 u32s) + block_len (1 u32) + flags (1 u32) - let mut combined_input = Vec::with_capacity(20); - combined_input.extend_from_slice(message); - combined_input.extend_from_slice(counter); - combined_input.push(block_len); - combined_input.push(flags); - - // Load the combined input - harness.load_input32(&combined_input); -} - -pub fn read_output(harness: &mut InlineTestHarness) -> ChainingValue { - let vec = harness.read_output32(crate::CHAINING_VALUE_LEN); - let mut output = [0u32; crate::CHAINING_VALUE_LEN]; - output.copy_from_slice(&vec); - output -} - -pub fn instruction() -> tracer::instruction::inline::INLINE { - InlineTestHarness::create_default_instruction(INLINE_OPCODE, BLAKE3_FUNCT3, BLAKE3_FUNCT7) -} - -#[cfg(test)] -pub mod helpers { - pub fn generate_random_bytes(len: usize) -> Vec { - use rand::rngs::StdRng; - use rand::{RngCore, SeedableRng}; - - let mut buf = vec![0u8; len]; - // Use a fixed seed for deterministic test results - let mut rng = StdRng::seed_from_u64(12345); - rng.fill_bytes(&mut buf); - buf - } - - pub fn bytes_to_u32_vec(bytes: &[u8]) -> Vec { - bytes - .chunks_exact(4) - .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) - .collect() - } - - pub fn compute_expected_result(input: &[u8]) -> [u8; crate::OUTPUT_SIZE_IN_BYTES] { - blake3::hash(input).as_bytes()[0..crate::OUTPUT_SIZE_IN_BYTES] - .try_into() - .unwrap() - } - - pub fn compute_keyed_expected_result( - input: &[u8], - key: [u32; crate::CHAINING_VALUE_LEN], - ) -> [u8; crate::OUTPUT_SIZE_IN_BYTES] { - let mut key_bytes = [0u8; 32]; - for (i, word) in key.iter().enumerate() { - key_bytes[i * 4..(i + 1) * 4].copy_from_slice(&word.to_le_bytes()); - } - blake3::keyed_hash(&key_bytes, input).as_bytes()[0..crate::OUTPUT_SIZE_IN_BYTES] - .try_into() - .unwrap() - } -} diff --git a/jolt-inlines/grumpkin/src/sdk.rs b/jolt-inlines/grumpkin/src/sdk.rs index 8fcf84b697..5516fb8ea4 100644 --- a/jolt-inlines/grumpkin/src/sdk.rs +++ b/jolt-inlines/grumpkin/src/sdk.rs @@ -2,11 +2,10 @@ use ark_ff::{AdditiveGroup, BigInt, Field, PrimeField, Zero}; use ark_grumpkin::{Fq, Fr}; +use core::marker::PhantomData; use serde::{Deserialize, Serialize}; -/// Returns `true` iff `x >= p` (Fq modulus), i.e., `x` is non-canonical. -/// Manually unrolled for performance. #[cfg(all( not(feature = "host"), any(target_arch = "riscv32", target_arch = "riscv64") @@ -30,8 +29,6 @@ fn is_fq_non_canonical(x: &[u64; 4]) -> bool { } } -/// Returns `true` iff `x >= n` (Fr modulus), i.e., `x` is non-canonical. -/// Manually unrolled for performance. #[cfg(all( not(feature = "host"), any(target_arch = "riscv32", target_arch = "riscv64") @@ -74,182 +71,422 @@ pub enum GrumpkinError { NotOnCurve, } -/// Wrapper around ark_grumpkin::Fq with inline-accelerated division -#[derive(Clone, PartialEq, Debug)] -pub struct GrumpkinFq { - e: ark_grumpkin::Fq, +/// Both `Fq` and `Fr` are `Fp`, so we can access `.0` (BigInt<4>) +/// and `.0.0` ([u64; 4]) uniformly. These helpers bridge the concrete `Fp` layout +/// through the `PrimeField` trait bound. +pub trait GrumpkinFieldConfig: 'static { + type ArkField: PrimeField + Field + Copy; + const DIV_FUNCT3: u32; + + fn from_bigint(repr: BigInt<4>) -> Option; + fn new_unchecked(repr: BigInt<4>) -> Self::ArkField; + /// Borrow the inner Montgomery limbs without copying. + fn limbs(e: &Self::ArkField) -> &[u64; 4]; + fn limbs_mut(e: &mut Self::ArkField) -> &mut [u64; 4]; + + #[cfg(all( + not(feature = "host"), + any(target_arch = "riscv32", target_arch = "riscv64") + ))] + fn is_non_canonical(limbs: &[u64; 4]) -> bool; + + fn invalid_element_error() -> GrumpkinError; } -impl GrumpkinFq { +pub enum GrumpkinFqConfig {} + +impl GrumpkinFieldConfig for GrumpkinFqConfig { + type ArkField = Fq; + const DIV_FUNCT3: u32 = crate::GRUMPKIN_DIVQ_ADV_FUNCT3; + #[inline(always)] - pub fn new(e: Fq) -> Self { - GrumpkinFq { e } + fn from_bigint(repr: BigInt<4>) -> Option { + Fq::from_bigint(repr) } - /// Converts from standard form to Montgomery. Returns error if >= modulus. #[inline(always)] - pub fn from_u64_arr(arr: &[u64; 4]) -> Result { - Fq::from_bigint(BigInt(*arr)) - .map(|e| GrumpkinFq { e }) - .ok_or(GrumpkinError::InvalidFqElement) + fn new_unchecked(repr: BigInt<4>) -> Fq { + Fq::new_unchecked(repr) } - /// SAFETY: input must be in canonical Montgomery form #[inline(always)] - pub fn from_u64_arr_unchecked(arr: &[u64; 4]) -> Self { - GrumpkinFq { - e: Fq::new_unchecked(BigInt(*arr)), + fn limbs(e: &Fq) -> &[u64; 4] { + &e.0 .0 + } + #[inline(always)] + fn limbs_mut(e: &mut Fq) -> &mut [u64; 4] { + &mut e.0 .0 + } + + #[cfg(all( + not(feature = "host"), + any(target_arch = "riscv32", target_arch = "riscv64") + ))] + #[inline(always)] + fn is_non_canonical(limbs: &[u64; 4]) -> bool { + is_fq_non_canonical(limbs) + } + + #[inline(always)] + fn invalid_element_error() -> GrumpkinError { + GrumpkinError::InvalidFqElement + } +} + +pub enum GrumpkinFrConfig {} + +impl GrumpkinFieldConfig for GrumpkinFrConfig { + type ArkField = Fr; + const DIV_FUNCT3: u32 = crate::GRUMPKIN_DIVR_ADV_FUNCT3; + + #[inline(always)] + fn from_bigint(repr: BigInt<4>) -> Option { + Fr::from_bigint(repr) + } + #[inline(always)] + fn new_unchecked(repr: BigInt<4>) -> Fr { + Fr::new_unchecked(repr) + } + #[inline(always)] + fn limbs(e: &Fr) -> &[u64; 4] { + &e.0 .0 + } + #[inline(always)] + fn limbs_mut(e: &mut Fr) -> &mut [u64; 4] { + &mut e.0 .0 + } + + #[cfg(all( + not(feature = "host"), + any(target_arch = "riscv32", target_arch = "riscv64") + ))] + #[inline(always)] + fn is_non_canonical(limbs: &[u64; 4]) -> bool { + is_fr_non_canonical(limbs) + } + + #[inline(always)] + fn invalid_element_error() -> GrumpkinError { + GrumpkinError::InvalidFrElement + } +} + +pub struct GrumpkinField { + e: C::ArkField, + _phantom: PhantomData, +} + +pub type GrumpkinFq = GrumpkinField; +pub type GrumpkinFr = GrumpkinField; + +impl Clone for GrumpkinField { + #[inline(always)] + fn clone(&self) -> Self { + Self { + e: self.e, + _phantom: PhantomData, } } +} + +impl PartialEq for GrumpkinField { #[inline(always)] - pub fn fq(&self) -> Fq { - self.e + fn eq(&self, other: &Self) -> bool { + self.e == other.e + } +} + +impl core::fmt::Debug for GrumpkinField { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("GrumpkinField").field("e", &self.e).finish() + } +} + +impl GrumpkinField { + #[inline(always)] + pub fn from_u64_arr(arr: &[u64; 4]) -> Result { + C::from_bigint(BigInt(*arr)) + .map(|e| Self { + e, + _phantom: PhantomData, + }) + .ok_or(C::invalid_element_error()) } + + /// SAFETY: input must be in canonical Montgomery form #[inline(always)] - pub fn zero() -> Self { - GrumpkinFq { e: Fq::zero() } + pub fn from_u64_arr_unchecked(arr: &[u64; 4]) -> Self { + Self { + e: C::new_unchecked(BigInt(*arr)), + _phantom: PhantomData, + } } - /// Precomputed -17 for curve equation y² = x³ - 17 + #[inline(always)] - pub fn negative_seventeen() -> Self { - GrumpkinFq { - e: Fq::new_unchecked(BigInt([ - 0xdd7056026000005a, - 0x223fa97acb319311, - 0xcc388229877910c0, - 0x34394632b724eaa, - ])), + pub fn zero() -> Self { + Self { + e: C::ArkField::zero(), + _phantom: PhantomData, } } + #[inline(always)] pub fn is_zero(&self) -> bool { self.e.is_zero() } + #[inline(always)] pub fn neg(&self) -> Self { - GrumpkinFq { e: -self.e } + Self { + e: -self.e, + _phantom: PhantomData, + } } + #[inline(always)] - pub fn add(&self, other: &GrumpkinFq) -> Self { - GrumpkinFq { + pub fn add(&self, other: &Self) -> Self { + Self { e: self.e + other.e, + _phantom: PhantomData, } } + #[inline(always)] - pub fn sub(&self, other: &GrumpkinFq) -> Self { - GrumpkinFq { + pub fn sub(&self, other: &Self) -> Self { + Self { e: self.e - other.e, + _phantom: PhantomData, } } + #[inline(always)] pub fn dbl(&self) -> Self { - GrumpkinFq { e: self.e.double() } + Self { + e: self.e.double(), + _phantom: PhantomData, + } } + #[inline(always)] pub fn tpl(&self) -> Self { - GrumpkinFq { + Self { e: self.e.double() + self.e, + _phantom: PhantomData, } } + #[inline(always)] - pub fn mul(&self, other: &GrumpkinFq) -> Self { - GrumpkinFq { + pub fn mul(&self, other: &Self) -> Self { + Self { e: self.e * other.e, + _phantom: PhantomData, } } + #[inline(always)] pub fn square(&self) -> Self { - GrumpkinFq { e: self.e.square() } + Self { + e: self.e.square(), + _phantom: PhantomData, + } } + /// SAFETY: caller must ensure other != 0 #[cfg(all( not(feature = "host"), any(target_arch = "riscv32", target_arch = "riscv64") ))] #[inline(always)] - pub fn div_assume_nonzero(&self, other: &GrumpkinFq) -> Self { - let mut c = GrumpkinFq::zero(); + pub fn div_assume_nonzero(&self, other: &Self) -> Self { + let mut c = Self::zero(); unsafe { - use crate::{GRUMPKIN_DIVQ_ADV_FUNCT3, GRUMPKIN_FUNCT7, INLINE_OPCODE}; + use crate::{GRUMPKIN_FUNCT7, INLINE_OPCODE}; core::arch::asm!( ".insn r {opcode}, {funct3}, {funct7}, {rd}, {rs1}, {rs2}", opcode = const INLINE_OPCODE, - funct3 = const GRUMPKIN_DIVQ_ADV_FUNCT3, + funct3 = const C::DIV_FUNCT3, funct7 = const GRUMPKIN_FUNCT7, - rd = in(reg) c.e.0.0.as_mut_ptr(), - rs1 = in(reg) self.e.0.0.as_ptr(), - rs2 = in(reg) other.e.0.0.as_ptr(), + rd = in(reg) C::limbs_mut(&mut c.e).as_mut_ptr(), + rs1 = in(reg) C::limbs(&self.e).as_ptr(), + rs2 = in(reg) C::limbs(&other.e).as_ptr(), options(nostack) ); } let tmp = other.mul(&c); - // Verify advice: c must be canonical and other * c == self - if is_fq_non_canonical(&c.e.0 .0) || is_not_equal(&tmp.e.0 .0, &self.e.0 .0) { - spoil_proof(); // Spoils proof - assert_eq! alone doesn't suffice + if C::is_non_canonical(C::limbs(&c.e)) || is_not_equal(C::limbs(&tmp.e), C::limbs(&self.e)) + { + spoil_proof(); } c } + #[cfg(all( not(feature = "host"), any(target_arch = "riscv32", target_arch = "riscv64") ))] #[inline(always)] - pub fn div(&self, other: &GrumpkinFq) -> Self { + pub fn div(&self, other: &Self) -> Self { if other.is_zero() { spoil_proof(); } self.div_assume_nonzero(other) } + #[cfg(all( not(feature = "host"), not(any(target_arch = "riscv32", target_arch = "riscv64")) ))] - pub fn div_assume_nonzero(&self, _other: &GrumpkinFq) -> Self { - panic!("GrumpkinFq::div_assume_nonzero called on non-RISC-V target without host feature"); + pub fn div_assume_nonzero(&self, _other: &Self) -> Self { + panic!( + "GrumpkinField::div_assume_nonzero called on non-RISC-V target without host feature" + ); } + #[cfg(all( not(feature = "host"), not(any(target_arch = "riscv32", target_arch = "riscv64")) ))] - pub fn div(&self, _other: &GrumpkinFq) -> Self { - panic!("GrumpkinFq::div called on non-RISC-V target without host feature"); + pub fn div(&self, _other: &Self) -> Self { + panic!("GrumpkinField::div called on non-RISC-V target without host feature"); } + #[cfg(feature = "host")] #[inline(always)] - pub fn div_assume_nonzero(&self, other: &GrumpkinFq) -> Self { - GrumpkinFq { + pub fn div_assume_nonzero(&self, other: &Self) -> Self { + Self { e: self.e / other.e, + _phantom: PhantomData, } } + #[cfg(feature = "host")] #[inline(always)] - pub fn div(&self, other: &GrumpkinFq) -> Self { + pub fn div(&self, other: &Self) -> Self { if other.is_zero() { - panic!("division by zero in GrumpkinFq::div"); + panic!("division by zero in GrumpkinField::div"); } self.div_assume_nonzero(other) } } -impl ECField for GrumpkinFq { - type Error = GrumpkinError; +// Operator impls for &GrumpkinField + +impl core::ops::Add<&GrumpkinField> for &GrumpkinField { + type Output = GrumpkinField; #[inline(always)] - fn zero() -> Self { - Self::zero() + fn add(self, rhs: &GrumpkinField) -> GrumpkinField { + GrumpkinField { + e: self.e + rhs.e, + _phantom: PhantomData, + } } +} + +impl core::ops::Sub<&GrumpkinField> for &GrumpkinField { + type Output = GrumpkinField; #[inline(always)] - fn is_zero(&self) -> bool { - self.is_zero() + fn sub(self, rhs: &GrumpkinField) -> GrumpkinField { + GrumpkinField { + e: self.e - rhs.e, + _phantom: PhantomData, + } } +} + +impl core::ops::Mul<&GrumpkinField> for &GrumpkinField { + type Output = GrumpkinField; #[inline(always)] - fn neg(&self) -> Self { - self.neg() + fn mul(self, rhs: &GrumpkinField) -> GrumpkinField { + GrumpkinField { + e: self.e * rhs.e, + _phantom: PhantomData, + } } +} + +impl core::ops::Neg for &GrumpkinField { + type Output = GrumpkinField; #[inline(always)] - fn add(&self, other: &Self) -> Self { - self.add(other) + fn neg(self) -> GrumpkinField { + GrumpkinField { + e: -self.e, + _phantom: PhantomData, + } } +} + +// Operator impls for GrumpkinField (delegate to &self) + +impl core::ops::Add<&GrumpkinField> for GrumpkinField { + type Output = GrumpkinField; #[inline(always)] - fn sub(&self, other: &Self) -> Self { - self.sub(other) + fn add(self, rhs: &GrumpkinField) -> GrumpkinField { + &self + rhs + } +} + +impl core::ops::Sub<&GrumpkinField> for GrumpkinField { + type Output = GrumpkinField; + #[inline(always)] + fn sub(self, rhs: &GrumpkinField) -> GrumpkinField { + &self - rhs + } +} + +impl core::ops::Mul<&GrumpkinField> for GrumpkinField { + type Output = GrumpkinField; + #[inline(always)] + fn mul(self, rhs: &GrumpkinField) -> GrumpkinField { + &self * rhs + } +} + +impl core::ops::Neg for GrumpkinField { + type Output = GrumpkinField; + #[inline(always)] + fn neg(self) -> GrumpkinField { + -&self + } +} + +// Fq-specific methods + +impl GrumpkinField { + #[inline(always)] + pub fn new(e: Fq) -> Self { + Self { + e, + _phantom: PhantomData, + } + } + + #[inline(always)] + pub fn fq(&self) -> Fq { + self.e + } + + /// Precomputed -17 for curve equation y^2 = x^3 - 17 + #[inline(always)] + pub fn negative_seventeen() -> Self { + Self { + e: Fq::new_unchecked(BigInt([ + 0xdd7056026000005a, + 0x223fa97acb319311, + 0xcc388229877910c0, + 0x34394632b724eaa, + ])), + _phantom: PhantomData, + } + } +} + +impl ECField for GrumpkinFq { + type Error = GrumpkinError; + #[inline(always)] + fn zero() -> Self { + Self::zero() + } + #[inline(always)] + fn is_zero(&self) -> bool { + self.is_zero() } #[inline(always)] fn dbl(&self) -> Self { @@ -260,10 +497,6 @@ impl ECField for GrumpkinFq { self.tpl() } #[inline(always)] - fn mul(&self, other: &Self) -> Self { - self.mul(other) - } - #[inline(always)] fn square(&self) -> Self { self.square() } @@ -277,7 +510,7 @@ impl ECField for GrumpkinFq { } #[inline(always)] fn to_u64_arr(&self) -> [u64; 4] { - self.e.0 .0 + *GrumpkinFqConfig::limbs(&self.e) } #[inline(always)] fn from_u64_arr(arr: &[u64; 4]) -> Result { @@ -289,139 +522,25 @@ impl ECField for GrumpkinFq { } } -/// Wrapper around ark_grumpkin::Fr with inline-accelerated division -#[derive(Clone, PartialEq, Debug)] -pub struct GrumpkinFr { - e: ark_grumpkin::Fr, -} +// Fr-specific methods -impl GrumpkinFr { +impl GrumpkinField { #[inline(always)] pub fn new(e: Fr) -> Self { - GrumpkinFr { e } - } - /// Converts from standard form to Montgomery. Returns error if >= modulus. - #[inline(always)] - pub fn from_u64_arr(arr: &[u64; 4]) -> Result { - Fr::from_bigint(BigInt(*arr)) - .map(|e| GrumpkinFr { e }) - .ok_or(GrumpkinError::InvalidFrElement) - } - /// SAFETY: input must be in canonical Montgomery form - #[inline(always)] - pub fn from_u64_arr_unchecked(arr: &[u64; 4]) -> Self { - GrumpkinFr { - e: Fr::new_unchecked(BigInt(*arr)), + Self { + e, + _phantom: PhantomData, } } + #[inline(always)] pub fn fr(&self) -> Fr { self.e } - #[inline(always)] - pub fn zero() -> Self { - GrumpkinFr { e: Fr::zero() } - } - #[inline(always)] - pub fn is_zero(&self) -> bool { - self.e.is_zero() - } - #[inline(always)] - pub fn neg(&self) -> Self { - GrumpkinFr { e: -self.e } - } - #[inline(always)] - pub fn add(&self, other: &GrumpkinFr) -> Self { - GrumpkinFr { - e: self.e + other.e, - } - } - #[inline(always)] - pub fn sub(&self, other: &GrumpkinFr) -> Self { - GrumpkinFr { - e: self.e - other.e, - } - } - #[inline(always)] - pub fn mul(&self, other: &GrumpkinFr) -> Self { - GrumpkinFr { - e: self.e * other.e, - } - } - #[inline(always)] - pub fn square(&self) -> Self { - GrumpkinFr { e: self.e.square() } - } - /// SAFETY: caller must ensure other != 0 - #[cfg(all( - not(feature = "host"), - any(target_arch = "riscv32", target_arch = "riscv64") - ))] - #[inline(always)] - pub fn div_assume_nonzero(&self, other: &GrumpkinFr) -> Self { - let mut c = GrumpkinFr::zero(); - unsafe { - use crate::{GRUMPKIN_DIVR_ADV_FUNCT3, GRUMPKIN_FUNCT7, INLINE_OPCODE}; - core::arch::asm!( - ".insn r {opcode}, {funct3}, {funct7}, {rd}, {rs1}, {rs2}", - opcode = const INLINE_OPCODE, - funct3 = const GRUMPKIN_DIVR_ADV_FUNCT3, - funct7 = const GRUMPKIN_FUNCT7, - rd = in(reg) c.e.0.0.as_mut_ptr(), - rs1 = in(reg) self.e.0.0.as_ptr(), - rs2 = in(reg) other.e.0.0.as_ptr(), - options(nostack) - ); - } - let tmp = other.mul(&c); - // Verify advice: c must be canonical and other * c == self - if is_fr_non_canonical(&c.e.0 .0) || is_not_equal(&tmp.e.0 .0, &self.e.0 .0) { - spoil_proof(); // Spoils proof - assert_eq! alone doesn't suffice - } - c - } - #[cfg(all( - not(feature = "host"), - any(target_arch = "riscv32", target_arch = "riscv64") - ))] - #[inline(always)] - pub fn div(&self, other: &GrumpkinFr) -> Self { - if other.is_zero() { - spoil_proof(); - } - self.div_assume_nonzero(other) - } - #[cfg(all( - not(feature = "host"), - not(any(target_arch = "riscv32", target_arch = "riscv64")) - ))] - pub fn div_assume_nonzero(&self, _other: &GrumpkinFr) -> Self { - panic!("GrumpkinFr::div_assume_nonzero called on non-RISC-V target without host feature"); - } - #[cfg(all( - not(feature = "host"), - not(any(target_arch = "riscv32", target_arch = "riscv64")) - ))] - pub fn div(&self, _other: &GrumpkinFr) -> Self { - panic!("GrumpkinFr::div called on non-RISC-V target without host feature"); - } - #[cfg(feature = "host")] - #[inline(always)] - pub fn div_assume_nonzero(&self, other: &GrumpkinFr) -> Self { - GrumpkinFr { - e: self.e / other.e, - } - } - #[cfg(feature = "host")] - #[inline(always)] - pub fn div(&self, other: &GrumpkinFr) -> Self { - if other.is_zero() { - panic!("division by zero in GrumpkinFr::div"); - } - self.div_assume_nonzero(other) - } } +// Curve definition + #[derive(Clone)] pub struct GrumpkinCurve; diff --git a/jolt-inlines/grumpkin/src/sequence_builder.rs b/jolt-inlines/grumpkin/src/sequence_builder.rs index b44002d06f..8628122cdc 100644 --- a/jolt-inlines/grumpkin/src/sequence_builder.rs +++ b/jolt-inlines/grumpkin/src/sequence_builder.rs @@ -61,8 +61,7 @@ impl GrumpkinDivAdv { .to_vec(), ) } - // inline sequence function - fn inline_sequence(mut self) -> Vec { + fn build(mut self) -> Vec { for i in 0..4 { self.asm.emit_j::(*self.vr, 0); self.asm @@ -80,17 +79,14 @@ impl InlineOp for GrumpkinDivQAdv { const FUNCT3: u32 = crate::GRUMPKIN_DIVQ_ADV_FUNCT3; const FUNCT7: u32 = crate::GRUMPKIN_FUNCT7; const NAME: &'static str = crate::GRUMPKIN_DIVQ_ADV_NAME; + type Advice = VecDeque; fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { - GrumpkinDivAdv::new(asm, operands, true).inline_sequence() + GrumpkinDivAdv::new(asm, operands, true).build() } - fn build_advice( - asm: InstrAssembler, - operands: FormatInline, - cpu: &mut Cpu, - ) -> Option> { - Some(GrumpkinDivAdv::new(asm, operands, true).advice(cpu)) + fn build_advice(asm: InstrAssembler, operands: FormatInline, cpu: &mut Cpu) -> VecDeque { + GrumpkinDivAdv::new(asm, operands, true).advice(cpu) } } @@ -101,16 +97,13 @@ impl InlineOp for GrumpkinDivRAdv { const FUNCT3: u32 = crate::GRUMPKIN_DIVR_ADV_FUNCT3; const FUNCT7: u32 = crate::GRUMPKIN_FUNCT7; const NAME: &'static str = crate::GRUMPKIN_DIVR_ADV_NAME; + type Advice = VecDeque; fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { - GrumpkinDivAdv::new(asm, operands, false).inline_sequence() + GrumpkinDivAdv::new(asm, operands, false).build() } - fn build_advice( - asm: InstrAssembler, - operands: FormatInline, - cpu: &mut Cpu, - ) -> Option> { - Some(GrumpkinDivAdv::new(asm, operands, false).advice(cpu)) + fn build_advice(asm: InstrAssembler, operands: FormatInline, cpu: &mut Cpu) -> VecDeque { + GrumpkinDivAdv::new(asm, operands, false).advice(cpu) } } diff --git a/jolt-inlines/keccak256/src/lib.rs b/jolt-inlines/keccak256/src/lib.rs index 1a8c86dc87..0a305b761b 100644 --- a/jolt-inlines/keccak256/src/lib.rs +++ b/jolt-inlines/keccak256/src/lib.rs @@ -14,10 +14,10 @@ pub type Keccak256State = [u64; NUM_LANES]; pub mod sdk; pub use sdk::*; -#[cfg(feature = "host")] -pub mod exec; #[cfg(feature = "host")] pub mod sequence_builder; +#[cfg(feature = "host")] +pub mod spec; #[cfg(feature = "host")] mod host; diff --git a/jolt-inlines/keccak256/src/sdk.rs b/jolt-inlines/keccak256/src/sdk.rs index 6e15ee93c3..cf1cc4b5d5 100644 --- a/jolt-inlines/keccak256/src/sdk.rs +++ b/jolt-inlines/keccak256/src/sdk.rs @@ -344,13 +344,11 @@ pub(crate) unsafe fn keccak_f(state: *mut u64) { /// * Passing an invalid pointer, misaligned pointer, or insufficiently sized /// memory region results in undefined behaviour. pub(crate) unsafe fn keccak_f(state: *mut u64) { - // On the host, we call our own reference implementation from the tracer crate. - let state_slice = core::slice::from_raw_parts_mut(state, 25); - crate::exec::execute_keccak_f( - state_slice - .try_into() - .expect("State slice was not 25 words"), - ); + use crate::sequence_builder::Keccak256Permutation; + use jolt_inlines_sdk::spec::InlineSpec; + let state_ref = &*(state as *const [u64; 25]); + let result = Keccak256Permutation::reference(state_ref); + core::ptr::copy_nonoverlapping(result.as_ptr(), state, 25); } #[cfg(all( diff --git a/jolt-inlines/keccak256/src/sequence_builder.rs b/jolt-inlines/keccak256/src/sequence_builder.rs index 1cd278cbcd..25f4f94ef8 100644 --- a/jolt-inlines/keccak256/src/sequence_builder.rs +++ b/jolt-inlines/keccak256/src/sequence_builder.rs @@ -70,30 +70,6 @@ struct Keccak256SequenceBuilder { operands: FormatInline, } -/// `Keccak256SequenceBuilder` is a helper struct for constructing the virtual instruction -/// sequence required to emulate the Keccak-256 hashing operation within the RISC-V -/// instruction set. This builder is responsible for generating the correct sequence of -/// `Instruction` instances that together perform the Keccak-256 permutation and -/// hashing steps, using a set of virtual registers to hold intermediate state. -/// -/// # Fields -/// - `address`: The starting program counter address for the sequence. -/// - `asm`: Builder for the vector of generated instructions representing the Keccak-256 operation. -/// - `round`: The current round of the Keccak permutation (0..24). -/// - `vr`: An array of virtual register indices used for state and temporary values. -/// - `operand_rs1`: The source register index for the first operand (input state pointer). -/// - `operand_rs2`: Unused. -/// -/// # Usage -/// Typically, you construct a `Keccak256SequenceBuilder` with the required register mapping -/// and operands, then call `.build()` to obtain the full instruction sequence for the -/// Keccak-256 operation. This is used to inline the Keccak-256 hash logic into the -/// RISC-V instruction stream for tracing or emulation purposes. -/// -/// # Note -/// The actual Keccak-256 logic is implemented in the `build` method, which generates -/// the appropriate instruction sequence. This struct is not intended for direct execution, -/// but rather for constructing instruction traces or emulation flows. impl Keccak256SequenceBuilder { fn new(asm: InstrAssembler, operands: FormatInline) -> Self { let vr = array::from_fn(|_| asm.allocator.allocate_for_inline()); @@ -258,6 +234,7 @@ impl InlineOp for Keccak256Permutation { const FUNCT3: u32 = crate::KECCAK256_FUNCT3; const FUNCT7: u32 = crate::KECCAK256_FUNCT7; const NAME: &'static str = crate::KECCAK256_NAME; + type Advice = (); fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { Keccak256SequenceBuilder::new(asm, operands).build() diff --git a/jolt-inlines/keccak256/src/exec.rs b/jolt-inlines/keccak256/src/spec.rs similarity index 53% rename from jolt-inlines/keccak256/src/exec.rs rename to jolt-inlines/keccak256/src/spec.rs index 6c830286bb..b55cc75d05 100644 --- a/jolt-inlines/keccak256/src/exec.rs +++ b/jolt-inlines/keccak256/src/spec.rs @@ -1,19 +1,18 @@ -use crate::sequence_builder::{ROTATION_OFFSETS, ROUND_CONSTANTS}; +use crate::sequence_builder::{Keccak256Permutation, ROTATION_OFFSETS, ROUND_CONSTANTS}; use crate::{Keccak256State, NUM_LANES}; +use jolt_inlines_sdk::host::Xlen; +use jolt_inlines_sdk::spec::rand::rngs::StdRng; +use jolt_inlines_sdk::spec::rand::Rng; +use jolt_inlines_sdk::spec::{InlineMemoryLayout, InlineSpec, InlineTestHarness}; -// Host-side Keccak-256 implementation for reference and testing. #[cfg(all(test, feature = "host"))] pub(crate) fn execute_keccak256(msg: &[u8]) -> [u8; 32] { - // Keccak-256 parameters. - const RATE_IN_BYTES: usize = 136; // 1088-bit rate + const RATE_IN_BYTES: usize = 136; - // NUM_LANES × 64-bit state lanes initialised to zero. let mut state = [0u64; NUM_LANES]; - // 1. Absorb full RATE blocks. let mut offset = 0; while offset + RATE_IN_BYTES <= msg.len() { - // XOR message block into the state. for (i, lane_bytes) in msg[offset..offset + RATE_IN_BYTES] .chunks_exact(8) .enumerate() @@ -21,36 +20,29 @@ pub(crate) fn execute_keccak256(msg: &[u8]) -> [u8; 32] { state[i] ^= u64::from_le_bytes(lane_bytes.try_into().unwrap()); } - // Apply the Keccak-f permutation after each full block. - execute_keccak_f(&mut state); + state = Keccak256Permutation::reference(&state); offset += RATE_IN_BYTES; } - // 2. Absorb the final (possibly empty) partial block with padding. let mut block = [0u8; RATE_IN_BYTES]; let remaining = &msg[offset..]; block[..remaining.len()].copy_from_slice(remaining); - // Domain separation / padding (Keccak: 0x01 .. 0x80). - block[remaining.len()] ^= 0x01; // 0x01 delimiter after the message. - block[RATE_IN_BYTES - 1] ^= 0x80; // Final bit of padding. + block[remaining.len()] ^= 0x01; + block[RATE_IN_BYTES - 1] ^= 0x80; - // XOR padded block into the state and permute once more. for (i, lane_bytes) in block.chunks_exact(8).enumerate() { state[i] ^= u64::from_le_bytes(lane_bytes.try_into().unwrap()); } - execute_keccak_f(&mut state); + state = Keccak256Permutation::reference(&state); - // 3. Squeeze the first 32 bytes of the state as the hash output. let mut hash = [0u8; 32]; for (i, lane) in state.iter().take(4).enumerate() { - // 4 lanes * 8 bytes/lane = 32 bytes hash[i * 8..(i + 1) * 8].copy_from_slice(&lane.to_le_bytes()); } hash } -/// Executes the 24-round Keccak-f[1600] permutation. pub(crate) fn execute_keccak_f(state: &mut Keccak256State) { for rc in ROUND_CONSTANTS { execute_theta(state); @@ -60,20 +52,15 @@ pub(crate) fn execute_keccak_f(state: &mut Keccak256State) { } } -/// The `theta` step of the Keccak-f permutation mixes columns to provide diffusion. -/// This step XORs each bit in the state with the parities of two columns in the state array. pub(crate) fn execute_theta(state: &mut Keccak256State) { - // 1. Compute the parity of each of the 5 columns (an array `C` of 5 lanes). let mut c = [0u64; 5]; for x in 0..5 { c[x] = state[x] ^ state[x + 5] ^ state[x + 10] ^ state[x + 15] ^ state[x + 20]; } - // 2. Compute `D[x] = C[x-1] ^ rotl64(C[x+1], 1)` let mut d = [0u64; 5]; for x in 0..5 { d[x] = c[(x + 4) % 5] ^ c[(x + 1) % 5].rotate_left(1); } - // 3. XOR `D[x]` into each lane in column `x`. for x in 0..5 { for y in 0..5 { state[x + 5 * y] ^= d[x]; @@ -81,25 +68,19 @@ pub(crate) fn execute_theta(state: &mut Keccak256State) { } } -/// The `rho` and `pi` steps of the Keccak-f permutation shuffles the state to provide diffusion. -/// `rho` rotates each lane by a different fixed offset. `pi` permutes positions of the lanes. pub(crate) fn execute_rho_and_pi(state: &mut Keccak256State) { let mut b = [0u64; NUM_LANES]; for x in 0..5 { for y in 0..5 { let nx = y; let ny = (2 * x + 3 * y) % 5; - // Definitely [x][y] here. That behavior allows the test to pass. b[nx + 5 * ny] = state[x + 5 * y].rotate_left(ROTATION_OFFSETS[x][y]); } } state.copy_from_slice(&b); } -/// The `chi` step of the Keccak-f permutation introduces non-linearity (relationships between input and output). pub(crate) fn execute_chi(state: &mut Keccak256State) { - // For each row, it updates each lane as: lane[x] ^= (~lane[x+1] & lane[x+2]) - // This ensures output bit is a non-linear function of three input bits. for y in 0..5 { let mut row = [0u64; 5]; for x in 0..5 { @@ -111,7 +92,36 @@ pub(crate) fn execute_chi(state: &mut Keccak256State) { } } -/// The `iota` step of Keccak-f breaks the symmetry of the rounds by injecting a round constant into the first lane. pub(crate) fn execute_iota(state: &mut Keccak256State, round_constant: u64) { - state[0] ^= round_constant; // Inject round constant. + state[0] ^= round_constant; +} + +impl InlineSpec for Keccak256Permutation { + type Input = [u64; 25]; + type Output = [u64; 25]; + + fn random_input(rng: &mut StdRng) -> Self::Input { + core::array::from_fn(|_| rng.gen()) + } + + fn reference(input: &Self::Input) -> Self::Output { + let mut state = *input; + execute_keccak_f(&mut state); + state + } + + fn create_harness() -> InlineTestHarness { + let layout = InlineMemoryLayout::single_input(136, 200); + InlineTestHarness::new(layout, Xlen::Bit64) + } + + fn load(harness: &mut InlineTestHarness, input: &Self::Input) { + harness.setup_registers(); + harness.load_state64(input); + } + + fn read(harness: &mut InlineTestHarness) -> Self::Output { + let vec = harness.read_output64(25); + vec.try_into().unwrap() + } } diff --git a/jolt-inlines/keccak256/src/test_utils.rs b/jolt-inlines/keccak256/src/test_utils.rs index adf3f2d35e..8762a40fcf 100644 --- a/jolt-inlines/keccak256/src/test_utils.rs +++ b/jolt-inlines/keccak256/src/test_utils.rs @@ -1,50 +1,26 @@ -use crate::exec::execute_keccak_f; +use crate::sequence_builder::Keccak256Permutation; use crate::test_constants::{self, TestVectors}; use crate::Keccak256State; -use tracer::emulator::cpu::Xlen; -use tracer::instruction::inline::INLINE; -use tracer::utils::inline_test_harness::{InlineMemoryLayout, InlineTestHarness}; +use jolt_inlines_sdk::spec::InlineSpec; -/// Simple test case structure for Keccak tests pub struct KeccakTestCase { pub input: Keccak256State, pub expected: Keccak256State, pub description: &'static str, } -pub fn create_keccak_harness(xlen: Xlen) -> InlineTestHarness { - let layout = InlineMemoryLayout::single_input(136, 200); - InlineTestHarness::new(layout, xlen) -} - -pub fn instruction() -> INLINE { - InlineTestHarness::create_default_instruction( - crate::INLINE_OPCODE, - crate::KECCAK256_FUNCT3, - crate::KECCAK256_FUNCT7, - ) -} - -/// Create test cases for direct execution testing. pub fn keccak_test_vectors() -> Vec { vec![ - // Test case 1: All zeros input (standard test vector) KeccakTestCase { input: [0u64; 25], expected: test_constants::xkcp_vectors::AFTER_ONE_PERMUTATION, description: "All zeros input (XKCP test vector)", }, - // Test case 2: Simple pattern KeccakTestCase { input: TestVectors::create_simple_pattern(), - expected: { - let mut state = TestVectors::create_simple_pattern(); - execute_keccak_f(&mut state); - state - }, + expected: Keccak256Permutation::reference(&TestVectors::create_simple_pattern()), description: "Simple arithmetic pattern", }, - // Test case 3: Single bit set KeccakTestCase { input: { let mut state = [0u64; 25]; @@ -54,14 +30,13 @@ pub fn keccak_test_vectors() -> Vec { expected: { let mut state = [0u64; 25]; state[0] = 1; - execute_keccak_f(&mut state); - state + Keccak256Permutation::reference(&state) }, description: "Single bit in first lane", }, ] } -/// Print a Keccak state in hex format for debugging. + pub fn print_state_hex(state: &Keccak256State) { for (i, &lane) in state.iter().enumerate() { if i % 5 == 0 { @@ -72,24 +47,21 @@ pub fn print_state_hex(state: &Keccak256State) { println!(); } -/// Keccak-specific helpers for assertions. pub mod kverify { use super::*; - /// Assert two Keccak states are identical. pub fn assert_states_equal( expected: &Keccak256State, actual: &Keccak256State, test_name: &str, ) { if expected != actual { - println!("\n❌ {test_name} FAILED"); + println!("\n{test_name} FAILED"); println!("Expected state:"); print_state_hex(expected); println!("Actual state:"); print_state_hex(actual); - // Show first few mismatches let mut mismatch_count = 0; for i in 0..25 { if expected[i] != actual[i] { diff --git a/jolt-inlines/keccak256/src/tests.rs b/jolt-inlines/keccak256/src/tests.rs index 6e895b7ddc..8fb3e62eb9 100644 --- a/jolt-inlines/keccak256/src/tests.rs +++ b/jolt-inlines/keccak256/src/tests.rs @@ -1,20 +1,17 @@ #![cfg(all(test, feature = "host"))] mod exec { + use crate::sequence_builder::Keccak256Permutation; use crate::test_utils::*; - use tracer::emulator::cpu::Xlen; + use jolt_inlines_sdk::spec::InlineSpec; #[test] fn test_keccak256_direct_execution() { for (i, test_case) in keccak_test_vectors().iter().enumerate() { - let mut harness = create_keccak_harness(Xlen::Bit64); - harness.setup_registers(); - harness.load_state64(&test_case.input); - let instruction = instruction(); - harness.execute_inline(instruction); - let result_vec = harness.read_output64(25); - let mut result = [0u64; 25]; - result.copy_from_slice(&result_vec); + let mut harness = Keccak256Permutation::create_harness(); + Keccak256Permutation::load(&mut harness, &test_case.input); + harness.execute_inline(Keccak256Permutation::instruction()); + let result = Keccak256Permutation::read(&mut harness); assert_eq!( result, test_case.expected, "Keccak256 direct execution test case {} failed: {}\nInput: {:016x?}\nExpected: {:016x?}\nActual: {:016x?}", @@ -41,29 +38,25 @@ mod exec { ]; for (input, expected_hash) in e2e_vectors { - let hash = crate::exec::execute_keccak256(input); + let hash = crate::spec::execute_keccak256(input); assert_eq!(&hash, expected_hash); } } } mod exec_trace_equivalence { + use crate::sequence_builder::Keccak256Permutation; use crate::test_constants::*; - use crate::test_utils::*; - use tracer::emulator::cpu::Xlen; + use jolt_inlines_sdk::spec::InlineSpec; #[test] fn test_keccak_against_reference() { let initial_state = [0u64; 25]; let expected_final_state = xkcp_vectors::AFTER_ONE_PERMUTATION; - let mut harness = create_keccak_harness(Xlen::Bit64); - harness.setup_registers(); - harness.load_state64(&initial_state); - let instruction = instruction(); - harness.execute_inline(instruction); - let trace_result_vec = harness.read_output64(25); - let mut trace_result = [0u64; 25]; - trace_result.copy_from_slice(&trace_result_vec); + let mut harness = Keccak256Permutation::create_harness(); + Keccak256Permutation::load(&mut harness, &initial_state); + harness.execute_inline(Keccak256Permutation::instruction()); + let trace_result = Keccak256Permutation::read(&mut harness); for i in 0..25 { assert_eq!(trace_result[i], expected_final_state[i]); } @@ -71,14 +64,13 @@ mod exec_trace_equivalence { } mod exec_unit { - use crate::exec::{execute_chi, execute_iota, execute_rho_and_pi, execute_theta}; use crate::sequence_builder::ROUND_CONSTANTS; + use crate::spec::{execute_chi, execute_iota, execute_rho_and_pi, execute_theta}; use crate::test_constants::xkcp_vectors; use crate::NUM_LANES; #[test] fn test_execute_theta() { - // Patterned state to exercise column parities; theta should change the state. let mut state = [0u64; NUM_LANES]; state[0] = 1; state[5] = 2; @@ -93,7 +85,6 @@ mod exec_unit { #[test] fn test_execute_rho_and_pi() { - // Rho rotates lanes and Pi permutes positions; the state must change and lane [1] should move. let mut state = [0u64; NUM_LANES]; state[1] = 0xFF; let original_state = state; @@ -110,7 +101,6 @@ mod exec_unit { #[test] fn test_execute_chi() { - // Chi applies non-linearity: A[x] ^= (~A[x+1] & A[x+2]). Check one row cell explicitly. let mut state = [0u64; NUM_LANES]; state[0] = 0xFF; state[1] = 0xAA; @@ -126,7 +116,6 @@ mod exec_unit { #[test] fn test_execute_iota() { - // Iota xors the round constant into A[0,0]; all other lanes remain unchanged. let mut state = [0u64; NUM_LANES]; state[0] = 0x1234; execute_iota(&mut state, 0x5678); @@ -140,7 +129,6 @@ mod exec_unit { #[test] fn test_step_by_step_round_1() { - // Round-1 step-by-step: compare post-theta/rho+pi/chi states to XKCP expected snapshots. let mut state = [0u64; NUM_LANES]; state[0] = 0x0000000000000001; let round = 1; @@ -155,7 +143,6 @@ mod exec_unit { step_fn(&mut state); assert_eq!(state, expected, "round 1: mismatch after {name}"); } - // Iota has a different signature; apply it separately and check final snapshot. execute_iota(&mut state, ROUND_CONSTANTS[round]); assert_eq!(state, expected_states.iota, "round 1: mismatch after iota"); } diff --git a/jolt-inlines/sdk/Cargo.toml b/jolt-inlines/sdk/Cargo.toml index 42a46a8517..ca310c7d22 100644 --- a/jolt-inlines/sdk/Cargo.toml +++ b/jolt-inlines/sdk/Cargo.toml @@ -10,9 +10,10 @@ repository = "https://github.com/a16z/jolt" [features] default = [] elliptic-curve = [] -host = ["dep:tracer", "dep:inventory"] +host = ["dep:tracer", "dep:inventory", "dep:rand"] [dependencies] jolt-platform = { workspace = true } -tracer = { workspace = true, optional = true, features = ["std"] } +tracer = { workspace = true, optional = true, features = ["std", "test-utils"] } inventory = { workspace = true, optional = true } +rand = { workspace = true, optional = true, features = ["std", "std_rng"] } diff --git a/jolt-inlines/sdk/src/ec.rs b/jolt-inlines/sdk/src/ec.rs index c6df6a1ace..f5623185f5 100644 --- a/jolt-inlines/sdk/src/ec.rs +++ b/jolt-inlines/sdk/src/ec.rs @@ -1,17 +1,23 @@ use core::marker::PhantomData; +use core::ops::{Add, Mul, Neg, Sub}; /// Shared interface for field elements used in EC point arithmetic. -pub trait ECField: Clone + PartialEq + core::fmt::Debug + Sized { +pub trait ECField: + Clone + + PartialEq + + core::fmt::Debug + + Sized + + for<'a> Add<&'a Self, Output = Self> + + for<'a> Sub<&'a Self, Output = Self> + + for<'a> Mul<&'a Self, Output = Self> + + Neg +{ type Error; fn zero() -> Self; fn is_zero(&self) -> bool; - fn neg(&self) -> Self; - fn add(&self, other: &Self) -> Self; - fn sub(&self, other: &Self) -> Self; fn dbl(&self) -> Self; fn tpl(&self) -> Self; - fn mul(&self, other: &Self) -> Self; fn square(&self) -> Self; fn div(&self, other: &Self) -> Self; fn div_assume_nonzero(&self, other: &Self) -> Self; @@ -144,7 +150,7 @@ impl> AffinePoint { if self.is_infinity() { Self::infinity() } else { - Self::new_unchecked(self.x.clone(), self.y.neg()) + Self::new_unchecked(self.x.clone(), -self.y.clone()) } } @@ -159,8 +165,8 @@ impl> AffinePoint { None => num, }; let s = num.div_assume_nonzero(&self.y.dbl()); - let x2 = s.square().sub(&self.x.dbl()); - let y2 = s.mul(&self.x.sub(&x2)).sub(&self.y); + let x2 = s.square() - &self.x.dbl(); + let y2 = s * &(self.x.clone() - &x2) - &self.y; Self::new_unchecked(x2, y2) } } @@ -176,12 +182,9 @@ impl> AffinePoint { } else if self.x == other.x { Self::infinity() } else { - let s = self - .y - .sub(&other.y) - .div_assume_nonzero(&self.x.sub(&other.x)); - let x2 = s.square().sub(&self.x.add(&other.x)); - let y2 = s.mul(&self.x.sub(&x2)).sub(&self.y); + let s = (self.y.clone() - &other.y).div_assume_nonzero(&(self.x.clone() - &other.x)); + let x2 = s.square() - &(self.x.clone() + &other.x); + let y2 = s * &(self.x.clone() - &x2) - &self.y; Self::new_unchecked(x2, y2) } } @@ -198,18 +201,15 @@ impl> AffinePoint { } else if self.x == other.x && self.y != other.y { self.clone() } else { - let ns = self - .y - .sub(&other.y) - .div_assume_nonzero(&other.x.sub(&self.x)); - let nx2 = other.x.sub(&ns.square()); - let divisor = self.x.dbl().add(&nx2); + let ns = (self.y.clone() - &other.y).div_assume_nonzero(&(other.x.clone() - &self.x)); + let nx2 = other.x.clone() - &ns.square(); + let divisor = self.x.dbl() + &nx2; if C::DOUBLE_AND_ADD_DIVISOR_CHECK && divisor.is_zero() { return Self::infinity(); } - let t = self.y.dbl().div_assume_nonzero(&divisor).add(&ns); - let x3 = t.square().add(&nx2); - let y3 = t.mul(&self.x.sub(&x3)).sub(&self.y); + let t = self.y.dbl().div_assume_nonzero(&divisor) + &ns; + let x3 = t.square() + &nx2; + let y3 = t * &(self.x.clone() - &x3) - &self.y; Self::new_unchecked(x3, y3) } } diff --git a/jolt-inlines/sdk/src/host.rs b/jolt-inlines/sdk/src/host.rs index 9990cee944..e4dab09f96 100644 --- a/jolt-inlines/sdk/src/host.rs +++ b/jolt-inlines/sdk/src/host.rs @@ -12,6 +12,22 @@ pub use tracer::utils::inline_helpers::{InstrAssembler, Value}; pub use tracer::utils::inline_sequence_writer::AppendMode; pub use tracer::utils::virtual_registers::VirtualRegisterGuard; +pub trait InlineAdvice { + fn into_values(self) -> Option>; +} + +impl InlineAdvice for () { + fn into_values(self) -> Option> { + None + } +} + +impl InlineAdvice for VecDeque { + fn into_values(self) -> Option> { + Some(self) + } +} + /// Trait for declaring an inline operation's metadata and sequence builder. /// /// Implement this for each sub-inline (e.g. `Sha256Compression`, `Secp256k1MulQ`), @@ -22,14 +38,15 @@ pub trait InlineOp: Send + Sync { const FUNCT7: u32; const NAME: &'static str; + type Advice: InlineAdvice; + fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec; - fn build_advice( - _asm: InstrAssembler, - _operands: FormatInline, - _cpu: &mut Cpu, - ) -> Option> { - None + fn build_advice(_asm: InstrAssembler, _operands: FormatInline, _cpu: &mut Cpu) -> Self::Advice + where + Self::Advice: Default, + { + Self::Advice::default() } } @@ -49,6 +66,23 @@ pub trait InstrAssemblerExt { fn load_paired_u32(&mut self, temp: u8, base: u8, offset: i64, vr_lo: u8, vr_hi: u8); fn store_paired_u32(&mut self, base: u8, offset: i64, vr_lo: u8, vr_hi: u8); fn load_paired_u32_dirty(&mut self, base: u8, offset: i64, vr_lo: u8, vr_hi: u8); + + /// Load consecutive u64 words from `base + byte_offset` into `regs`. + fn load_u64_range(&mut self, base: u8, byte_offset: usize, regs: &[u8]); + + /// Store consecutive u64 words from `regs` to `base + byte_offset`. + fn store_u64_range(&mut self, base: u8, byte_offset: usize, regs: &[u8]); + + /// Load consecutive u32 words from `base + byte_offset` into `regs`. + fn load_u32_range(&mut self, base: u8, byte_offset: usize, regs: &[u8]); + + /// Load consecutive pairs of u32 from `base + byte_offset` into `regs` using paired LD+split. + /// `regs` length must be even. Uses `temp` as scratch for the 64-bit intermediate load. + fn load_u32_range_paired(&mut self, temp: u8, base: u8, byte_offset: usize, regs: &[u8]); + + /// Store consecutive pairs of u32 from `regs` to `base + byte_offset` using paired pack+SD. + /// `regs` length must be even. WARNING: clobbers register values. + fn store_u32_range_paired(&mut self, base: u8, byte_offset: usize, regs: &[u8]); } impl InstrAssemblerExt for InstrAssembler { @@ -86,6 +120,47 @@ impl InstrAssemblerExt for InstrAssembler { self.emit_ld::(vr_lo, base, offset); self.emit_i::(vr_hi, vr_lo, 32); } + + fn load_u64_range(&mut self, base: u8, byte_offset: usize, regs: &[u8]) { + use instruction::ld::LD; + for (i, ®) in regs.iter().enumerate() { + self.emit_ld::(reg, base, (byte_offset + i * 8) as i64); + } + } + + fn store_u64_range(&mut self, base: u8, byte_offset: usize, regs: &[u8]) { + use instruction::sd::SD; + for (i, ®) in regs.iter().enumerate() { + self.emit_s::(base, reg, (byte_offset + i * 8) as i64); + } + } + + fn load_u32_range(&mut self, base: u8, byte_offset: usize, regs: &[u8]) { + use instruction::lw::LW; + for (i, ®) in regs.iter().enumerate() { + self.emit_ld::(reg, base, (byte_offset + i * 4) as i64); + } + } + + fn load_u32_range_paired(&mut self, temp: u8, base: u8, byte_offset: usize, regs: &[u8]) { + debug_assert!( + regs.len().is_multiple_of(2), + "regs length must be even for paired loading" + ); + for (i, pair) in regs.chunks_exact(2).enumerate() { + self.load_paired_u32(temp, base, (byte_offset + i * 8) as i64, pair[0], pair[1]); + } + } + + fn store_u32_range_paired(&mut self, base: u8, byte_offset: usize, regs: &[u8]) { + debug_assert!( + regs.len().is_multiple_of(2), + "regs length must be even for paired storing" + ); + for (i, pair) in regs.chunks_exact(2).enumerate() { + self.store_paired_u32(base, (byte_offset + i * 8) as i64, pair[0], pair[1]); + } + } } /// Generate `store_inlines()` and submit `InlineRegistration` entries to `inventory`. @@ -143,7 +218,11 @@ macro_rules! __submit_inline_op { funct7: <$op as $crate::host::InlineOp>::FUNCT7, name: <$op as $crate::host::InlineOp>::NAME, build_sequence: <$op as $crate::host::InlineOp>::build_sequence, - build_advice: <$op as $crate::host::InlineOp>::build_advice, + build_advice: |asm, operands, cpu| { + $crate::host::InlineAdvice::into_values( + <$op as $crate::host::InlineOp>::build_advice(asm, operands, cpu) + ) + }, } } }; diff --git a/jolt-inlines/sdk/src/lib.rs b/jolt-inlines/sdk/src/lib.rs index cf8b3ffdd8..12e7d97ddb 100644 --- a/jolt-inlines/sdk/src/lib.rs +++ b/jolt-inlines/sdk/src/lib.rs @@ -7,3 +7,6 @@ pub mod ec; #[cfg(feature = "host")] pub mod host; + +#[cfg(feature = "host")] +pub mod spec; diff --git a/jolt-inlines/sdk/src/spec.rs b/jolt-inlines/sdk/src/spec.rs new file mode 100644 index 0000000000..12aa6b1a2f --- /dev/null +++ b/jolt-inlines/sdk/src/spec.rs @@ -0,0 +1,52 @@ +use crate::host::InlineOp; +use rand::rngs::StdRng; +use rand::SeedableRng; + +pub use rand; +pub use tracer::emulator::cpu::Xlen; +pub use tracer::instruction::inline::INLINE; +pub use tracer::utils::inline_test_harness::{InlineMemoryLayout, InlineTestHarness}; + +/// Formal specification of an inline's behavior. +/// Connects the mathematical reference implementation to the test harness. +/// Implemented on the same struct as `InlineOp`. +pub trait InlineSpec: InlineOp { + type Input; + type Output: PartialEq + core::fmt::Debug; + + /// Pure reference implementation — the formal verification target. + fn reference(input: &Self::Input) -> Self::Output; + + /// Generate a random input for property testing. + fn random_input(rng: &mut StdRng) -> Self::Input; + + fn create_harness() -> InlineTestHarness; + + fn instruction() -> INLINE { + InlineTestHarness::create_default_instruction(Self::OPCODE, Self::FUNCT3, Self::FUNCT7) + } + + /// Load typed input into harness memory and set up registers. + fn load(harness: &mut InlineTestHarness, input: &Self::Input); + + /// Read typed output from harness memory after execution. + fn read(harness: &mut InlineTestHarness) -> Self::Output; +} + +/// Verify sequence builder matches reference for a given input. +pub fn verify(input: &S::Input) { + let mut harness = S::create_harness(); + S::load(&mut harness, input); + harness.execute_inline(S::instruction()); + let actual = S::read(&mut harness); + let expected = S::reference(input); + assert_eq!(actual, expected); +} + +/// Verify sequence builder matches reference over `count` random inputs. +pub fn proptest(count: usize) { + let mut rng = StdRng::seed_from_u64(0); + for _ in 0..count { + verify::(&S::random_input(&mut rng)); + } +} diff --git a/jolt-inlines/secp256k1/src/sdk.rs b/jolt-inlines/secp256k1/src/sdk.rs index 909aaf9e2f..ae7f70e1aa 100644 --- a/jolt-inlines/secp256k1/src/sdk.rs +++ b/jolt-inlines/secp256k1/src/sdk.rs @@ -5,6 +5,7 @@ use ark_ff::AdditiveGroup; use ark_ff::Field; use ark_ff::{BigInt, PrimeField}; use ark_secp256k1::{Fq, Fr}; +use core::marker::PhantomData; use jolt_inlines_sdk::ec::{AffinePoint, CurveParams, ECField}; @@ -12,16 +13,11 @@ extern crate alloc; use alloc::vec::Vec; use serde::{Deserialize, Serialize}; -/// Returns `true` iff `x >= p` (Fq modulus), i.e., `x` is non-canonical. -/// Specialized: since p's upper 3 limbs are all u64::MAX, x >= p iff -/// all upper 3 limbs are MAX and limb[0] >= Fq::MODULUS.0[0]. #[inline(always)] fn is_fq_non_canonical(x: &[u64; 4]) -> bool { x[3] == u64::MAX && x[2] == u64::MAX && x[1] == u64::MAX && x[0] >= Fq::MODULUS.0[0] } -/// Returns `true` iff `x >= n` (Fr modulus), i.e., `x` is non-canonical. -/// Specialized: since n's limb[3] is u64::MAX, we short-circuit if x[3] < MAX. #[inline(always)] fn is_fr_non_canonical(x: &[u64; 4]) -> bool { if x[3] != u64::MAX { @@ -44,15 +40,14 @@ fn is_fr_non_canonical(x: &[u64; 4]) -> bool { pub use jolt_inlines_sdk::{spoil_proof, UnwrapOrSpoilProof}; -/// Error types for secp256k1 operations #[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub enum Secp256k1Error { - InvalidFqElement, // input array does not correspond to a valid Fq element - InvalidFrElement, // input array does not correspond to a valid Fr element - NotOnCurve, // point is not on the secp256k1 curve - QAtInfinity, // public key is point at infinity - ROrSZero, // one of the signature components is zero - RxMismatch, // computed R.x does not match r + InvalidFqElement, + InvalidFrElement, + NotOnCurve, + QAtInfinity, + ROrSZero, + RxMismatch, InvalidGlvSignWord(u64), } @@ -70,141 +65,267 @@ pub(crate) fn decode_glv_sign_word(sign: u64) -> Result { } } -/// secp256k1 base field element -/// not in montgomery form -/// as a wrapper around 4 u64 limbs -/// so that various operations can be replaced with inlines -/// uses arkworks Fq for addition and subtraction even though -/// arkworks Fq is in montgomery form. This doesn't affect correctness -/// since addition and subtraction are the same in montgomery and -/// non-montgomery form. -/// uses arkworks Fq for host multiplication and division with appropriate conversions -/// defers to inlines for multiplication and division in guest builds -#[derive(Clone, PartialEq, Debug)] -pub struct Secp256k1Fq { - e: [u64; 4], +/// Configuration for a secp256k1 field (base or scalar). +/// +/// Limbs are stored in non-Montgomery form. Addition and subtraction reinterpret +/// raw limbs as Montgomery-form arkworks elements — this is correct because +/// modular add/sub is representation-independent. Multiplication and division +/// use proper Montgomery conversion via `from_limbs_to_mont` / `from_mont_to_limbs`. +pub trait Secp256k1FieldConfig: 'static { + const MUL_FUNCT3: u32; + const SQUARE_FUNCT3: u32; + const DIV_FUNCT3: u32; + + fn is_non_canonical(limbs: &[u64; 4]) -> bool; + fn invalid_element_error() -> Secp256k1Error; + + /// Reinterpret raw limbs as a Montgomery-form element (no conversion). + /// Used for add/sub/neg/dbl where Montgomery vs non-Montgomery is irrelevant. + fn add_raw(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4]; + fn sub_raw(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4]; + fn neg_raw(a: &[u64; 4]) -> [u64; 4]; + fn dbl_raw(a: &[u64; 4]) -> [u64; 4]; + + /// Convert raw limbs into Montgomery form, perform multiplication, convert back. + #[cfg(feature = "host")] + fn mul_host(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4]; + #[cfg(feature = "host")] + fn square_host(a: &[u64; 4]) -> [u64; 4]; + #[cfg(feature = "host")] + fn div_host(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4]; } -impl Secp256k1Fq { - /// creates a new Secp256k1Fq element from a [u64; 4] array - /// returns Err(Secp256k1Error) if the array does not correspond to a valid Fq element +pub enum FqConfig {} + +impl Secp256k1FieldConfig for FqConfig { + const MUL_FUNCT3: u32 = crate::SECP256K1_MULQ_FUNCT3; + const SQUARE_FUNCT3: u32 = crate::SECP256K1_SQUAREQ_FUNCT3; + const DIV_FUNCT3: u32 = crate::SECP256K1_DIVQ_FUNCT3; + #[inline(always)] - pub fn from_u64_arr(arr: &[u64; 4]) -> Result { - if is_fq_non_canonical(arr) { - return Err(Secp256k1Error::InvalidFqElement); - } - Ok(Secp256k1Fq { e: *arr }) + fn is_non_canonical(limbs: &[u64; 4]) -> bool { + is_fq_non_canonical(limbs) } - /// creates a new Secp256k1Fq element from a [u64; 4] array (unchecked) - /// the array is assumed to contain a value in the range [0, p) #[inline(always)] - pub fn from_u64_arr_unchecked(arr: &[u64; 4]) -> Self { - Secp256k1Fq { e: *arr } + fn invalid_element_error() -> Secp256k1Error { + Secp256k1Error::InvalidFqElement } - /// get limbs #[inline(always)] - pub fn e(&self) -> [u64; 4] { - self.e + fn add_raw(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + (Fq::new_unchecked(BigInt(*a)) + Fq::new_unchecked(BigInt(*b))) + .0 + .0 } - /// returns the additive identity element (0) #[inline(always)] - pub fn zero() -> Self { - Secp256k1Fq { e: [0u64; 4] } + fn sub_raw(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + (Fq::new_unchecked(BigInt(*a)) - Fq::new_unchecked(BigInt(*b))) + .0 + .0 } - /// returns seven #[inline(always)] - pub fn seven() -> Self { - Secp256k1Fq { - e: [7u64, 0u64, 0u64, 0u64], + fn neg_raw(a: &[u64; 4]) -> [u64; 4] { + (-Fq::new_unchecked(BigInt(*a))).0 .0 + } + #[inline(always)] + fn dbl_raw(a: &[u64; 4]) -> [u64; 4] { + Fq::new_unchecked(BigInt(*a)).double().0 .0 + } + #[cfg(feature = "host")] + #[inline(always)] + fn mul_host(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + (Fq::new(BigInt(*a)) * Fq::new(BigInt(*b))).into_bigint().0 + } + #[cfg(feature = "host")] + #[inline(always)] + fn square_host(a: &[u64; 4]) -> [u64; 4] { + Fq::new(BigInt(*a)).square().into_bigint().0 + } + #[cfg(feature = "host")] + #[inline(always)] + fn div_host(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + (Fq::new(BigInt(*a)) / Fq::new(BigInt(*b))).into_bigint().0 + } +} + +pub enum FrConfig {} + +impl Secp256k1FieldConfig for FrConfig { + const MUL_FUNCT3: u32 = crate::SECP256K1_MULR_FUNCT3; + const SQUARE_FUNCT3: u32 = crate::SECP256K1_SQUARER_FUNCT3; + const DIV_FUNCT3: u32 = crate::SECP256K1_DIVR_FUNCT3; + + #[inline(always)] + fn is_non_canonical(limbs: &[u64; 4]) -> bool { + is_fr_non_canonical(limbs) + } + #[inline(always)] + fn invalid_element_error() -> Secp256k1Error { + Secp256k1Error::InvalidFrElement + } + #[inline(always)] + fn add_raw(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + (Fr::new_unchecked(BigInt(*a)) + Fr::new_unchecked(BigInt(*b))) + .0 + .0 + } + #[inline(always)] + fn sub_raw(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + (Fr::new_unchecked(BigInt(*a)) - Fr::new_unchecked(BigInt(*b))) + .0 + .0 + } + #[inline(always)] + fn neg_raw(a: &[u64; 4]) -> [u64; 4] { + (-Fr::new_unchecked(BigInt(*a))).0 .0 + } + #[inline(always)] + fn dbl_raw(a: &[u64; 4]) -> [u64; 4] { + Fr::new_unchecked(BigInt(*a)).double().0 .0 + } + #[cfg(feature = "host")] + #[inline(always)] + fn mul_host(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + (Fr::new(BigInt(*a)) * Fr::new(BigInt(*b))).into_bigint().0 + } + #[cfg(feature = "host")] + #[inline(always)] + fn square_host(a: &[u64; 4]) -> [u64; 4] { + Fr::new(BigInt(*a)).square().into_bigint().0 + } + #[cfg(feature = "host")] + #[inline(always)] + fn div_host(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + (Fr::new(BigInt(*a)) / Fr::new(BigInt(*b))).into_bigint().0 + } +} + +pub struct Secp256k1Field { + e: [u64; 4], + _phantom: PhantomData, +} + +impl Clone for Secp256k1Field { + #[inline(always)] + fn clone(&self) -> Self { + Self { + e: self.e, + _phantom: PhantomData, } } - /// returns true if the element is zero +} + +impl PartialEq for Secp256k1Field { #[inline(always)] - pub fn is_zero(&self) -> bool { - self.e == [0u64; 4] + fn eq(&self, other: &Self) -> bool { + self.e == other.e + } +} + +impl core::fmt::Debug for Secp256k1Field { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Secp256k1Field") + .field("e", &self.e) + .finish() } - /// returns -self +} + +pub type Secp256k1Fq = Secp256k1Field; +pub type Secp256k1Fr = Secp256k1Field; + +impl Secp256k1Field { #[inline(always)] - pub fn neg(&self) -> Self { - Secp256k1Fq { - e: (-Fq::new_unchecked(BigInt(self.e))).0 .0, + pub fn from_u64_arr(arr: &[u64; 4]) -> Result { + if C::is_non_canonical(arr) { + return Err(C::invalid_element_error()); } + Ok(Self { + e: *arr, + _phantom: PhantomData, + }) } - /// returns self + other + #[inline(always)] - pub fn add(&self, other: &Secp256k1Fq) -> Self { - Secp256k1Fq { - e: (Fq::new_unchecked(BigInt(self.e)) + Fq::new_unchecked(BigInt(other.e))) - .0 - .0, + pub fn from_u64_arr_unchecked(arr: &[u64; 4]) -> Self { + Self { + e: *arr, + _phantom: PhantomData, } } - /// returns self - other + #[inline(always)] - pub fn sub(&self, other: &Secp256k1Fq) -> Self { - Secp256k1Fq { - e: (Fq::new_unchecked(BigInt(self.e)) - Fq::new_unchecked(BigInt(other.e))) - .0 - .0, + pub fn e(&self) -> [u64; 4] { + self.e + } + + #[inline(always)] + pub fn zero() -> Self { + Self { + e: [0u64; 4], + _phantom: PhantomData, } } - /// returns 2*self + + #[inline(always)] + pub fn is_zero(&self) -> bool { + self.e == [0u64; 4] + } + #[inline(always)] pub fn dbl(&self) -> Self { - Secp256k1Fq { - e: (Fq::new_unchecked(BigInt(self.e)).double()).0 .0, + Self { + e: C::dbl_raw(&self.e), + _phantom: PhantomData, } } - /// returns 3*self + #[inline(always)] pub fn tpl(&self) -> Self { - self.dbl().add(self) + &self.dbl() + self } - /// returns self * other - /// uses custom inline for performance + #[cfg(all( not(feature = "host"), any(target_arch = "riscv32", target_arch = "riscv64") ))] #[inline(always)] - pub fn mul(&self, other: &Secp256k1Fq) -> Self { + pub fn mul(&self, other: &Self) -> Self { let mut e = [0u64; 4]; + // SAFETY: inline instruction writes exactly 4 u64s to `e` unsafe { - use crate::{INLINE_OPCODE, SECP256K1_FUNCT7, SECP256K1_MULQ_FUNCT3}; core::arch::asm!( ".insn r {opcode}, {funct3}, {funct7}, {rd}, {rs1}, {rs2}", - opcode = const INLINE_OPCODE, - funct3 = const SECP256K1_MULQ_FUNCT3, - funct7 = const SECP256K1_FUNCT7, + opcode = const crate::INLINE_OPCODE, + funct3 = const C::MUL_FUNCT3, + funct7 = const crate::SECP256K1_FUNCT7, rd = in(reg) e.as_mut_ptr(), rs1 = in(reg) self.e.as_ptr(), rs2 = in(reg) other.e.as_ptr(), options(nostack) ); } - if is_fq_non_canonical(&e) { + if C::is_non_canonical(&e) { spoil_proof(); } - Secp256k1Fq::from_u64_arr_unchecked(&e[0..4].try_into().unwrap()) + Self::from_u64_arr_unchecked(&e[0..4].try_into().unwrap()) } + #[cfg(all( not(feature = "host"), not(any(target_arch = "riscv32", target_arch = "riscv64")) ))] - pub fn mul(&self, _other: &Secp256k1Fq) -> Self { - panic!("Secp256k1Fq::mul called on non-RISC-V target without host feature"); + pub fn mul(&self, _other: &Self) -> Self { + panic!("Secp256k1Field::mul called on non-RISC-V target without host feature"); } + #[cfg(feature = "host")] #[inline(always)] - pub fn mul(&self, other: &Secp256k1Fq) -> Self { - Secp256k1Fq { - e: (Fq::new(BigInt(self.e)) * Fq::new(BigInt(other.e))) - .into_bigint() - .0, + pub fn mul(&self, other: &Self) -> Self { + Self { + e: C::mul_host(&self.e, &other.e), + _phantom: PhantomData, } } - /// returns self^2 - /// uses custom inline for performance + #[cfg(all( not(feature = "host"), any(target_arch = "riscv32", target_arch = "riscv64") @@ -212,136 +333,210 @@ impl Secp256k1Fq { #[inline(always)] pub fn square(&self) -> Self { let mut e = [0u64; 4]; + // SAFETY: inline instruction writes exactly 4 u64s to `e` unsafe { - use crate::{INLINE_OPCODE, SECP256K1_FUNCT7, SECP256K1_SQUAREQ_FUNCT3}; core::arch::asm!( ".insn r {opcode}, {funct3}, {funct7}, {rd}, {rs1}, x0", - opcode = const INLINE_OPCODE, - funct3 = const SECP256K1_SQUAREQ_FUNCT3, - funct7 = const SECP256K1_FUNCT7, + opcode = const crate::INLINE_OPCODE, + funct3 = const C::SQUARE_FUNCT3, + funct7 = const crate::SECP256K1_FUNCT7, rd = in(reg) e.as_mut_ptr(), rs1 = in(reg) self.e.as_ptr(), options(nostack) ); } - if is_fq_non_canonical(&e) { + if C::is_non_canonical(&e) { spoil_proof(); } - Secp256k1Fq::from_u64_arr_unchecked(&e[0..4].try_into().unwrap()) + Self::from_u64_arr_unchecked(&e[0..4].try_into().unwrap()) } + #[cfg(all( not(feature = "host"), not(any(target_arch = "riscv32", target_arch = "riscv64")) ))] pub fn square(&self) -> Self { - panic!("Secp256k1Fq::square called on non-RISC-V target without host feature"); + panic!("Secp256k1Field::square called on non-RISC-V target without host feature"); } + #[cfg(feature = "host")] #[inline(always)] pub fn square(&self) -> Self { - Secp256k1Fq { - e: Fq::new(BigInt(self.e)).square().into_bigint().0, + Self { + e: C::square_host(&self.e), + _phantom: PhantomData, } } - /// returns self / other - /// uses custom inline for performance - /// assumes that other is non-zero + #[cfg(all( not(feature = "host"), any(target_arch = "riscv32", target_arch = "riscv64") ))] #[inline(always)] - fn div_assume_nonzero(&self, other: &Secp256k1Fq) -> Self { + fn div_assume_nonzero(&self, other: &Self) -> Self { let mut e = [0u64; 4]; + // SAFETY: inline instruction writes exactly 4 u64s to `e` unsafe { - use crate::{INLINE_OPCODE, SECP256K1_DIVQ_FUNCT3, SECP256K1_FUNCT7}; core::arch::asm!( ".insn r {opcode}, {funct3}, {funct7}, {rd}, {rs1}, {rs2}", - opcode = const INLINE_OPCODE, - funct3 = const SECP256K1_DIVQ_FUNCT3, - funct7 = const SECP256K1_FUNCT7, + opcode = const crate::INLINE_OPCODE, + funct3 = const C::DIV_FUNCT3, + funct7 = const crate::SECP256K1_FUNCT7, rd = in(reg) e.as_mut_ptr(), rs1 = in(reg) self.e.as_ptr(), rs2 = in(reg) other.e.as_ptr(), options(nostack) ); } - if is_fq_non_canonical(&e) { + if C::is_non_canonical(&e) { spoil_proof(); } - Secp256k1Fq::from_u64_arr_unchecked(&e[0..4].try_into().unwrap()) + Self::from_u64_arr_unchecked(&e[0..4].try_into().unwrap()) } - /// panics and spoils the proof if other is zero - /// returns self / other - /// uses custom inline for performance + #[cfg(all( not(feature = "host"), any(target_arch = "riscv32", target_arch = "riscv64") ))] #[inline(always)] - pub fn div(&self, other: &Secp256k1Fq) -> Self { - // spoil proof if other == 0 + pub fn div(&self, other: &Self) -> Self { if other.is_zero() { spoil_proof(); } self.div_assume_nonzero(other) } + #[cfg(all( not(feature = "host"), not(any(target_arch = "riscv32", target_arch = "riscv64")) ))] - pub fn div_assume_nonzero(&self, _other: &Secp256k1Fq) -> Self { - panic!("Secp256k1Fq::div_assume_nonzero called on non-RISC-V target without host feature"); + pub fn div_assume_nonzero(&self, _other: &Self) -> Self { + panic!( + "Secp256k1Field::div_assume_nonzero called on non-RISC-V target without host feature" + ); } + #[cfg(all( not(feature = "host"), not(any(target_arch = "riscv32", target_arch = "riscv64")) ))] - pub fn div(&self, _other: &Secp256k1Fq) -> Self { - panic!("Secp256k1Fq::div called on non-RISC-V target without host feature"); + pub fn div(&self, _other: &Self) -> Self { + panic!("Secp256k1Field::div called on non-RISC-V target without host feature"); } - /// assumes other != 0 + #[cfg(feature = "host")] #[inline(always)] - pub fn div_assume_nonzero(&self, other: &Secp256k1Fq) -> Self { - Secp256k1Fq { - e: (Fq::new(BigInt(self.e)) / Fq::new(BigInt(other.e))) - .into_bigint() - .0, + pub fn div_assume_nonzero(&self, other: &Self) -> Self { + Self { + e: C::div_host(&self.e, &other.e), + _phantom: PhantomData, } } - /// checks other != 0 then calls div_assume_nonzero + #[cfg(feature = "host")] #[inline(always)] - pub fn div(&self, other: &Secp256k1Fq) -> Self { + pub fn div(&self, other: &Self) -> Self { if other.is_zero() { - panic!("division by zero in Secp256k1Fq::div"); + panic!("division by zero in Secp256k1Field::div"); } self.div_assume_nonzero(other) } } -impl ECField for Secp256k1Fq { - type Error = Secp256k1Error; +impl core::ops::Add<&Secp256k1Field> for &Secp256k1Field { + type Output = Secp256k1Field; #[inline(always)] - fn zero() -> Self { - Self::zero() + fn add(self, rhs: &Secp256k1Field) -> Secp256k1Field { + Secp256k1Field { + e: C::add_raw(&self.e, &rhs.e), + _phantom: PhantomData, + } } +} + +impl core::ops::Sub<&Secp256k1Field> for &Secp256k1Field { + type Output = Secp256k1Field; #[inline(always)] - fn is_zero(&self) -> bool { - self.is_zero() + fn sub(self, rhs: &Secp256k1Field) -> Secp256k1Field { + Secp256k1Field { + e: C::sub_raw(&self.e, &rhs.e), + _phantom: PhantomData, + } } +} + +impl core::ops::Mul<&Secp256k1Field> for &Secp256k1Field { + type Output = Secp256k1Field; + #[inline(always)] + fn mul(self, rhs: &Secp256k1Field) -> Secp256k1Field { + Secp256k1Field::mul(self, rhs) + } +} + +impl core::ops::Neg for &Secp256k1Field { + type Output = Secp256k1Field; + #[inline(always)] + fn neg(self) -> Secp256k1Field { + Secp256k1Field { + e: C::neg_raw(&self.e), + _phantom: PhantomData, + } + } +} + +impl core::ops::Add<&Secp256k1Field> for Secp256k1Field { + type Output = Secp256k1Field; + #[inline(always)] + fn add(self, rhs: &Secp256k1Field) -> Secp256k1Field { + &self + rhs + } +} + +impl core::ops::Sub<&Secp256k1Field> for Secp256k1Field { + type Output = Secp256k1Field; + #[inline(always)] + fn sub(self, rhs: &Secp256k1Field) -> Secp256k1Field { + &self - rhs + } +} + +impl core::ops::Mul<&Secp256k1Field> for Secp256k1Field { + type Output = Secp256k1Field; + #[inline(always)] + fn mul(self, rhs: &Secp256k1Field) -> Secp256k1Field { + Secp256k1Field::mul(&self, rhs) + } +} + +impl core::ops::Neg for Secp256k1Field { + type Output = Secp256k1Field; + #[inline(always)] + fn neg(self) -> Secp256k1Field { + -&self + } +} + +// Fq-specific methods + +impl Secp256k1Field { #[inline(always)] - fn neg(&self) -> Self { - self.neg() + pub fn seven() -> Self { + Self { + e: [7u64, 0u64, 0u64, 0u64], + _phantom: PhantomData, + } } +} + +impl ECField for Secp256k1Fq { + type Error = Secp256k1Error; #[inline(always)] - fn add(&self, other: &Self) -> Self { - self.add(other) + fn zero() -> Self { + Self::zero() } #[inline(always)] - fn sub(&self, other: &Self) -> Self { - self.sub(other) + fn is_zero(&self) -> bool { + self.is_zero() } #[inline(always)] fn dbl(&self) -> Self { @@ -352,10 +547,6 @@ impl ECField for Secp256k1Fq { self.tpl() } #[inline(always)] - fn mul(&self, other: &Self) -> Self { - self.mul(other) - } - #[inline(always)] fn square(&self) -> Self { self.square() } @@ -381,265 +572,24 @@ impl ECField for Secp256k1Fq { } } -/// secp256k1 scalar field element -/// not in montgomery form -/// as a wrapper around 4 u64 limbs -/// so that various operations can be replaced with inlines -/// uses arkworks Fr for addition and subtraction even though -/// arkworks Fr is in montgomery form. This doesn't affect correctness -/// since addition and subtraction are the same in montgomery and -/// non-montgomery form. -/// uses arkworks Fr for host multiplication and division with appropriate conversions -/// defers to inlines for multiplication and division in guest builds -#[derive(Clone, PartialEq, Debug)] -pub struct Secp256k1Fr { - e: [u64; 4], -} +// Fr-specific methods -impl Secp256k1Fr { - /// creates a new Secp256k1Fr element from a [u64; 4] array - /// returns Err(Secp256k1Error) if the array does not correspond to a valid Fr element - #[inline(always)] - pub fn from_u64_arr(arr: &[u64; 4]) -> Result { - if is_fr_non_canonical(arr) { - return Err(Secp256k1Error::InvalidFrElement); - } - Ok(Secp256k1Fr { e: *arr }) - } - /// creates a new Secp256k1Fr element from a [u64; 4] array (unchecked) - /// the array is assumed to contain a value in the range [0, p) - #[inline(always)] - pub fn from_u64_arr_unchecked(arr: &[u64; 4]) -> Self { - Secp256k1Fr { e: *arr } - } - /// get limbs - #[inline(always)] - pub fn e(&self) -> [u64; 4] { - self.e - } - /// as a pair of u128s (little-endian) +impl Secp256k1Field { #[inline(always)] pub fn as_u128_pair(&self) -> (u128, u128) { let low = self.e[0] as u128 + ((self.e[1] as u128) << 64); let high = self.e[2] as u128 + ((self.e[3] as u128) << 64); (low, high) } - /// returns the additive identity element (0) - #[inline(always)] - pub fn zero() -> Self { - Secp256k1Fr { e: [0u64; 4] } - } - /// returns true if the element is zero - #[inline(always)] - pub fn is_zero(&self) -> bool { - self.e == [0u64; 4] - } - /// returns -self - #[inline(always)] - pub fn neg(&self) -> Self { - Secp256k1Fr { - e: (-Fr::new_unchecked(BigInt(self.e))).0 .0, - } - } - /// returns self + other - #[inline(always)] - pub fn add(&self, other: &Secp256k1Fr) -> Self { - Secp256k1Fr { - e: (Fr::new_unchecked(BigInt(self.e)) + Fr::new_unchecked(BigInt(other.e))) - .0 - .0, - } - } - /// returns self - other - #[inline(always)] - pub fn sub(&self, other: &Secp256k1Fr) -> Self { - Secp256k1Fr { - e: (Fr::new_unchecked(BigInt(self.e)) - Fr::new_unchecked(BigInt(other.e))) - .0 - .0, - } - } - /// returns 2*self - #[inline(always)] - pub fn dbl(&self) -> Self { - Secp256k1Fr { - e: (Fr::new_unchecked(BigInt(self.e)).double()).0 .0, - } - } - /// returns 3*self - #[inline(always)] - pub fn tpl(&self) -> Self { - self.dbl().add(self) - } - /// returns self * other - /// uses custom inline for performance - #[cfg(all( - not(feature = "host"), - any(target_arch = "riscv32", target_arch = "riscv64") - ))] - #[inline(always)] - pub fn mul(&self, other: &Secp256k1Fr) -> Self { - let mut e = [0u64; 4]; - unsafe { - use crate::{INLINE_OPCODE, SECP256K1_FUNCT7, SECP256K1_MULR_FUNCT3}; - core::arch::asm!( - ".insn r {opcode}, {funct3}, {funct7}, {rd}, {rs1}, {rs2}", - opcode = const INLINE_OPCODE, - funct3 = const SECP256K1_MULR_FUNCT3, - funct7 = const SECP256K1_FUNCT7, - rd = in(reg) e.as_mut_ptr(), - rs1 = in(reg) self.e.as_ptr(), - rs2 = in(reg) other.e.as_ptr(), - options(nostack) - ); - } - if is_fr_non_canonical(&e) { - spoil_proof(); - } - Secp256k1Fr::from_u64_arr_unchecked(&e[0..4].try_into().unwrap()) - } - #[cfg(all( - not(feature = "host"), - not(any(target_arch = "riscv32", target_arch = "riscv64")) - ))] - pub fn mul(&self, _other: &Secp256k1Fr) -> Self { - panic!("Secp256k1Fr::mul called on non-RISC-V target without host feature"); - } - #[cfg(feature = "host")] - #[inline(always)] - pub fn mul(&self, other: &Secp256k1Fr) -> Self { - Secp256k1Fr { - e: (Fr::new(BigInt(self.e)) * Fr::new(BigInt(other.e))) - .into_bigint() - .0, - } - } - /// returns self^2 - /// uses custom inline for performance - #[cfg(all( - not(feature = "host"), - any(target_arch = "riscv32", target_arch = "riscv64") - ))] - #[inline(always)] - pub fn square(&self) -> Self { - let mut e = [0u64; 4]; - unsafe { - use crate::{INLINE_OPCODE, SECP256K1_FUNCT7, SECP256K1_SQUARER_FUNCT3}; - core::arch::asm!( - ".insn r {opcode}, {funct3}, {funct7}, {rd}, {rs1}, x0", - opcode = const INLINE_OPCODE, - funct3 = const SECP256K1_SQUARER_FUNCT3, - funct7 = const SECP256K1_FUNCT7, - rd = in(reg) e.as_mut_ptr(), - rs1 = in(reg) self.e.as_ptr(), - options(nostack) - ); - } - if is_fr_non_canonical(&e) { - spoil_proof(); - } - Secp256k1Fr::from_u64_arr_unchecked(&e[0..4].try_into().unwrap()) - } - #[cfg(all( - not(feature = "host"), - not(any(target_arch = "riscv32", target_arch = "riscv64")) - ))] - pub fn square(&self) -> Self { - panic!("Secp256k1Fr::square called on non-RISC-V target without host feature"); - } - #[cfg(feature = "host")] - #[inline(always)] - pub fn square(&self) -> Self { - Secp256k1Fr { - e: Fr::new(BigInt(self.e)).square().into_bigint().0, - } - } - /// returns self / other - /// uses custom inline for performance - /// assumes that other is non-zero - #[cfg(all( - not(feature = "host"), - any(target_arch = "riscv32", target_arch = "riscv64") - ))] - #[inline(always)] - fn div_assume_nonzero(&self, other: &Secp256k1Fr) -> Self { - let mut e = [0u64; 4]; - unsafe { - use crate::{INLINE_OPCODE, SECP256K1_DIVR_FUNCT3, SECP256K1_FUNCT7}; - core::arch::asm!( - ".insn r {opcode}, {funct3}, {funct7}, {rd}, {rs1}, {rs2}", - opcode = const INLINE_OPCODE, - funct3 = const SECP256K1_DIVR_FUNCT3, - funct7 = const SECP256K1_FUNCT7, - rd = in(reg) e.as_mut_ptr(), - rs1 = in(reg) self.e.as_ptr(), - rs2 = in(reg) other.e.as_ptr(), - options(nostack) - ); - } - if is_fr_non_canonical(&e) { - spoil_proof(); - } - Secp256k1Fr::from_u64_arr_unchecked(&e[0..4].try_into().unwrap()) - } - /// panics and spoils the proof if other is zero - /// returns self / other - /// uses custom inline for performance - #[cfg(all( - not(feature = "host"), - any(target_arch = "riscv32", target_arch = "riscv64") - ))] - #[inline(always)] - pub fn div(&self, other: &Secp256k1Fr) -> Self { - // spoil proof if other == 0 - if other.is_zero() { - spoil_proof(); - } - self.div_assume_nonzero(other) - } - #[cfg(all( - not(feature = "host"), - not(any(target_arch = "riscv32", target_arch = "riscv64")) - ))] - pub fn div_assume_nonzero(&self, _other: &Secp256k1Fr) -> Self { - panic!("Secp256k1Fr::div_assume_nonzero called on non-RISC-V target without host feature"); - } - #[cfg(all( - not(feature = "host"), - not(any(target_arch = "riscv32", target_arch = "riscv64")) - ))] - pub fn div(&self, _other: &Secp256k1Fr) -> Self { - panic!("Secp256k1Fr::div called on non-RISC-V target without host feature"); - } - /// assumes other != 0 - #[cfg(feature = "host")] - #[inline(always)] - pub fn div_assume_nonzero(&self, other: &Secp256k1Fr) -> Self { - Secp256k1Fr { - e: (Fr::new(BigInt(self.e)) / Fr::new(BigInt(other.e))) - .into_bigint() - .0, - } - } - /// checks other != 0 then calls div_assume_nonzero - #[cfg(feature = "host")] - #[inline(always)] - pub fn div(&self, other: &Secp256k1Fr) -> Self { - if other.is_zero() { - panic!("division by zero in Secp256k1Fr::div"); - } - self.div_assume_nonzero(other) - } - /// GLV scalar decomposition: returns (k1, k2) such that - /// self = k1 + k2 * lambda (mod r) and |k1|, |k2| < 2^128. - /// Each entry is (is_negative, abs_value). #[inline(always)] pub fn glv_decompose(&self) -> [(bool, u128); 2] { decompose_scalar_impl(self) } } +// Curve definition + #[derive(Clone)] pub struct Secp256k1Curve; @@ -724,8 +674,7 @@ impl Secp256k1PointExt for Secp256k1Point { ]) } - // returns lambda * self - // where lambda is 0x5363ad4cc05c30e0a5261c028812645a122e22ea20816678df02967c1b23bd72 + // lambda = 0x5363ad4cc05c30e0a5261c028812645a122e22ea20816678df02967c1b23bd72 #[inline(always)] fn endomorphism(&self) -> Secp256k1Point { if self.is_infinity() { @@ -737,17 +686,20 @@ impl Secp256k1PointExt for Secp256k1Point { 0x6e64479eac3434e9, 0x7ae96a2b657c0710, ]); - Self::new_unchecked(self.x().mul(&beta), self.y()) + Self::new_unchecked(self.x() * &beta, self.y()) } } } +// GLV scalar decomposition + #[cfg(all( not(feature = "host"), any(target_arch = "riscv32", target_arch = "riscv64") ))] fn decompose_scalar_impl(k: &Secp256k1Fr) -> [(bool, u128); 2] { let mut out = [0u64; 6]; + // SAFETY: inline instruction writes exactly 6 u64s to `out` unsafe { use crate::{INLINE_OPCODE, SECP256K1_FUNCT7, SECP256K1_GLVR_ADV_FUNCT3}; core::arch::asm!( @@ -770,13 +722,13 @@ fn decompose_scalar_impl(k: &Secp256k1Fr) -> [(bool, u128); 2] { ]); let mut k1 = Secp256k1Fr::from_u64_arr_unchecked(&[out[1], out[2], 0u64, 0u64]); if k1_sign { - k1 = k1.neg(); + k1 = -k1; } let mut k2 = Secp256k1Fr::from_u64_arr_unchecked(&[out[4], out[5], 0u64, 0u64]); if k2_sign { - k2 = k2.neg(); + k2 = -k2; } - let recomposed_k = k1.add(&k2.mul(&lambda)); + let recomposed_k = &k1 + &k2.mul(&lambda); if recomposed_k != *k { spoil_proof(); } @@ -800,7 +752,7 @@ fn decompose_scalar_impl(k: &Secp256k1Fr) -> [(bool, u128); 2] { crate::glv::decompose_scalar(k) } -// ECDSA signature verification function + helpers +// ECDSA signature verification #[inline(always)] fn scalars_to_index(scalars: &[u128; 4], bit_index: usize) -> usize { @@ -813,8 +765,6 @@ fn scalars_to_index(scalars: &[u128; 4], bit_index: usize) -> usize { idx } -// performs a 4x128-bit scalar multiplication -// first two points assumed to be generator and 2^128 * generator #[inline(always)] fn secp256k1_4x128_inner_scalar_mul( scalars: [u128; 4], @@ -845,7 +795,6 @@ fn secp256k1_4x128_inner_scalar_mul( res } -// if cond is true, negate x, otherwise return x unchanged #[inline(always)] fn conditional_negate(x: Secp256k1Point, cond: bool) -> Secp256k1Point { if cond { @@ -855,12 +804,6 @@ fn conditional_negate(x: Secp256k1Point, cond: bool) -> Secp256k1Point { } } -/// verify an ECDSA signature -/// z is the hash of the message being signed -/// r and s are the signature components -/// q is the uncompressed public key point -/// returns Ok(()) if the signature is valid -/// returns Err(Secp256k1Error) if at any point, the verification fails #[inline(always)] pub fn ecdsa_verify( z: Secp256k1Fr, @@ -868,40 +811,29 @@ pub fn ecdsa_verify( s: Secp256k1Fr, q: Secp256k1Point, ) -> Result<(), Secp256k1Error> { - // step 0: check that q is not infinity if q.is_infinity() { return Result::Err(Secp256k1Error::QAtInfinity); } - // step 1: check that r and s are in the correct range if r.is_zero() || s.is_zero() { return Result::Err(Secp256k1Error::ROrSZero); } - // step 2: compute u1 = z / s (mod r) and u2 = r / s (mod r) let u1 = z.div_assume_nonzero(&s); let u2 = r.div_assume_nonzero(&s); - // step 3: compute R = u1 * G + u2 * q - // 3.1: perform the glv scalar decomposition let decomp_u = u1.as_u128_pair(); let decomp_v = u2.glv_decompose(); - // 3.2: get decomposed scalars as a 4x128-bit array let scalars = [decomp_u.0, decomp_u.1, decomp_v[0].1, decomp_v[1].1]; - // 3.3: prepare Q, and lambda*Q, appropriately negated let points = [ conditional_negate(q.clone(), decomp_v[0].0), conditional_negate(q.endomorphism(), decomp_v[1].0), ]; - // 3.4: perform the 4x128-bit scalar multiplication let r_claim = secp256k1_4x128_inner_scalar_mul(scalars, points); - // step 4: check that r == R.x mod n. - // We implement the `mod n` as a single conditional subtraction on the bigint: - // for secp256k1, `0 <= x_R < p` and `p < 2n`, so `x_R mod n` is either `x_R` or `x_R - n`. + // x_R mod n: for secp256k1, 0 <= x_R < p and p < 2n, so at most one subtraction let mut rx = r_claim.x(); if is_fr_non_canonical(&rx.e()) { - rx = rx.sub(&Secp256k1Fq::from_u64_arr_unchecked(&Fr::MODULUS.0)); + rx = &rx - &Secp256k1Fq::from_u64_arr_unchecked(&Fr::MODULUS.0); } if rx.e() != r.e() { return Result::Err(Secp256k1Error::RxMismatch); } - // if all checks passed, return Ok(()) Result::Ok(()) } diff --git a/jolt-inlines/secp256k1/src/sequence_builder.rs b/jolt-inlines/secp256k1/src/sequence_builder.rs index 7d52e2265d..3297f16860 100644 --- a/jolt-inlines/secp256k1/src/sequence_builder.rs +++ b/jolt-inlines/secp256k1/src/sequence_builder.rs @@ -38,8 +38,7 @@ impl GlvrAdvBuilder { let result = crate::glv::decompose_scalar_to_u64s(k); VecDeque::from(result.to_vec()) } - // inline sequence function - fn inline_sequence(mut self) -> Vec { + fn build(mut self) -> Vec { for i in 0..6 { self.asm.emit_j::(*self.vr, 0); self.asm @@ -229,8 +228,7 @@ impl MulqBuilder { } } } - // inline sequence function - fn inline_sequence(mut self) -> Vec { + fn build(mut self) -> Vec { // load a, b, and w for i in 0..4 { match self.op_type { @@ -579,37 +577,113 @@ impl MulqBuilder { } } -macro_rules! secp256k1_mulq_op { - ($name:ident, funct3: $funct3:expr, name: $op_name:expr, mul_type: $mul_type:expr, is_scalar: $is_scalar:expr) => { - pub struct $name; +pub struct Secp256k1MulQ; - impl InlineOp for $name { - const OPCODE: u32 = crate::INLINE_OPCODE; - const FUNCT3: u32 = $funct3; - const FUNCT7: u32 = crate::SECP256K1_FUNCT7; - const NAME: &'static str = $op_name; +impl InlineOp for Secp256k1MulQ { + const OPCODE: u32 = crate::INLINE_OPCODE; + const FUNCT3: u32 = crate::SECP256K1_MULQ_FUNCT3; + const FUNCT7: u32 = crate::SECP256K1_FUNCT7; + const NAME: &'static str = crate::SECP256K1_MULQ_NAME; + type Advice = VecDeque; - fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { - MulqBuilder::new(asm, operands, $mul_type, $is_scalar).inline_sequence() - } + fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { + MulqBuilder::new(asm, operands, MulqType::Mul, false).build() + } - fn build_advice( - asm: InstrAssembler, - operands: FormatInline, - cpu: &mut Cpu, - ) -> Option> { - Some(MulqBuilder::new(asm, operands, $mul_type, $is_scalar).advice(cpu)) - } - } - }; + fn build_advice(asm: InstrAssembler, operands: FormatInline, cpu: &mut Cpu) -> VecDeque { + MulqBuilder::new(asm, operands, MulqType::Mul, false).advice(cpu) + } +} + +pub struct Secp256k1SquareQ; + +impl InlineOp for Secp256k1SquareQ { + const OPCODE: u32 = crate::INLINE_OPCODE; + const FUNCT3: u32 = crate::SECP256K1_SQUAREQ_FUNCT3; + const FUNCT7: u32 = crate::SECP256K1_FUNCT7; + const NAME: &'static str = crate::SECP256K1_SQUAREQ_NAME; + type Advice = VecDeque; + + fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { + MulqBuilder::new(asm, operands, MulqType::Square, false).build() + } + + fn build_advice(asm: InstrAssembler, operands: FormatInline, cpu: &mut Cpu) -> VecDeque { + MulqBuilder::new(asm, operands, MulqType::Square, false).advice(cpu) + } } -secp256k1_mulq_op!(Secp256k1MulQ, funct3: crate::SECP256K1_MULQ_FUNCT3, name: crate::SECP256K1_MULQ_NAME, mul_type: MulqType::Mul, is_scalar: false); -secp256k1_mulq_op!(Secp256k1SquareQ, funct3: crate::SECP256K1_SQUAREQ_FUNCT3, name: crate::SECP256K1_SQUAREQ_NAME, mul_type: MulqType::Square, is_scalar: false); -secp256k1_mulq_op!(Secp256k1DivQ, funct3: crate::SECP256K1_DIVQ_FUNCT3, name: crate::SECP256K1_DIVQ_NAME, mul_type: MulqType::Div, is_scalar: false); -secp256k1_mulq_op!(Secp256k1MulR, funct3: crate::SECP256K1_MULR_FUNCT3, name: crate::SECP256K1_MULR_NAME, mul_type: MulqType::Mul, is_scalar: true); -secp256k1_mulq_op!(Secp256k1SquareR, funct3: crate::SECP256K1_SQUARER_FUNCT3, name: crate::SECP256K1_SQUARER_NAME, mul_type: MulqType::Square, is_scalar: true); -secp256k1_mulq_op!(Secp256k1DivR, funct3: crate::SECP256K1_DIVR_FUNCT3, name: crate::SECP256K1_DIVR_NAME, mul_type: MulqType::Div, is_scalar: true); +pub struct Secp256k1DivQ; + +impl InlineOp for Secp256k1DivQ { + const OPCODE: u32 = crate::INLINE_OPCODE; + const FUNCT3: u32 = crate::SECP256K1_DIVQ_FUNCT3; + const FUNCT7: u32 = crate::SECP256K1_FUNCT7; + const NAME: &'static str = crate::SECP256K1_DIVQ_NAME; + type Advice = VecDeque; + + fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { + MulqBuilder::new(asm, operands, MulqType::Div, false).build() + } + + fn build_advice(asm: InstrAssembler, operands: FormatInline, cpu: &mut Cpu) -> VecDeque { + MulqBuilder::new(asm, operands, MulqType::Div, false).advice(cpu) + } +} + +pub struct Secp256k1MulR; + +impl InlineOp for Secp256k1MulR { + const OPCODE: u32 = crate::INLINE_OPCODE; + const FUNCT3: u32 = crate::SECP256K1_MULR_FUNCT3; + const FUNCT7: u32 = crate::SECP256K1_FUNCT7; + const NAME: &'static str = crate::SECP256K1_MULR_NAME; + type Advice = VecDeque; + + fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { + MulqBuilder::new(asm, operands, MulqType::Mul, true).build() + } + + fn build_advice(asm: InstrAssembler, operands: FormatInline, cpu: &mut Cpu) -> VecDeque { + MulqBuilder::new(asm, operands, MulqType::Mul, true).advice(cpu) + } +} + +pub struct Secp256k1SquareR; + +impl InlineOp for Secp256k1SquareR { + const OPCODE: u32 = crate::INLINE_OPCODE; + const FUNCT3: u32 = crate::SECP256K1_SQUARER_FUNCT3; + const FUNCT7: u32 = crate::SECP256K1_FUNCT7; + const NAME: &'static str = crate::SECP256K1_SQUARER_NAME; + type Advice = VecDeque; + + fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { + MulqBuilder::new(asm, operands, MulqType::Square, true).build() + } + + fn build_advice(asm: InstrAssembler, operands: FormatInline, cpu: &mut Cpu) -> VecDeque { + MulqBuilder::new(asm, operands, MulqType::Square, true).advice(cpu) + } +} + +pub struct Secp256k1DivR; + +impl InlineOp for Secp256k1DivR { + const OPCODE: u32 = crate::INLINE_OPCODE; + const FUNCT3: u32 = crate::SECP256K1_DIVR_FUNCT3; + const FUNCT7: u32 = crate::SECP256K1_FUNCT7; + const NAME: &'static str = crate::SECP256K1_DIVR_NAME; + type Advice = VecDeque; + + fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { + MulqBuilder::new(asm, operands, MulqType::Div, true).build() + } + + fn build_advice(asm: InstrAssembler, operands: FormatInline, cpu: &mut Cpu) -> VecDeque { + MulqBuilder::new(asm, operands, MulqType::Div, true).advice(cpu) + } +} pub struct Secp256k1GlvrAdv; @@ -618,16 +692,13 @@ impl InlineOp for Secp256k1GlvrAdv { const FUNCT3: u32 = crate::SECP256K1_GLVR_ADV_FUNCT3; const FUNCT7: u32 = crate::SECP256K1_FUNCT7; const NAME: &'static str = crate::SECP256K1_GLVR_ADV_NAME; + type Advice = VecDeque; fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { - GlvrAdvBuilder::new(asm, operands).inline_sequence() + GlvrAdvBuilder::new(asm, operands).build() } - fn build_advice( - asm: InstrAssembler, - operands: FormatInline, - cpu: &mut Cpu, - ) -> Option> { - Some(GlvrAdvBuilder::new(asm, operands).advice(cpu)) + fn build_advice(asm: InstrAssembler, operands: FormatInline, cpu: &mut Cpu) -> VecDeque { + GlvrAdvBuilder::new(asm, operands).advice(cpu) } } diff --git a/jolt-inlines/sha2/src/exec.rs b/jolt-inlines/sha2/src/exec.rs deleted file mode 100644 index 50b9f6c5a8..0000000000 --- a/jolt-inlines/sha2/src/exec.rs +++ /dev/null @@ -1,68 +0,0 @@ -use crate::sequence_builder::{BLOCK, K}; - -pub fn execute_sha256_compression(initial_state: [u32; 8], input: [u32; 16]) -> [u32; 8] { - let mut a = initial_state[0]; - let mut b = initial_state[1]; - let mut c = initial_state[2]; - let mut d = initial_state[3]; - let mut e = initial_state[4]; - let mut f = initial_state[5]; - let mut g = initial_state[6]; - let mut h = initial_state[7]; - - let mut w = [0u32; 64]; - - w[..16].copy_from_slice(&input); - - // Calculate word schedule - for i in 16..64 { - // σ₁(w[i-2]) + w[i-7] + σ₀(w[i-15]) + w[i-16] - let s0 = w[i - 15].rotate_right(7) ^ w[i - 15].rotate_right(18) ^ (w[i - 15] >> 3); - let s1 = w[i - 2].rotate_right(17) ^ w[i - 2].rotate_right(19) ^ (w[i - 2] >> 10); - w[i] = w[i - 16] - .wrapping_add(s0) - .wrapping_add(w[i - 7]) - .wrapping_add(s1); - } - - // Perform 64 rounds - for i in 0..64 { - let ch = (e & f) ^ ((!e) & g); - let maj = (a & b) ^ (a & c) ^ (b & c); - - let sigma0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22); // Σ₀(a) - let sigma1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25); // Σ₁(e) - - let t1 = h - .wrapping_add(sigma1) - .wrapping_add(ch) - .wrapping_add(K[i] as u32) - .wrapping_add(w[i]); - let t2 = sigma0.wrapping_add(maj); - - h = g; - g = f; - f = e; - e = d.wrapping_add(t1); - d = c; - c = b; - b = a; - a = t1.wrapping_add(t2); - } - - // Final IV addition - [ - initial_state[0].wrapping_add(a), - initial_state[1].wrapping_add(b), - initial_state[2].wrapping_add(c), - initial_state[3].wrapping_add(d), - initial_state[4].wrapping_add(e), - initial_state[5].wrapping_add(f), - initial_state[6].wrapping_add(g), - initial_state[7].wrapping_add(h), - ] -} - -pub fn execute_sha256_compression_initial(input: [u32; 16]) -> [u32; 8] { - execute_sha256_compression(BLOCK.map(|x| x as u32), input) -} diff --git a/jolt-inlines/sha2/src/lib.rs b/jolt-inlines/sha2/src/lib.rs index f1a8384fe7..cb9185f13e 100644 --- a/jolt-inlines/sha2/src/lib.rs +++ b/jolt-inlines/sha2/src/lib.rs @@ -18,10 +18,10 @@ pub const SHA256_INIT_NAME: &str = "SHA256_INIT_INLINE"; pub mod sdk; pub use sdk::*; -#[cfg(feature = "host")] -pub mod exec; #[cfg(feature = "host")] pub mod sequence_builder; +#[cfg(feature = "host")] +pub mod spec; #[cfg(feature = "host")] mod host; diff --git a/jolt-inlines/sha2/src/sdk.rs b/jolt-inlines/sha2/src/sdk.rs index ef661d40be..fabd51637b 100644 --- a/jolt-inlines/sha2/src/sdk.rs +++ b/jolt-inlines/sha2/src/sdk.rs @@ -356,11 +356,11 @@ pub(crate) unsafe fn sha256_compression(input: *const u32, state: *mut u32) { /// - The memory regions must not overlap #[cfg(feature = "host")] pub(crate) unsafe fn sha256_compression(input: *const u32, state: *mut u32) { - use crate::exec; - + use crate::sequence_builder::Sha256Compression; + use jolt_inlines_sdk::spec::InlineSpec; let input_array = *(input as *const [u32; 16]); let state_array = *(state as *const [u32; 8]); - let result = exec::execute_sha256_compression(state_array, input_array); + let result = Sha256Compression::reference(&(state_array, input_array)); std::ptr::copy_nonoverlapping(result.as_ptr(), state, 8) } @@ -417,10 +417,10 @@ pub(crate) unsafe fn sha256_compression_initial(input: *const u32, state: *mut u /// - The memory regions must not overlap #[cfg(feature = "host")] pub(crate) unsafe fn sha256_compression_initial(input: *const u32, state: *mut u32) { - use crate::exec; - - let input = *(input as *const [u32; 16]); - let result = exec::execute_sha256_compression_initial(input); + use crate::sequence_builder::Sha256CompressionInitial; + use jolt_inlines_sdk::spec::InlineSpec; + let input_array = *(input as *const [u32; 16]); + let result = Sha256CompressionInitial::reference(&input_array); std::ptr::copy_nonoverlapping(result.as_ptr(), state, 8) } diff --git a/jolt-inlines/sha2/src/sequence_builder.rs b/jolt-inlines/sha2/src/sequence_builder.rs index 5019a2a347..6c1633263a 100644 --- a/jolt-inlines/sha2/src/sequence_builder.rs +++ b/jolt-inlines/sha2/src/sequence_builder.rs @@ -349,6 +349,8 @@ impl InlineOp for Sha256Compression { const FUNCT7: u32 = crate::SHA256_FUNCT7; const NAME: &'static str = crate::SHA256_NAME; + type Advice = (); + fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { Sha256SequenceBuilder::new(asm, operands, false).build() } @@ -362,6 +364,8 @@ impl InlineOp for Sha256CompressionInitial { const FUNCT7: u32 = crate::SHA256_INIT_FUNCT7; const NAME: &'static str = crate::SHA256_INIT_NAME; + type Advice = (); + fn build_sequence(asm: InstrAssembler, operands: FormatInline) -> Vec { Sha256SequenceBuilder::new(asm, operands, true).build() } diff --git a/jolt-inlines/sha2/src/spec.rs b/jolt-inlines/sha2/src/spec.rs new file mode 100644 index 0000000000..76443c53c9 --- /dev/null +++ b/jolt-inlines/sha2/src/spec.rs @@ -0,0 +1,123 @@ +use crate::sequence_builder::{Sha256Compression, Sha256CompressionInitial, BLOCK, K}; +use jolt_inlines_sdk::host::Xlen; +use jolt_inlines_sdk::spec::rand::rngs::StdRng; +use jolt_inlines_sdk::spec::rand::Rng; +use jolt_inlines_sdk::spec::{InlineMemoryLayout, InlineSpec, InlineTestHarness}; + +fn create_harness() -> InlineTestHarness { + let layout = InlineMemoryLayout::single_input(64, 32); + InlineTestHarness::new(layout, Xlen::Bit64) +} + +impl InlineSpec for Sha256Compression { + type Input = ([u32; 8], [u32; 16]); + type Output = [u32; 8]; + + fn random_input(rng: &mut StdRng) -> Self::Input { + ( + core::array::from_fn(|_| rng.gen()), + core::array::from_fn(|_| rng.gen()), + ) + } + + fn reference(input: &Self::Input) -> Self::Output { + let (initial_state, input) = input; + + let mut a = initial_state[0]; + let mut b = initial_state[1]; + let mut c = initial_state[2]; + let mut d = initial_state[3]; + let mut e = initial_state[4]; + let mut f = initial_state[5]; + let mut g = initial_state[6]; + let mut h = initial_state[7]; + + let mut w = [0u32; 64]; + + w[..16].copy_from_slice(input); + + for i in 16..64 { + let s0 = w[i - 15].rotate_right(7) ^ w[i - 15].rotate_right(18) ^ (w[i - 15] >> 3); + let s1 = w[i - 2].rotate_right(17) ^ w[i - 2].rotate_right(19) ^ (w[i - 2] >> 10); + w[i] = w[i - 16] + .wrapping_add(s0) + .wrapping_add(w[i - 7]) + .wrapping_add(s1); + } + + for i in 0..64 { + let ch = (e & f) ^ ((!e) & g); + let maj = (a & b) ^ (a & c) ^ (b & c); + + let sigma0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22); + let sigma1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25); + + let t1 = h + .wrapping_add(sigma1) + .wrapping_add(ch) + .wrapping_add(K[i] as u32) + .wrapping_add(w[i]); + let t2 = sigma0.wrapping_add(maj); + + h = g; + g = f; + f = e; + e = d.wrapping_add(t1); + d = c; + c = b; + b = a; + a = t1.wrapping_add(t2); + } + + [ + initial_state[0].wrapping_add(a), + initial_state[1].wrapping_add(b), + initial_state[2].wrapping_add(c), + initial_state[3].wrapping_add(d), + initial_state[4].wrapping_add(e), + initial_state[5].wrapping_add(f), + initial_state[6].wrapping_add(g), + initial_state[7].wrapping_add(h), + ] + } + + fn create_harness() -> InlineTestHarness { + create_harness() + } + + fn load(harness: &mut InlineTestHarness, input: &Self::Input) { + harness.setup_registers(); + harness.load_input32(&input.1); + harness.load_state32(&input.0); + } + + fn read(harness: &mut InlineTestHarness) -> Self::Output { + harness.read_output32(8).try_into().unwrap() + } +} + +impl InlineSpec for Sha256CompressionInitial { + type Input = [u32; 16]; + type Output = [u32; 8]; + + fn random_input(rng: &mut StdRng) -> Self::Input { + core::array::from_fn(|_| rng.gen()) + } + + fn reference(input: &Self::Input) -> Self::Output { + Sha256Compression::reference(&(BLOCK.map(|x| x as u32), *input)) + } + + fn create_harness() -> InlineTestHarness { + create_harness() + } + + fn load(harness: &mut InlineTestHarness, input: &Self::Input) { + harness.setup_registers(); + harness.load_input32(input); + } + + fn read(harness: &mut InlineTestHarness) -> Self::Output { + harness.read_output32(8).try_into().unwrap() + } +} diff --git a/jolt-inlines/sha2/src/tests.rs b/jolt-inlines/sha2/src/tests.rs index 561f41deae..ed31151d64 100644 --- a/jolt-inlines/sha2/src/tests.rs +++ b/jolt-inlines/sha2/src/tests.rs @@ -20,12 +20,11 @@ pub fn instruction_sha256init() -> tracer::instruction::inline::INLINE { } mod exec_functions { - use crate::exec::{execute_sha256_compression, execute_sha256_compression_initial}; - use crate::sequence_builder::BLOCK; + use crate::sequence_builder::{Sha256Compression, Sha256CompressionInitial, BLOCK}; + use jolt_inlines_sdk::spec::InlineSpec; #[test] fn test_exec_sha256_compression_function() { - // Test with standard test vectors let input = [ 0x61626380, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, @@ -34,9 +33,8 @@ mod exec_functions { let initial_state = BLOCK.map(|x| x as u32); - let result = execute_sha256_compression(initial_state, input); + let result = Sha256Compression::reference(&(initial_state, input)); - // Expected result for SHA-256("abc") let expected = [ 0xba7816bf, 0x8f01cfea, 0x414140de, 0x5dae2223, 0xb00361a3, 0x96177a9c, 0xb410ff61, 0xf20015ad, @@ -50,16 +48,14 @@ mod exec_functions { #[test] fn test_exec_sha256_compression_initial_function() { - // Test the initial compression function let input = [ 0x61626380, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000018, ]; - let result = execute_sha256_compression_initial(input); + let result = Sha256CompressionInitial::reference(&input); - // Expected result for SHA-256("abc") let expected = [ 0xba7816bf, 0x8f01cfea, 0x414140de, 0x5dae2223, 0xb00361a3, 0x96177a9c, 0xb410ff61, 0xf20015ad, @@ -73,26 +69,22 @@ mod exec_functions { #[test] fn test_exec_sha256_multi_block() { - // Test with a two-block message - // First block let input1 = [ 0x61626364, 0x62636465, 0x63646566, 0x64656667, 0x65666768, 0x66676869, 0x6768696a, 0x68696a6b, 0x696a6b6c, 0x6a6b6c6d, 0x6b6c6d6e, 0x6c6d6e6f, 0x6d6e6f70, 0x6e6f7071, 0x80000000, 0x00000000, ]; - let state1 = execute_sha256_compression_initial(input1); + let state1 = Sha256CompressionInitial::reference(&input1); - // Second block with padding and length let input2 = [ 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x000001c0, ]; - let result = execute_sha256_compression(state1, input2); + let result = Sha256Compression::reference(&(state1, input2)); - // Expected result for SHA-256("abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq") let expected = [ 0x248d6a61, 0xd20638b8, 0xe5c02693, 0x0c3e6039, 0xa33ce459, 0x64ff2167, 0xf6ecedd4, 0x19db06c1,