diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7aa166de..5715a7f5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions-rust-lang/setup-rust-toolchain@v1 with: components: rustfmt - name: Check formatting @@ -31,20 +31,20 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions-rust-lang/setup-rust-toolchain@v1 with: components: clippy - name: Clippy (all features) - run: cargo clippy -q --message-format=short --all-features --all-targets -- -D warnings + run: cargo clippy --all --all-targets --all-features -- -D warnings - name: Clippy (no default features) - run: cargo clippy -q --message-format=short --no-default-features --lib -- -D warnings + run: cargo clippy --all --all-targets --no-default-features -- -D warnings doc: name: Documentation runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Build documentation run: cargo doc -q --no-deps --all-features env: @@ -55,8 +55,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Install cargo-nextest uses: taiki-e/install-action@nextest - name: Run tests - run: cargo nextest run -q --all-features + run: cargo nextest run --all-features diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..68f667fa --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,35 @@ +# AGENTS.md + +**Compatibility notice (explicit): This repo makes NO backward-compatibility guarantees. Breaking changes are allowed and expected.** + +## Project Overview + +Hachi is a lattice-based polynomial commitment scheme (PCS) with transparent setup and post-quantum security. Built in Rust. Intended to replace Dory in Jolt. + +## Essential Commands + +```bash +cargo clippy --all --message-format=short -q -- -D warnings +cargo fmt -q +cargo test # no nextest yet +``` + +## Crate Structure + +Two workspace members: `hachi-pcs` (root) and `derive` (proc macros). + +- `src/primitives/` — Core traits: `FieldCore`, `Module`, `MultilinearLagrange`, `Transcript`, serialization +- `src/algebra/` — Concrete backends: prime fields, extension fields, cyclotomic rings, NTT, domains +- `src/protocol/` — Protocol layer: commitment, prover, verifier, opening (ring-switch), challenges, transcript +- `src/error.rs` — Error types + +## Key Abstractions + +- `CommitmentScheme` / `StreamingCommitmentScheme` — top-level PCS traits +- `FieldCore` + `PseudoMersenneField` + `Module` — arithmetic over lattice-friendly fields and rings +- `MultilinearLagrange` — multilinear polynomial in Lagrange basis +- `Transcript` — Fiat-Shamir + +## Feature Flags + +- `parallel` — Rayon parallelization diff --git a/CLAUDE.md b/CLAUDE.md new file mode 120000 index 00000000..47dc3e3d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/CONSTANT_TIME_NOTES.md b/CONSTANT_TIME_NOTES.md new file mode 100644 index 00000000..dac24600 --- /dev/null +++ b/CONSTANT_TIME_NOTES.md @@ -0,0 +1,42 @@ +# Constant-Time Review Notes (Phase 0/1 Algebra) + +This note tracks timing-sensitive implementation decisions for the current +algebra and ring stack. + +## Reviewed Components + +- `src/algebra/fields/fp32.rs` +- `src/algebra/fields/fp64.rs` +- `src/algebra/fields/fp128.rs` +- `src/algebra/ntt/prime.rs` +- `src/algebra/ntt/butterfly.rs` +- `src/algebra/ring/cyclotomic.rs` +- `src/algebra/ring/crt_ntt_repr.rs` + +## Current State + +- Branchless primitives are in place for: + - `Fp32/Fp64/Fp128` add/sub/neg raw helpers. + - `Fp128` multiplication reduction (`reduce_u256`) with branchless conditional subtract. + - `Fp32/Fp64` multiplication reduction (division-free fixed-iteration paths). + - NTT helper operations `csubp`, `caddp`, and `center`. +- NTT butterfly arithmetic runs in fixed loop structure independent of data. +- Ring multiplication (`CyclotomicRing`) is fixed-structure schoolbook over `D`. +- CRT reconstruction inner accumulation now uses fixed-trip, branchless + modular add/mul-by-small-factor helpers. +- Prime fields now expose `Invertible::inv_or_zero()` for secret-bearing + inversion use-cases without input-dependent branching on zero. +- CRT reconstruction final projection now uses a division-free fixed-iteration + reducer (`reduce_u128_divfree`) instead of `% q`. + +## Known Timing Risks / Follow-ups + +- `FieldCore::inv()` still returns `Option` and therefore branches on zero; + treat that API as public-value oriented. Use `Invertible::inv_or_zero()` + in secret-dependent paths. + +## Action Items Before Production-Critical Use + +1. Wire secret-bearing call sites to `Invertible::inv_or_zero()` as + protocol code matures. +2. Add dedicated CT review tests/checklists for any arithmetic subsystem changes. diff --git a/Cargo.lock b/Cargo.lock index a505ce5c..e184f5cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,18 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -11,6 +23,33 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocative" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fac2ce611db8b8cee9b2aa886ca03c924e9da5e5295d0dbd0526e5d0b0710f7" +dependencies = [ + "allocative_derive", + "ctor", +] + +[[package]] +name = "allocative_derive" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe233a377643e0fc1a56421d7c90acdec45c291b30345eb9f08e8d0ddce5a4ab" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "anes" version = "0.1.6" @@ -23,12 +62,154 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +[[package]] +name = "ark-bn254" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-serialize", + "ark-std", +] + +[[package]] +name = "ark-ec" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "ahash", + "ark-ff", + "ark-poly", + "ark-serialize", + "ark-std", + "educe", + "fnv", + "hashbrown", + "itertools 0.13.0", + "num-bigint", + "num-integer", + "num-traits", + "zeroize", +] + +[[package]] +name = "ark-ff" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "allocative", + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std", + "arrayvec", + "digest", + "educe", + "itertools 0.13.0", + "num-bigint", + "num-traits", + "paste", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "quote", + "syn 2.0.110", +] + +[[package]] +name = "ark-ff-macros" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "ark-poly" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "ahash", + "ark-ff", + "ark-serialize", + "ark-std", + "educe", + "fnv", + "hashbrown", +] + +[[package]] +name = "ark-serialize" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "ark-serialize-derive", + "ark-std", + "arrayvec", + "digest", + "num-bigint", +] + +[[package]] +name = "ark-serialize-derive" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "ark-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -99,6 +280,15 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "criterion" version = "0.5.1" @@ -111,7 +301,7 @@ dependencies = [ "clap", "criterion-plot", "is-terminal", - "itertools", + "itertools 0.10.5", "num-traits", "once_cell", "oorandom", @@ -132,7 +322,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" dependencies = [ "cast", - "itertools", + "itertools 0.10.5", ] [[package]] @@ -166,12 +356,91 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "ctor" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d2301688392eb071b0bf1a37be05c469d3cc4dbbd95df672fe28ab021e6a096" +dependencies = [ + "quote", + "syn 1.0.109", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.110", +] + [[package]] name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "enum-ordinalize" +version = "4.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a1091a7bb1f8f2c4b28f1fe2cef4980ca2d410a3d727d67ecc3178c9b0800f0" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ca9601fb2d62598ee17836250842873a413586e5d7ed88b356e38ddbb0ec631" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -189,18 +458,23 @@ version = "0.1.0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", ] [[package]] name = "hachi-pcs" version = "0.1.0" dependencies = [ + "ark-bn254", + "ark-ff", + "blake2", "criterion", "hachi-derive", + "num-bigint", "rand", "rand_core", "rayon", + "sha3", "thiserror", "tracing", ] @@ -216,6 +490,15 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "allocator-api2", +] + [[package]] name = "hermit-abi" version = "0.5.2" @@ -242,6 +525,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.15" @@ -258,6 +550,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "keccak" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" +dependencies = [ + "cpufeatures", +] + [[package]] name = "libc" version = "0.2.177" @@ -270,6 +571,25 @@ version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -291,6 +611,12 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -479,7 +805,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", ] [[package]] @@ -495,6 +821,33 @@ dependencies = [ "serde_core", ] +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.110" @@ -523,7 +876,7 @@ checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", ] [[package]] @@ -555,7 +908,7 @@ checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", ] [[package]] @@ -567,12 +920,24 @@ dependencies = [ "once_cell", ] +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicode-ident" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "walkdir" version = "2.5.0" @@ -621,7 +986,7 @@ dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn", + "syn 2.0.110", "wasm-bindgen-shared", ] @@ -685,5 +1050,25 @@ checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", ] diff --git a/Cargo.toml b/Cargo.toml index 47abb4d8..ab91eead 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ resolver = "2" name = "hachi-pcs" version = "0.1.0" edition = "2021" -rust-version = "1.75" +rust-version = "1.88" authors = [ "Markos Georghiades ", ] @@ -32,19 +32,44 @@ include = [ all-features = true [features] -default = [] +default = ["parallel"] parallel = ["dep:rayon"] [dependencies] thiserror = "2.0" -rand_core = "0.6" +rand_core = { version = "0.6", features = ["getrandom"] } hachi-derive = { version = "0.1.0", path = "derive" } tracing = "0.1" rayon = { version = "1.10", optional = true } +blake2 = "0.10.6" +sha3 = "0.10.8" [dev-dependencies] rand = "0.8" criterion = { version = "0.5", features = ["html_reports"] } +num-bigint = "0.4.6" +ark-bn254 = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout", features = ["scalar_field"] } +ark-ff = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" } + +[[bench]] +name = "ring_ntt" +harness = false + +[[bench]] +name = "field_arith" +harness = false + +[[bench]] +name = "fp64_reduce_probe" +harness = false + +[[bench]] +name = "hachi_e2e" +harness = false + +[[bench]] +name = "norm_sumcheck" +harness = false [lints.rust] missing_docs = "warn" diff --git a/HACHI_PROGRESS.md b/HACHI_PROGRESS.md new file mode 100644 index 00000000..afca63e6 --- /dev/null +++ b/HACHI_PROGRESS.md @@ -0,0 +1,200 @@ +## Hachi PCS implementation progress + +This file is the **single source of truth** for implementation status and near-term priorities. + +### Goals (project-level) + +- **Production-ready implementation**: correctness, security, maintainability, and performance are first-class goals. +- **Standalone codebase**: implementation and comments should stand on their own; external acknowledgements live in `README.md`. +- **Constant-time cryptographic core**: arithmetic and protocol-critical paths must be constant-time with respect to secret data. +- **No shortcuts / no fallback design**: avoid temporary or degraded code paths in the core implementation. + +### Non-negotiable requirements + +- **Constant-time discipline** + - No secret-dependent branches or memory access patterns in cryptographic hot paths. + - No secret-indexed table lookups; table access patterns must be independent of secret data. + - Keep data representations and reductions explicit and auditable for timing behavior. + - Add targeted tests/reviews for constant-time-sensitive code as features land. +- **Code quality bar** + - Clear naming, explicit invariants, small cohesive modules, and API docs for public interfaces. + - No placeholder crypto logic in mainline code (no "temporary" arithmetic shortcuts). + - Tests are required for correctness-critical arithmetic before dependent protocol code is built. + - No section-banner comments (e.g., `// ---- Section ----`, `// === ... ===`). Let the code and doc-comments speak for themselves. +- **Standalone implementation policy** + - Do not mention external inspirations/ports in core code comments. + - Keep terminology and structure internally coherent and project-native. + - Keep external attribution limited to dedicated docs (for now: `README.md` acknowledgements). +- **Git discipline** + - Do not commit or push without explicit user approval. + +### Implementation workflow (cautious + approval-driven) + +- Before each major subsystem, present implementation options with trade-offs. +- Seek explicit approval before proceeding with a selected option. +- Pause at milestone boundaries for review and feedback before continuing. +- Prefer slow, verifiable progress over rapid, high-risk changes. +- Ask for user input frequently when requirements are ambiguous or involve design trade-offs. + +### Definition of Done (all crypto-critical work) + +- **Security / constant-time** + - Secret-independent control flow and memory access in cryptographic paths. + - Constant-time review notes included for non-trivial arithmetic/ring changes. +- **Correctness** + - Unit tests for edge cases and algebraic identities. + - Cross-check vectors/reference checks added where practical. +- **Code quality** + - Clear naming, explicit invariants, and no placeholder logic in core paths. + - Public interfaces documented sufficiently for safe usage. +- **Performance** + - Hot-path performance impact evaluated (benchmark or measured rationale). +- **Tooling + CI** + - `cargo fmt --all --check` passes. + - `cargo clippy --all --all-targets --all-features` passes. + - `cargo test` (or targeted suite for touched modules) passes. +- **Process** + - Implementation options reviewed with user before major subsystem changes. + - Milestone update recorded in this file. + +### Scope (current) + +- **Implemented so far (Phase 0 + Phase 1 functional core)**: prime fields (32/64/128-bit representations), extension fields, cyclotomic `R_q = Z_q[X]/(X^d + 1)`, CRT+NTT representation, backend/domain layering, ring automorphisms, and functional gadget decomposition. +- **Phase 2+ protocol status**: interface scaffold plus ring-native §4.1 commitment core are present (`Transcript`, Blake2b/Keccak backends, phase-grounded labels, `RingCommitmentScheme`, config layer, and setup/commit implementation). Sumcheck core building blocks (univariate messages + transcript-driving prover/verifier driver) are now implemented, with tests. Open-check prover/verifier paths remain stubbed. +- **Deferred future phase**: integration into Jolt (replacement of Dory with Hachi) is intentionally out of current execution scope; cross-repo analysis is design input only. + +### Critical review snapshot (2026-02-13) + +- **Phase 1 functional milestone appears complete** + - Ring/gadget components listed in Phase 1 are implemented and currently checked off. + - Conversion and arithmetic paths in coefficient and CRT+NTT domains are exercised by passing tests. +- **Not yet "production-ready" despite functional completion** + - Constant-time hardening follow-ups narrowed: secret-bearing call-sites still need to migrate from `FieldCore::inv()` to `Invertible::inv_or_zero()` as protocol code lands (see `CONSTANT_TIME_NOTES.md`). + - Current ring multiplication in coefficient form remains `O(D^2)` schoolbook (`src/algebra/ring/cyclotomic.rs`), with CRT+NTT available as the faster domain path. +- **Tooling/quality gate status (current branch snapshot)** + - `cargo test` passes, including protocol transcript/label/commitment contract tests and new ring-commitment core/config/stub tests. + - `cargo fmt --all --check` passes. + - `cargo clippy --all --all-targets --all-features` passes. +- **Phase 2 scaffold + commitment core landed; proof-system work still pending** + - `src/protocol/*` now provides transcript + commitment abstraction boundaries with `Transcript` naming. + - Two transcript backends are wired (`Blake2bTranscript`, `KeccakTranscript`) with deterministic replay/order/reset tests. + - Hachi-native labels are now calibrated to paper-stage phases (§4.1, §4.2, §4.3, §4.5). + - Commitment absorption is label-directed at call sites (`AppendToTranscript` no longer hardcodes commitment labels). + - Ring-native commitment setup/commit flow for §4.1 is implemented in `src/protocol/commitment/commit.rs` behind `RingCommitmentScheme`. + - Sumcheck core module landed (`src/protocol/sumcheck.rs`) with unit/integration tests (`tests/sumcheck_core.rs`, `tests/sumcheck_prover_driver.rs`). + - Prover/verifier split folders are wired with explicit stubs (`src/protocol/prover/stub.rs`, `src/protocol/verifier/stub.rs`) for future open-check implementation. +- **Conclusion** + - Treat **Phase 1 as functionally complete**. + - Treat **Phase 2 as active/in-progress** (commitment core implemented; prove/verify and later reductions still open). + - Remaining strict CT follow-ups stay tracked in `CONSTANT_TIME_NOTES.md`. + +### Status board + +#### Phase 0 — Algebra + +- [x] Prime field `Fp32` (u32 storage; u64 mul) implementing `FieldCore + CanonicalField` (`src/algebra/fields/fp32.rs`) +- [x] Prime field `Fp64` (u64 storage; u128 mul) implementing `FieldCore + CanonicalField` (`src/algebra/fields/fp64.rs`) +- [x] Prime field `Fp128` (u128 storage; 256-bit intermediate) implementing `FieldCore + CanonicalField` (`src/algebra/fields/fp128.rs`, `src/algebra/fields/u256.rs`) +- [x] Branchless constant-time `add_raw`, `sub_raw`, `neg` for all field types +- [x] Constant-time inversion helper for prime fields: `Invertible::inv_or_zero()` (`src/primitives/arithmetic.rs`, `src/algebra/fields/fp*.rs`) +- [x] Division-free fixed-iteration reduction for `Fp32/Fp64` multiplication paths +- [x] Division-free fixed-iteration CRT final projection (replaced `% q` in scalar reconstruction path) +- [x] Rejection-sampled `FieldSampling::sample()` for all field types (no modular bias) +- [x] Pow2Offset pseudo-Mersenne registry + aliases (`q = 2^k - offset`, bounded `k <= 128`, `q % 8 == 5`) (`src/algebra/fields/pseudo_mersenne.rs`) +- [x] Constant-time review notes for current algebra/ring paths (`CONSTANT_TIME_NOTES.md`) +- [x] Deterministic parameter presets + - [x] `q = 2^32 - 99` constants scaffold (`src/algebra/ntt/tables.rs`) + - [x] `Pow2Offset` presets selected for 64/128-bit path: + - `q = 2^64 - 59` (`POW2_OFFSET_MODULUS_64`) + - `q = 2^128 - 275` (`POW2_OFFSET_MODULUS_128`) + - source: `src/algebra/fields/pseudo_mersenne.rs` +- [x] `Module` implementations: + - [x] `VectorModule` (fixed-length vectors; `Module` via scalar*vector mul) (`src/algebra/module.rs`) + - [x] `PolyModule` removed from current scope (not needed for near-term Hachi milestones) +- [ ] Extension fields: + - [x] `Fp2` quadratic extension (`src/algebra/fields/ext.rs`) + - [x] `Fp4` tower extension (`src/algebra/fields/ext.rs`) +- [x] Serialization for algebra types (`HachiSerialize` / `HachiDeserialize`) (+ `u128/i128` primitives in `src/primitives/serialization.rs`) +- [x] NTT small-prime arithmetic: Montgomery-like `fpmul`, Barrett-like `fpred`, branchless `csubq`/`caddq`/`center` (`src/algebra/ntt/prime.rs`) +- [x] CRT limb arithmetic: `LimbQ`, `QData` (`src/algebra/ntt/crt.rs`) +- [x] Tests (49 total in `tests/algebra.rs`): + - [x] field arithmetic, identities, distributivity (Fp32/Fp64/Fp128) + - [x] zero inversion returns None + - [x] serialization round-trips (all field types, extensions, Poly, VectorModule) + - [x] Fp2 conjugate, norm, distributivity + - [x] U256 wide multiply and bit access + - [x] LimbQ round-trip, add/sub inverse, QData consistency + - [x] NTT normalize range, fpmul commutativity + - [x] Poly add/sub/neg + - [x] Cyclotomic ring identities and serialization (D=4, D=64) + - [x] NTT forward/inverse round-trips (single prime and all Q32 primes) + - [x] Cyclotomic CRT+NTT full round-trip (`from_ring` -> `to_ring`) + - [x] Scalar backend path equivalence (`*_with_backend` vs default path) + - [x] Pow2Offset profile invariants (`q = 2^k - offset`, `q % 8 == 5`) + - [x] `FieldSampling::sample()` output bound checks + - [x] Checked deserialization rejects non-canonical field encodings + - [x] Galois automorphism checks (`sigma` composition + multiplicativity) + - [x] Functional gadget decompose/recompose round-trip checks + - [x] Sparse `+/-1` challenge support checks (`hamming_weight = omega`) +- [x] Dedicated Pow2Offset primality regression tests (`tests/primality.rs`) + - [x] Miller-Rabin probable-prime checks for all registered Pow2Offset moduli + - [x] Composite sanity rejection checks + +#### Phase 1 — Ring + gadgets (functional core) + +- [x] Cyclotomic ring `Rq` with `X^D = -1` (`src/algebra/ring/cyclotomic.rs`) +- [x] CRT+NTT-domain ring representation + CRT conversion (`src/algebra/ring/crt_ntt_repr.rs`) +- [x] Backend/domain layering for ring execution (`src/algebra/backend/*`, `src/algebra/domains/*`) +- [x] Galois automorphisms `sigma_i: X ↦ X^i` (odd `i`) +- [x] Functional gadget decomposition/recomposition (`G^{-1}` / `G` behavior) for base-`2^d` digits, without materializing dense gadget matrices +- [x] sparse short challenges (paper: `||c||_1 ≤ ω`, sparse ±1) + +#### Phase 2+ — Protocol (later) + +- [x] Protocol module scaffold (`src/protocol/*`) and top-level re-exports +- [x] Transcript interface (`Transcript`) plus Blake2b/Keccak implementations +- [x] Hachi-native transcript label schedule aligned to paper phases (§4.1/§4.2/§4.3/§4.5) +- [x] Commitment trait surface + streaming trait surface + contract tests +- [x] Label-directed transcript absorption for commitments (`AppendToTranscript` takes label at call site) +- [x] ring-native commitment core (`RingCommitmentScheme`, `commit.rs`, config wiring) for §4.1 setup/commit +- [x] protocol prover/verifier folder split with explicit stubs (`prover/stub.rs`, `verifier/stub.rs`) +- [x] ring-commitment tests (`ring_commitment_core`, `ring_commitment_config`, `prover_verifier_stub_contract`) +- [x] sumcheck core building blocks (univariate messages + transcript-driving prover/verifier driver) (`src/protocol/sumcheck.rs`) +- [x] sumcheck core tests (`tests/sumcheck_core.rs`, `tests/sumcheck_prover_driver.rs`) +- [ ] commitment open-check prove/verify implementation (currently stubs) +- [ ] evaluation → linear relation (paper §4.2) +- [ ] ring-switching constraints as sumcheck instances (paper §4.3, Fig. 4–7) +- [ ] recursion / “stop condition” + optional Greyhound composition (§4.5) + +#### Phase 3 — Integration into Jolt (deferred; not active now) + +- [ ] Define compatibility boundary document (what must match Jolt/Dory behavior vs what can remain Hachi-native) +- [ ] Provide Jolt-facing transcript adapter design (`Jolt` transcript pattern ↔ Hachi transcript object) +- [ ] Provide Jolt-facing PCS shim design (`CommitmentScheme`/`StreamingCommitmentScheme` mapping) +- [ ] Add transcript/commitment compatibility tests for integration-readiness (without wiring into Jolt yet) + +### Conventions + +- **Correctness first**: lock arithmetic with tests before touching protocol code. +- **Security first**: enforce constant-time behavior for secret-dependent operations. +- **Lean deps**: avoid heavyweight crypto crates until there is a clear need. +- **Explicit parameter sets**: each field/ring preset lives in code with a clear name and rationale. + +### Module layout + +``` +src/algebra/ +├── backend/ Backend execution traits + scalar backend +├── domains/ Domain-level aliases (coefficient / CRT+NTT) +├── fields/ Prime fields, pseudo-mersenne registry, u256, and extensions +├── ntt/ NTT kernels (butterfly), prime kernels (prime), CRT helpers (crt), presets (tables) +├── module.rs VectorModule +├── poly.rs Poly container +└── ring/ Cyclotomic ring and CRT+NTT representation +``` + +### References + +- Hachi paper: `paper/hachi.pdf` +- Core traits: `src/primitives/arithmetic.rs`, `src/primitives/serialization.rs` + diff --git a/NTT_PRIME_ANALYSIS.md b/NTT_PRIME_ANALYSIS.md new file mode 100644 index 00000000..e6b9c85c --- /dev/null +++ b/NTT_PRIME_ANALYSIS.md @@ -0,0 +1,146 @@ +# NTT Prime Analysis (Pow2Offset / Solinas Context) + +This note records the current analysis for small NTT primes and CRT coverage targets. + +## References + +- NIST ML-KEM: `paper/standards/NIST.FIPS.203.pdf` +- NIST ML-DSA: `paper/standards/NIST.FIPS.204.pdf` +- Current small-prime table: `src/algebra/ntt/tables.rs` +- Labrador generator heuristic: `../labrador/data.py` + +## Why does `2D` divide `p - 1`? + +For negacyclic NTT on `Z_p[X]/(X^D + 1)`, we need a primitive `2D`-th root `psi` such that: + +- `psi^D = -1 (mod p)` +- `psi^(2D) = 1 (mod p)` + +Over prime fields, `F_p^*` is cyclic of size `p - 1`, so an element of order `2D` exists iff: + +- `2D | (p - 1)` + +So yes, the `128 | (p - 1)` condition is directly tied to `D = 64`. + +## What if `D = 1024`? + +Then requirement becomes: + +- `2D = 2048`, so `2048 | (p - 1)`. + +Under the current "small prime" cap (`p < 2^14`), this is extremely restrictive. + +## Why `< 2^14` in current code? + +This is a backend implementation constraint, not a hard NTT math requirement: + +- Current small-prime NTT backend stores modulus/coefficients in signed 16-bit lanes (`i16`). +- It relies on centered signed arithmetic and butterfly add/sub before full normalization. +- Keeping `p < 2^14` leaves practical headroom in those 16-bit operations. +- Current CRT limb code is also radix-`2^14`, matching this design style. + +So the `2^14` cap is about the present `i16` scalar kernel design. If we introduce an `i32` backend, this cap can be raised substantially. + +## Exhaustive counts (for `p < 2^14`) + +We classify exact Solinas as: + +- `p = 2^x - 2^y + 1`. + +Results: + +- `D = 64` (`128 | p-1`) + - all small NTT primes: **31** + - exact Solinas NTT primes: **6** + - all-prime set: + - `257, 641, 769, 1153, 1409, 2689, 3329, 3457, 4481, 4993, 6529, 7297, 7681, 7937, 9473, 9601, 9857, 10369, 10753, 11393, 11777, 12161, 12289, 13313, 13441, 13697, 14081, 14593, 15233, 15361, 16001` + - Solinas set: `257, 769, 7681, 7937, 12289, 15361` +- `D = 256` (`512 | p-1`) + - all small NTT primes: **6** + - exact Solinas NTT primes: **3** + - Solinas set: `7681, 12289, 15361` +- `D = 1024` (`2048 | p-1`) + - all small NTT primes: **1** + - exact Solinas NTT primes: **1** + - Solinas set: `12289` + +Conclusion: for higher `D`, the small-prime pool shrinks rapidly. + +## 30-bit exploration (`p < 2^30`) with NTT constraints + +To assess a larger-prime backend direction, we scanned for primes under `2^30` with: + +- `p ≡ 1 (mod 2D)`. + +Below are the **full outputs of the bounded search run** (top 30 largest primes found by descending scan): + +### `D = 64` (`2D = 128`) + +- Top-30 list: + - `1073741441, 1073739649, 1073738753, 1073736449, 1073735297, 1073734913, 1073732993, 1073732609, 1073731201, 1073731073, 1073730817, 1073728897, 1073727617, 1073726977, 1073722753, 1073719681, 1073717377, 1073716993, 1073713409, 1073712769, 1073712257, 1073710721, 1073708929, 1073707009, 1073703809, 1073702657, 1073702401, 1073698817, 1073696257, 1073693441` +- Coverage for `q = 2^128 - 275`: + - `P > q`: **5** limbs + - `P > 128*q^2`: **9** limbs + +### `D = 1024` (`2D = 2048`) + +- Top-30 list: + - `1073707009, 1073698817, 1073692673, 1073682433, 1073668097, 1073655809, 1073651713, 1073643521, 1073620993, 1073600513, 1073569793, 1073563649, 1073551361, 1073539073, 1073522689, 1073510401, 1073508353, 1073479681, 1073453057, 1073442817, 1073440769, 1073430529, 1073412097, 1073391617, 1073385473, 1073354753, 1073350657, 1073330177, 1073299457, 1073268737` +- Coverage for `q = 2^128 - 275`: + - `P > q`: **5** limbs + - `P > 128*q^2`: **9** limbs + +### Bit-estimate sanity check + +- `ceil(128 / 30) = 5` +- `ceil(263 / 30) = 9` (for `128*q^2 ~ 2^263`) + +This matches the concrete product counts above. + +## CRT size targets for `q = 2^128 - 275` + +Two common thresholds: + +1. Minimal uniqueness target: + - `P = prod(p_i) > q` +2. Labrador conservative heuristic: + - `P > 128 * q^2` (from `data.py`, with `FIXME` comment) + +### Limb counts at `D = 64` with current small-prime pool + +- Using all small NTT primes (`31` available): + - `P > q` achievable with **10** limbs + - `P > 128*q^2` achievable with **20** limbs +- Using only exact Solinas NTT primes (`6` available): + - `P > q`: **not achievable** + - `P > 128*q^2`: **not achievable** + - total product is only about `2^70` + +### Limb counts at `D = 1024` with current small-prime pool + +- Only one qualifying prime (`12289`) under `p < 2^14`, so neither threshold is achievable. + +## What is Labrador's safety margin doing? + +In Labrador code, prime selection stops at: + +- `P > 128 * q^2` + +Interpretation: + +- `q^2` tracks product-scale growth, +- extra factor `128` gives additional headroom (for `N=64`, this is `2N`), +- but their own `FIXME` comment indicates this is a conservative engineering bound, not a tight proof. + +So treat this as a robust heuristic rather than a formal minimum. + +## Practical implication for Hachi + +- If we stay with `D=64` and small i16-ish primes, we need non-Solinas primes in the CRT set. +- If we push to `D=1024`, we must either: + - lift prime size beyond `<2^14`, or + - change CRT strategy (fewer larger limbs / different backend), or + - avoid strict small-prime CRT-NTT at that degree. +- A mixed backend model is sensible: + - keep the current `i16` backend for small-prime kernels, + - add an `i32`/wider backend for larger-prime kernels (e.g., up to ~30-bit). diff --git a/README.md b/README.md index 492b4d12..8036344f 100644 --- a/README.md +++ b/README.md @@ -3,3 +3,7 @@ A high performance and modular implementation of the Hachi polynomial commitment scheme. Hachi is a lattice-based polynomial commitment scheme with transparent setup and post-quantum security. + +## Acknowledgements + +The CRT/NTT and small-prime arithmetic design in this repository is informed by the Labrador/Greyhound C implementation family. In particular, the current pseudo-Mersenne profile uses moduli of the form `q = 2^k - offset` (smallest prime below `2^k` with `q % 8 == 5`). Hachi provides a Rust-native architecture and APIs, while drawing algorithmic inspiration from those implementations. diff --git a/benches/field_arith.rs b/benches/field_arith.rs new file mode 100644 index 00000000..c3b04578 --- /dev/null +++ b/benches/field_arith.rs @@ -0,0 +1,1448 @@ +#![allow(missing_docs)] + +use ark_bn254::Fr as BN254Fr; +use ark_ff::{AdditiveGroup, Field}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use hachi_pcs::algebra::fields::fp128::{Prime128M18M0, Prime128M54P0}; +use hachi_pcs::algebra::fields::fp32::Fp32; +use hachi_pcs::algebra::{HasPacking, PackedField, PackedValue, Prime128M13M4P0, Prime128M8M4M1M0}; +use hachi_pcs::algebra::{ + Pow2Offset24Field, Pow2Offset30Field, Pow2Offset31Field, Pow2Offset32Field, Pow2Offset40Field, + Pow2Offset48Field, Pow2Offset56Field, Pow2Offset64Field, +}; +use hachi_pcs::{CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible}; +use rand::{rngs::StdRng, RngCore, SeedableRng}; +use std::env; +#[cfg(feature = "parallel")] +use std::thread; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; +#[cfg(feature = "parallel")] +use rayon::ThreadPoolBuilder; + +fn rand_u128(rng: &mut R) -> u128 { + let lo = rng.next_u64() as u128; + let hi = rng.next_u64() as u128; + lo | (hi << 64) +} + +fn env_usize(name: &str, default: usize) -> usize { + env::var(name) + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(default) +} + +fn bench_mul(c: &mut Criterion) { + type F13 = Prime128M13M4P0; + type F275 = Prime128M8M4M1M0; + type F2p18p1 = Prime128M18M0; + type F2p54m1 = Prime128M54P0; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let inputs_u128: Vec = (0..2048).map(|_| rand_u128(&mut rng)).collect(); + + let inputs_f13: Vec = inputs_u128 + .iter() + .copied() + .map(F13::from_canonical_u128_reduced) + .collect(); + + let inputs_f275: Vec = inputs_u128 + .iter() + .copied() + .map(F275::from_canonical_u128_reduced) + .collect(); + let inputs_f2p18p1: Vec = inputs_u128 + .iter() + .copied() + .map(F2p18p1::from_canonical_u128_reduced) + .collect(); + let inputs_f2p54m1: Vec = inputs_u128 + .iter() + .copied() + .map(F2p54m1::from_canonical_u128_reduced) + .collect(); + + let mut group = c.benchmark_group("field_mul"); + + group.bench_function("fp128_prime128m13m4p0", |b| { + b.iter(|| { + let mut acc = F13::one(); + for x in inputs_f13.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + + group.bench_function("fp128_prime128m8m4m1m0", |b| { + b.iter(|| { + let mut acc = F275::one(); + for x in inputs_f275.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + + group.bench_function("fp128_prime128m18m0_shift_special", |b| { + b.iter(|| { + let mut acc = F2p18p1::one(); + for x in inputs_f2p18p1.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + + group.bench_function("fp128_prime128m54p0_shift_special", |b| { + b.iter(|| { + let mut acc = F2p54m1::one(); + for x in inputs_f2p54m1.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + + group.finish(); +} + +fn bench_mul_only(c: &mut Criterion) { + type F13 = Prime128M13M4P0; + type F2p18p1 = Prime128M18M0; + type F2p54m1 = Prime128M54P0; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let inputs_f13: Vec = (0..2048) + .map(|_| F13::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let inputs_f2p18p1: Vec = (0..2048) + .map(|_| F2p18p1::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let inputs_f2p54m1: Vec = (0..2048) + .map(|_| F2p54m1::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let mut group = c.benchmark_group("field_mul_only"); + + group.bench_function("mul_chain_2048", |b| { + b.iter(|| { + let mut acc = F13::one(); + for x in inputs_f13.iter() { + acc = acc * *x; + } + black_box(acc) + }) + }); + + group.bench_function("mul_chain_16384", |b| { + b.iter(|| { + let mut acc = F13::one(); + for _ in 0..8 { + for x in inputs_f13.iter() { + acc = acc * *x; + } + } + black_box(acc) + }) + }); + + group.bench_function("mul_parallel_1024", |b| { + b.iter(|| { + let mut sum = F13::zero(); + for pair in inputs_f13.chunks_exact(2) { + sum = sum + pair[0] * pair[1]; + } + black_box(sum) + }) + }); + + group.bench_function("mul_chain_2048_special_m18m0", |b| { + b.iter(|| { + let mut acc = F2p18p1::one(); + for x in inputs_f2p18p1.iter() { + acc = acc * *x; + } + black_box(acc) + }) + }); + + group.bench_function("mul_chain_2048_special_m54p0", |b| { + b.iter(|| { + let mut acc = F2p54m1::one(); + for x in inputs_f2p54m1.iter() { + acc = acc * *x; + } + black_box(acc) + }) + }); + + group.finish(); +} + +fn bench_mul_isolated(c: &mut Criterion) { + use ark_ff::UniformRand; + + type F13 = Prime128M13M4P0; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let a_fp128 = F13::from_canonical_u128_reduced(rand_u128(&mut rng)); + let b_fp128 = F13::from_canonical_u128_reduced(rand_u128(&mut rng)); + let a_bn254 = BN254Fr::rand(&mut rng); + let b_bn254 = BN254Fr::rand(&mut rng); + + let mut group = c.benchmark_group("field_mul_isolated"); + + group.bench_function("fp128_black_box_only", |b| b.iter(|| black_box(a_fp128))); + + group.bench_function("bn254_black_box_only", |b| b.iter(|| black_box(a_bn254))); + + group.bench_function("fp128_pair_passthrough", |b| { + b.iter(|| { + let x = black_box(a_fp128); + let y = black_box(b_fp128); + black_box((x, y)) + }) + }); + + group.bench_function("bn254_pair_passthrough", |b| { + b.iter(|| { + let x = black_box(a_bn254); + let y = black_box(b_bn254); + black_box((x, y)) + }) + }); + + group.bench_function("fp128_mul_single", |b| { + b.iter(|| { + let x = black_box(a_fp128); + let y = black_box(b_fp128); + black_box(x * y) + }) + }); + + group.bench_function("bn254_mul_single", |b| { + b.iter(|| { + let x = black_box(a_bn254); + let y = black_box(b_bn254); + black_box(x * y) + }) + }); + + let lanes_fp128: [(F13, F13); 8] = std::array::from_fn(|_| { + ( + F13::from_canonical_u128_reduced(rand_u128(&mut rng)), + F13::from_canonical_u128_reduced(rand_u128(&mut rng)), + ) + }); + let lanes_bn254: [(BN254Fr, BN254Fr); 8] = + std::array::from_fn(|_| (BN254Fr::rand(&mut rng), BN254Fr::rand(&mut rng))); + + group.bench_function("fp128_mul_8way_independent", |b| { + b.iter(|| { + let lanes = black_box(&lanes_fp128); + let p0 = lanes[0].0 * lanes[0].1; + let p1 = lanes[1].0 * lanes[1].1; + let p2 = lanes[2].0 * lanes[2].1; + let p3 = lanes[3].0 * lanes[3].1; + let p4 = lanes[4].0 * lanes[4].1; + let p5 = lanes[5].0 * lanes[5].1; + let p6 = lanes[6].0 * lanes[6].1; + let p7 = lanes[7].0 * lanes[7].1; + black_box([p0, p1, p2, p3, p4, p5, p6, p7]) + }) + }); + + group.bench_function("fp128_8way_passthrough", |b| { + b.iter(|| { + let lanes = black_box(&lanes_fp128); + let p0 = lanes[0].0; + let p1 = lanes[1].0; + let p2 = lanes[2].0; + let p3 = lanes[3].0; + let p4 = lanes[4].0; + let p5 = lanes[5].0; + let p6 = lanes[6].0; + let p7 = lanes[7].0; + black_box([p0, p1, p2, p3, p4, p5, p6, p7]) + }) + }); + + group.bench_function("bn254_mul_8way_independent", |b| { + b.iter(|| { + let lanes = black_box(&lanes_bn254); + let p0 = lanes[0].0 * lanes[0].1; + let p1 = lanes[1].0 * lanes[1].1; + let p2 = lanes[2].0 * lanes[2].1; + let p3 = lanes[3].0 * lanes[3].1; + let p4 = lanes[4].0 * lanes[4].1; + let p5 = lanes[5].0 * lanes[5].1; + let p6 = lanes[6].0 * lanes[6].1; + let p7 = lanes[7].0 * lanes[7].1; + black_box([p0, p1, p2, p3, p4, p5, p6, p7]) + }) + }); + + group.bench_function("bn254_8way_passthrough", |b| { + b.iter(|| { + let lanes = black_box(&lanes_bn254); + let p0 = lanes[0].0; + let p1 = lanes[1].0; + let p2 = lanes[2].0; + let p3 = lanes[3].0; + let p4 = lanes[4].0; + let p5 = lanes[5].0; + let p6 = lanes[6].0; + let p7 = lanes[7].0; + black_box([p0, p1, p2, p3, p4, p5, p6, p7]) + }) + }); + + group.finish(); +} + +fn bench_sqr(c: &mut Criterion) { + type F13 = Prime128M13M4P0; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let start = F13::from_canonical_u128_reduced(rand_u128(&mut rng)); + + let mut group = c.benchmark_group("field_sqr"); + + group.bench_function("sqr_chain_2048", |b| { + b.iter(|| { + let mut acc = start; + for _ in 0..2048 { + acc = acc.square(); + } + black_box(acc) + }) + }); + + group.bench_function("mul_self_chain_2048", |b| { + b.iter(|| { + let mut acc = start; + for _ in 0..2048 { + acc = acc * acc; + } + black_box(acc) + }) + }); + + group.finish(); +} + +fn bench_inv(c: &mut Criterion) { + type F13 = Prime128M13M4P0; + + let mut rng = StdRng::seed_from_u64(0x1a2b_3c4d_5e6f_7788); + let inputs: Vec = (0..256) + .map(|_| F13::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + c.bench_function("fp128_inv_or_zero_prime128m13m4p0", |b| { + b.iter(|| { + let mut acc = F13::one(); + for x in inputs.iter() { + acc = acc * x.inv_or_zero(); + } + black_box(acc) + }) + }); +} + +fn bench_bn254(c: &mut Criterion) { + use ark_ff::UniformRand; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let inputs: Vec = (0..2048).map(|_| BN254Fr::rand(&mut rng)).collect(); + + let mut group = c.benchmark_group("bn254_fr"); + + group.bench_function("mul_add_chain_2048", |b| { + b.iter(|| { + let mut acc = BN254Fr::ONE; + for x in inputs.iter() { + acc = acc * x + acc; + } + black_box(acc) + }) + }); + + group.bench_function("mul_chain_2048", |b| { + b.iter(|| { + let mut acc = BN254Fr::ONE; + for x in inputs.iter() { + acc *= x; + } + black_box(acc) + }) + }); + + group.bench_function("mul_chain_16384", |b| { + b.iter(|| { + let mut acc = BN254Fr::ONE; + for _ in 0..8 { + for x in inputs.iter() { + acc *= x; + } + } + black_box(acc) + }) + }); + + group.bench_function("mul_parallel_1024", |b| { + b.iter(|| { + let mut sum = BN254Fr::ZERO; + for pair in inputs.chunks_exact(2) { + sum += pair[0] * pair[1]; + } + black_box(sum) + }) + }); + + group.bench_function("sqr_chain_2048", |b| { + b.iter(|| { + let mut acc = inputs[0]; + for _ in 0..2048 { + acc.square_in_place(); + } + black_box(acc) + }) + }); + + group.bench_function("inv_256", |b| { + b.iter(|| { + let mut acc = BN254Fr::ONE; + for x in inputs[..256].iter() { + acc *= x.inverse().unwrap_or(BN254Fr::ZERO); + } + black_box(acc) + }) + }); + + group.finish(); +} + +fn bench_packed_fp128_backend(c: &mut Criterion) { + type F = Prime128M13M4P0; + type PF = ::Packing; + let packed_streams = env_usize("HACHI_BENCH_PACKED_STREAMS", 8); + let latency_iters = env_usize("HACHI_BENCH_LATENCY_ITERS", 4096); + let throughput_iters = env_usize("HACHI_BENCH_THROUGHPUT_ITERS", 256); + let stream_iters = env_usize("HACHI_BENCH_STREAM_ITERS", 2048); + let mix_iters = env_usize("HACHI_BENCH_MIX_ITERS", 256); + let mix_muls = env_usize("HACHI_BENCH_MIX_MULS", 3); + let mix_adds = env_usize("HACHI_BENCH_MIX_ADDS", 1); + let mix_subs = env_usize("HACHI_BENCH_MIX_SUBS", 1); + + assert!(packed_streams > 0, "HACHI_BENCH_PACKED_STREAMS must be > 0"); + assert!(latency_iters > 0, "HACHI_BENCH_LATENCY_ITERS must be > 0"); + assert!( + throughput_iters > 0, + "HACHI_BENCH_THROUGHPUT_ITERS must be > 0" + ); + assert!(stream_iters > 0, "HACHI_BENCH_STREAM_ITERS must be > 0"); + assert!(mix_iters > 0, "HACHI_BENCH_MIX_ITERS must be > 0"); + + let muls_per_stream = throughput_iters + 1; + let mix_ops = mix_muls + mix_adds + mix_subs; + assert!(mix_ops > 0, "at least one mix operation must be enabled"); + + let backend = if cfg!(all(target_arch = "aarch64", target_feature = "neon")) { + "aarch64_neon" + } else { + "scalar_fallback" + }; + let mut group = c.benchmark_group(format!("field_packed_backend/{backend}/w{}", PF::WIDTH)); + + let mut rng = StdRng::seed_from_u64(0xd00d_f00d_1122_3344); + let scalar_stream_len = PF::WIDTH * stream_iters; + let lhs: Vec = (0..scalar_stream_len) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let rhs: Vec = (0..scalar_stream_len) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let packed_lhs: Vec = PF::pack_slice(&lhs); + let packed_rhs: Vec = PF::pack_slice(&rhs); + let scalar_latency_inputs: Vec = (0..latency_iters) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let packed_latency_inputs: Vec = (0..latency_iters) + .map(|_| PF::from_fn(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng)))) + .collect(); + + let scalar_streams = packed_streams * PF::WIDTH; + let scalar_lanes: Vec<(F, F)> = (0..scalar_streams) + .map(|_| { + ( + F::from_canonical_u128_reduced(rand_u128(&mut rng)), + F::from_canonical_u128_reduced(rand_u128(&mut rng)), + ) + }) + .collect(); + let packed_lanes: Vec<(PF, PF)> = (0..packed_streams) + .map(|_| { + ( + PF::from_fn(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))), + PF::from_fn(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))), + ) + }) + .collect(); + + group.throughput(Throughput::Elements(scalar_stream_len as u64)); + group.bench_function("scalar_add_stream", |b| { + let mut out = lhs.clone(); + b.iter(|| { + for (dst, src) in out.iter_mut().zip(rhs.iter()) { + *dst = *dst + *src; + } + black_box(out[0]) + }) + }); + + group.throughput(Throughput::Elements(scalar_stream_len as u64)); + group.bench_function("packed_add_stream", |b| { + let mut out = packed_lhs.clone(); + b.iter(|| { + for (dst, src) in out.iter_mut().zip(packed_rhs.iter()) { + *dst = *dst + *src; + } + black_box(out[0].extract(0)) + }) + }); + + group.throughput(Throughput::Elements(latency_iters as u64)); + group.bench_function("scalar_mul_latency_chain", |b| { + b.iter(|| { + let mut acc = F::one(); + for x in scalar_latency_inputs.iter() { + acc = acc * *x; + } + black_box(acc) + }) + }); + + group.throughput(Throughput::Elements((latency_iters * PF::WIDTH) as u64)); + group.bench_function("packed_mul_latency_chain", |b| { + b.iter(|| { + let mut acc = PF::broadcast(F::one()); + for x in packed_latency_inputs.iter() { + acc = acc * *x; + } + black_box(acc.extract(0)) + }) + }); + + group.throughput(Throughput::Elements( + (scalar_streams * muls_per_stream) as u64, + )); + group.bench_function("scalar_mul_throughput_8way", |b| { + b.iter(|| { + let lanes = black_box(&scalar_lanes); + let mut acc: Vec = lanes.iter().map(|(a, b)| *a * *b).collect(); + for _ in 0..throughput_iters { + for (acc_i, lane) in acc.iter_mut().zip(lanes.iter()) { + *acc_i = *acc_i * lane.0; + } + } + black_box(acc[0]) + }) + }); + + group.throughput(Throughput::Elements( + (packed_streams * muls_per_stream * PF::WIDTH) as u64, + )); + group.bench_function("packed_mul_throughput_8way", |b| { + b.iter(|| { + let lanes = black_box(&packed_lanes); + let mut acc: Vec = lanes.iter().map(|(a, b)| *a * *b).collect(); + for _ in 0..throughput_iters { + for (acc_i, lane) in acc.iter_mut().zip(lanes.iter()) { + *acc_i = *acc_i * lane.0; + } + } + black_box(acc[0].extract(0)) + }) + }); + + group.throughput(Throughput::Elements( + (scalar_streams * mix_iters * mix_ops) as u64, + )); + group.bench_function("scalar_mix_sumcheck_like", |b| { + b.iter(|| { + let lanes = black_box(&scalar_lanes); + let mut acc: Vec = lanes.iter().map(|(a, b)| *a + *b).collect(); + for _ in 0..mix_iters { + for (acc_i, lane) in acc.iter_mut().zip(lanes.iter()) { + let (x, y) = *lane; + for _ in 0..mix_muls { + *acc_i = *acc_i * x; + } + for _ in 0..mix_adds { + *acc_i = *acc_i + y; + } + for _ in 0..mix_subs { + *acc_i = *acc_i - x; + } + } + } + black_box(acc[0]) + }) + }); + + group.throughput(Throughput::Elements( + (packed_streams * PF::WIDTH * mix_iters * mix_ops) as u64, + )); + group.bench_function("packed_mix_sumcheck_like", |b| { + b.iter(|| { + let lanes = black_box(&packed_lanes); + let mut acc: Vec = lanes.iter().map(|(a, b)| *a + *b).collect(); + for _ in 0..mix_iters { + for (acc_i, lane) in acc.iter_mut().zip(lanes.iter()) { + let (x, y) = *lane; + for _ in 0..mix_muls { + *acc_i = *acc_i * x; + } + for _ in 0..mix_adds { + *acc_i = *acc_i + y; + } + for _ in 0..mix_subs { + *acc_i = *acc_i - x; + } + } + } + black_box(acc[0].extract(0)) + }) + }); + + group.finish(); +} + +fn bench_fp32_fp64_mul(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(0x3264_3264); + let n = 2048; + + let inputs_24: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_30: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_31: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_32: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_40: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_64: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let mut group = c.benchmark_group("fp32_fp64_mul"); + + macro_rules! chain_bench { + ($name:expr, $ty:ty, $inputs:expr) => { + group.bench_function(concat!($name, "_mul_chain_2048"), |b| { + b.iter(|| { + let mut acc = <$ty>::one(); + for x in $inputs.iter() { + acc = acc * *x; + } + black_box(acc) + }) + }); + group.bench_function(concat!($name, "_mul_add_chain_2048"), |b| { + b.iter(|| { + let mut acc = <$ty>::one(); + for x in $inputs.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + }; + } + + chain_bench!("fp32_2pow24m3", Pow2Offset24Field, inputs_24); + chain_bench!("fp32_2pow30m35", Pow2Offset30Field, inputs_30); + chain_bench!("fp32_2pow31m19", Pow2Offset31Field, inputs_31); + chain_bench!("fp32_2pow32m99", Pow2Offset32Field, inputs_32); + chain_bench!("fp64_2pow40m195", Pow2Offset40Field, inputs_40); + chain_bench!("fp64_2pow64m59", Pow2Offset64Field, inputs_64); + + group.finish(); +} + +fn bench_widening_ops(c: &mut Criterion) { + type F = Prime128M8M4M1M0; + + let mut rng = StdRng::seed_from_u64(0x01de_be0c_0001); + let a = F::from_canonical_u128_reduced(rand_u128(&mut rng)); + let b = F::from_canonical_u128_reduced(rand_u128(&mut rng)); + let b_u64 = rng.next_u64(); + + let mut group = c.benchmark_group("widening_ops"); + + group.bench_function("mul_wide_u64_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_u64(black_box(b_u64)))) + }); + + group.bench_function("mul_wide_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide(black_box(b)))) + }); + + let limbs3 = [rng.next_u64(), rng.next_u64(), rng.next_u64()]; + let limbs4 = [ + rng.next_u64(), + rng.next_u64(), + rng.next_u64(), + rng.next_u64(), + ]; + + group.bench_function("mul_wide_limbs_3_to_5_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_limbs::<3, 5>(black_box(limbs3)))) + }); + group.bench_function("mul_wide_limbs_3_to_4_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_limbs::<3, 4>(black_box(limbs3)))) + }); + group.bench_function("mul_wide_limbs_4_to_5_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_limbs::<4, 5>(black_box(limbs4)))) + }); + group.bench_function("mul_wide_limbs_4_to_4_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_limbs::<4, 4>(black_box(limbs4)))) + }); + + group.bench_function("full_mul_u64_reduce", |bench| { + bench.iter(|| black_box(black_box(a) * F::from_u64(black_box(b_u64)))) + }); + + group.bench_function("full_mul_reduce", |bench| { + bench.iter(|| black_box(black_box(a) * black_box(b))) + }); + + let wide3 = a.mul_wide_u64(b_u64); + let wide4 = a.mul_wide(b); + let wide5 = { + let mut l = [0u64; 5]; + l[..3].copy_from_slice(&wide3); + l[4] = rng.next_u64() & 0xFF; + l + }; + + group.bench_function("solinas_reduce_3_limbs", |bench| { + bench.iter(|| black_box(F::solinas_reduce(black_box(&wide3)))) + }); + + group.bench_function("solinas_reduce_4_limbs", |bench| { + bench.iter(|| black_box(F::solinas_reduce(black_box(&wide4)))) + }); + + group.bench_function("solinas_reduce_5_limbs", |bench| { + bench.iter(|| black_box(F::solinas_reduce(black_box(&wide5)))) + }); + + group.bench_function("mul_wide_u64_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let y = black_box(b_u64); + black_box(F::solinas_reduce(&x.mul_wide_u64(y))) + }) + }); + + group.bench_function("mul_wide_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let y = black_box(b); + black_box(F::solinas_reduce(&x.mul_wide(y))) + }) + }); + + group.bench_function("mul_wide_limbs_3_to_5_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let m = black_box(limbs3); + black_box(F::solinas_reduce(&x.mul_wide_limbs::<3, 5>(m))) + }) + }); + group.bench_function("mul_wide_limbs_3_to_4_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let m = black_box(limbs3); + black_box(F::solinas_reduce(&x.mul_wide_limbs::<3, 4>(m))) + }) + }); + group.bench_function("mul_wide_limbs_4_to_5_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let m = black_box(limbs4); + black_box(F::solinas_reduce(&x.mul_wide_limbs::<4, 5>(m))) + }) + }); + group.bench_function("mul_wide_limbs_4_to_4_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let m = black_box(limbs4); + black_box(F::solinas_reduce(&x.mul_wide_limbs::<4, 4>(m))) + }) + }); + + group.finish(); +} + +fn bench_accumulator_pattern(c: &mut Criterion) { + type F = Prime128M8M4M1M0; + + let mut rng = StdRng::seed_from_u64(0xacc0_1a70_0002); + let inputs_a: Vec = (0..256) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let inputs_b_u64: Vec = (0..256).map(|_| rng.next_u64()).collect(); + let inputs_b_f: Vec = (0..256) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let mut group = c.benchmark_group("accumulator_pattern"); + + for &n in &[16, 64, 256] { + group.bench_function(format!("eager_mul_u64_{n}"), |bench| { + bench.iter(|| { + let a_s = black_box(&inputs_a[..n]); + let b_s = black_box(&inputs_b_u64[..n]); + let mut acc = F::zero(); + for i in 0..n { + acc = acc + a_s[i] * F::from_u64(b_s[i]); + } + black_box(acc) + }) + }); + + group.bench_function(format!("widening_accum_u64_{n}"), |bench| { + bench.iter(|| { + let a_s = black_box(&inputs_a[..n]); + let b_s = black_box(&inputs_b_u64[..n]); + let mut acc = [0u64; 5]; + for i in 0..n { + let wide = a_s[i].mul_wide_u64(b_s[i]); + let mut carry: u64 = 0; + for j in 0..3 { + let sum = acc[j] as u128 + wide[j] as u128 + carry as u128; + acc[j] = sum as u64; + carry = (sum >> 64) as u64; + } + for item in &mut acc[3..5] { + let sum = *item as u128 + carry as u128; + *item = sum as u64; + carry = (sum >> 64) as u64; + } + } + black_box(F::solinas_reduce(&acc)) + }) + }); + + group.bench_function(format!("eager_mul_full_{n}"), |bench| { + bench.iter(|| { + let a_s = black_box(&inputs_a[..n]); + let b_s = black_box(&inputs_b_f[..n]); + let mut acc = F::zero(); + for i in 0..n { + acc = acc + a_s[i] * b_s[i]; + } + black_box(acc) + }) + }); + + group.bench_function(format!("widening_accum_full_{n}"), |bench| { + bench.iter(|| { + let a_s = black_box(&inputs_a[..n]); + let b_s = black_box(&inputs_b_f[..n]); + let mut acc = [0u64; 6]; + for i in 0..n { + let wide = a_s[i].mul_wide(b_s[i]); + let mut carry: u64 = 0; + for j in 0..4 { + let sum = acc[j] as u128 + wide[j] as u128 + carry as u128; + acc[j] = sum as u64; + carry = (sum >> 64) as u64; + } + for item in &mut acc[4..6] { + let sum = *item as u128 + carry as u128; + *item = sum as u64; + carry = (sum >> 64) as u64; + } + } + black_box(F::solinas_reduce(&acc)) + }) + }); + } + + group.finish(); +} + +fn bench_throughput(c: &mut Criterion) { + let n = 4096u64; + let mut rng = StdRng::seed_from_u64(0xdead_cafe); + + type M31 = Fp32<{ (1u32 << 31) - 1 }>; + + let a24: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b24: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a30: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b30: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let am31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let bm31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a32: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b32: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a40: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b40: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a48: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b48: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a56: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b56: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a64: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b64: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a128: Vec = (0..n) + .map(|_| Prime128M8M4M1M0::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let b128: Vec = (0..n) + .map(|_| Prime128M8M4M1M0::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let mut out24 = vec![Pow2Offset24Field::zero(); n as usize]; + let mut out30 = vec![Pow2Offset30Field::zero(); n as usize]; + let mut out31 = vec![Pow2Offset31Field::zero(); n as usize]; + let mut outm31 = vec![M31::zero(); n as usize]; + let mut out32 = vec![Pow2Offset32Field::zero(); n as usize]; + let mut out40 = vec![Pow2Offset40Field::zero(); n as usize]; + let mut out48 = vec![Pow2Offset48Field::zero(); n as usize]; + let mut out56 = vec![Pow2Offset56Field::zero(); n as usize]; + let mut out64 = vec![Pow2Offset64Field::zero(); n as usize]; + let mut out128 = vec![Prime128M8M4M1M0::zero(); n as usize]; + + let mut group = c.benchmark_group("throughput"); + group.throughput(Throughput::Elements(n)); + + macro_rules! bench_op { + ($name:expr, $a:expr, $b:expr, $out:expr, $op:tt) => { + group.bench_function($name, |bench| { + bench.iter(|| { + let a = black_box(&$a); + let b = black_box(&$b); + let out = &mut $out; + for i in 0..n as usize { + out[i] = a[i] $op b[i]; + } + }) + }); + }; + } + + bench_op!("fp32_24b_mul", a24, b24, out24, *); + bench_op!("fp32_24b_add", a24, b24, out24, +); + bench_op!("fp32_30b_mul", a30, b30, out30, *); + bench_op!("fp32_30b_add", a30, b30, out30, +); + bench_op!("fp32_31b_mul", a31, b31, out31, *); + bench_op!("fp32_31b_add", a31, b31, out31, +); + bench_op!("fp32_m31_mul", am31, bm31, outm31, *); + bench_op!("fp32_m31_add", am31, bm31, outm31, +); + bench_op!("fp32_32b_mul", a32, b32, out32, *); + bench_op!("fp32_32b_add", a32, b32, out32, +); + bench_op!("fp64_40b_mul", a40, b40, out40, *); + bench_op!("fp64_40b_add", a40, b40, out40, +); + bench_op!("fp64_48b_mul", a48, b48, out48, *); + bench_op!("fp64_48b_add", a48, b48, out48, +); + bench_op!("fp64_56b_mul", a56, b56, out56, *); + bench_op!("fp64_56b_add", a56, b56, out56, +); + bench_op!("fp64_64b_mul", a64, b64, out64, *); + bench_op!("fp64_64b_add", a64, b64, out64, +); + bench_op!("fp128_mul", a128, b128, out128, *); + bench_op!("fp128_add", a128, b128, out128, +); + + group.finish(); +} + +fn bench_packed_throughput(c: &mut Criterion) { + use hachi_pcs::algebra::{Fp128Packing, Fp32Packing, Fp64Packing}; + + let n = 4096u64; + let mut rng = StdRng::seed_from_u64(0xbeef_cafe); + + macro_rules! packed_bench { + ($group:expr, $label:expr, $field:ty, $packing:ty, $rng:expr, $n:expr) => {{ + let lhs: Vec<$field> = (0..$n).map(|_| FieldSampling::sample($rng)).collect(); + let rhs: Vec<$field> = (0..$n).map(|_| FieldSampling::sample($rng)).collect(); + let lhs_p = <$packing>::pack_slice(&lhs); + let rhs_p = <$packing>::pack_slice(&rhs); + let mut out_p = vec![<$packing>::broadcast(<$field>::zero()); lhs_p.len()]; + + $group.bench_function(concat!($label, "_packed_mul"), |b| { + b.iter(|| { + let a = black_box(&lhs_p); + let b_v = black_box(&rhs_p); + let out = &mut out_p; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + }) + }); + $group.bench_function(concat!($label, "_packed_add"), |b| { + b.iter(|| { + let a = black_box(&lhs_p); + let b_v = black_box(&rhs_p); + let out = &mut out_p; + for i in 0..out.len() { + out[i] = a[i] + b_v[i]; + } + }) + }); + $group.bench_function(concat!($label, "_packed_sub"), |b| { + b.iter(|| { + let a = black_box(&lhs_p); + let b_v = black_box(&rhs_p); + let out = &mut out_p; + for i in 0..out.len() { + out[i] = a[i] - b_v[i]; + } + }) + }); + }}; + } + + let mut group = c.benchmark_group("packed_throughput"); + group.throughput(Throughput::Elements(n)); + + use hachi_pcs::algebra::fields::pseudo_mersenne::*; + type M31 = Fp32<{ (1u32 << 31) - 1 }>; + + type P24 = Fp32Packing<{ POW2_OFFSET_MODULUS_24 }>; + type P30 = Fp32Packing<{ POW2_OFFSET_MODULUS_30 }>; + type P31 = Fp32Packing<{ POW2_OFFSET_MODULUS_31 }>; + type PM31 = Fp32Packing<{ (1u32 << 31) - 1 }>; + type P32 = Fp32Packing<{ POW2_OFFSET_MODULUS_32 }>; + type P40 = Fp64Packing<{ POW2_OFFSET_MODULUS_40 }>; + type P48 = Fp64Packing<{ POW2_OFFSET_MODULUS_48 }>; + type P56 = Fp64Packing<{ POW2_OFFSET_MODULUS_56 }>; + type P64 = Fp64Packing<{ POW2_OFFSET_MODULUS_64 }>; + type P128 = Fp128Packing<{ POW2_OFFSET_MODULUS_128 }>; + + packed_bench!(group, "fp32_24b", Pow2Offset24Field, P24, &mut rng, n); + packed_bench!(group, "fp32_30b", Pow2Offset30Field, P30, &mut rng, n); + packed_bench!(group, "fp32_31b", Pow2Offset31Field, P31, &mut rng, n); + packed_bench!(group, "fp32_m31", M31, PM31, &mut rng, n); + packed_bench!(group, "fp32_32b", Pow2Offset32Field, P32, &mut rng, n); + packed_bench!(group, "fp64_40b", Pow2Offset40Field, P40, &mut rng, n); + packed_bench!(group, "fp64_48b", Pow2Offset48Field, P48, &mut rng, n); + packed_bench!(group, "fp64_56b", Pow2Offset56Field, P56, &mut rng, n); + packed_bench!(group, "fp64_64b", Pow2Offset64Field, P64, &mut rng, n); + packed_bench!(group, "fp128", Prime128M8M4M1M0, P128, &mut rng, n); + + group.finish(); +} + +#[cfg(feature = "parallel")] +fn bench_parallel_throughput(c: &mut Criterion) { + use hachi_pcs::algebra::{Fp32Packing, Fp64Packing}; + + let profile = env::var("HACHI_BENCH_PAR_PROFILE").unwrap_or_else(|_| "dev".to_string()); + let default_n = match profile.as_str() { + "scale" | "large" => 1 << 20, + "xlarge" => 1 << 22, + _ => 1 << 15, + }; + let n = env_usize("HACHI_BENCH_PAR_N", default_n); + let default_chunk = match profile.as_str() { + "scale" | "large" => 1 << 14, + "xlarge" => 1 << 15, + _ => 1 << 12, + }; + let chunk = env_usize("HACHI_BENCH_PAR_CHUNK", default_chunk); + let threads = env_usize( + "HACHI_BENCH_PAR_THREADS", + thread::available_parallelism() + .map(|v| v.get()) + .unwrap_or(1), + ); + + assert!(threads > 0, "HACHI_BENCH_PAR_THREADS must be > 0"); + assert!(n > 0, "HACHI_BENCH_PAR_N must be > 0"); + assert!(chunk > 0, "HACHI_BENCH_PAR_CHUNK must be > 0"); + assert!(n % 4 == 0, "HACHI_BENCH_PAR_N must be divisible by 4"); + + let pool = ThreadPoolBuilder::new() + .num_threads(threads) + .build() + .expect("failed to build rayon pool"); + + let mut rng = StdRng::seed_from_u64(0xfeed_face); + + let lhs31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let rhs31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let lhs64: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let rhs64: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let lhs128: Vec = (0..n) + .map(|_| Prime128M13M4P0::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let rhs128: Vec = (0..n) + .map(|_| Prime128M13M4P0::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + type P31 = Fp32Packing<{ hachi_pcs::algebra::fields::pseudo_mersenne::POW2_OFFSET_MODULUS_31 }>; + type P64 = Fp64Packing<{ hachi_pcs::algebra::fields::pseudo_mersenne::POW2_OFFSET_MODULUS_64 }>; + type F128 = Prime128M13M4P0; + type P128 = ::Packing; + let chunk31_p = (chunk / P31::WIDTH).max(1); + let chunk64_p = (chunk / P64::WIDTH).max(1); + let chunk128_p = (chunk / P128::WIDTH).max(1); + + let lhs31_p = P31::pack_slice(&lhs31); + let rhs31_p = P31::pack_slice(&rhs31); + let lhs64_p = P64::pack_slice(&lhs64); + let rhs64_p = P64::pack_slice(&rhs64); + let lhs128_p = P128::pack_slice(&lhs128); + let rhs128_p = P128::pack_slice(&rhs128); + + let mut out31 = vec![Pow2Offset31Field::zero(); n]; + let mut out64 = vec![Pow2Offset64Field::zero(); n]; + let mut out128 = vec![F128::zero(); n]; + let mut out31_p = vec![P31::broadcast(Pow2Offset31Field::zero()); lhs31_p.len()]; + let mut out64_p = vec![P64::broadcast(Pow2Offset64Field::zero()); lhs64_p.len()]; + let mut out128_p = vec![P128::broadcast(F128::zero()); lhs128_p.len()]; + + let mut group = c.benchmark_group(format!( + "parallel_throughput/{profile}/t{threads}/n{n}/c{chunk}" + )); + group.throughput(Throughput::Elements(n as u64)); + + group.bench_function("fp32_31b_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs31); + let b_v = black_box(&rhs31); + let out = &mut out31; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0]) + }) + }); + + group.bench_function("fp32_31b_mul_par_zip", |b| { + b.iter(|| { + let a = black_box(&lhs31); + let b_v = black_box(&rhs31); + let out = &mut out31; + pool.install(|| { + out.par_iter_mut() + .zip(a.par_iter()) + .zip(b_v.par_iter()) + .for_each(|((dst, lhs), rhs)| *dst = *lhs * *rhs); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp32_31b_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs31); + let b_v = black_box(&rhs31); + let out = &mut out31; + pool.install(|| { + out.par_chunks_mut(chunk) + .zip(a.par_chunks(chunk)) + .zip(b_v.par_chunks(chunk)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp32_31b_packed_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs31_p); + let b_v = black_box(&rhs31_p); + let out = &mut out31_p; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp32_31b_packed_mul_par_zip", |b| { + b.iter(|| { + let a = black_box(&lhs31_p); + let b_v = black_box(&rhs31_p); + let out = &mut out31_p; + pool.install(|| { + out.par_iter_mut() + .zip(a.par_iter()) + .zip(b_v.par_iter()) + .for_each(|((dst, lhs), rhs)| *dst = *lhs * *rhs); + }); + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp32_31b_packed_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs31_p); + let b_v = black_box(&rhs31_p); + let out = &mut out31_p; + pool.install(|| { + out.par_chunks_mut(chunk31_p) + .zip(a.par_chunks(chunk31_p)) + .zip(b_v.par_chunks(chunk31_p)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp64_64b_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs64); + let b_v = black_box(&rhs64); + let out = &mut out64; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0]) + }) + }); + + group.bench_function("fp64_64b_mul_par_zip", |b| { + b.iter(|| { + let a = black_box(&lhs64); + let b_v = black_box(&rhs64); + let out = &mut out64; + pool.install(|| { + out.par_iter_mut() + .zip(a.par_iter()) + .zip(b_v.par_iter()) + .for_each(|((dst, lhs), rhs)| *dst = *lhs * *rhs); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp64_64b_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs64); + let b_v = black_box(&rhs64); + let out = &mut out64; + pool.install(|| { + out.par_chunks_mut(chunk) + .zip(a.par_chunks(chunk)) + .zip(b_v.par_chunks(chunk)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp64_64b_packed_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs64_p); + let b_v = black_box(&rhs64_p); + let out = &mut out64_p; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp64_64b_packed_mul_par_zip", |b| { + b.iter(|| { + let a = black_box(&lhs64_p); + let b_v = black_box(&rhs64_p); + let out = &mut out64_p; + pool.install(|| { + out.par_iter_mut() + .zip(a.par_iter()) + .zip(b_v.par_iter()) + .for_each(|((dst, lhs), rhs)| *dst = *lhs * *rhs); + }); + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp64_64b_packed_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs64_p); + let b_v = black_box(&rhs64_p); + let out = &mut out64_p; + pool.install(|| { + out.par_chunks_mut(chunk64_p) + .zip(a.par_chunks(chunk64_p)) + .zip(b_v.par_chunks(chunk64_p)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp128_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs128); + let b_v = black_box(&rhs128); + let out = &mut out128; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0]) + }) + }); + + group.bench_function("fp128_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs128); + let b_v = black_box(&rhs128); + let out = &mut out128; + pool.install(|| { + out.par_chunks_mut(chunk) + .zip(a.par_chunks(chunk)) + .zip(b_v.par_chunks(chunk)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp128_packed_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs128_p); + let b_v = black_box(&rhs128_p); + let out = &mut out128_p; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp128_packed_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs128_p); + let b_v = black_box(&rhs128_p); + let out = &mut out128_p; + pool.install(|| { + out.par_chunks_mut(chunk128_p) + .zip(a.par_chunks(chunk128_p)) + .zip(b_v.par_chunks(chunk128_p)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0].extract(0)) + }) + }); + + group.finish(); +} + +#[cfg(not(feature = "parallel"))] +fn bench_parallel_throughput(_: &mut Criterion) {} + +fn bench_packed_sumcheck_mix(c: &mut Criterion) { + use hachi_pcs::algebra::{Fp128Packing, Fp32Packing, Fp64Packing}; + + let n = 4096u64; + let mut rng = StdRng::seed_from_u64(0xface_bead); + + macro_rules! sumcheck_bench { + ($group:expr, $label:expr, $field:ty, $packing:ty, $rng:expr, $n:expr) => {{ + let eq: Vec<$field> = (0..$n).map(|_| FieldSampling::sample($rng)).collect(); + let poly: Vec<$field> = (0..$n).map(|_| FieldSampling::sample($rng)).collect(); + let eq_p = <$packing>::pack_slice(&eq); + let poly_p = <$packing>::pack_slice(&poly); + let mut acc = <$packing>::broadcast(<$field>::zero()); + + $group.bench_function(concat!($label, "_packed_macc"), |b| { + b.iter(|| { + let e = black_box(&eq_p); + let p_v = black_box(&poly_p); + acc = <$packing>::broadcast(<$field>::zero()); + for i in 0..e.len() { + acc = acc + e[i] * p_v[i]; + } + black_box(acc) + }) + }); + }}; + } + + let mut group = c.benchmark_group("packed_sumcheck_mix"); + group.throughput(Throughput::Elements(n)); + + use hachi_pcs::algebra::fields::pseudo_mersenne::*; + type M31 = Fp32<{ (1u32 << 31) - 1 }>; + + type P24 = Fp32Packing<{ POW2_OFFSET_MODULUS_24 }>; + type P30 = Fp32Packing<{ POW2_OFFSET_MODULUS_30 }>; + type P31 = Fp32Packing<{ POW2_OFFSET_MODULUS_31 }>; + type PM31 = Fp32Packing<{ (1u32 << 31) - 1 }>; + type P32 = Fp32Packing<{ POW2_OFFSET_MODULUS_32 }>; + type P40 = Fp64Packing<{ POW2_OFFSET_MODULUS_40 }>; + type P48 = Fp64Packing<{ POW2_OFFSET_MODULUS_48 }>; + type P56 = Fp64Packing<{ POW2_OFFSET_MODULUS_56 }>; + type P64 = Fp64Packing<{ POW2_OFFSET_MODULUS_64 }>; + type P128 = Fp128Packing<{ POW2_OFFSET_MODULUS_128 }>; + + sumcheck_bench!(group, "fp32_24b", Pow2Offset24Field, P24, &mut rng, n); + sumcheck_bench!(group, "fp32_30b", Pow2Offset30Field, P30, &mut rng, n); + sumcheck_bench!(group, "fp32_31b", Pow2Offset31Field, P31, &mut rng, n); + sumcheck_bench!(group, "fp32_m31", M31, PM31, &mut rng, n); + sumcheck_bench!(group, "fp32_32b", Pow2Offset32Field, P32, &mut rng, n); + sumcheck_bench!(group, "fp64_40b", Pow2Offset40Field, P40, &mut rng, n); + sumcheck_bench!(group, "fp64_48b", Pow2Offset48Field, P48, &mut rng, n); + sumcheck_bench!(group, "fp64_56b", Pow2Offset56Field, P56, &mut rng, n); + sumcheck_bench!(group, "fp64_64b", Pow2Offset64Field, P64, &mut rng, n); + sumcheck_bench!(group, "fp128", Prime128M8M4M1M0, P128, &mut rng, n); + + group.finish(); +} + +criterion_group!( + field_arith, + bench_mul, + bench_mul_only, + bench_mul_isolated, + bench_sqr, + bench_inv, + bench_packed_fp128_backend, + bench_bn254, + bench_fp32_fp64_mul, + bench_widening_ops, + bench_accumulator_pattern, + bench_throughput, + bench_packed_throughput, + bench_packed_sumcheck_mix, + bench_parallel_throughput +); +criterion_main!(field_arith); diff --git a/benches/fp64_reduce_probe.rs b/benches/fp64_reduce_probe.rs new file mode 100644 index 00000000..073544a9 --- /dev/null +++ b/benches/fp64_reduce_probe.rs @@ -0,0 +1,118 @@ +#![allow(missing_docs)] + +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; + +const P40: u64 = hachi_pcs::algebra::fields::pseudo_mersenne::POW2_OFFSET_MODULUS_40; +const P64: u64 = hachi_pcs::algebra::fields::pseudo_mersenne::POW2_OFFSET_MODULUS_64; +const C40: u64 = (1u64 << 40) - P40; // 195 +const C64: u64 = 0u64.wrapping_sub(P64); // 59 +const MASK40: u64 = (1u64 << 40) - 1; +const MASK64_U128: u128 = u64::MAX as u128; + +#[inline(always)] +fn mul_c40_split(x: u64) -> u64 { + let c = C40 as u32; + let x_lo = x as u32; + let x_hi = (x >> 32) as u32; + (c as u64 * x_lo as u64).wrapping_add((c as u64 * x_hi as u64) << 32) +} + +#[inline(always)] +fn reduce40_split(lo: u64, hi: u64) -> u64 { + let high = (lo >> 40) | (hi << 24); + let f1 = (lo & MASK40).wrapping_add(mul_c40_split(high)); + let f2 = (f1 & MASK40).wrapping_add(mul_c40_split(f1 >> 40)); + let reduced = f2.wrapping_sub(P40); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P40) +} + +#[inline(always)] +fn reduce40_direct(lo: u64, hi: u64) -> u64 { + let high = (lo >> 40) | (hi << 24); + let f1 = (lo & MASK40).wrapping_add(C40.wrapping_mul(high)); + let f2 = (f1 & MASK40).wrapping_add(C40.wrapping_mul(f1 >> 40)); + let reduced = f2.wrapping_sub(P40); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P40) +} + +#[inline(always)] +fn reduce64(lo: u64, hi: u64) -> u64 { + let f1 = (lo as u128) + (C64 as u128) * (hi as u128); + let f2 = (f1 & MASK64_U128) + (C64 as u128) * ((f1 >> 64) as u64 as u128); + let reduced = f2.wrapping_sub(P64 as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P64 as u128)) as u64 +} + +#[inline(always)] +fn next_u64(state: &mut u64) -> u64 { + let mut x = *state; + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + *state = x; + x +} + +fn bench_fp64_reduce_probe(c: &mut Criterion) { + let n = 1 << 13; + let mut seed = 0x9e37_79b9_7f4a_7c15u64; + + let mut pairs40 = Vec::with_capacity(n); + let mut pairs64 = Vec::with_capacity(n); + for _ in 0..n { + let a40 = next_u64(&mut seed) % P40; + let b40 = next_u64(&mut seed) % P40; + let x40 = (a40 as u128) * (b40 as u128); + pairs40.push((x40 as u64, (x40 >> 64) as u64)); + + let a64 = next_u64(&mut seed); + let b64 = next_u64(&mut seed); + let x64 = (a64 as u128) * (b64 as u128); + pairs64.push((x64 as u64, (x64 >> 64) as u64)); + } + + for &(lo, hi) in &pairs40 { + assert_eq!(reduce40_split(lo, hi), reduce40_direct(lo, hi)); + } + + let mut group = c.benchmark_group("fp64_reduce_probe"); + group.throughput(Throughput::Elements(n as u64)); + + group.bench_function("reduce40_split", |b| { + b.iter(|| { + let mut acc = 0u64; + for &(lo, hi) in black_box(&pairs40) { + acc ^= reduce40_split(lo, hi); + } + black_box(acc) + }) + }); + + group.bench_function("reduce40_direct", |b| { + b.iter(|| { + let mut acc = 0u64; + for &(lo, hi) in black_box(&pairs40) { + acc ^= reduce40_direct(lo, hi); + } + black_box(acc) + }) + }); + + group.bench_function("reduce64", |b| { + b.iter(|| { + let mut acc = 0u64; + for &(lo, hi) in black_box(&pairs64) { + acc ^= reduce64(lo, hi); + } + black_box(acc) + }) + }); + + group.finish(); +} + +criterion_group!(fp64_reduce_probe, bench_fp64_reduce_probe); +criterion_main!(fp64_reduce_probe); diff --git a/benches/hachi_e2e.rs b/benches/hachi_e2e.rs new file mode 100644 index 00000000..c9825245 --- /dev/null +++ b/benches/hachi_e2e.rs @@ -0,0 +1,516 @@ +#![allow(missing_docs)] + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use hachi_pcs::algebra::{CyclotomicRing, Fp128}; +use hachi_pcs::error::HachiError; +use hachi_pcs::primitives::multilinear_evals::DenseMultilinearEvals; +use hachi_pcs::protocol::commitment::{ + HachiCommitmentCore, HachiCommitmentLayout, HachiProverSetup, HachiVerifierSetup, + MegaPolyBlock, ProductionFp128CommitmentConfig, RingCommitment, SparseBlockEntry, +}; +use hachi_pcs::protocol::commitment_scheme::{commit_onehot, HachiCommitmentScheme}; +use hachi_pcs::protocol::proof::HachiCommitmentHint; +use hachi_pcs::protocol::transcript::Blake2bTranscript; +use hachi_pcs::protocol::{CommitmentConfig, HachiProof}; +use hachi_pcs::{CommitmentScheme, FromSmallInt, Polynomial, Transcript}; +use std::time::Duration; + +type F = Fp128<0xfffffffffffffffffffffffffffffeed>; + +const D: usize = ProductionFp128CommitmentConfig::D; + +macro_rules! bench_config { + ($name:ident, M = $m:expr, R = $r:expr) => { + #[derive(Clone, Copy, Debug)] + struct $name; + impl CommitmentConfig for $name { + const D: usize = D; + const N_A: usize = ProductionFp128CommitmentConfig::N_A; + const N_B: usize = ProductionFp128CommitmentConfig::N_B; + const N_D: usize = ProductionFp128CommitmentConfig::N_D; + const LOG_BASIS: u32 = ProductionFp128CommitmentConfig::LOG_BASIS; + const DELTA: usize = ProductionFp128CommitmentConfig::DELTA; + const TAU: usize = ProductionFp128CommitmentConfig::TAU; + const CHALLENGE_WEIGHT: usize = ProductionFp128CommitmentConfig::CHALLENGE_WEIGHT; + + fn commitment_layout( + _max_num_vars: usize, + ) -> Result { + HachiCommitmentLayout::new::($m, $r) + } + } + }; +} + +bench_config!(CfgNv10, M = 4, R = 2); +bench_config!(CfgNv14, M = 6, R = 4); +bench_config!(CfgNv18, M = 8, R = 6); +bench_config!(CfgNv20, M = 8, R = 8); + +fn num_vars() -> usize { + let alpha = Cfg::D.trailing_zeros() as usize; + let layout = Cfg::commitment_layout(0).expect("benchmark layout"); + layout.m_vars + layout.r_vars + alpha +} + +fn make_poly(nv: usize) -> DenseMultilinearEvals { + let len = 1usize << nv; + let evals: Vec = (0..len).map(|i| F::from_u64(i as u64)).collect(); + DenseMultilinearEvals::new_padded(evals) +} + +fn opening_point(nv: usize) -> Vec { + (0..nv).map(|i| F::from_u64((i + 2) as u64)).collect() +} + +fn bench_phases(c: &mut Criterion, label: &str) +where + HachiCommitmentScheme: CommitmentScheme, +{ + type S = HachiCommitmentScheme; + let nv = num_vars::(); + let poly = make_poly(nv); + let pt = opening_point(nv); + + let mut group = c.benchmark_group(format!("hachi/{label}/nv{nv}")); + if nv >= 18 { + group.sample_size(10); + group.measurement_time(Duration::from_secs(30)); + } + + group.bench_function("setup", |b| { + b.iter(|| black_box( as CommitmentScheme>::setup_prover(black_box(nv)))) + }); + + let setup = as CommitmentScheme>::setup_prover(nv); + + group.bench_function("commit", |b| { + b.iter(|| { + black_box( + as CommitmentScheme>::commit(black_box(&poly), black_box(&setup)) + .unwrap(), + ) + }) + }); + + let (commitment, hint) = as CommitmentScheme>::commit(&poly, &setup).unwrap(); + + group.bench_function("prove", |b| { + b.iter(|| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + black_box( + as CommitmentScheme>::prove( + black_box(&setup), + black_box(&poly), + black_box(&pt), + Some(hint.clone()), + &mut transcript, + black_box(&commitment), + ) + .unwrap(), + ) + }) + }); + + let verifier_setup = as CommitmentScheme>::setup_verifier(&setup); + let opening = poly.evaluate(&pt); + let mut prover_transcript = Blake2bTranscript::::new(b"bench"); + let proof = as CommitmentScheme>::prove( + &setup, + &poly, + &pt, + Some(hint), + &mut prover_transcript, + &commitment, + ) + .unwrap(); + + group.bench_function("verify", |b| { + b.iter(|| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + black_box(&proof), + black_box(&verifier_setup), + &mut transcript, + black_box(&pt), + black_box(&opening), + black_box(&commitment), + ) + .unwrap(); + }) + }); + + group.bench_function(BenchmarkId::new("e2e", nv), |b| { + b.iter(|| { + let (cm, h) = as CommitmentScheme>::commit(&poly, &setup).unwrap(); + let mut pt_tr = Blake2bTranscript::::new(b"bench"); + let pf = as CommitmentScheme>::prove( + &setup, + &poly, + &pt, + Some(h), + &mut pt_tr, + &cm, + ) + .unwrap(); + let mut vt_tr = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + &pf, + &verifier_setup, + &mut vt_tr, + &pt, + &opening, + &cm, + ) + .unwrap(); + black_box(()) + }) + }); + + group.finish(); +} + +fn bench_onehot_phases(c: &mut Criterion, label: &str) +where + HachiCommitmentScheme: CommitmentScheme< + F, + ProverSetup = HachiProverSetup, + VerifierSetup = HachiVerifierSetup, + Commitment = RingCommitment, + Proof = HachiProof, + OpeningProofHint = HachiCommitmentHint, + >, +{ + type S = HachiCommitmentScheme; + let nv = num_vars::(); + let total_elems = 1usize << nv; + let onehot_k = D; + let num_chunks = total_elems / onehot_k; + + let indices: Vec> = (0..num_chunks).map(|i| Some(i % onehot_k)).collect(); + + let mut evals = vec![F::from_u64(0); total_elems]; + for (ci, opt_idx) in indices.iter().enumerate() { + if let Some(idx) = opt_idx { + evals[ci * onehot_k + idx] = F::from_u64(1); + } + } + let poly = DenseMultilinearEvals::new_padded(evals); + let pt = opening_point(nv); + + let setup = as CommitmentScheme>::setup_prover(nv); + + let mut group = c.benchmark_group(format!("hachi_onehot/{label}/nv{nv}")); + if nv >= 18 { + group.sample_size(10); + group.measurement_time(Duration::from_secs(30)); + } + + group.bench_function("commit", |b| { + b.iter(|| { + black_box( + commit_onehot::( + black_box(onehot_k), + black_box(&indices), + black_box(&setup), + ) + .unwrap(), + ) + }) + }); + + let (commitment, hint) = commit_onehot::(onehot_k, &indices, &setup).unwrap(); + + group.bench_function("prove", |b| { + b.iter(|| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + black_box( + as CommitmentScheme>::prove( + black_box(&setup), + black_box(&poly), + black_box(&pt), + Some(hint.clone()), + &mut transcript, + black_box(&commitment), + ) + .unwrap(), + ) + }) + }); + + let verifier_setup = as CommitmentScheme>::setup_verifier(&setup); + let opening = poly.evaluate(&pt); + let mut prover_transcript = Blake2bTranscript::::new(b"bench"); + let proof = as CommitmentScheme>::prove( + &setup, + &poly, + &pt, + Some(hint.clone()), + &mut prover_transcript, + &commitment, + ) + .unwrap(); + + group.bench_function("verify", |b| { + b.iter(|| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + black_box(&proof), + black_box(&verifier_setup), + &mut transcript, + black_box(&pt), + black_box(&opening), + black_box(&commitment), + ) + .unwrap(); + }) + }); + + group.bench_function(BenchmarkId::new("e2e", nv), |b| { + b.iter(|| { + let (cm, h) = commit_onehot::(onehot_k, &indices, &setup).unwrap(); + let mut pt_tr = Blake2bTranscript::::new(b"bench"); + let pf = as CommitmentScheme>::prove( + &setup, + &poly, + &pt, + Some(h), + &mut pt_tr, + &cm, + ) + .unwrap(); + let mut vt_tr = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + &pf, + &verifier_setup, + &mut vt_tr, + &pt, + &opening, + &cm, + ) + .unwrap(); + black_box(()) + }) + }); + + group.finish(); +} + +fn bench_mixed_phases(c: &mut Criterion, label: &str) +where + HachiCommitmentScheme: CommitmentScheme< + F, + ProverSetup = HachiProverSetup, + VerifierSetup = HachiVerifierSetup, + Commitment = RingCommitment, + Proof = HachiProof, + OpeningProofHint = HachiCommitmentHint, + >, +{ + type S = HachiCommitmentScheme; + let nv = num_vars::(); + let layout = Cfg::commitment_layout(0).expect("benchmark layout"); + let block_len = layout.block_len; + let num_blocks = layout.num_blocks; + let dense_blocks = num_blocks / 2; + + let mut ring_coeffs: Vec> = Vec::with_capacity(num_blocks * block_len); + + for i in 0..(dense_blocks * block_len) { + ring_coeffs.push(CyclotomicRing::from_coefficients(std::array::from_fn( + |j| F::from_u64((i * D + j + 1) as u64), + ))); + } + + let mut sparse_per_block: Vec> = Vec::new(); + for bi in 0..(num_blocks - dense_blocks) { + let mut entries = Vec::new(); + for ri in 0..block_len { + let idx = (bi * block_len + ri) % D; + let mut coeffs = [F::from_u64(0); D]; + coeffs[idx] = F::from_u64(1); + ring_coeffs.push(CyclotomicRing::from_coefficients(coeffs)); + entries.push(SparseBlockEntry { + pos_in_block: ri, + nonzero_coeffs: vec![idx], + }); + } + sparse_per_block.push(entries); + } + + let evals: Vec = ring_coeffs + .iter() + .flat_map(|r| r.coefficients().iter().copied()) + .collect(); + let poly = DenseMultilinearEvals::new_padded(evals); + let pt = opening_point(nv); + + let setup = as CommitmentScheme>::setup_prover(nv); + + let mut group = c.benchmark_group(format!("hachi_mixed/{label}/nv{nv}")); + if nv >= 18 { + group.sample_size(10); + group.measurement_time(Duration::from_secs(30)); + } + + group.bench_function("commit", |b| { + b.iter(|| { + let blocks: Vec> = (0..num_blocks) + .map(|i| { + if i < dense_blocks { + let start = i * block_len; + let end = start + block_len; + MegaPolyBlock::Dense(&ring_coeffs[start..end]) + } else { + MegaPolyBlock::OneHot(&sparse_per_block[i - dense_blocks]) + } + }) + .collect(); + black_box( + HachiCommitmentCore::commit_mixed::( + black_box(&blocks), + black_box(&setup), + ) + .unwrap(), + ) + }) + }); + + let blocks: Vec> = (0..num_blocks) + .map(|i| { + if i < dense_blocks { + let start = i * block_len; + let end = start + block_len; + MegaPolyBlock::Dense(&ring_coeffs[start..end]) + } else { + MegaPolyBlock::OneHot(&sparse_per_block[i - dense_blocks]) + } + }) + .collect(); + let w = HachiCommitmentCore::commit_mixed::(&blocks, &setup).unwrap(); + let commitment = w.commitment; + let hint = HachiCommitmentHint { + t_hat: w.t_hat, + ring_coeffs: ring_coeffs.clone(), + }; + + group.bench_function("prove", |b| { + b.iter(|| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + black_box( + as CommitmentScheme>::prove( + black_box(&setup), + black_box(&poly), + black_box(&pt), + Some(hint.clone()), + &mut transcript, + black_box(&commitment), + ) + .unwrap(), + ) + }) + }); + + let verifier_setup = as CommitmentScheme>::setup_verifier(&setup); + let opening = poly.evaluate(&pt); + let mut prover_transcript = Blake2bTranscript::::new(b"bench"); + let proof = as CommitmentScheme>::prove( + &setup, + &poly, + &pt, + Some(hint.clone()), + &mut prover_transcript, + &commitment, + ) + .unwrap(); + + group.bench_function("verify", |b| { + b.iter(|| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + black_box(&proof), + black_box(&verifier_setup), + &mut transcript, + black_box(&pt), + black_box(&opening), + black_box(&commitment), + ) + .unwrap(); + }) + }); + + group.bench_function(BenchmarkId::new("e2e", nv), |b| { + b.iter(|| { + let blocks: Vec> = (0..num_blocks) + .map(|i| { + if i < dense_blocks { + let start = i * block_len; + let end = start + block_len; + MegaPolyBlock::Dense(&ring_coeffs[start..end]) + } else { + MegaPolyBlock::OneHot(&sparse_per_block[i - dense_blocks]) + } + }) + .collect(); + let w = HachiCommitmentCore::commit_mixed::(&blocks, &setup).unwrap(); + let cm = w.commitment; + let h = HachiCommitmentHint { + t_hat: w.t_hat, + ring_coeffs: ring_coeffs.clone(), + }; + let mut pt_tr = Blake2bTranscript::::new(b"bench"); + let pf = as CommitmentScheme>::prove( + &setup, + &poly, + &pt, + Some(h), + &mut pt_tr, + &cm, + ) + .unwrap(); + let mut vt_tr = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + &pf, + &verifier_setup, + &mut vt_tr, + &pt, + &opening, + &cm, + ) + .unwrap(); + black_box(()) + }) + }); + + group.finish(); +} + +fn bench_nv10(c: &mut Criterion) { + bench_phases::(c, "fp128_p275"); +} +fn bench_nv14(c: &mut Criterion) { + bench_phases::(c, "fp128_p275"); +} +fn bench_nv18(c: &mut Criterion) { + bench_phases::(c, "fp128_p275"); +} +fn bench_nv20(c: &mut Criterion) { + bench_phases::(c, "fp128_p275"); +} +fn bench_onehot_nv14(c: &mut Criterion) { + bench_onehot_phases::(c, "fp128_p275"); +} +fn bench_mixed_nv14(c: &mut Criterion) { + bench_mixed_phases::(c, "fp128_p275"); +} + +criterion_group!( + hachi_benches, + bench_nv10, + bench_nv14, + bench_nv18, + bench_nv20, + bench_onehot_nv14, + bench_mixed_nv14, +); +criterion_main!(hachi_benches); diff --git a/benches/norm_sumcheck.rs b/benches/norm_sumcheck.rs new file mode 100644 index 00000000..655d6cab --- /dev/null +++ b/benches/norm_sumcheck.rs @@ -0,0 +1,203 @@ +#![allow(missing_docs)] + +use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use hachi_pcs::algebra::Fp128; +use hachi_pcs::protocol::sumcheck::norm_sumcheck::NormSumcheckProver; +use hachi_pcs::protocol::sumcheck::split_eq::GruenSplitEq; +use hachi_pcs::protocol::sumcheck::{ + fold_evals_in_place, prove_sumcheck, range_check_eval, SumcheckInstanceProver, UniPoly, +}; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::Blake2bTranscript; +use hachi_pcs::{FieldCore, FromSmallInt, Transcript}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; +use std::time::Duration; + +type F = Fp128<0xfffffffffffffffffffffffffffffeed>; + +/// Baseline prover keeps the pre-dispatch point-eval kernel for apples-to-apples benchmarks. +/// It is intentionally local to this bench and should not be used in production code. +struct BaselineNormSumcheckProver { + split_eq: GruenSplitEq, + w_table: Vec, + num_vars: usize, + b: usize, +} + +impl BaselineNormSumcheckProver { + fn new(tau: &[E], w_evals: Vec, b: usize) -> Self { + let num_vars = tau.len(); + assert_eq!(w_evals.len(), 1 << num_vars); + Self { + split_eq: GruenSplitEq::new(tau), + w_table: w_evals, + num_vars, + b, + } + } +} + +impl SumcheckInstanceProver for BaselineNormSumcheckProver { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 2 * self.b + } + + fn input_claim(&self) -> E { + E::zero() + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + let half = self.w_table.len() / 2; + let degree_q = 2 * self.b - 1; + let num_points_q = degree_q + 1; + + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + let b = self.b; + + #[cfg(feature = "parallel")] + let q_evals = { + (0..half) + .into_par_iter() + .fold( + || vec![E::zero(); num_points_q], + |mut evals, j| { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + for (t, eval) in evals.iter_mut().enumerate() { + let t_e = E::from_u64(t as u64); + let w_t = w_0 + t_e * (w_1 - w_0); + *eval = *eval + eq_rem * range_check_eval(w_t, b); + } + evals + }, + ) + .reduce( + || vec![E::zero(); num_points_q], + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai = *ai + *bi; + } + a + }, + ) + }; + #[cfg(not(feature = "parallel"))] + let q_evals = { + let mut q_evals = vec![E::zero(); num_points_q]; + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + for (t, eval) in q_evals.iter_mut().enumerate() { + let t_e = E::from_u64(t as u64); + let w_t = w_0 + t_e * (w_1 - w_0); + *eval = *eval + eq_rem * range_check_eval(w_t, b); + } + } + q_evals + }; + + let q_poly = UniPoly::from_evals(&q_evals); + self.split_eq.gruen_mul(&q_poly) + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + self.split_eq.bind(r); + fold_evals_in_place(&mut self.w_table, r); + } +} + +#[derive(Clone)] +struct NormCase { + num_vars: usize, + b: usize, + tau: Vec, + w_evals: Vec, +} + +fn build_case(num_vars: usize, b: usize, seed: u64) -> NormCase { + let mut rng = StdRng::seed_from_u64(seed); + let n = 1usize << num_vars; + let tau: Vec = (0..num_vars) + .map(|_| F::from_u64(rng.gen_range(0u64..(1u64 << 24)))) + .collect(); + let w_evals: Vec = (0..n) + .map(|_| F::from_u64(rng.gen_range(0u64..(1u64 << 24)))) + .collect(); + NormCase { + num_vars, + b, + tau, + w_evals, + } +} + +fn bench_norm_sumcheck(c: &mut Criterion) { + let cases = [ + build_case(10, 4, 0xA11CE001), + build_case(10, 8, 0xA11CE002), + build_case(14, 4, 0xA11CE003), + build_case(14, 8, 0xA11CE004), + build_case(14, 16, 0xA11CE005), + build_case(18, 8, 0xA11CE006), + ]; + + let mut group = c.benchmark_group("norm_sumcheck"); + group.warm_up_time(Duration::from_secs(8)); + group.measurement_time(Duration::from_secs(24)); + group.sample_size(35); + + for case in &cases { + let case_tag = format!("nv{}_b{}", case.num_vars, case.b); + group.bench_function(BenchmarkId::new("baseline", &case_tag), |bencher| { + bencher.iter_batched( + || BaselineNormSumcheckProver::new(&case.tau, case.w_evals.clone(), case.b), + |mut prover| { + let mut transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + black_box( + prove_sumcheck::(&mut prover, &mut transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(), + ) + }, + BatchSize::SmallInput, + ); + }); + + group.bench_function(BenchmarkId::new("dispatched", &case_tag), |bencher| { + bencher.iter_batched( + || NormSumcheckProver::new(&case.tau, case.w_evals.clone(), case.b), + |mut prover| { + let mut transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + black_box( + prove_sumcheck::(&mut prover, &mut transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(), + ) + }, + BatchSize::SmallInput, + ); + }); + } + + group.finish(); +} + +criterion_group!(benches, bench_norm_sumcheck); +criterion_main!(benches); diff --git a/benches/ring_ntt.rs b/benches/ring_ntt.rs new file mode 100644 index 00000000..aa2ac9ff --- /dev/null +++ b/benches/ring_ntt.rs @@ -0,0 +1,68 @@ +#![allow(missing_docs)] + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use hachi_pcs::algebra::ntt::butterfly::{forward_ntt, inverse_ntt, NttTwiddles}; +use hachi_pcs::algebra::tables::{q32_garner, Q32_MODULUS, Q32_NUM_PRIMES, Q32_PRIMES}; +use hachi_pcs::algebra::{CyclotomicCrtNtt, CyclotomicRing, Fp64, MontCoeff}; +use hachi_pcs::FromSmallInt; + +type F = Fp64<{ Q32_MODULUS }>; +type R = CyclotomicRing; +type N = CyclotomicCrtNtt; + +fn sample_ring(seed: u64) -> R { + let coeffs = std::array::from_fn(|i| { + let x = seed + .wrapping_mul(31) + .wrapping_add((i as u64).wrapping_mul(17)); + F::from_u64(x % Q32_MODULUS) + }); + R::from_coefficients(coeffs) +} + +fn bench_ring_schoolbook_mul(c: &mut Criterion) { + let lhs = sample_ring(3); + let rhs = sample_ring(11); + c.bench_function("ring_schoolbook_mul_d64", |b| { + b.iter(|| black_box(lhs) * black_box(rhs)) + }); +} + +fn bench_ntt_single_prime_round_trip(c: &mut Criterion) { + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + let base: [MontCoeff; 64] = + std::array::from_fn(|i| prime.from_canonical(((i * 5 + 7) as i16) % prime.p)); + + c.bench_function("ntt_single_prime_forward_inverse_d64", |b| { + b.iter(|| { + let mut a = base; + forward_ntt(&mut a, prime, &tw); + inverse_ntt(&mut a, prime, &tw); + black_box(a) + }) + }); +} + +fn bench_crt_round_trip(c: &mut Criterion) { + let ring = sample_ring(19); + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + let garner = q32_garner(); + + c.bench_function("ring_ntt_crt_round_trip_d64_k6", |b| { + b.iter(|| { + let ntt = N::from_ring(black_box(&ring), &Q32_PRIMES, &twiddles); + let back: R = ntt.to_ring(&Q32_PRIMES, &twiddles, &garner); + black_box(back) + }) + }); +} + +criterion_group!( + ring_ntt, + bench_ring_schoolbook_mul, + bench_ntt_single_prime_round_trip, + bench_crt_round_trip +); +criterion_main!(ring_ntt); diff --git a/docs/FIELD_EMBEDDINGS_SUPERNEO_VS_HACHI.md b/docs/FIELD_EMBEDDINGS_SUPERNEO_VS_HACHI.md new file mode 100644 index 00000000..90b631c2 --- /dev/null +++ b/docs/FIELD_EMBEDDINGS_SUPERNEO_VS_HACHI.md @@ -0,0 +1,698 @@ +# Field embeddings in SuperNeo vs Hachi (math-first notes) + +This note focuses **only** on the math in the two papers: + +- `docs/superneo.pdf` (“Neo and SuperNeo: Post-quantum folding with pay-per-bit costs over small fields”) +- `paper/hachi.pdf` (“Hachi: Efficient Lattice-Based Multilinear Polynomial Commitments over Extension Fields”) + +The shared theme is: lattice commitments naturally live over **cyclotomic rings**, but we want the *interactive proof logic* (sum-check, norm checks, etc.) to live over a **(small) field or a small extension field**. Both works build *embeddings/reductions* that let you: + +- commit in a ring/module world (Ajtai/Module-SIS commitments), while +- proving the needed algebraic statements using field arithmetic, and +- keeping norms under control (for binding) and enabling linear-combination/folding operations. + +--- + +## 1) Common background and notation (as used in SuperNeo) + +SuperNeo sets up a base field \(F = \mathbb{F}_q\), an extension field \(K/F\) of minimal degree such that \(1/|K| = \mathrm{negl}(\lambda)\), and a cyclotomic ring + +\[ +R_F := F[X]/(\Phi(X)), \quad R_K := K[X]/(\Phi(X)), +\] + +where \(\Phi(X)\) is an \(\eta\)-th cyclotomic polynomial of degree \(d\). It explicitly treats + +\[ +F \subseteq R_F \subseteq R_K, \qquad F \subseteq K +\] + +as nested substructures (SuperNeo, Def. 1; `docs/superneo.pdf`, p. 19–20: “-- 19 of 60 --”, “-- 20 of 60 --”). + +Two coefficient maps show up everywhere: + +- coefficient vector: \(\mathrm{cf}(a)\in F^d\) for \(a\in R_F\), +- constant term: \(\mathrm{ct}(a)\in F\) for \(a\in R_F\), + +and similarly over \(R_K\) (SuperNeo, Def. 2; `docs/superneo.pdf`, p. 20: “-- 20 of 60 --”). + +### 1.1 What I mean by the “Gram operator” in this note + +Fix: + +- an \(F\)-basis \(\{e_0,\dots,e_{d-1}\}\) of the \(F\)-vector space \(R_F\) (most often \(e_i = X^i\), the coefficient basis), and +- a bilinear form \(B: R_F\times R_F \to F\). + +Typical examples in these papers are: + +- \(B(u,v)=\mathrm{ct}(u\cdot v)\) (SuperNeo’s constant-term functional), possibly with an automorphism inserted, or +- \(B(u,v)=\mathrm{Tr}_H(u\cdot \sigma_{-1}(v))\) (Hachi’s trace-to-subfield functional). + +Then the **Gram matrix** of \(B\) in the basis \(\{e_i\}\) is the \(d\times d\) matrix +\[ +G_{ij} := B(e_i,e_j). +\] +This matrix encodes the pairing: +if \(u=\sum_i a_i e_i\) and \(v=\sum_j b_j e_j\), then +\[ +B(u,v) = a^\top\, G\, b. +\] + +The corresponding **Gram operator** is just the linear map \(g: F^d\to F^d\) given by +\[ +g(b) := G\,b, +\] +so that \(B(u,v)=a^\top g(b)\). + +When SuperNeo says “there exists a linear transform \(T\) such that \(\mathrm{ct}(T(a)\cdot b)=\langle a,b\rangle\)” (Thm. 3), +one way to interpret it is: choose a bilinear form \(B_0(u,v)=\mathrm{ct}(u\cdot v)\) (or a close variant), write down its Gram matrix \(G\) in the coefficient basis, and take \(T = G^{-{\top}}\). This makes the pairing become the standard dot product in coordinates. + +--- + +## 2) What “field embedding” means in SuperNeo (the core problem) + +### 2.1 The Ajtai-commitment mismatch + +Ajtai/Module-SIS style commitments are *ring-module* commitments: + +- Commit to \(z \in R_F^n\) via a linear map \(c = A z\) over \(R_F\). + +But CCS witnesses (and CCS arithmetic) are naturally vectors over \(F\). So you need a map: + +\[ +\iota: F^{N} \longrightarrow R_F^n +\] + +that is compatible with: + +- **norm constraints** (binding only holds for “small-norm” openings), +- **field constraint checking** via sum-check over \(F\) or \(K\), +- **folding** (random linear combinations of commitments and claims). + +SuperNeo frames this as “embed field vectors (CCS witnesses) into the ring vectors that Ajtai commitments operate over” and calls out that the embedding must preserve **norm bounds** and an **evaluation homomorphism** needed for sum-check-based folding (SuperNeo, §1.2; `docs/superneo.pdf`, p. 5–6: “-- 5 of 60 --”, “-- 6 of 60 --”). + +### 2.2 What went wrong before (NTT embedding) + +Prior lattice folding used an NTT/SIMD isomorphism that maps ring elements into a product of extension fields; this makes field-constraint checking look “ring-native”, but: + +- the NTT map is **not norm-preserving**, so small bit-width witnesses become arbitrary-norm ring elements, +- the commitment must then decompose regardless of bit-width ⇒ **no pay-per-bit**, +- packing efficiency is limited by the factor \(t\) in \(F_{q^t}\) (SuperNeo, §1.2.1; `docs/superneo.pdf`, p. 6–7: “-- 6 of 60 --”, “-- 7 of 60 --”). + +--- + +## 3) SuperNeo’s key innovation: *norm-preserving embeddings + evaluation homomorphism* + +SuperNeo’s abstract summarizes the core as: + +> “two new norm-preserving embeddings of field vectors into ring vectors that respect an evaluation homomorphism required for folding” +(`docs/superneo.pdf`, p. 1: “-- 1 of 60 --”). + +The paper describes both a “Neo embedding” (SIMD-friendly) and a “SuperNeo embedding” (general, non-SIMD), and then focuses the rest of the paper on SuperNeo (see `docs/superneo.pdf`, p. 10–11: “-- 10 of 60 --”, “-- 11 of 60 --”). + +### 3.1 Neo embedding (high level: “coefficients-as-SIMD lanes”) + +Neo’s embedding idea: pack **\(d\) field vectors** \(z^{(1)},\dots,z^{(d)}\in F^n\) into the **coefficient slots** of a ring vector \(z \in R_F^n\), so the coefficient matrix \(\mathrm{cf}(z)\in F^{d\times n}\) literally equals those \(d\) vectors (SuperNeo, §1.2.2 “Contribution 1”; `docs/superneo.pdf`, p. 8–10: “-- 8 of 60 --”, “-- 9 of 60 --”, “-- 10 of 60 --”). + +Key consequences: + +- **norm-preserving**: small field entries ⇒ small ring coefficients ⇒ binding is aligned with bit-width. +- **optimal SIMD packing**: achieves \(d\cdot n\) field elements per length-\(n\) ring vector (under a SIMD constraint system). +- **evaluation homomorphism for folding**: if you fold commitments with a *short ring* challenge \(\delta\in R_F\), the embedded evaluation claims can be folded consistently by embedding the \(d\) field evaluations as a ring element \(y=\sum_i y^{(i)}X^{i-1}\) (SuperNeo, p. 9: “-- 9 of 60 --”). + +But Neo still needs SIMD constraints: “the same constraint system must be applied to all \(d\) underlying field vectors” (SuperNeo, p. 10: “-- 10 of 60 --”). + +### 3.2 SuperNeo embedding (formal: coefficient embedding of **one** length-\(d n_R\) vector) + +SuperNeo removes SIMD by embedding a **single** long field vector \(z\in F^{n_F}\) where \(n_F = d\cdot n_R\), by chunking it into \(n_R\) blocks of length \(d\), and mapping each block to one ring element’s coefficient vector. + +This is defined formally as “Coefficient Embedding” (SuperNeo, §5, Def. 7; `docs/superneo.pdf`, p. 23: “-- 23 of 60 --”): + +- element: \(v\in F^d \mapsto \mathbf{v}\in R_F\) with \(\mathrm{cf}(\mathbf{v})=v\) +- vector: \(z\in F^{d n_R}\mapsto \mathbf{z}\in R_F^{n_R}\) by splitting into \(d\)-sized blocks +- matrix: \(M\in F^{m\times d n_R}\mapsto \mathbf{M}\in R_F^{m\times n_R}\) row-wise. + +**Why this matters:** + +- **optimal packing without SIMD**: you pack \(d\cdot n_R\) field elements into \(n_R\) ring elements. +- **norm-preserving**: the committed object’s coefficients are exactly the witness entries. +- **field-native checks become possible**: you can write constraints directly over the underlying field vector \(z\) and use sum-check over \(F\) or \(K\). + +### 3.3 The nontrivial part: lifting *field* products to *ring* products while keeping folding linear + +SuperNeo’s obstacle is: commitments/folding live over the ring, but sum-check outputs **field** multilinear evaluation claims like + +\[ +M z \;\widetilde{}\; (r) \in K +\] + +for some random \(r\), and folding wants to take ring-linear combinations \(z'' = z + \delta z'\) with \(\delta \in R_F\) and have the **claims** fold “the same way”. + +SuperNeo’s main embedding tool is Section 5 “Embedding products with evaluation homomorphism”: + +1. Use a cyclotomic **inner-product automorphism trick** to turn coefficient inner products into ring constant terms. + +SuperNeo states an “Inner Product Transform” (Thm. 3; `docs/superneo.pdf`, p. 23: “-- 23 of 60 --”): + +> there exists a linear transform \(\bar{\cdot}: F^d \to F^d\) such that for all \(a,b\in F^d\), +> \[ +> \mathrm{ct}(\overline{a}\cdot \mathbf{b}) = \langle a,b\rangle. +> \] + +Conceptually: cyclotomic rings have many Galois automorphisms (e.g. “conjugation” \(X\mapsto X^{-1}\)), and by applying an appropriate automorphism/linear transform to one operand, the constant coefficient of a ring product recovers a dot product of coefficient vectors. (SuperNeo explicitly attributes this to “(Galois, conjugation, or inner product) automorphism trick” in §1.2.2; `docs/superneo.pdf`, p. 10–11: “-- 10 of 60 --”, “-- 11 of 60 --”.) + +2. Extend this transform blockwise to vectors/matrices (Def. 8; `docs/superneo.pdf`, p. 23: “-- 23 of 60 --”). + +3. Obtain a **matrix-vector product transform** (Thm. 4; `docs/superneo.pdf`, p. 23–24: “-- 23 of 60 --”, “-- 24 of 60 --”): + +> For \(M\in F^{m\times n_F}\), \(z\in F^{n_F}\), +> \[ +> M z = \mathrm{ct}(\overline{M}\,\mathbf{z}), +> \] +> i.e. the field product equals the vector of constant terms of a ring product. + +4. Lift this to evaluation claims and prove the **evaluation homomorphism** (Thm. 5; `docs/superneo.pdf`, p. 24: “-- 24 of 60 --”): + +Roughly: if you linearly combine committed ring vectors with ring scalars \(\rho_i\in R_F\), then the lifted ring-evaluation objects combine linearly as well, and constant terms track the underlying field evaluations. + +This is the formal engine that makes “field-native sum-check + ring-linear folding” composable. + +### 3.3.1 Explicit inner-product transforms for two cyclotomics you care about + +SuperNeo’s Theorem 3 is an *existence* statement: there is a linear map \(T: F^d \to F^d\) (write \(T(a)=\bar a\)) such that for all \(a,b\in F^d\), + +\[ +\mathrm{ct}(\mathbf{\bar a}\cdot \mathbf{b}) = \langle a,b\rangle +\] + +where \(\mathbf{v}\in R_F\) denotes the coefficient embedding of \(v\in F^d\). + +Below are **concrete closed forms** for \(T\) in two important special cases. + +#### (A) Power-of-two cyclotomic: \(\Phi(X)=X^d+1\) (negacyclic ring) + +Let \(R_F = F[X]/(X^d+1)\), and write +\(a(X)=\sum_{i=0}^{d-1} a_i X^i\), \(b(X)=\sum_{i=0}^{d-1} b_i X^i\). + +Define \(\bar a(X)\) by the coefficient rule: + +- \(\bar a_0 := a_0\) +- for \(i=1,\dots,d-1\): \(\bar a_i := -a_{d-i}\) + +Equivalently, +\[ +\bar a(X)=a_0 - \sum_{i=1}^{d-1} a_i X^{d-i}. +\] + +Then in \(R_F\), +\[ +\mathrm{ct}(\bar a(X)\,b(X)) = \sum_{i=0}^{d-1} a_i b_i. +\] + +Reason (one-line): the term \((-a_i X^{d-i})(b_i X^i)=-a_i b_i X^d\) contributes \(+a_i b_i\) to the constant term since \(X^d=-1\); other cross-terms cannot reduce to constants without leaving a nonzero power of \(X\). + +This is exactly the classical “conjugation/inversion automorphism” trick specialized to \(X^d+1\). + +#### (B) Trinomial cyclotomic: \(\Phi_{81}(X)=X^{54}+X^{27}+1\) + +Let \(R_F = F[X]/(X^{54}+X^{27}+1)\) (so \(d=54\)). Write +\(a(X)=\sum_{i=0}^{53} a_i X^i\), \(b(X)=\sum_{i=0}^{53} b_i X^i\). + +One valid “inner product transform” \(T(a)=\bar a\) with +\(\mathrm{ct}(\bar a(X)\,b(X))=\sum_i a_i b_i\) +is: + +- \(\bar a_0 := a_0\) +- for \(i=1,\dots,26\): \(\bar a_i := -(a_{27-i} + a_{54-i})\) +- for \(i=27,\dots,53\): \(\bar a_i := -a_{54-i}\) + +(indices are in \(\{0,\dots,53\}\)). + +Notable features: + +- **extremely sparse**: each output coefficient depends on at most 2 input coefficients, +- **\(O(d)\)** time with only adds + sign flips, matching SuperNeo’s efficiency remark for power-of-two / trinomial cyclotomics. + +--- + +### 3.4 Why this is the “field-embedding innovation” (in one sentence) + +SuperNeo’s innovation on the embedding side is: + +- **embed a single field witness vector into ring coefficients in a norm-preserving way**, and +- **systematically lift field matrix products/evaluations to ring expressions whose constant terms recover the field values**, so that +- **ring-linear folding preserves the field evaluation claims** (evaluation homomorphism), + +thereby enabling a HyperNova-like folding architecture where *sum-check and norm checks run over \(K\)* rather than over the ring, while commitments still live over \(R_F\). + +--- + +## 4) What “field embedding” means in Hachi (the PCS perspective) + +Hachi is not a folding scheme; it is a **multilinear PCS**. But the verification bottleneck is similar: classic lattice PCS machinery lives in cyclotomic rings \(R_q = \mathbb{Z}_q[X]/(X^d+1)\), whereas sum-check is naturally over a field. + +Hachi’s abstract states its two embedding/reduction ideas (Hachi, Abstract; `paper/hachi.pdf`, p. 1: “-- 1 of 33 --”): + +1. **Ring-switching + sum-check**: integrate Greyhound with ring-switching so the verifier avoids ring multiplication. +2. **Generic reduction (extension field → ring)**: convert evaluation proofs over \( \mathbb{F}_{q^k}\) into statements over cyclotomic rings \(R_q\). + +### 4.1 Embedding extension fields inside cyclotomic rings (fixed rings under automorphisms) + +In Hachi’s technical overview (Hachi, §1.3; `paper/hachi.pdf`, p. 4–5: “-- 4 of 33 --”, “-- 5 of 33 --”), it identifies finite fields \( \mathbb{F}_{q^k}\) *inside* \(R_q\) using fixed subrings under a subgroup of the Galois group: + +- Let \(R = \mathbb{Z}[X]/(X^d+1)\) and \(R_q = R/(q)\) with \(d=2^\alpha\). +- For automorphisms \(\sigma_i: X\mapsto X^i\), define the fixed ring + \[ + R_q^H := \{x\in R_q : \forall \sigma\in H,\;\sigma(x)=x\}. + \] + +Then (Lemma 1, informal; `paper/hachi.pdf`, p. 5: “-- 5 of 33 --”): + +> for suitable \(q\) (notably \(q\equiv 5\pmod 8\)) and \(k\mid d/2\), there exists \(H\) such that \(R_q^H\) is a **subfield** of \(R_q\) isomorphic to \(\mathbb{F}_{q^k}\). + +This is a literal *field embedding into the ring*: \(\mathbb{F}_{q^k}\hookrightarrow R_q\) realized as a fixed subring. + +### 4.2 Inner products via trace + automorphisms (a close cousin of SuperNeo’s transform) + +Hachi then uses a trace map \(\mathrm{Tr}_H:R_q\to R_q^H\) and an automorphism (notably \(\sigma_{-1}\)) to turn ring products into **field inner products** (Theorem 1, informal; `paper/hachi.pdf`, p. 5: “-- 5 of 33 --”): + +> there exists a bijection \(\psi:(R_q^H)^{d/k}\to R_q\) such that +> \[ +> \mathrm{Tr}_H\big(\psi(a)\cdot \sigma_{-1}(\psi(b))\big) = (d/k)\cdot \langle a,b\rangle. +> \] + +This is structurally very similar to SuperNeo’s “inner product transform → constant term” idea, except Hachi uses: + +- a **trace to a subfield** \(R_q^H \cong \mathbb{F}_{q^k}\), +- whereas SuperNeo phrases it as a **linear transform** on coefficient vectors whose ring-product constant term recovers the dot product. + +### 4.3 Ring-switching: from ring equations to extension-field equations (so the verifier stays field-native) + +Hachi’s “ring switching and sum-check over extension fields” overview (Hachi, §1.3; `paper/hachi.pdf`, p. 6–7: “-- 6 of 33 --”, “-- 7 of 33 --”) sketches: + +- lift a relation over \(R_q\) to an identity over \(\mathbb{Z}_q[X]\) with an explicit multiple of \((X^d+1)\), +- sample \(\alpha \leftarrow \mathbb{F}_{q^k}\) and substitute \(X=\alpha\), +- reducing the ring relation to a **field inner product / sum-check-type claim** over \(\mathbb{F}_{q^k}\), +- then run sum-check over the field and recurse. + +This is the PCS analogue of “avoid ring operations during sum-check”: Hachi’s verifier can avoid cyclotomic ring multiplications even though the underlying assumption/commitments are lattice/ring-based. + +--- + +## 5) Same vs different (embedding viewpoint) + +### 5.1 The same (high-level mathematical pattern) + +- **Cyclotomic rings + automorphisms are the bridge.** + - SuperNeo: an automorphism/linear transform makes \(\mathrm{ct}(\bar{a}\cdot b)\) become \(\langle a,b\rangle\) (SuperNeo Thm. 3; `docs/superneo.pdf`, p. 23: “-- 23 of 60 --”). + - Hachi: trace + automorphisms make \(\mathrm{Tr}_H(\psi(a)\cdot \sigma_{-1}(\psi(b)))\) become \(\langle a,b\rangle\) (Hachi Thm. 1; `paper/hachi.pdf`, p. 5: “-- 5 of 33 --”). + +- **Field-native sum-check is the goal.** + - SuperNeo explicitly targets “field-native arithmetic” where “sum-check and norm checks run purely over a small field” (SuperNeo Abstract; `docs/superneo.pdf`, p. 1: “-- 1 of 60 --”). + - Hachi’s verifier similarly reduces to sum-check over \(\mathbb{F}_{q^k}\) after ring switching (Hachi §1.3; `paper/hachi.pdf`, p. 6–7: “-- 6 of 33 --”, “-- 7 of 33 --”). + +- **Linearity matters.** + Both constructions rely on the fact that the commitment operation is linear in the ring/module, and they build embeddings/reductions so the *claimed evaluations* transform linearly under the same combinations (SuperNeo Thm. 5; `docs/superneo.pdf`, p. 24: “-- 24 of 60 --”). + +### 5.2 The different (what each paper is optimizing for) + +- **Direction of “embedding”:** + - **SuperNeo**: embeds *field witnesses* into *ring vectors* to make Ajtai commitments “pay-per-bit” and folding-friendly. + - **Hachi**: embeds *extension-field evaluation statements* into *ring statements* (and back), to make PCS verification fast via sum-check. + +- **Primary object being preserved:** + - **SuperNeo**: preserves **norm** (for binding/pay-per-bit) *and* preserves **evaluation homomorphism** (for folding). + - **Hachi**: preserves the **truth of evaluation claims** (over \(\mathbb{F}_{q^k}\)) when translated into ring relations; norm constraints are handled via sum-check after ring switching. + +- **Ring family emphasis:** + - **Hachi**’s core ring is explicitly power-of-two cyclotomic \(X^d+1\) (Hachi Abstract; `paper/hachi.pdf`, p. 1: “-- 1 of 33 --”). + - **SuperNeo** broadens to more general cyclotomics to support fields like Goldilocks without “full splitting” issues, and explicitly mentions supporting trinomials in parameter sets (SuperNeo intro + parameters discussion; `docs/superneo.pdf`, p. 1–4 and later “concrete parameters” sections). + +- **Protocol role:** + - **SuperNeo** needs an embedding compatible with **folding** (random linear combination of instances/commitments). + - **Hachi** needs a reduction compatible with **PCS recursion and verifier-time reduction** (ring switching + sum-check). + +### 5.3 Are the embeddings “the same”, mathematically? + +It depends what you mean by “same”. There are two distinct layers: + +- **Layer 1: the raw embedding map(s) \(F^{dn}\to R_F^n\) vs \(F_{q^k}\to R_q\)** + These are **not the same functions**. + - SuperNeo’s core embedding is literally the **coefficient embedding** (Def. 7 in their §5), i.e. *place field coordinates into ring coefficients* (`docs/superneo.pdf`, “Definition 7”; see the quote range in this repo at lines 1222–1238 of the extracted text). + - Hachi’s key “field embedding” is to realize \(F_{q^k}\) as a **fixed subfield** \(R_q^H \subseteq R_q\), then use a **basis-dependent bijection** \(\psi:(R_q^H)^{d/k}\to R_q\) (`paper/hachi.pdf`, lines 290–300 in the extracted text). + +- **Layer 2: the algebraic *mechanism* (automorphisms/trace giving inner products, and linearity of evaluations)** + At this layer, they are **the same underlying idea**: both are exploiting a canonical bilinear pairing on cyclotomic rings derived from Galois automorphisms (e.g. \(\sigma_{-1}\)) plus a linear functional (constant term or trace). + - SuperNeo packages it as “there exists a linear transform \(T\) so that \(\mathrm{ct}(T(a)\cdot b)=\langle a,b\rangle\)” (Thm. 3). + - Hachi packages it as “there exists a bijection \(\psi\) so that \(\mathrm{Tr}_H(\psi(a)\cdot\sigma_{-1}(\psi(b)))=(d/k)\langle a,b\rangle\)” (Thm. 1). + +Concretely: SuperNeo’s \(T\) is (mathematically) the **inverse Gram operator** of the ring’s trace/constant-term pairing *expressed in the coefficient basis*. Hachi’s \(\psi\) is a **choice of basis** that identifies \(R_q\) as a free module over a subfield \(R_q^H\cong F_{q^k}\), so that the same pairing looks like a scaled dot product over that subfield. + +So the correct crisp statement is: + +- **Not identical as maps** (different domains/codomains and different basis choices), but +- **equivalent in the sense that both instantiate the same cyclotomic “automorphism + linear functional = inner product” backbone**, just presented in different coordinates. + +--- + +## 6) Practical mental model (how to read both papers through one lens) + +Both papers can be read as building “interfaces” between three layers: + +1. **Ring/module commitment layer** (Ajtai / Module-SIS commitments) +2. **Field arithmetic layer** (sum-check, equality tests, low-degree checks) +3. **Embedding/reduction layer** (the math glue) + +SuperNeo’s embedding layer is primarily: + +- coefficient embedding \(F^{d n_R} \leftrightarrow R_F^{n_R}\), +- inner-product transform + constant-term extraction, +- evaluation homomorphism for ring-linear combinations. + +Hachi’s embedding layer is primarily: + +- fixed-field embedding \( \mathbb{F}_{q^k} \cong R_q^H \subseteq R_q\), +- trace + automorphisms for inner products, +- ring-switching by evaluation \(X=\alpha\) to move verifier work into \( \mathbb{F}_{q^k}\). + +If you want a single phrase: + +> **SuperNeo**: “make folding happen over small fields even though commitments live over rings.” +> **Hachi**: “make PCS verification happen over extension fields even though commitments live over rings.” + + +--- + +## 7) SuperNeo's bilinear form, explicitly identified + +SuperNeo's Theorem 3 states existence of a linear transform \(T\) but never identifies it as a Galois automorphism. Here we close that gap. + +### 7.1 Claim: \(T = \sigma_{-1}\) for \(\Phi(X) = X^d + 1\) + +In \(R_q = F_q[X]/(X^d+1)\), we have \(X^d = -1\), so \(X^{-1} = -X^{d-1}\). The automorphism \(\sigma_{-1}: X \mapsto X^{-1}\) acts on monomials as: + +\[ +\sigma_{-1}(X^0) = 1, \qquad \sigma_{-1}(X^i) = X^{-i} = -X^{d-i} \;\text{ for } 1 \le i \le d-1. +\] + +So on a polynomial \(a(X) = \sum_i a_i X^i\): + +\[ +\sigma_{-1}(a) = a_0 - \sum_{i=1}^{d-1} a_i\, X^{d-i}. +\] + +This is *exactly* the transform \(\bar{a}(X)\) from §3.3.1(A). + +### 7.2 Why the Gram matrix is the identity + +Consider the bilinear form \(B(a,b) = \mathrm{ct}(\sigma_{-1}(a)\cdot b)\) in the monomial basis \(\{1, X, \dots, X^{d-1}\}\). The Gram matrix is: + +\[ +G_{ij} = \mathrm{ct}\big(X^{j-i} \bmod (X^d+1)\big). +\] + +For \(j = i\): \(\mathrm{ct}(1) = 1\). + +For \(j \ne i\): the exponent \(j-i\) satisfies \(-(d-1) \le j-i \le d-1\) and \(j-i \ne 0\). +- If \(j > i\): \(\mathrm{ct}(X^{j-i}) = 0\) since \(1 \le j-i \le d-1\). +- If \(j < i\): \(X^{j-i} = X^{j-i+d}\cdot X^{-d} = -X^{d+j-i}\) where \(1 \le d+j-i \le d-1\), so \(\mathrm{ct}(-X^{d+j-i}) = 0\). + +Therefore \(G = I\) (the identity matrix). This means: + +\[ +\mathrm{ct}(\sigma_{-1}(a)\cdot b) = \mathrm{cf}(a)^\top\, I\, \mathrm{cf}(b) = \langle \mathrm{cf}(a),\,\mathrm{cf}(b)\rangle. +\] + +No Gram correction is needed. The pairing \((\sigma_{-1}, \mathrm{ct})\) gives the standard dot product directly. + +### 7.3 For the trinomial \(\Phi_{81}(X) = X^{54}+X^{27}+1\), the Gram matrix is NOT the identity + +Here \(X^{81} = 1\) but \(X^{27} \ne 1\) in \(R_F\). So \(X^{-27} = X^{54} = -X^{27}-1\), which gives: + +\[ +\mathrm{ct}(X^{-27}) = \mathrm{ct}(-X^{27}-1) = -1. +\] + +The Gram matrix \(G_{ij} = \mathrm{ct}(X^{j-i} \bmod \Phi_{81})\) therefore has off-diagonal entries \(G_{i,\,i+27} = -1\) (and their transposes). The inner-product transform for the trinomial (§3.3.1(B)) is \(T = G^{-1}\circ \sigma_{-1}\), which is why its formula involves sums like \(-(a_{27-i}+a_{54-i})\) rather than a simple sign flip. + +### 7.4 Bottom line + +SuperNeo's bilinear form is: + +\[ +B_{\mathrm{SuperNeo}}(a,b) = \mathrm{ct}\big(\sigma_{-1}(a)\cdot b\big) \in F_q. +\] + +Hachi's bilinear form is: + +\[ +B_{\mathrm{Hachi}}(a,b) = \mathrm{Tr}_H\big(a\cdot \sigma_{-1}(b)\big) \in R_q^H \cong \mathbb{F}_{q^k}. +\] + +Same involution \(\sigma_{-1}\). Different linear functional (\(\mathrm{ct}\) vs \(\mathrm{Tr}_H\)). Different target (\(F_q\) vs \(\mathbb{F}_{q^k}\)). + +--- + +## 8) Classification of non-degenerate bilinear forms on cyclotomic rings + +### 8.1 The general template + +For a cyclotomic ring \(R_q = F_q[X]/(\Phi(X))\) of degree \(d\), with Galois group \(\mathrm{Gal} = (\mathbb{Z}/\eta\mathbb{Z})^\times\) acting by \(\sigma_i: X\mapsto X^i\), the general bilinear form from an automorphism + linear functional is: + +\[ +B_{\sigma,\lambda}(a,b) = \lambda\big(a\cdot \sigma(b)\big) +\] + +where \(\sigma \in \mathrm{Gal}\) and \(\lambda: R_q \to T\) is a linear functional into some target \(T\). + +### 8.2 When is this non-degenerate? (Frobenius algebra theory) + +\(R_q = F_q[X]/(\Phi(X))\) is a **Frobenius algebra** over \(F_q\) (since cyclotomic polynomials are squarefree). This gives a clean classification: + +**Theorem (Frobenius classification)**: A linear functional \(\lambda: R_q \to F_q\) makes the bilinear form \(B(a,b) = \lambda(a\cdot b)\) non-degenerate if and only if \(\lambda\) is a **generating functional**, meaning the induced map \(R_q \to R_q^\vee\) (sending \(a\) to the functional \(b\mapsto \lambda(ab)\)) is an isomorphism. + +Moreover: **the set of all generating functionals is \(\{\lambda(g\,\cdot\,{-}) : g \in R_q^\times\}\)**. Once you have one non-degenerate form, all others come from pre-multiplying by a unit. + +For the form \(B_{\sigma,\lambda}(a,b)=\lambda(a\cdot\sigma(b))\), non-degeneracy reduces to: \(\lambda\circ(\text{mult by }\sigma(\cdot))\) is generating. Since \(\sigma\) is an automorphism (hence maps units to units), \(B_{\sigma,\lambda}\) is non-degenerate iff \(\lambda\) is generating. + +### 8.3 Complete menu of linear functionals + +The dual space \(R_q^\vee = \mathrm{Hom}_{F_q}(R_q, F_q)\) is \(d\)-dimensional. Here are the "natural" families: + +#### (A) Coefficient extraction: \(\mathrm{ct}_j(a) = [X^j](a)\) + +Target: \(F_q\). There are \(d\) such functionals (one per coefficient index \(j\)). + +\(\mathrm{ct} = \mathrm{ct}_0\) is the constant term. For \(\Phi = X^d+1\) with \(\sigma_{-1}\), the Gram matrix of \(\mathrm{ct}_0\) is the identity (§7.2). Other \(\mathrm{ct}_j\) give non-degenerate forms too, but with permuted/signed Gram matrices. + +*Generating?* Yes for many cyclotomics including \(X^d+1\). Can fail for pathological factorizations. + +#### (B) Absolute trace: \(\mathrm{Tr}_{R_q/F_q}(a) = \sum_{\sigma\in\mathrm{Gal}} \sigma(a)\) + +Target: \(F_q\). This is the canonical choice from algebraic number theory. + +*Generating?* Always yes for separable algebras (which cyclotomic rings are). + +Relation to \(\mathrm{ct}\): since both are generating, \(\mathrm{Tr}(a) = \mathrm{ct}(g\cdot a)\) for some explicit unit \(g\in R_q^\times\) (related to the "different ideal" of the cyclotomic extension). + +#### (C) Relative/partial trace: \(\mathrm{Tr}_H(a) = \sum_{\sigma\in H}\sigma(a) \in R_q^H\) + +Target: \(R_q^H \cong \mathbb{F}_{q^k}\) (the fixed subfield under a subgroup \(H \subseteq \mathrm{Gal}\)). + +This is what Hachi uses. The key upgrade: the target is an extension field \(\mathbb{F}_{q^k}\), not just \(F_q\). You get an \(\mathbb{F}_{q^k}\)-valued bilinear form, which means inner products live natively in the sumcheck domain. + +*Generating?* As a map \(R_q \to R_q^H\), yes: the induced \(R_q^H\)-bilinear form on \(R_q\) (viewed as a free \(R_q^H\)-module of rank \(d/k\)) is non-degenerate. + +Trace tower: \(\mathrm{Tr}_{R_q/F_q} = \mathrm{Tr}_{R_q^H/F_q}\circ \mathrm{Tr}_H\). So the relative trace "refines" the absolute trace. + +#### (D) Evaluation at a root: \(\mathrm{ev}_\zeta(a) = a(\zeta)\) for \(\zeta\) a root of \(\Phi\) in some \(\mathbb{F}_{q^t}\) + +Target: \(\mathbb{F}_{q^t}\). + +*Generating?* If \(\Phi\) is irreducible over \(F_q\) (so \(R_q \cong \mathbb{F}_{q^d}\)), every nonzero functional is generating. If \(\Phi\) factors, \(\mathrm{ev}_\zeta\) projects onto one CRT slot and annihilates the others, so it is **not** generating. + +**Fatal for norms**: the evaluation map is equivalent to NTT, which is not norm-preserving. This is exactly the problem SuperNeo identifies with prior approaches (§2.2). + +#### (E) CRT-slot projection: \(\pi_s: R_q \twoheadrightarrow \mathbb{F}_{q^t}\) + +When \(\Phi = \prod_i \phi_i\) factors over \(F_q\) (each \(\phi_i\) irreducible of degree \(t\)), \(R_q \cong \prod_i \mathbb{F}_{q^t}\). The projection \(\pi_s\) onto one factor is: + +Target: \(\mathbb{F}_{q^t}\). + +*Generating?* **No** --- it annihilates the other factors. + +**Same norm problem as (D)**: the CRT isomorphism is an NTT-like map. + +#### (F) Arbitrary \(F_q\)-linear combinations of the above + +Any \(\lambda = \sum_j c_j \,\mathrm{ct}_j\) for scalars \(c_j \in F_q\) gives a functional. It is generating iff the associated element \(g = \sum_j c_j X^j \in R_q\) is a unit. + +### 8.4 Practical takeaway + +If you want **norm preservation**, you must stay in "coefficient-land" (functionals A, B, C, F). Evaluation-based functionals (D, E) destroy norms. + +If you want **extension-field-valued inner products** (for native sumcheck over \(\mathbb{F}_{q^k}\)), the only natural coefficient-land option is **(C) the relative trace** \(\mathrm{Tr}_H\). + +--- + +## 9) The evaluation homomorphism (SuperNeo Theorem 5), spelled out + +This is the key property that makes SuperNeo's folding work. The paper states it abstractly; here is the fully explicit version. + +### 9.1 Setup + +Fix: +- Base field \(F = \mathbb{F}_q\), extension \(K = \mathbb{F}_{q^2}\) (or any extension with \(1/|K| = \mathrm{negl}(\lambda)\)). +- Cyclotomic ring \(R_F = F[X]/(\Phi(X))\), \(R_K = K[X]/(\Phi(X))\), degree \(d\). +- A CCS matrix \(M \in F^{m\times n_F}\) with \(n_F = d\cdot n_R\). +- Witness vectors \(z_1, \dots, z_\ell \in F^{n_F}\), coefficient-embedded as \(\mathbf{z}_i \in R_F^{n_R}\). +- Folding challenges \(\rho_1, \dots, \rho_\ell \in C \subset R_F\) (short ring elements). +- An evaluation point \(r \in K^{\log m}\) (from a prior sumcheck round). +- Any \(R_F\)-module homomorphisms \(L: R_F^{n_R}\to C\) (commitment) and \(\mathrm{Lin}: R_F^{n_R}\to R_F^{n_{R,\mathrm{in}}}\) (input extraction). + +### 9.2 Per-instance "lifted" claims + +For each instance \(i\in[\ell]\), define three objects: + +- **Commitment**: \(c_i := L(\mathbf{z}_i) \in C\) (Ajtai commitment of the ring vector). +- **Input**: \(x_i := \mathrm{Lin}(\mathbf{z}_i) \in R_F^{n_{R,\mathrm{in}}}\) (public input part). +- **Lifted evaluation**: \(y_i := \overline{M}\,\widetilde{\mathbf{z}_i}(r) \in R_K\) (apply the transformed matrix \(\overline{M}\) to the multilinear extension of the ring vector \(\mathbf{z}_i\), evaluated at \(r\)). + +The lifted evaluation \(y_i\) is a **ring element in \(R_K\)** (not a field element in \(K\)). Its constant term recovers the field-level evaluation (by Theorem 4 / Remark 2): + +\[ +\mathrm{ct}(y_i) = M\,\widetilde{z_i}(r) \in K. +\] + +In other words: \(y_i\) has \(d\) coefficients in \(K\); the constant coefficient is the "real" field evaluation; the other \(d-1\) coefficients carry extra information that is needed for folding consistency. + +### 9.3 The evaluation homomorphism (Theorem 5) + +**Statement**: Define the folded objects: + +\[ +\mathbf{z} := \sum_{i\in[\ell]} \rho_i\,\mathbf{z}_i \in R_F^{n_R}, \quad +c := \sum_{i\in[\ell]} \rho_i\, c_i \in C, \quad +x := \sum_{i\in[\ell]} \rho_i\, x_i \in R_F^{n_{R,\mathrm{in}}}, +\quad +y := \sum_{i\in[\ell]} \rho_i\, y_i \in R_K. +\] + +Then **all three of the following hold simultaneously**: + +1. **Commitment homomorphism**: \(c = L(\mathbf{z})\). +2. **Lifted-evaluation homomorphism**: \(y = \overline{M}\,\widetilde{\mathbf{z}}(r)\). +3. **Field-evaluation consistency**: \(\mathrm{ct}(y) = M\,\widetilde{z}(r) \in K\) where \(z = \sum_i \rho_i z_i \in F^{n_F}\). + +(1) holds because \(L\) is \(R_F\)-linear. (2) holds because multilinear extension is \(K\)-linear and matrix multiplication is \(R_K\)-linear. (3) follows from applying \(\mathrm{ct}\) to (2) and invoking Theorem 4. + +### 9.4 Why ring-level claims are essential (the subtle point) + +The folding challenge \(\rho_i \in R_F\) is a **ring** scalar, not a field scalar. At the field level, multiplying by \(\rho_i\) acts on the coefficient vector \(z_i \in F^{n_F}\) by **convolution** (polynomial multiplication scrambles coefficients), not by coordinate-wise scaling. + +This means: +- **At the field level**: after folding, the individual CCS constraints \(Mz = 0\) do NOT hold in any simple sense. The "meaning" of the folded witness as a CCS satisfier is lost. +- **At the ring level**: the commitment, lifted evaluation claims, and input extraction all fold cleanly. The accumulator stores \((c, x, r, y)\) as ring-level objects. + +This is why SuperNeo's accumulator relation is **CE (CCS Evaluation)**, not CCS itself: it stores ring-level evaluation claims \(y_j \in R_K\), and the decider checks them by opening the commitment and verifying ring-level equalities. The \(\mathrm{ct}\) extraction (field-level check) happens only once, at decision time. + +### 9.5 Contrast with Hachi + +Hachi does not need an "evaluation homomorphism" in SuperNeo's sense because Hachi does not fold multiple instances. Instead: + +- Hachi's protocol reduces a *single* opening claim by splitting the coefficient table into blocks, folding with a sparse ring challenge \(c\), and then using ring-switching + sumcheck. +- The analogue of "linearity under ring multiplication" in Hachi is the fold relations (Eq. (18)--(19)): \(a^\top G_{2^m} z = (c^\top\otimes G_1)\hat{w}\) and \(Az = (c^\top\otimes G_{n_A})\hat{t}\), which are linear in the sparse challenge \(c\). + +Both papers exploit the same algebraic fact (ring multiplication is \(R_F\)-linear), but for different protocol purposes: folding (SuperNeo) vs opening reduction (Hachi). + +--- + +## 10) Unified embedding framework (best of both worlds) + +### 10.1 The core observation + +The two bilinear forms use the **same involution** \(\sigma_{-1}\) and differ only in the "readout functional": + +| | SuperNeo | Hachi | +|---|---|---| +| Form | \(\mathrm{ct}(\sigma_{-1}(a)\cdot b)\) | \(\mathrm{Tr}_H(\sigma_{-1}(a)\cdot b)\) | +| Target | \(F_q\) | \(\mathbb{F}_{q^k}\) | +| Gram (for \(X^d+1\)) | \(I\) (identity) | Non-trivial (absorbed into \(\psi\)) | + +These are **compatible**: they use the same product \(\sigma_{-1}(a)\cdot b\) in \(R_q\) and differ only in what you project onto. The trace tower connects them: + +\[ +\mathrm{Tr}_{R_q/F_q} = \mathrm{Tr}_{\mathbb{F}_{q^k}/F_q} \circ \mathrm{Tr}_H. +\] + +### 10.2 The unified pairing abstraction + +**Definition (Cyclotomic Pairing).** Given: +- cyclotomic ring \(R_q\) of degree \(d\), +- a subfield \(S \cong \mathbb{F}_{q^k}\) realized as \(R_q^H \subseteq R_q\) (or \(S = F_q\) as the trivial case \(k=1\)), +- the involution \(\sigma_{-1} \in \mathrm{Aut}(R_q)\), +- the relative trace \(\mathrm{Tr}_H: R_q \to S\), + +define the **level-\(k\) pairing**: + +\[ +B_k(a,b) := \mathrm{Tr}_H\big(\sigma_{-1}(a)\cdot b\big) \in S \cong \mathbb{F}_{q^k}. +\] + +Special cases: +- **\(k = 1\)** (\(H = \mathrm{Gal}\), \(S = F_q\)): \(B_1\) is (up to unit scaling) the absolute-trace pairing. Closely related to SuperNeo's \(\mathrm{ct}\)-based form. +- **\(k = d\)** (\(H = \{1\}\), \(S = R_q\)): \(B_d(a,b) = \sigma_{-1}(a)\cdot b\) (the ring product itself, no projection). +- **\(1 < k < d\)** (\(H\) a proper subgroup): Hachi's setting. \(B_k\) is an \(\mathbb{F}_{q^k}\)-valued non-degenerate bilinear form on the free \(\mathbb{F}_{q^k}\)-module \(R_q \cong \mathbb{F}_{q^k}^{d/k}\). + +### 10.3 Combined embedding pipeline + +``` +Field witness z ∈ F_q^{d·n_R} + │ coefficient embedding (SuperNeo Def. 7, norm-preserving) + ▼ +Ring vector z ∈ R_q^{n_R} ← ‖z‖_∞ preserved + │ Ajtai commit: c = Az + ▼ +Commitment c ∈ R_q^κ + │ + ├── MODE A (SuperNeo-style folding): + │ Use B_1: ct(σ_{-1}(a)·b) = ⟨cf(a), cf(b)⟩ ∈ F_q + │ → sumcheck over K = F_{q²}, fold with ρ ∈ R_F, track ring-level claims y ∈ R_K + │ + └── MODE B (Hachi-style PCS opening): + Use B_k: Tr_H(σ_{-1}(a)·b) ∈ F_{q^k} + → ring-switch at α ∈ F_{q^k}, sumcheck over F_{q^k}, recurse on smaller instance +``` + +The committed object \(\mathbf{z}\in R_q^{n_R}\) is the same in both modes. The choice of functional only affects how you read out inner products and run the interactive proof. + +### 10.4 What you gain + +| Property | SuperNeo alone | Hachi alone | Unified | +|---|---|---|---| +| Norm-preserving embedding | Yes | N/A (PCS, not CCS) | Yes | +| Pay-per-bit commitments | Yes | No | Yes | +| Extension-field inner products | No (\(\mathrm{ct}\to F_q\)) | Yes (\(\mathrm{Tr}_H\to \mathbb{F}_{q^k}\)) | Yes | +| Sumcheck natively over \(\mathbb{F}_{q^k}\) | No (needs \(K\) externally) | Yes | Yes | +| Power-of-two cyclotomics | Supported | Required | Supported | +| General cyclotomics (trinomials) | Supported | Not discussed | Supported (Gram \(\ne I\)) | + +### 10.5 The Gram cost of switching from \(\mathrm{ct}\) to \(\mathrm{Tr}_H\) + +For \(\Phi = X^d+1\) with \(\sigma_{-1}\): +- \(\mathrm{ct}\): Gram = \(I\). No correction needed. Beautifully simple. +- \(\mathrm{Tr}_H\): Gram \(\ne I\). Hachi absorbs this into the packing map \(\psi:(R_q^H)^{d/k}\to R_q\) (Theorem 2, Eq. (8)). + +If you use \(\mathrm{Tr}_H\) instead of \(\mathrm{ct}\), you need: +- the packing map \(\psi\) (one \(O(d)\) linear map per block), +- its inverse \(\psi^{-1}\) (for verification). + +Both are explicit, \(O(d)\)-time, and defined once per parameter set. The Galois subgroup \(H\) and the map \(\psi\) are static public data. + +### 10.6 When you would want each mode + +- **Folding** (SuperNeo-style IVC) with pay-per-bit: use \(\mathrm{ct}\) mode. The evaluation homomorphism (§9) works cleanly because the Gram is trivial. +- **PCS opening** (Hachi-style recursive reduction): switch to \(\mathrm{Tr}_H\) mode when you need ring-switching at \(\alpha \in \mathbb{F}_{q^k}\). +- **Hybrid fold-then-open**: fold in \(\mathrm{ct}\) mode (accumulating ring-level claims), then at the end, open the accumulated commitment using \(\mathrm{Tr}_H\) mode for efficient PCS verification. + +The handoff between modes is seamless because the committed object \(\mathbf{z}\in R_q^{n_R}\) is the same either way. diff --git a/docs/FIELD_OPS_PERF.md b/docs/FIELD_OPS_PERF.md new file mode 100644 index 00000000..13a90846 --- /dev/null +++ b/docs/FIELD_OPS_PERF.md @@ -0,0 +1,237 @@ +# Field Operations Performance + +Benchmark results for scalar and packed (SIMD) field arithmetic across +platforms. All numbers are **element throughput** (median) reported by +`criterion` over 4096-element arrays (`cargo bench --bench field_arith`). + +## Prime selection + +All primes (except M31) are pseudo-Mersenne: `q = 2^k − c` where `c` is +the **smallest positive offset** such that `q` is prime and `q ≡ 5 +(mod 8)`. The congruence `q ≡ 5 (mod 8)` is required so that the +cyclotomic ring `Z_q[X]/(X^d + 1)` splits fully via NTT when `d` is a +power of two (equivalently, `−1` is a quadratic residue but not a quartic +residue mod `q`). + +The constraint `q ≡ 5 (mod 8)` forces `c ≡ 2^k − 5 (mod 8)`: + +| k mod 8 | required c mod 8 | examples | +|---------|------------------|----------| +| 0 | 3 | k=24,32,40,48,56,64,128 | +| 6 | 1 | k=30 | +| 7 | 2 | k=31 | + +For each `k`, smaller candidates were checked and found composite. For +instance at `k = 31` (`c ≡ 3 mod 8`): `2^31 − 3 = 5 × 429496729`, +`2^31 − 11 = 3 × 715827879`, so the first prime is `2^31 − 19`. + +**M31** (`2^31 − 1`, Mersenne prime, `q ≡ 7 mod 8`) is included for +comparison with plonky3 even though it does not satisfy `q ≡ 5 (mod 8)`. + +## Primes benchmarked + +| Label | Modulus | Offset | Rust type | SIMD width | +|-------|---------|--------|-----------|------------| +| fp32_24b | `2^24 − 3` | 3 | `Fp32` | AVX-512: 16, AVX2: 8, NEON: 4 | +| fp32_30b | `2^30 − 35` | 35 | `Fp32` | AVX-512: 16, AVX2: 8, NEON: 4 | +| fp32_31b | `2^31 − 19` | 19 | `Fp32` | AVX-512: 16, AVX2: 8, NEON: 4 | +| fp32_m31 | `2^31 − 1` | 1 | `Fp32` | AVX-512: 16, AVX2: 8, NEON: 4 | +| fp32_32b | `2^32 − 99` | 99 | `Fp32` | AVX-512: 16, AVX2: 8, NEON: 4 | +| fp64_40b | `2^40 − 195` | 195 | `Fp64` | AVX-512: 8, AVX2: 4, NEON: 2 | +| fp64_48b | `2^48 − 59` | 59 | `Fp64` | AVX-512: 8, AVX2: 4, NEON: 2 | +| fp64_56b | `2^56 − 27` | 27 | `Fp64` | AVX-512: 8, AVX2: 4, NEON: 2 | +| fp64_64b | `2^64 − 59` | 59 | `Fp64` | AVX-512: 8, AVX2: 4, NEON: 2 | +| fp128 | `2^128 − 275` | 275 | `Fp128` | AVX-512: 8 (SoA), AVX2: 4 (SoA), NEON: 2 (SoA) | + +--- + +## AMD Zen 5 (Ryzen 9950X / leopard) + +Backend: **AVX-512** (16-wide Fp32, 8-wide Fp64, 8-wide Fp128 SoA with +vectorized add/sub and scalar-per-lane mul). +`RUSTFLAGS='-C target-cpu=native'`, nightly toolchain. + +### Scalar (`throughput/`) + +| Field | mul | add | +|-------|-----|-----| +| fp32_24b | 1.224 Gelem/s | 2.050 Gelem/s | +| fp32_30b | 1.220 Gelem/s | 2.026 Gelem/s | +| fp32_31b | 1.212 Gelem/s | 1.866 Gelem/s | +| fp32_m31 | 1.355 Gelem/s | 1.993 Gelem/s | +| fp32_32b | 1.219 Gelem/s | 1.955 Gelem/s | +| fp64_40b | 1.018 Gelem/s | 2.074 Gelem/s | +| fp64_48b | 1.021 Gelem/s | 2.073 Gelem/s | +| fp64_56b | 0.945 Gelem/s | 2.060 Gelem/s | +| fp64_64b | 0.927 Gelem/s | 1.840 Gelem/s | +| fp128 | 0.452 Gelem/s | 1.127 Gelem/s | + +### Packed (`packed_throughput/`) + +| Field | mul | add | sub | +|-------|-----|-----|-----| +| fp32_24b | 5.362 Gelem/s | 12.76 Gelem/s | 12.74 Gelem/s | +| fp32_30b | 6.145 Gelem/s | 13.53 Gelem/s | 13.55 Gelem/s | +| fp32_31b | 6.187 Gelem/s | 13.53 Gelem/s | 13.54 Gelem/s | +| fp32_m31 | 6.943 Gelem/s | 13.56 Gelem/s | 13.50 Gelem/s | +| fp32_32b | 6.785 Gelem/s | 13.02 Gelem/s | 12.66 Gelem/s | +| fp64_40b | 1.961 Gelem/s | 5.847 Gelem/s | 5.861 Gelem/s | +| fp64_48b | 1.942 Gelem/s | 5.852 Gelem/s | 5.853 Gelem/s | +| fp64_56b | 1.937 Gelem/s | 5.847 Gelem/s | 5.796 Gelem/s | +| fp64_64b | 1.742 Gelem/s | 5.278 Gelem/s | 5.760 Gelem/s | +| fp128 | 0.284 Gelem/s | 2.314 Gelem/s | 3.175 Gelem/s | + +### Packed speedup over scalar + +| Field | mul | add | +|-------|-----|-----| +| fp32_24b | **4.4x** | **6.2x** | +| fp32_30b | **5.0x** | **6.7x** | +| fp32_31b | **5.1x** | **7.3x** | +| fp32_m31 | **5.1x** | **6.8x** | +| fp32_32b | **5.6x** | **6.7x** | +| fp64_40b | **1.9x** | **2.8x** | +| fp64_48b | **1.9x** | **2.8x** | +| fp64_56b | **2.0x** | **2.8x** | +| fp64_64b | **1.9x** | **2.9x** | +| fp128 | **0.6x** | **2.1x** | + +### Sumcheck MACC (`packed_sumcheck_mix/`) + +`acc += eq[i] * poly[i]` loop (dominant inner loop in sumcheck provers). + +| Field | MACC | % of pure mul | +|-------|------|---------------| +| fp32_24b | 4.764 Gelem/s | 89% | +| fp32_30b | 5.352 Gelem/s | 87% | +| fp32_31b | 5.351 Gelem/s | 86% | +| fp32_m31 | 6.097 Gelem/s | 88% | +| fp32_32b | 3.409 Gelem/s | 50% | +| fp64_40b | 1.488 Gelem/s | 76% | +| fp64_48b | 1.492 Gelem/s | 77% | +| fp64_56b | 1.491 Gelem/s | 77% | +| fp64_64b | 1.141 Gelem/s | 65% | +| fp128 | 0.323 Gelem/s | 114% | + +--- + +## Apple M4 Pro (macOS / aarch64) + +Backend: **NEON** (4-wide Fp32, 2-wide Fp64, 2-wide Fp128 SoA). +`RUSTFLAGS='-C target-cpu=native'`, nightly toolchain. + +### Scalar (`throughput/`) + +| Field | mul | add | +|-------|-----|-----| +| fp32_24b | 1.129 Gelem/s | 1.426 Gelem/s | +| fp32_30b | 1.133 Gelem/s | 1.425 Gelem/s | +| fp32_31b | 1.043 Gelem/s | 1.433 Gelem/s | +| fp32_m31 | 1.319 Gelem/s | 1.435 Gelem/s | +| fp32_32b | 1.135 Gelem/s | 1.423 Gelem/s | +| fp64_40b | 0.871 Gelem/s | 1.446 Gelem/s | +| fp64_48b | 0.886 Gelem/s | 1.385 Gelem/s | +| fp64_56b | 0.891 Gelem/s | 1.442 Gelem/s | +| fp64_64b | 0.923 Gelem/s | 1.443 Gelem/s | +| fp128 | 0.444 Gelem/s | 0.938 Gelem/s | + +### Packed (`packed_throughput/`) + +| Field | mul | add | sub | +|-------|-----|-----|-----| +| fp32_24b | 3.717 Gelem/s | 5.272 Gelem/s | 5.278 Gelem/s | +| fp32_30b | 3.719 Gelem/s | 5.281 Gelem/s | 5.275 Gelem/s | +| fp32_31b | 3.719 Gelem/s | 5.283 Gelem/s | 5.268 Gelem/s | +| fp32_m31 | 3.720 Gelem/s | 5.263 Gelem/s | 5.263 Gelem/s | +| fp32_32b | 2.524 Gelem/s | 5.296 Gelem/s | 5.253 Gelem/s | +| fp64_40b | 1.253 Gelem/s | 2.648 Gelem/s | 2.645 Gelem/s | +| fp64_48b | 1.254 Gelem/s | 2.650 Gelem/s | 2.643 Gelem/s | +| fp64_56b | 1.255 Gelem/s | 2.632 Gelem/s | 2.650 Gelem/s | +| fp64_64b | 1.399 Gelem/s | 2.639 Gelem/s | 2.602 Gelem/s | +| fp128 | 0.480 Gelem/s | 1.724 Gelem/s | 2.107 Gelem/s | + +### Packed speedup over scalar + +| Field | mul | add | +|-------|-----|-----| +| fp32_24b | **3.3x** | **3.7x** | +| fp32_30b | **3.3x** | **3.7x** | +| fp32_31b | **3.6x** | **3.7x** | +| fp32_m31 | **2.8x** | **3.7x** | +| fp32_32b | **2.2x** | **3.7x** | +| fp64_40b | **1.4x** | **1.8x** | +| fp64_48b | **1.4x** | **1.9x** | +| fp64_56b | **1.4x** | **1.8x** | +| fp64_64b | **1.5x** | **1.8x** | +| fp128 | **1.1x** | **1.8x** | + +### Sumcheck MACC (`packed_sumcheck_mix/`) + +| Field | MACC | % of pure mul | +|-------|------|---------------| +| fp32_24b | 2.652 Gelem/s | 71% | +| fp32_30b | 2.660 Gelem/s | 72% | +| fp32_31b | 2.662 Gelem/s | 72% | +| fp32_m31 | 2.661 Gelem/s | 72% | +| fp32_32b | 1.991 Gelem/s | 79% | +| fp64_40b | 0.990 Gelem/s | 79% | +| fp64_48b | 0.991 Gelem/s | 79% | +| fp64_56b | 0.993 Gelem/s | 79% | +| fp64_64b | 0.795 Gelem/s | 57% | +| fp128 | 0.450 Gelem/s | 94% | + +--- + +## Notes + +### Zen 5 AVX-512 observations + +- **Fp32 add/sub** saturate at ~13–13.5 Gelem/s, close to 1 cycle per + 16-wide vector at 5 GHz. +- **M31 is the fastest Fp32 prime** for packed mul (6.94 Gelem/s) and + sumcheck MACC (6.10 Gelem/s), because `C = 1` minimizes the reduction + chain. +- **Fp32 mul** is latency-bound by the 2-fold Solinas reduction chain + (~18 cycles per vector). +- **Fp64 sub-word** primes (40b, 48b, 56b) show nearly identical packed + mul throughput (~1.94 Gelem/s), since the vectorized schoolbook + multiply + Solinas reduction dominates regardless of bit-width. +- **Fp64 64b** is slower than sub-word variants due to multi-stage + overflow tracking in the Solinas reduction. +- **Fp32 32b sumcheck MACC** drops to 50% of pure mul, the worst ratio, + because the carry-based add correction creates additional dependencies + in the `acc += eq * poly` loop. +- **Fp128** packed backend now uses SoA layout (8-wide) with vectorized + add/sub via `__m512i`. Add improved **2.1x** (1.13 → 2.31 Gelem/s), + sub improved **2.4x** (1.34 → 3.18 Gelem/s). Mul remains scalar + per-lane and regressed 0.6x (0.45 → 0.28 Gelem/s) due to SoA + pack/unpack overhead. Sumcheck MACC is -7% (0.35 → 0.32 Gelem/s); + MACC exceeds pure-mul throughput (114%) because the accumulation loop + avoids the SoA store overhead that the throughput benchmark incurs. + +### M4 Pro NEON observations + +- **NEON is 4-wide for Fp32, 2-wide for Fp64/Fp128**, so maximum + speedup is 4x and 2x respectively (vs 16x/8x for AVX-512). +- **Fp32 packed mul** is uniform at ~3.72 Gelem/s for all sub-word + primes (24b–31b, including M31), unlike Zen 5 where M31 is notably + faster. The 4-wide NEON `vmull_u32` + reduction is the bottleneck. +- **Fp32 32b packed mul** drops to 2.52 Gelem/s (carry-based path). +- **Fp64 packed mul** ~1.25 Gelem/s for all sub-word primes, ~1.40 + for 64b — the 64b prime is *faster* on NEON, opposite to Zen 5. +- **Fp128 packed add** 1.72 Gelem/s = **1.8x** scalar speedup; sub + 2.11 Gelem/s = **2.2x**, both close to the theoretical 2x from + 2-wide `uint64x2_t`. Mul 0.48 Gelem/s ≈ **1.1x** (scalar per-lane). +- **Sumcheck MACC** is 71–72% of pure mul for sub-word Fp32, 79% for + Fp64 sub-word, and 94% for Fp128 — higher than Zen 5 ratios. + +### Reduction strategy by field width + +| Field | Reduction method | +|-------|-----------------| +| Fp32, BITS ≤ 31 | `min(t, t−P)` — single unsigned compare + blend | +| Fp32, BITS = 32 | carry-based: detect overflow, conditionally add `C`, then subtract `P` if `≥ P` | +| Fp64, BITS ≤ 62 | 2-fold Solinas in u64: `(lo & mask) + c * (lo >> k)`, repeat | +| Fp64, BITS = 64 | vectorized schoolbook 64×64→128 + multi-stage Solinas with overflow tracking | +| Fp128 add/sub | vectorized 128-bit add/sub with carry/borrow propagation via `__m512i` | +| Fp128 mul | scalar per-lane: 9-limb Solinas via u64 decomposition | diff --git a/docs/HACHI_DIGEST.md b/docs/HACHI_DIGEST.md new file mode 100644 index 00000000..c0d0328b --- /dev/null +++ b/docs/HACHI_DIGEST.md @@ -0,0 +1,1362 @@ +# Hachi Digest (for side-by-side comparison with SuperNeo) + +This file captures the parts of `paper/hachi.pdf` that are most actionable for understanding Hachi’s **parameterization** and **protocol shape**, in a compact AI-readable format. + +Primary source: `paper/hachi.pdf` (“Hachi: Efficient Lattice-Based Multilinear Polynomial Commitments over Extension Fields”). + +## Canonical parameter tuples (from paper) + +The paper gives a concrete benchmark parameter set (ℓ = 30) in Figure 9 and uses it to estimate proof size (~55KB) in §5.2. + +```yaml +hachi_parameter_tuples: + - id: l30_benchmark_first_round + witness_num_vars_ell: 30 + base_field: + q: 4294967197 # ~2^32, prime (paper Fig. 9) + ring_coeffs: "Z_q (aka F_q)" + extension_field_for_sumcheck: + k: 4 + field: "F_{q^4}" # paper §5.1, §5.4 + cyclotomic_ring: + phi: "X^d + 1" + alpha: 10 + d: 1024 + ring: "R_q = Z_q[X]/(X^d + 1)" + split_and_fold_params: + m: 10 + r: 10 + commitment_matrices_heights: + nA: 1 + nB: 1 + nD: 1 + decomposition_and_challenges: + decomposition_base_b: 16 + delta: 8 # decomposition length (Fig. 9) + tau: 4 # expansion factor for decomposing z (Fig. 9) + omega_L1: 16 + c_nonzero_coeffs_in_challenge: 16 + norm_bounds: + z_Linf_bound: 30583 + next_witness_Linf_bound: 8 + next_round_witness: + size: 226 + proof_size_estimate: + first_round_sumcheck: "~7.3KB" + adaptation_overhead: "~4.8KB" + greyhound_subproof: "~43KB" + total: "~55.1KB" + timings_reported: + verify_ms_server: 227 # paper §5.2 narrative (first round), plus Fig. 8 context + verify_ms_server_greyhound: 130 # cited as Greyhound estimate in §5.2 narrative +``` + +Notes: + +- The paper’s 55.1KB estimate is explicitly derived in §5.2 (“To conclude, the total evaluation proof can be estimated to be 7.3 + 4.8 + 43 KB = 55.1KB.”). +- This tuple is the “Hachi + compose with Greyhound” estimate for ℓ = 30, not “Hachi alone forever”; Hachi’s design explicitly allows switching to Greyhound/LaBRADOR at small witness sizes. + +## Exact counts (from the paper’s concrete section) + +- Total explicit concrete parameter tables in the paper: **1** (Figure 9; ℓ = 30). +- Concrete benchmark variable counts shown: **3** (ℓ ∈ {26, 28, 30} in Figure 8 timings), but only ℓ = 30 is fully parameterized in Figure 9. +- Unique cyclotomic family used: **1** (`X^d + 1` with `d` power of two). + +## What Hachi suggests (for its purpose) + +High-level message (from the abstract + technical overview): + +- Use **sum-check** to get fast verification, but avoid doing sum-check “over the ring”. +- Use **ring switching** (evaluate at a random α in an extension field) so that the verifier’s checks are field-native and do not require expensive \(R_q\) multiplication. +- Use a **generic reduction** to convert evaluation proofs over extension fields \(F_{q^k}\) into equivalent ring statements over \(R_q\), enabling extension-field evaluation support for lattice PCS. + +Concrete implication (paper §5.3–§5.4): + +- Hachi can pick **larger ring dimensions** (e.g. \(d=1024\)) than Greyhound’s typical \(d=64\) and still keep verification efficient; larger \(d\) helps commitment time (fewer ring mults, NTT-friendly structure) and enables very sparse challenges. + +## Fit into this repo design (easy vs harder) + +- Easy fit (already aligned): + - Hachi’s core ring family is power-of-two cyclotomic \(R_q = Z_q[X]/(X^d+1)\), which matches this repo’s existing algebra/ring direction. +- Harder pieces (protocol-level, not yet fully implemented here): + - Ring switching pipeline (lifting ring equations to \(Z_q[X]\), evaluate at random α in \(F_{q^k}\), sum-check over \(F_{q^k}\)). + - The “compose with Greyhound” handoff (treating the reduced witness as a short ring instance for an existing PCS/proof system). + +## Notation glossary (from `paper/hachi.pdf`) + +Hachi’s paper reuses common lattice-proof-system symbols; here are the ones that tend to confuse on first read. + +### Base objects + +- **\(q\)**: prime modulus; base ring/field is \(Z_q\) (paper treats \(Z_q\) and \(F_q\) interchangeably). +- **\(d = 2^\alpha\)**: cyclotomic ring dimension (power of two). +- **\(\alpha\)**: shorthand for \(\log_2 d\) (used heavily in §3’s “variable count after transformation” formulas). +- **\(R\)**: integer cyclotomic ring \(Z[X]/(X^d+1)\). +- **\(R_q\)**: cyclotomic ring mod \(q\): \(Z_q[X]/(X^d+1)\). +- **\(F_{q^k}\)**: extension field used to run sum-check with negligible soundness error. +- **\(\kappa\)**: shorthand for \(\log_2 k\) when \(k\) is a power of two (this is the \(\kappa\) used in §3.2 and in the “\(\ell-\alpha+\kappa\)” variable-count formulas). +- **\(\sigma_i\)**: Galois automorphism \(X \mapsto X^i\) on \(R\) / \(R_q\) (with \(i\in (\mathbb{Z}/2d\mathbb{Z})^\times\)). +- **\(R_q^H\)**: fixed ring under a subgroup \(H\) of automorphisms; becomes a subfield isomorphic to \(F_{q^k}\) under conditions (Lemma 1 informal in §1.3). +- **\(\mathrm{Tr}_H\)**: trace map \(R_q \to R_q^H\). +- **\(\psi\)**: an efficiently computable bijection \((R_q^H)^{d/k} \to R_q\) used to turn trace-of-product into inner products (Theorem 1 informal in §1.3). + +### Multilinear polynomials / sizes + +- **\(\ell\)**: number of variables of the multilinear polynomial (so witness length is \(2^\ell\)). +- **\(L := 2^\ell\)**: number of coefficients / evaluation-table length. + +### Split-and-fold / decomposition parameters (Figure 9) + +The paper uses **\(m, r\)** as “folding parameters” (they are *not* the same \(m\) as “#constraints” in SuperNeo’s CCS section). + +- **\(m, r\)**: split-and-fold parameters controlling the shape of the quadratic relation after one reduction step. +- **\(b\)**: decomposition base (e.g. 16 in Fig. 9). +- **\(\delta\)**: decomposition length for base witness (e.g. 8 in Fig. 9; essentially \(\lceil \log_b q\rceil\) in spirit). +- **\(\tau\)**: expansion factor / extra length parameter for decomposing intermediate vectors (e.g. 4 in Fig. 9). +- **\(\omega\)**: \(\ell_1\)-norm bound on a sparse challenge (Fig. 9). +- **\(c\)**: number of nonzero coefficients in a sparse challenge polynomial (Fig. 9). + +### Commitment matrix “heights” (Figure 9) + +Hachi uses commitment matrices \(A,B,D\) (not the same “Ajtai A” notation as in SuperNeo). + +- **\(n_A, n_B, n_D\)**: heights (row counts) of the corresponding commitment matrices in the composed relation (Fig. 9 uses all 1). + +## Protocol overview (what is proved, and what is sent) + +This is an end-to-end walkthrough of the **full Hachi protocol**, in the order it’s built in `paper/hachi.pdf`. + +### What Hachi is (PCS statement + design goal, **field-first**) + +The digest previously jumped straight to the §4 “ring PCS statement”. That is **not the full picture**: the paper’s *natural* PCS interface is the usual one where the **witness coefficients are in the base field** \(Z_q \cong F_q\) but the **evaluation point is in an extension field** \(F_{q^k}\) (because sumcheck / batching wants negligible soundness error). + +Concretely, the “true” opening statement the paper starts from (Intro + §3.2) is: + +- **Witness polynomial (true witness)**: \(f \in Z_q^{\le 1}[X_1,\dots,X_\ell]\) (equivalently \(F_q^{\le1}[\cdot]\)), with coefficient table \((f_i)_{i\in\{0,1\}^\ell}\subset Z_q\). +- **Claim (extension-field point/value)**: for a public point \(x=(x_1,\dots,x_\ell)\in F_{q^k}^\ell\), prove \(f(x)=y\in F_{q^k}\). + +Hachi is engineered so that the verifier’s heavy checking runs as **sumcheck over \(F_{q^k}\)** (fast), while commitments / Module-SIS structure remain over the cyclotomic ring \(R_q = Z_q[X]/(X^d+1)\). + +### Step 0 (paper §3.2 → §3.1): the missing “embedding” step (from **\(F_q\)-witness @ \(F_{q^k}\)-point** to a ring statement) + +The paper’s §3 is exactly the bridge from the “true” statement above to the §4 ring PCS statement. + +#### 0.A (paper §3.2): reduce **\(f\in Z_q[\cdot]\)** at **\(x\in F_{q^k}^\ell\)** to one evaluation over \(F_{q^k}\) + +Assume \(k\) is a power of two and write \(k = 2^\kappa\) (this \(\kappa\) is the one used in §3). Split variables into the first \(\kappa\) and the remaining \(\ell-\kappa\). The evaluation can be rewritten (paper Eq. (11)) as: + +\[ +y \;=\;\sum_{i\in\{0,1\}^\kappa}\Big(\prod_{t=1}^{\kappa} x_t^{i_t}\Big)\cdot y_i, +\quad\text{where}\quad +y_i \;:=\; f_{i}(\,x_{\kappa+1},\dots,x_\ell\,)\in F_{q^k}. +\] + +So the prover can send the \(k=2^\kappa\) *partial evaluations* \((y_i)_i\) (and the verifier checks the recombination); what remains is to prove each \(y_i\) is well-formed. + +Paper detail (§3.2, right after Eq. (11)): the verifier can compute \(y_{0\ldots 0}\) from the claimed \(y\) and the other \(y_i\), so in principle only **\(k-1\)** of the partial evaluations need to be transmitted. + +To make that “prove all \(y_i\)” look like **one** extension-field evaluation, §3.2 defines \(F_{q^k} := F_q[Z]/\varphi(Z)\) and builds an \((\ell-\kappa)\)-variate multilinear polynomial \(f' \in F_{q^k}[X_{\kappa+1},\dots,X_\ell]\) by *embedding* the \(k\) slices \((f_i)_i\) into the \(F_q\)-basis \(1,Z,Z^2,\dots,Z^{k-1}\) (paper §3.2): + +\[ +f'(X_{\kappa+1},\dots,X_\ell) +:=\sum_{i\in\{0,1\}^\kappa} +f_i(X_{\kappa+1},\dots,X_\ell)\cdot Z^{\sum_{t=1}^{\kappa} i_t 2^{t-1}}. +\] + +Then \(f'(x_{\kappa+1},\dots,x_\ell)=\sum_i y_i\cdot Z^{(\cdot)}\) holds as an algebraic identity. + +**Critical caveat (security / binding):** it is **not** generally sound to claim that “proving the single packed value \(f'(x_{\kappa+1},\dots,x_\ell)\)” suffices to prove that *each* \(y_i\) is correct, because the coefficients \(y_i\) live in \(F_{q^k}\) (not in the ground field \(F_q\)). + +Concretely, the linear map + +\[ +(y_i)_{i\in\{0,1\}^\kappa}\in (F_{q^k})^{2^\kappa} +\;\longmapsto\; +\sum_i y_i \cdot Z^{(\cdot)} \in F_{q^k} +\] + +is **\(F_{q^k}\)-linear** and therefore has a large kernel whenever \(2^\kappa>1\). So many distinct tuples \((y_i)_i\) produce the *same* packed sum. A toy example for \(k=2\) (basis \(\{1,Z\}\)): the two different pairs \((y_0,y_1)\) and \((y_0+Z,\; y_1-1)\) satisfy + +\[ +(y_0+Z) + (y_1-1)\cdot Z \;=\; y_0 + y_1\cdot Z, +\] + +so the packed value alone does not pin down \(y_0,y_1\). + +This is exactly the “basis is only independent over the ground field” pitfall called out in `paper/fri-binius.pdf`, where they explain that a basis \((\beta_v)\) is linearly independent over \(K\) but **not** over its extension \(L\); hence basis-combining \(L\)-valued claims is insecure. See the strawman discussion around Figure 1 in: + +- `paper/fri-binius.pdf` §1.3 “Ring-Switching”, “A strawman approach”, Figure 1, and the paragraph beginning “While this protocol is complete, it’s not secure.” (pages 4–6 in this PDF copy). + +**How to fix (high-level, sumcheck-style):** you must reduce the “\(y_i\) are well-formed” constraints into **ground-field (\(F_q\)) constraints** before basis-combining / packing them. + +The standard way (as in Fri-Binius ring-switching) is: + +1. **Basis-decompose each \(y_i\) over \(F_q\)**: write \(y_i = \sum_{u=0}^{k-1} y_{u,i}\,Z^u\) with \(y_{u,i}\in F_q\). +2. Also basis-decompose the extension-field weights (the equality-polynomial weights / monomial weights) into \(F_q\) coefficients. +3. Check the resulting family of \(F_q\)-valued equalities “slice-wise” (over the \(u\) index), and then **batch them with an additional sumcheck** so the verifier only pays polylog overhead. + +After this extra sumcheck layer pins down the \(F_q\)-slices, packing becomes injective again (because it is now combining **\(F_q\)-vectors** against an \(F_q\)-basis), and the reduction from “many partial claims” to “one packed claim” becomes sound. + +##### 0.A.1 Concrete “extra sumcheck” shape (Fri-Binius Eq. (12) style), and how many rounds it costs + +This subsection spells out the *exact* reason the extra sumcheck has \(\ell\) rounds (and how it fits the objects already in §3.2). + +Let \(K:=F_q\) and \(L:=F_{q^k}\), and assume \(k=2^\kappa\) for some \(\kappa\) (as in §3.2). Fix a \(K\)-basis \((\beta_u)_{u\in\{0,1\}^\kappa}\) of \(L\) (in §3.2 the paper chooses \(1,Z,\dots,Z^{k-1}\), which is just one such basis). + +The goal is to prove the family of claims (paper after Eq. (11)): + +\[ +\forall v\in\{0,1\}^\kappa:\quad +y_v \stackrel{?}{=} f_v(x_{\kappa+1},\dots,x_\ell)\in L, +\] + +where \(f_v\in K^{\le1}[X_{\kappa+1},\dots,X_\ell]\) is the “slice” of \(f\) with the first \(\kappa\) variables fixed to \(v\). + +The unsafe step is to basis-combine these \(L\)-valued equalities directly. The safe replacement is to basis-decompose everything so the equalities become \(K\)-valued first. + +1) **Basis-decompose the prover’s \(L\)-claims**: for each \(v\in\{0,1\}^\kappa\), write + +\[ +y_v = \sum_{u\in\{0,1\}^\kappa} y_{u,v}\,\beta_u, +\quad\text{with } y_{u,v}\in K. +\] + +2) **Expand \(f_v(x_{\kappa+1},\dots,x_\ell)\) as a \(K\)-weighted sum.** Let \(\ell':=\ell-\kappa\) and index \(w\in\{0,1\}^{\ell'}\). Then + +\[ +f_v(x_{\kappa+1},\dots,x_\ell) += +\sum_{w\in\{0,1\}^{\ell'}} +\mathrm{eq}(x_{\kappa+1},\dots,x_\ell;\,w)\cdot f(v,w), +\] + +where \(f(v,w)\in K\) is the \((v,w)\) Lagrange coefficient of \(f\), and \(\mathrm{eq}(\cdot;\,w)\in L\) is the multilinear equality indicator value. + +3) **Basis-decompose the (public) weights**: for each \(w\), decompose the \(L\)-element \(\mathrm{eq}(x_{\kappa+1},\dots,x_\ell;\,w)\) in the same basis: + +\[ +\mathrm{eq}(x_{\kappa+1},\dots,x_\ell;\,w) += +\sum_{u\in\{0,1\}^\kappa} A_{w,u}\,\beta_u, +\quad\text{with } A_{w,u}\in K, +\] + +where the \(A_{w,u}\) are deterministically computable from the public point \((x_{\kappa+1},\dots,x_\ell)\) and the chosen basis. + +4) **Now each coordinate is a \(K\)-statement**: equating coefficients of \(\beta_u\) yields the family of \(K\)-equalities + +\[ +\forall u,v\in\{0,1\}^\kappa:\quad +y_{u,v} \stackrel{?}{=} \sum_{w\in\{0,1\}^{\ell'}} A_{w,u}\cdot f(v,w). +\] + +5) **Pack the \(f(v,w)\) table into one \(L\)-multilinear** (this is the same “packing” idea as §3.2, but applied to a \(K\)-table so it is information-preserving): + +\[ +f'(w) := \sum_{v\in\{0,1\}^\kappa} f(v,w)\,\beta_v \in L, +\] + +so \(f'\) has \(\ell'=\ell-\kappa\) variables over \(L\). + +6) **Batch and sumcheck.** Define the combined \(L\)-valued claims (these are the secure analog of the strawman’s linear combination step): + +\[ +\hat y_u := \sum_{v\in\{0,1\}^\kappa} y_{u,v}\,\beta_v \in L. +\] + +Then the \(K\)-equalities above imply the \(L\)-equalities + +\[ +\forall u\in\{0,1\}^\kappa:\quad +\hat y_u \stackrel{?}{=} \sum_{w\in\{0,1\}^{\ell'}} A_{w,u}\cdot f'(w). +\] + +Finally, batch over \(u\) with a random point \(r''\in L^\kappa\) and apply sumcheck to the identity + +\[ +\sum_{u\in\{0,1\}^\kappa} \mathrm{eq}(u;\,r'')\cdot \hat y_u +\stackrel{?}{=} +\sum_{w\in\{0,1\}^{\ell'}} +\Big(\sum_{u\in\{0,1\}^\kappa}\mathrm{eq}(u;\,r'')\cdot A_{w,u}\Big)\cdot f'(w). +\] + +This is exactly the structural form of Fri-Binius Eq. (12). The sum ranges over \((u,w)\in\{0,1\}^\kappa\times\{0,1\}^{\ell'}\), so the sumcheck has **\(\kappa+\ell'=\ell\) rounds**, i.e. it costs **\(+\kappa=\log_2 k\)** more rounds than a hypothetical scheme that only needed to open \(f'\) (which has \(\ell'\) variables). + +Asymptotically, prover time for this added sumcheck is linear in the domain size \(2^{\kappa+\ell'}=2^\ell\) (times poly\((\ell,k)\) factors), i.e. \(\Theta(2^\ell)\) field operations in \(L\) for the sumcheck layer, plus one opening of the \(L\)-PCS on \(f'\) at the final point. + +#### 0.B (paper §3.1): embed \(F_{q^k}\) inside \(R_q\) and turn extension-field inner products into trace statements over \(R_q\) + +Now treat the remaining claim as an evaluation over \(F_{q^k}\) (equivalently over the subfield \(R_q^H \cong F_{q^k}\) from Lemma 5). §3.1 then provides: + +- **A subfield of \(R_q\)**: for \(q \equiv 5 \pmod 8\) and \(k\mid d/2\), define \(H=\langle\sigma^{-1},\sigma^{4k+1}\rangle\subset\mathrm{Aut}(R)\). Then the fixed ring \(R_q^H\) is a field and \(R_q^H \cong F_{q^k}\). (Lemma 5.) +- **A packing bijection** \(\psi:(R_q^H)^{d/k}\to R_q\). (Eq. (8), Theorem 2.) +- **A trace/inner-product identity** (Theorem 2): + + \[ + \mathrm{Tr}_H\big(\psi(a)\cdot \sigma^{-1}(\psi(b))\big) \;=\; \frac{d}{k}\cdot \langle a,b\rangle, + \quad a,b \in (R_q^H)^{d/k}. + \] + +Concretely (paper around Eq. (10)), pick \(\alpha:=\log_2 d\) and split the \(\ell\) variables into an “outer” prefix of length \(\ell-\alpha+\kappa\) and an “inner” suffix of length \(\alpha-\kappa\). Write indices as \(i\in\{0,1\}^{\ell-\alpha+\kappa}\) and \(j\in\{0,1\}^{\alpha-\kappa}\). Define: + +- **Packed coefficient blocks**: \(F_i := \psi\big((f_{i\parallel j})_{j\in\{0,1\}^{\alpha-\kappa}}\big)\in R_q\). +- **Packed monomial block at the suffix**: + + \[ + v := \psi\Big(\big(\prod_{t=1}^{\alpha-\kappa} x_{\ell-\alpha+\kappa+t}^{\,j_t}\big)_{j\in\{0,1\}^{\alpha-\kappa}}\Big)\in R_q. + \] + +Then the extension-field evaluation is reduced to checking a **single trace equation** in \(R_q\) involving: + +\[ +Y \;:=\; \sum_{i\in\{0,1\}^{\ell-\alpha+\kappa}} +\Big(\prod_{t=1}^{\ell-\alpha+\kappa} x_t^{i_t}\Big)\cdot F_i +\;\in\; R_q, +\] + +namely \(\mathrm{Tr}_H\big(Y\cdot \sigma^{-1}(v)\big)=\tfrac{d}{k}\,y\). The prover sends this single ring element \(Y\). + +What remains is to prove that \(Y\) is well-formed, i.e. that it is *exactly* the evaluation of the \((\ell-\alpha+\kappa)\)-variate ring polynomial \(F := (F_i)_i \in R_q^{\le 1}[X_1,\dots,X_{\ell-\alpha+\kappa}]\) at the point \((x_1,\dots,x_{\ell-\alpha+\kappa})\) (viewing those \(x_t\) as elements of the subfield \(R_q^H\subset R_q\)). This is the “smaller multilinear evaluation claim over \(R_q\)” that becomes the input to the §4 core PCS. + +### Step 1 (paper §4): the internal ring PCS statement Hachi actually proves + +After §3’s transformation, Hachi reduces extension-field evaluation claims to the §4 “core” ring PCS statement: + +- **Witness polynomial (ring core)**: \(f \in R_q^{\le 1}[X_1,\dots,X_{\ell'}]\) with coefficients in \(R_q\) (this \(f\) is the reduced ring polynomial constructed in Step 0.B; i.e., it is the \(F := (F_i)_i\) from above, just returning to the paper’s §4 notation), +- **Claim**: for a public point \(x\in R_q^{\ell'}\) (in the reduction flow, typically \(x\in (R_q^H)^{\ell'}\subset R_q^{\ell'}\)), prove \(f(x)=u\in R_q\), + +where (per §3.2) \(\ell' = \ell - \alpha\) in the important special case “coefficients in \(Z_q\), point in \(F_{q^k}\)”, with \(d=2^\alpha\). + +From here on, we are *inside* §4. The paper continues to write the ring instance as \(\ell\)-variate; in digest notation, you can read the \(\ell\) used in Steps A/B/C below as \(\ell'\) (the post-§3 reduced variable count). + +### Step A (paper §4.1): Commit to the coefficient table (inner + outer commitments) + +Hachi’s commitment structure is Greyhound-style: it commits to the coefficient table in blocks using two Ajtai commitments. + +Split \(\ell = m + r\) with \(m \approx r\). Define the block slices \(f_i^\top := (f_{i\|j})_{j\in\{0,1\}^m}\in R_q^{2^m}\) for each outer index \(i\in\{0,1\}^r\). (This is the “\(f_i\)” notation right below Equation (12).) + +Commitment construction (Equations (13)–(14)): + +- **Decompose each slice**: \(s_i := G^{-1}_{2^m}(f_i)\in R_q^{2^m\delta}\), where \(\delta=\lceil \log_b q\rceil\). (Eq. (13).) +- **Inner Ajtai commit**: \(t_i := A s_i \in R_q^{n_A}\) for all \(i\in[2^r]\). +- **Decompose the inner commit**: \(\hat t_i := G^{-1}_{n_A}(t_i)\). +- **Outer Ajtai commit**: stack all \(\hat t_i\) and commit: + - \(u := B[\hat t_1;\dots;\hat t_{2^r}] \in R_q^{n_B}\). (Eq. (14).) + +So: + +- **Commitment** is \(u \in R_q^{n_B}\). +- **Opening witness** (conceptually) is \((s_i,\hat t_i)_i\). +- Knowledge/binding is phrased in terms of “weak openings” (definition right after Eq. (14), Lemma 7). + +### Step B (paper §4.2): Reduce “open \(f(x)=u\)” to “prove knowledge of short vectors satisfying a public linear system” + +This is the core reduction pipeline of Hachi’s opening proof. + +#### B.0. Descriptive names for the “hat” variables (implementation-minded) + +The paper uses “hats” (\(\hat{\cdot}\)) for **digit-decomposed** objects (small coefficients), and uses a few single-letter vectors that are easy to lose track of. Here is a naming map that matches the paper’s definitions in §4.1–§4.2 and is meant to be used as docstring text during implementation. + +- **\(s_i\)** = **block_digits** (digits of the \(i\)-th coefficient block): + - \(s_i = G^{-1}_{2^m}(f_i)\). (Eq. (13).) + - Shape: vector of ring elements, length \(2^m\cdot\delta\). + +- **\(t_i\)** = **inner_commitment_block_i** (Ajtai commit of the \(i\)-th block digits): + - \(t_i := A s_i \in R_q^{n_A}\). + +- **\(\hat t_i\)** = **inner_commitment_digits_block_i** (digits of \(t_i\)): + - \(\hat t_i := G^{-1}_{n_A}(t_i)\). + +- **\(u\)** = **outer_commitment** (the actual PCS commitment output): + - \(u := B[\hat t_1;\dots;\hat t_{2^r}] \in R_q^{n_B}\). (Eq. (14).) + +- **\(w_i\)** = **block_partial_eval_i** (block \(i\)’s contribution after plugging the “\(a\)” half of the opening point): + - \(w_i := a^\top G_{2^m} s_i \in R_q\). (Defined right before Eq. (16).) + +- **\(\hat w_i\)** = **block_partial_eval_digits_i** (digits of \(w_i\)): + - \(\hat w_i := G^{-1}_1(w_i) \in R_q^\delta\). + +- **\(v\)** = **aux_commitment_to_block_partials** (commitment to all \(\hat w_i\)): + - \(v := D\hat w \in R_q^{n_D}\). (Eq. (16).) + +- **\(c\)** = **fold_challenge_vector** (short/sparse ring challenge used to fold blocks): + - \(c=(c_1,\dots,c_{2^r})\in C^{2^r}\), with \(\|c_i\|_1\le \omega\). + +- **\(z\)** = **folded_block_digits** (folded witness over block digits): + - \(z := \sum_{i=1}^{2^r} c_i s_i\). (Eq. (18)/(19) discussion.) + - This is the key “compress all blocks into one witness” object. + +- **\(\hat z\)** = **folded_block_digits_redecomposed** (extra decomposition of \(z\) after coefficient growth): + - \(\hat z := J^{-1}_{2^m}(z)\), where \(J\) is the gadget matrix sized for \(\tau\) digits (so \(\hat z\) has length \(2^m\cdot\delta\cdot\tau\)). (Right after Eq. (20).) + +- **\(r\)** = **modulus_quotient_witness** / **slack_witness** (ring-switching quotient): + - \(Mz = y + (X^d+1)\cdot r\) over \(Z_q[X]\). (§4.3.) + - In the “real” protocol, \(r\) is digit-decomposed as \(r=\sum_u b^u r_u\) and the prover commits to the digit vectors \(r_u\). (See §4.3 and our C.2.1.) + +#### B.1. Rewrite evaluation as a bilinear form (Eq. (12)) + +Define the monomial vectors from the evaluation point \(x\): + +- \(b^\top := (x_1^{i_1}\cdots x_r^{i_r})_{i\in\{0,1\}^r}\in R_q^{2^r}\) +- \(a^\top := (x_{r+1}^{j_1}\cdots x_\ell^{j_m})_{j\in\{0,1\}^m}\in R_q^{2^m}\) + +Then the evaluation can be written as Equation (12), which motivates everything that follows. + +#### B.2. Introduce intermediate values \(w_i\) and commit to them (Eq. (16)) + +Using the decomposed slices \(s_i\), Equation (15) rewrites the opening equation using \(s_i\). + +Define the intermediate ring elements: + +- \(w_i := a^\top G_{2^m} s_i \in R_q\). (Defined right before Eq. (16).) + +Then: + +- \(u = b^\top w\). (Eq. (17).) + +The prover also commits to \(w\) using another Ajtai-style commitment (Eq. (16)): + +- decompose \(w_i\) to \(\hat w_i := G^{-1}_1(w_i)\in R_q^\delta\), +- stack \(\hat w := (\hat w_1,\dots,\hat w_{2^r})\), +- compute **\(v := D \hat w \in R_q^{n_D}\)**. (Eq. (16).) + +This is the first prover message in Figure 3: + +- **P → V**: send \(v\). + +#### B.3. Fold all the \(s_i\) using a short challenge vector \(c\) (Eqs. (18)–(19)) + +The verifier samples a short/sparse challenge vector: + +- **V → P**: \(c=(c_1,\dots,c_{2^r}) \leftarrow C^{2^r}\), where \(C \subset \{c\in R_q : \|c\|_1 \le \omega\}\). (See the paragraph introducing \(c\).) + +The prover folds the witness: + +- \(z := \sum_{i=1}^{2^r} c_i s_i \in R_q^{2^m\delta}\). (Immediately after defining \(c\).) + +Two crucial linear identities then hold: + +- \(a^\top G_{2^m} z = (c^\top \otimes G_1)\hat w\). (Eq. (18).) +- \(A z = (c^\top \otimes G_{n_A})\hat t\). (Eq. (19).) + +These relate the folded witness \(z\) to the already-committed objects \(\hat w,\hat t\). + +#### B.4. Decompose \(z\) further and form one big “unstructured linear relation” (Eq. (20)) + +Because coefficients of \(z\) are larger, the prover further decomposes \(z\) using another gadget matrix \(J\): + +- pick a bound \(\beta\) on \(\|z\|_\infty\), define \(\tau:=\lceil \log_b \beta\rceil\), +- compute \(\hat z := J^{-1}_{2^m}(z)\in R_q^{2^m\delta\tau}\). + +At this point, the prover’s remaining task becomes: + +> prove knowledge of a short vector \((\hat w,\hat t,\hat z)\) satisfying a public linear system over \(R_q\). (Eq. (20).) + +Equation (20) is exactly that public linear system: + +- it includes the matrices \(D,B,A\), +- it includes the evaluation-derived vectors \(a,b\), +- it includes the verifier challenge \(c\), +- and it enforces: (i) consistency with the commitments \(u,v\), (ii) the opening equation, and (iii) the fold relations (18)–(19). + +##### B.4.1 Eq. (20) rewritten as a list of constraints (no hats, descriptive meaning) + +Mentally, Eq. (20) is just a *bundling* of several checks into one linear system \(M_{\text{big}}\cdot \text{witness} = \text{statement}\). Written as explicit constraints, the prover is proving existence of **small** objects: + +- **`block_partial_eval_digits`** = \(\hat w\) +- **`inner_commitment_digits`** = \(\hat t\) +- **`folded_block_digits_redecomposed`** = \(\hat z\) (and implicitly \(z = J\hat z\)) + +such that all of the following hold: + +1. **Aux-commitment consistency (ties \(\hat w\) to the sent commitment \(v\))**: + - \(D\hat w = v\). (Eq. (16).) + +2. **Main commitment consistency (ties \(\hat t\) to the sent commitment \(u\))**: + - \(B\hat t = u\). (Eq. (14).) + +3. **Evaluation equation (ties the claimed opening value \(u\) to the partials \(w\))**: + - If \(w := G_{2^r}\hat w\) then \(b^\top w = u\). (Eq. (17).) + +4. **Fold-consistency: partial-eval folding matches the folded witness \(z\)**: + - Let \(z := J\hat z\) (so \(z\) is the “recomposed” folded block-digits witness). + - Then \(a^\top G_{2^m} z = (c^\top\otimes G_1)\hat w\). (Eq. (18).) + +5. **Fold-consistency: inner-commitment folding matches the same folded witness \(z\)**: + - \(A z = (c^\top\otimes G_{n_A})\hat t\). (Eq. (19).) + +6. **Smallness/range constraints (coefficient bounds)**: + - Coefficients of \(\hat w,\hat t,\hat z\) lie in the intended digit ranges (bounded by the gadget decomposition), and the protocol’s range-check machinery in §4.3 ultimately enforces these via the \(H_0\) constraint. + +Figure 3 shows this “conceptual protocol”, ending with the prover *sending* \((\hat w,\hat t,\hat z)\), but the paper immediately notes: + +- in the **final scheme**, the prover does not send these in the clear; instead it proves knowledge of them. + +### Step C (paper §4.3): Prove the unstructured linear relation + coefficient smallness via ring switching + sumcheck over \(F_{q^k}\) + +This is the main verifier-efficiency trick: transform ring relations into field relations (via evaluation at a random \(\alpha\)), then apply sumcheck over \(F_{q^k}\). + +#### C.1. The generic linear relation they want to prove + +They define: + +- \(R^{\mathrm{lin}}_{q,d,n,\mu,b}\): given public \(M\in R_q^{n\times \mu}\), \(y\in R_q^n\), prove knowledge of \(z\in R_q^\mu\) such that \(Mz=y\) and \(\|z\|_\infty \le b-1\). + +Eq. (20) is an instance of this relation (with \(z=(\hat w,\hat t,\hat z)\)). + +#### C.2. Ring switching (Figure 4): lift equality in \(R_q\) to an identity in \(Z_q[X]\), then evaluate at random \(\alpha\) + +Because \(R_q = Z_q[X]/(X^d+1)\), the ring equation \(Mz=y\) holds iff there exists a “slack” polynomial vector \(r\) such that (paper §4.3, “Ring switching.”): + +- \(Mz = y + (X^d+1)\cdot r\) over \(Z_q[X]\). + +##### C.2.1 Critical “hidden detail”: \(r\) is gadget-decomposed, and the protocol commits to the digits + +The paper explicitly notes that **both \(z\) and \(r\)** are prover witnesses, and while the verifier ultimately needs a check involving \(r\), the prover *does not* want to commit to a large-\(q\) object and then range-check it “as-is”. + +So, for “notation and implementation simplicity”, Hachi performs a base-\(b\) gadget decomposition of the quotient witness \(r\) (paper §4.3): + +- \(r = \sum_u b^u \cdot r_u\), +- the prover commits to \((z, r_1, \dots, r_{\log_b(q)})\) instead of \((z,r)\), +- and the prover proves \(\|r_u\|_\infty \le b-1\) for every digit-vector \(r_u\). + +The paper then says it **omits the subscript \(u\)** from this point on, and that to incorporate this in the sumcheck view, “we can modify the multilinear extension correspondingly.” + +Concretely, the “real” linear constraint is: + +\[ +Mz = y + (X^d+1)\cdot \sum_u b^u r_u \quad\text{over } Z_q[X], +\] + +and the “real” smallness constraints are: + +\[ +\|z\|_\infty \le b-1 +\quad\text{and}\quad +\|r_u\|_\infty \le b-1 \ \forall u. +\] + +Protocol idea (Figure 4): + +- **P → V**: commit to \((z,r)\): \(t := \mathrm{Com}(z,r)\). (In the “real” protocol: commit to \((z,(r_u)_u)\) as above; the paper keeps writing \((z,r)\) after omitting the digit index.) +- **V → P**: sample \(\alpha \leftarrow F_{q^k}\). +- Reduce to checking the *field* equations: + - \(M(\alpha)z(\alpha) = y(\alpha) + (\alpha^d+1)r(\alpha)\) over \(F_{q^k}\). (Figure 4.) + +Soundness of this “evaluate at random \(\alpha\)” step is formalized as \(2d\)-special soundness (Lemma 9), reflecting degree \(\le 2d-1\). + +They also need to enforce that the witness coefficients are genuinely in \(Z_q\) and small (not arbitrary in \(F_{q^k}\)), and they fold those checks into the next “sumcheck view”. + +#### C.3. Represent the constraints as multilinear polynomials and batch them (Eqs. (21)–(23), Figure 5) + +This is where the **full constraint system** (including the range check) is spelled out in the paper. + +##### C.3.1 The witness polynomial \(\tilde w\) (Eq. (21)) + +The prover’s committed witness for sumcheck is a multilinear polynomial \(\tilde w\) that encodes coefficient tables. + +**Important:** because of §4.3’s hidden gadget decomposition \(r=\sum_u b^u r_u\), the *real* committed witness should be thought of as encoding coefficient tables of: + +- the vector of polynomials \(z\), and +- all digit-vectors \(r_u\), + +not just a single undigitized \(r\). The paper immediately omits the digit index \(u\) and keeps writing \((z,r)\); this is why Eq. (21) below only has one \(r\). + +The paper defines a multilinear polynomial \(e_w\) (same role as \(\tilde w\) in our prose) as: + +- \(e_w(u,\ell) = z_{u,\ell}\) if \(u \le \mu\), +- \(e_w(u,\ell) = r_{u-\mu,\ell}\) if \(\mu < u \le \mu+n\). (Eq. (21).) + +Here \(u\) and \(\ell\) are treated as binary strings indexing \([\,\mu+n\,]\) and \([\,d\,]\). + +To “incorporate the gadget decomposition” in the exact same shape, one natural flattened encoding is: + +- let \(\delta := \lceil \log_b(q)\rceil\) be the number of digits, +- treat the committed table as indexed by \([\,\mu + n\cdot \delta\,]\times[\,d\,]\), +- keep \(e_w(u,\ell)=z_{u,\ell}\) for \(u\in[\mu]\), +- and set \(e_w(\mu + u'\cdot n + i,\ell) := (r_{u'})_{i,\ell}\) for digit index \(u'\in[\delta]\) and row index \(i\in[n]\). + +With this flattening, the “range check” constraints apply to **every** coordinate of \(z\) and **every** coordinate of every digit vector \(r_{u'}\). + +##### C.3.1.1 Digit-aware \(e_{M_\alpha}\): the fully expanded constraint coefficient function + +The paper’s simplified definition (after Eq. (21)) encodes the linear constraint at \(\alpha\) by defining a public function \(e_{M_\alpha}(i,u)\) that (informally) gives the coefficient multiplying the polynomial \(w_u(\alpha)\) inside row \(i\). + +Once you flatten the digit vectors \((r_{u'})_{u'\in[\delta]}\) into the witness index \(u\in[\,\mu+n\delta\,]\), the *literal* digit-aware version is: + +- Let \(u\in[\,\mu+n\delta\,]\). +- Define the “decoded” witness coordinate \(W_u(\alpha)\) by: + - for \(1 \le u \le \mu\): \(W_u(\alpha) := z_u(\alpha)\), + - for \(\mu + (u'-1)n + i\) with \(u'\in[\delta]\) and \(i\in[n]\): \(W_{\mu + (u'-1)n + i}(\alpha) := (r_{u'})_i(\alpha)\). + +Then define: + +\[ +e^{\text{dig}}_{M_\alpha}(i,u) := +\begin{cases} +M_{i,u}(\alpha) & \text{if } 1 \le u \le \mu, \\ +-\,b^{u'-1}\cdot(\alpha^d+1) & \text{if } u = \mu + (u'-1)n + i \text{ for some } u'\in[\delta], \\ +0 & \text{otherwise.} +\end{cases} +\] + +With this explicit \(e^{\text{dig}}_{M_\alpha}\), the digitized ring-switch check for each row \(i\in[n]\) is exactly: + +\[ +\sum_{u=1}^{\mu+n\delta} e^{\text{dig}}_{M_\alpha}(i,u)\cdot W_u(\alpha) \;=\; y_i(\alpha). +\] + +This is just the statement: + +\[ +\sum_{j=1}^{\mu} M_{i,j}(\alpha)\,z_j(\alpha)\;-\;(\alpha^d+1)\sum_{u'=1}^{\delta} b^{u'-1}\,(r_{u'})_i(\alpha) \;=\; y_i(\alpha), +\] + +which is equivalent to the “real” lifted identity \(Mz = y + (X^d+1)\sum_{u'}b^{u'-1}r_{u'}\) after evaluation at \(X=\alpha\). + +##### C.3.2 The two *literal* constraints before batching (paper right before Eq. (22)/(23)) + +For a fixed ring-switch challenge \(\alpha\in F_{q^k}\), the verifier wants to enforce: + +1. **Linear constraints (ring switching at \(\alpha\))**: for each row \(i\in[n]\), + + \[ + \sum_{u=1}^{\mu+n} e_{M_\alpha}(i,u)\cdot \sum_{\ell} e_w(u,\ell)\cdot e_\alpha(\ell)\;=\;y_i(\alpha), + \] + + where \(e_\alpha(\ell)=\alpha^\ell\) and \(e_{M_\alpha}(i,u)\) is the public multilinear encoding of \(M(\alpha)\) plus the extra \(-(\alpha^d+1)\) term on the \(r\)-coordinates. (This is the first bullet list item under Figure 5 in the paper.) + + With the hidden gadget decomposition \(r=\sum_u b^u r_u\), the same constraint is enforced except the right-hand side becomes \(y_i(\alpha) + (\alpha^d+1)\sum_u b^u r_{i,u}(\alpha)\). In the “\(e_{M_\alpha}\cdot e_w\)” encoding, this is handled by: + + - expanding \(e_w\) to include all digit-vectors \((r_u)_u\) (as described in C.3.1), and + - modifying the \(-(\alpha^d+1)\) part of \(e_{M_\alpha}\) to include the appropriate digit weights \(-b^u(\alpha^d+1)\) on the digit blocks. + +2. **Smallness / range constraints (this is the “range check”)**: for *every* coordinate \((u,\ell)\), + + \[ + P_b\big(e_w(u,\ell)\big) = 0 + \quad\text{where}\quad + P_b(T) := \prod_{t=-(b-1)}^{b-1}(T-t). + \] + + The paper writes this explicitly as the vanishing product: + \(e_w(u,\ell)\cdot(e_w(u,\ell)-1)\cdot(e_w(u,\ell)+1)\cdots(e_w(u,\ell)-b+1)\cdot(e_w(u,\ell)+b-1)=0\). + +This is **not optional**: it is exactly how Hachi enforces that the coefficients the prover is claiming for \(z\) (and the \(r\)-side witness they commit) are small integers (embedded into \(F_{q^k}\)), i.e. a range/membership proof via a root-check polynomial. + +##### C.3.3 Batching with equality polynomials: the exact \(H_\alpha\) and \(H_0\) (Eqs. (22)–(23)) + +Let the multilinear equality polynomial be: + +- \(e_{eq}(t,i) = \prod_j (t_j i_j + (1-t_j)(1-i_j))\). + +Then the paper defines: + +- **Linear constraint batch** \(H_\alpha(t)\) (Eq. (22)): + + \[ + H_\alpha(t) := + \sum_{i\in[n]} e_{eq}(t,i)\cdot + \Big(\sum_{u,\ell} e_{M_\alpha}(i,u)\cdot e_w(u,\ell)\cdot e_\alpha(\ell) - y_i(\alpha)\Big). + \] + +- **Smallness/range batch** \(H_0(t)\) (Eq. (23)): + + \[ + H_0(t) := + \sum_{u,\ell} e_{eq}\big(t,(u,\ell)\big)\cdot P_b\big(e_w(u,\ell)\big). + \] + +The goal is to prove both are **identically zero polynomials**, which they reduce (Figure 5) to random-point checks: + +- **V → P**: send random points \(\tau_0,\tau_1\), +- prove \(H_0(\tau_0)=0\) and \(H_\alpha(\tau_1)=0\). + +#### C.4. Use sumcheck to prove the remaining batched sums (Figure 6) + +After fixing \(\tau_0,\tau_1\), they rewrite \(H_0(\tau_0)\) and \(H_\alpha(\tau_1)\) as sums over \((u,\ell)\) of polynomials \(F_{0,\tau_0}\) and \(F_{\alpha,\tau_1}\), and then apply **sumcheck** over \(F_{q^k}\) (discussion right after Eq. (23)). + +Figure 6 gives the “single-round view” of sumcheck: prover sends univariate \(g_i\), verifier sends random challenge scalars \(a_i\), and in the end the verifier reduces everything to checking an evaluation of the witness polynomial \(\tilde w\) at one final random point. + +Crucially, the sumcheck ends in: + +- a **final evaluation claim** of the committed \(\tilde w\) at a random point \(r^\*\), +- plus the requirement to prove this evaluation is consistent with the commitment \(t\). + +That “prove the evaluation claim for \(\tilde w\)” is exactly where recursion happens: you invoke the PCS again on the smaller committed object. + +##### C.4.1 Full message flow (Figure 7), spelled out + +Figure 7 in the paper is the full composition of Figures 4, 5, and 6. Written as an explicit transcript skeleton: + +- **P → V**: \(t := \mathrm{Com}(z,r)\). (In the “real” protocol: \(t := \mathrm{Com}(z,(r_u)_u)\) with \(r=\sum_u b^u r_u\).) +- **V → P**: sample challenges: + - \(\alpha \leftarrow F_{q^k}\), + - \(\tau_0 \leftarrow F_{q^k}^{\log(\mu)+\log(d)}\), + - \(\tau_1 \leftarrow F_{q^k}^{\log(n)}\). +- **P ↔ V (sumcheck)**: run sumcheck to prove \(H_0(\tau_0)=0\) and \(H_\alpha(\tau_1)=0\): + - in each sumcheck round, **P → V** sends a univariate polynomial \(g_i(X_i)\), + - and **V → P** responds with a random scalar challenge \(a_i \leftarrow F_{q^k}\). (Figure 6.) +- **P → V (final opening)**: open the commitment \(t\) at the final point determined by \((a_1,\dots,a_\ell)\) to provide the needed value(s) of \(\tilde w\). +- **V (final checks)**: + - evaluate the public multilinear extensions (notably \(e_{M_\alpha}\) / \(e_\alpha\)) at the same final point, + - and check the final sumcheck identities (Figure 6’s last-round checks). + +### Step D (paper §3 + §5): recursion shape and concrete proof-size accounting + +Recursion shape (high-level): + +- One invocation reduces the big opening proof into opening proofs for **smaller** committed objects (smaller \(\ell\), smaller coefficient domains after decomposition, and evaluation points living in \(F_{q^k}\)). +- Eventually, Hachi suggests switching to different sub-proofs when the witness is small (paper discusses both switching to LaBRADOR/JL and composing with Greyhound). + +Concrete proof-size estimate for \(\ell=30\) (paper §5.2): + +- Sumcheck for the first round: ~7.3KB. +- “Adaptation + Greyhound subproof” gives total ~55.1KB for an evaluation proof. + +This is the explicit estimate in §5.2: + +- first-round sumcheck: \(7.3\)KB, +- plus preparation/adaptation \(4.8\)KB, +- plus Greyhound evaluation subproof \(43\)KB, +- total \(7.3 + 4.8 + 43 = 55.1\)KB. + +#### D.1 What the “adaptation to Greyhound” actually means (paper §4.5, §5.2) + +When the witness becomes small enough, Hachi can stop running its own §4.3 ring-switch + sumcheck recursion and instead reduce the remaining claim into a **Greyhound-native** opening proof. + +The key shape (paper §4.5 and the concrete instantiation in §5.2) is: + +- you end up with an evaluation claim where the verifier has reduced everything to checking that a committed multilinear object \(\tilde w\) evaluates correctly at a random point; +- the prover groups / “packs” the needed coefficients into extension-field elements, and then uses the §3 embedding machinery (the \(\psi\) map and trace identity) to turn the remaining check into a ring statement supported by Greyhound. + +In the concrete accounting, the additional prover communication for this handoff is bounded as (Equation (28) in the paper excerpted in §4.5): + +- \((k-1)\cdot k \cdot \log q\) bits for the sent partial evaluations \((y_i)\), plus +- \(d'\cdot \log q\) bits for sending one ring element \(p\in R'_{q}\) (where \(d'\) is the Greyhound ring dimension, e.g. \(d'=64\) in §5.2). + +This is why §5.2 reports the “adaptation overhead” as small (~0.3KB for the non-commitment part), and then adds the cost to **commit** to the new Greyhound witness element(s). + +### Message flow summary (first / dominant round) + +If you want a “wire-format mental model” aligned to Figures 3–6, the dominant first round contains: + +- **Ring commitments**: the main commitment \(u\in R_q^{n_B}\) (Eq. (14)) and the auxiliary commitment \(v\in R_q^{n_D}\) (Eq. (16)). +- **Ring challenge**: short/sparse \(c\in C^{2^r}\) used to fold the opening witness (Eq. (18)–(19)). +- **Field challenge**: \(\alpha\in F_{q^k}\) for ring switching (Figure 4), and random points \(\tau_0,\tau_1\) (Figure 5). +- **Sumcheck transcript**: univariate polynomials \(g_i\) and challenge scalars \(a_i\) (Figure 6), ending in a point \(r^\*\) and an evaluation value \(\tilde w(r^\*)\). +- **Opening proof(s)**: recursive PCS openings that prove the claimed evaluations of the committed \(\tilde w\) values match the commitments. + +## Implementation-oriented PCS flow spec (what we need before coding) + +This section re-states the end-to-end PCS as a **software spec**: what the prover/verifier compute, what objects exist (and their “witness” roles), and what must be pinned down *before* we implement §4.3 “witness table embedding” or the ring-switch sumcheck instances. + +### E.0 The key mental model: “stacked linear relation” → “ring switch” → “sumcheck” → “new opening claim” + +The paper explains §4.3 in the simplified setting \(Mz=y\) for \(z\in R_q^\mu\). In the actual PCS opening proof, §4.2 produces exactly such a relation, but with: + +- a **stacked witness vector** \(z\) that bundles multiple unknowns (notably \(\hat w,\hat t,\hat z\) from Eq. (20)), and +- a stacked statement \(y\) and matrix \(M\) derived from public matrices \(A,B,D\), the opening point \(x\) (via \(a,b\)), the claimed opening value, and the verifier challenge \(c\). + +Once the prover and verifier agree on that concrete \((M,y)\), §4.3 proceeds as: + +1. **Ring switching introduces a quotient witness** \(r\) such that: + + \[ + Mz = y + (X^d+1)\cdot r \quad \text{over } Z_q[X]. + \] + +2. **Digitize the quotient witness** \(r = \sum_{u'=0}^{\delta-1} b^{u'} r_{u'}\) and range-check each digit block. +3. Encode the (digitized) coefficient tables of \((z,(r_{u'})_{u'})\) into one multilinear object \(\tilde w\) (paper Eq. (21), with the digit index omitted in the paper’s notation). +4. Prove the linear and range constraints by sumcheck, which ends in an evaluation claim \(\tilde w(r^\*)\) at a random point \(r^\*\). +5. **Recursion boundary**: the next PCS subproblem is “open the commitment to \(\tilde w\) at point \(r^\*\)”. + +### E.1 What the prover’s “witness” is after the first interaction rounds (ring switching + sumcheck) + +If by “after the first rounds” you mean “after the verifier samples \(\alpha,\tau_0,\tau_1\) and sumcheck begins”, then the prover’s relevant witnesses are: + +- **The stacked linear-relation witness** \(z\) from §4.2 (Eq. (20)’s unknown vector; in the paper’s naming, this is \((\hat w,\hat t,\hat z)\) with the implicit recomposition \(z := J\hat z\)). +- **The ring-switch quotient witness digits** \((r_{u'})_{u'\in[\delta]}\) satisfying: + + \[ + Mz - y = (X^d+1)\sum_{u'} b^{u'} r_{u'}. + \] + +These are not sent directly. Instead they are **committed** and then accessed *only* through: + +- evaluations of the committed multilinear object \(\tilde w\) during sumcheck, and finally +- one PCS opening of \(\tilde w\) at the final sumcheck point \(r^\*\). + +So the short answer is: + +> After ring switching starts, the “PCS witness” (for the next recursion layer) is the committed multilinear polynomial \(\tilde w\) that encodes the coefficient table of the stacked witness \(z\) and the digitized quotient witness \((r_{u'})_{u'}\). Sumcheck reduces everything to opening \(\tilde w\) at one random point \(r^\*\). + +### E.2 What “witness table embedding” (Eq. (21)) really must encode in the full PCS (not the simplified \(Mz=y\) story) + +Eq. (21) defines \(e_w(u,\ell)\) in the simplified \((z,r)\) notation. For implementation, the important points are: + +- The “\(z\)” rows are not just “some vector”; in the PCS they are the **stacked unknowns** of Eq. (20), i.e. the prover’s hidden objects that tie together: + - the main commitment \(u\), + - the aux commitment \(v\), + - the opening equation, + - the fold identities (18)–(19), + - and the redecomposition via \(J\). +- The “\(r\)” rows are not optional: they are the **quotient witness** for the lifted equation over \(Z_q[X]\), and in the real protocol are digitized as \((r_{u'})_{u'}\) and range-checked. + +In other words, the witness-table encoder we build should take as input: + +- the stacked witness vector (call it `linear_relation_witness` instead of paper’s `z`), and +- the digitized quotient witness blocks `quotient_digits` (instead of paper’s `r`), + +and output a padded, evaluation-form multilinear object \(\tilde w\). + +### E.3 Why the PoC’s “next witness” contains “extra quotient terms” (and why they don’t contradict the paper) + +The PoC does not implement §4.3 for a single abstract relation \(Mz=y\). It already constructs (a variant of) Eq. (20)’s **stacked linear system**, which bundles multiple constraints. + +Each lifted constraint row contributes its own quotient polynomial(s) when you rewrite it as: + +\[ +\text{(row LHS)} - \text{(row RHS)} = (X^d+1)\cdot r_i(X), +\] + +so the PoC ends up with several quotient vectors (for several different stacked sub-constraints), and it concatenates them into the “next witness table” before padding to a power of two. + +This is an implementation manifestation of the same principle the paper uses: + +- §4.2 stacks many checks into one linear system (Eq. (20)), +- §4.3 introduces quotient witnesses for that system when lifted to \(Z_q[X]\), +- §4.3 then commits to a single coefficient-table object \(\tilde w\) representing *all* witness coordinates and *all* quotient-digit coordinates. + +So: seeing “more quotient chunks” in a prototype is expected whenever you expand the simplified \(Mz=y\) exposition into the concrete PCS relation. + +### E.4 Concrete “before we code” checklist (MVP up through §4.3) + +To avoid implementing §4.3 machinery “in the dark”, we should pin down the following concrete spec items first. + +#### E.4.1 Fix the exact stacked linear relation produced by §4.2 + +We need a concrete `LinearRelationInstance` spec with: + +- **Public statement**: + - the stacked matrix \(M\in R_q^{n\times \mu}\), + - the stacked right-hand side \(y\in R_q^n\), + - and the coefficient smallness parameters (base \(b\), digit length \(\delta\), and any redecomposition \(\tau\) where relevant). +- **Witness semantics**: + - what each coordinate of the stacked witness vector means (e.g., slices corresponding to \(\hat w,\hat t,\hat z\)), + - and which coordinates are subject to the range-check polynomial \(P_b\). + +Without this, we cannot correctly define how many “\(z\) rows” Eq. (21) has (the paper’s \(\mu\)). + +#### E.4.2 Fix the quotient witness structure introduced by ring switching + +For the chosen \((M,y)\), ring switching requires: + +- the quotient witness vector \(r\in (Z_q[X]_{ idx\) to become “evaluations on the hypercube”. + +This is the “witness table embedding” deliverable: it is the object whose commitment is opened at the end of sumcheck (Figure 6 / Figure 7). + +#### E.4.4 Fix the public multilinear encodings needed by §4.3 constraints + +To define \(H_\alpha\) and \(H_0\) (Eqs. (22)–(23)), we need: + +- `AlphaPowers`: the table \(e_\alpha(\ell)=\alpha^\ell\), +- `LinearCoeffEncoding`: the digit-aware coefficient function \(e^{dig}_{M_\alpha}(i,u)\) (i.e., how \(M(\alpha)\) and \(-(\alpha^d+1)\cdot b^{u'}\) are encoded), +- and the equality polynomials \(e_{eq}\) used for batching. + +#### E.4.5 Fix how sumcheck’s final oracle check becomes a PCS opening claim + +Our sumcheck core deliberately stops at “here is the final point \(r^\*\)”; the ring-switch module must: + +- compute the expected final value using the public parts, and +- reduce to a **single opening claim** of the committed \(\tilde w\) at \(r^\*\). + +This is exactly where the PCS prover/verifier “open-check” logic plugs in (currently stubbed in this repo). + +## Modulus switching / cross-prime sumcheck (Jolt-motivated extension; not in the Hachi paper) + +This section sketches how to adapt the Hachi “ring switch \(\to\) sumcheck” pipeline to the setting where: + +- **Commitments** are over a *small* prime field / ring modulus \(q\) (e.g. \(\approx 2^{32}\)), because that makes commitment-time arithmetic and NTT/CRT layouts fast. +- **Sumcheck / arithmetization** must run over a *large* prime field \(F_{q'}\) (e.g. 128-bit prime), because the application (e.g. Jolt) requires characteristic large enough to avoid wrap-around in \(u64\cdot u64\) accumulation. + +This is *similar in spirit* to §3’s extension-field story, but **strictly harder**: there is no field embedding \(F_q \hookrightarrow F_{q'}\) that preserves addition/multiplication mod the prime, so we must explicitly control an **integer lift** via range checks. + +### F.0 Target statement (“foreign-field opening”) + +Let \(q\) be a small prime, \(q'\) a large prime, and let \(f\) be an \(\ell\)-variate multilinear with **small coefficients**, ideally bits: + +\[ +f \in F_q^{\le 1}[X_1,\dots,X_\ell],\quad f_i \in \{0,1\}\subset F_q. +\] + +We commit to the coefficient table \((f_i)_{i\in\{0,1\}^\ell}\) using the Hachi/Greyhound-style commitment core over \(R_q\) (or over \(F_q\) as the \(d=1\) special case). + +The opening claim we want to support is over the **large prime field**: + +\[ +\text{given } x\in F_{q'}^\ell,\ y\in F_{q'},\ \text{prove } f(x)=y\ \text{(interpreting the coefficients as small integers in }F_{q'}\text{).} +\] + +Because the coefficients are in \(\{0,1\}\), there is a canonical injection \(\iota:\{0,1\}\to F_{q'}\) (map to the same integers). The only remaining job is to enforce that the committed coefficients are indeed in \(\{0,1\}\) (bitness) and that every algebraic check is performed with respect to this integer lift. + +#### F.0.1 Important clarification: this is **not** “digit-decompose \(x,y\)” + +It is tempting to think “we have an evaluation claim \(f(x)=y\) in the large field \(F_{q'}\), so we should decompose \(x\) and \(y\) into base-\(b\) digits and prove digit-wise subclaims.” That is **not** what the modulus-switching reduction does, and it generally does not work (polynomial evaluation does not decompose into independent digit evaluations). + +Instead: + +- The evaluation point \(x\in F_{q'}^\ell\) and claimed value \(y\in F_{q'}\) are already *native* to the field where Jolt runs sumcheck. There is typically no reason to decompose them. +- The lift issue arises because the **committed objects** live over the *small* modulus \(q\) (i.e. values are only defined modulo \(q\)), while the verifier wants to check equations **inside \(F_{q'}\)**. + +So the purpose of \(\mathrm{lift}_q\) is to assign a **canonical integer meaning** to committed mod-\(q\) values (made unambiguous by range constraints like “bitness”), so those values can be interpreted inside \(F_{q'}\). + +Warm-up (scalar version). For integers \(\tilde A,\tilde B\in Z\), the statement “\(A=B\) in \(Z_q\)” is exactly: + +\[ +\tilde A \equiv \tilde B \pmod q +\quad\Longleftrightarrow\quad +\exists s\in Z:\ \tilde A - \tilde B = q\cdot s. +\] + +The \(q\cdot s\) term is the “modulus switching slack”. The ring/polynomial version in F.1 is the same idea applied coefficient-wise (and with an additional cyclotomic slack \((X^d+1)\cdot r\) for the ring quotient). + +#### F.0.2 How this plugs into Hachi’s six core constraints (B.4.1) in the Jolt Stage-8 setting + +Jolt Stage 8 asks the PCS to prove openings of the form “\(P(\mathbf r)=v\)” where both the point \(\mathbf r\) and value \(v\) live in the **Jolt field** \(F\) (in your target case, \(F=F_{q'}\)). See: + +- `../jolt/jolt-core/src/poly/opening_proof.rs`: `pub type Opening = (OpeningPoint<..., F>, F);` +- `../jolt/jolt-core/src/zkvm/prover.rs`: Stage 8 calls `PCS::prove(..., &opening_point.r, ...)` where `opening_point.r` is a vector of field challenges. + +Hachi’s Step B constraints (the six items in B.4.1) are written as equalities over the **ring** \(R_q\). To use a Hachi-style PCS under Jolt’s interface, we keep the *same witness objects* \((\hat w,\hat t,\hat z,\dots)\) and the *same logical constraints*, but we do **not** expect the verifier to check them natively in \(R_q\) (and we cannot even form the “\(a,b\in R_q\) from the opening point” parts when \(\mathbf r\in F_{q'}^\ell\)). + +Instead, the PCS opening proof checks these constraints **after ring switching**, i.e. after applying an evaluation map + +\[ +\mathrm{ev}_\alpha: R_q \to F_{q'} +\] + +at a random \(\alpha\in F_{q'}\) (and including the modulus-switch slack \(q\cdot s\) to make “mod \(q\)” equalities meaningful in \(F_{q'}\)). + +Concretely: + +- Constraints **(1), (2), and the purely ring-linear parts** of (4),(5) are still “ring equations”, but they are verified in \(F_{q'}\) by checking \(\mathrm{ev}_\alpha(\text{LHS}-\text{RHS})=0\) (with quotient witnesses for \((X^d+1)\) and \(q\) as in F.1). +- Constraints **(3)–(5)** are the ones that *mention the evaluation point* (via the monomial vectors \(a,b\)) and therefore must be interpreted in the **field**: + - compute \(a,b\) directly from the Stage-8 opening point \(\mathbf r\in F_{q'}^\ell\), + - treat unknown ring elements like \(w_i\in R_q\) only through their field images \(w_i(\alpha):=\mathrm{ev}_\alpha(w_i)\in F_{q'}\), + - and enforce the same algebraic equalities (Eq. (17)–(19)) in \(F_{q'}\). + +So the “generalization to account for \(q'\)” is not “change a few constraints while keeping the rest in \(q\)”; it is: + +> keep the constraint *structure*, but **move their verification domain** from \(R_q\) to \(F_{q'}\) via ring switching (plus a \(q\cdot s\) slack), because Stage 8’s statement lives in \(F_{q'}\). + +#### F.0.3 Why you cannot keep the *entire* opening proof “purely in \(F_q/R_q\)” under Jolt’s interface + +Under Jolt, the opening point \(\mathbf r\) is sampled as transcript challenges in the **same field as the sumchecks**, i.e. \(\mathbf r\in F_{q'}^m\) (see `OpeningPoint<..., F>` in `../jolt/jolt-core/src/poly/opening_proof.rs`). + +Any PCS that plugs into Stage 8 must therefore convince the verifier of a statement that is *parameterized by* these \(F_{q'}\) elements: + +\[ +P(\mathbf r)=v\quad\text{for }\mathbf r,v\in F_{q'}. +\] + +If a verifier refuses to do any \(F_{q'}\) arithmetic, it cannot even *evaluate the public weights* (equality polynomials) at \(\mathbf r\), nor can it check the final identity that defines “evaluation at \(\mathbf r\)”. The only way around this would be to represent every \(F_{q'}\) element (including \(\mathbf r\), \(v\), and all derived weights) as **non-native data** over \(F_q\) (e.g. base-\(b\) limbs) and then prove that these limbs satisfy the mod-\(q'\) arithmetic via additional quotient/carry constraints. + +That “all-\(F_q\) verification” route is a different, much heavier design: it is essentially a SNARK for \(F_{q'}\) arithmetic implemented over \(F_q\), and it introduces digit/carry constraints for every multiplication/addition in the opening protocol. This is *not* what Hachi’s ring switching is optimizing for. + +So the realistic design space is: + +- commitments / witness representation over \(F_q\) / \(R_q\) for performance, +- but verifier-side checking (sumcheck challenges, equality weights, and final identities) in \(F_{q'}\). + +#### F.0.4 Recursion / “next folding steps” in the cross-prime setting + +In Hachi, the “recursion boundary” is: sumcheck reduces a large set of constraints to **one new opening claim** for a committed multilinear object (the \(\tilde w\) table). The next layer repeats the same pattern on a smaller witness. + +In the cross-prime Stage-8 adaptation: + +- the **commitment domain stays** \(F_q/R_q\) at every layer (you keep committing to digitized tables over the small modulus), +- the **opening points stay** in \(F_{q'}\) at every layer, because they are derived from sumcheck challenges / transcript in \(F_{q'}\), +- and the verifier continues to check constraints in \(F_{q'}\) via ring switching (evaluation at \(\alpha\in F_{q'}\) plus the \((X^d+1)\cdot r\) and \(q\cdot s\) quotient witnesses). + +So “continuing in \(F_{q'}\)” does not mean “switch commitments to \(q'\)”; it means the *interactive checking algebra* (sumcheck, batching, evaluation weights) remains in the Jolt field where the statement is expressed. + +#### F.0.5 One full “folding/recursion step” of the adapted PCS (clean, end-to-end) + +This subsection gives a clean one-layer view: how we go from **one opening claim at point \(\mathbf r\in F_{q'}^m\)** to **a smaller opening claim at a new point \(\mathbf r^\*\in F_{q'}^{m'}\)**, and why that output is amenable to repeating the same procedure. + +We describe this as a PCS protocol (what Stage 8 wants), independent of Jolt’s earlier stages. + +##### Inputs (statement) and commitment domain + +- **Commitment domain** (small): commitments are to coefficient tables over \(F_q\) (or ring elements over \(R_q\)) using Hachi’s §4.1 Ajtai-style structure. +- **Opening statement domain** (large): the opening point \(\mathbf r\) and claimed value \(v\) are in \(F_{q'}\) (Jolt’s `JoltField`). + +So the opening statement is: + +\[ +\text{Given commitment } C \text{ to a multilinear table } P,\ \text{and public } \mathbf r\in F_{q'}^m,\ v\in F_{q'},\ \text{prove } P(\mathbf r)=v. +\] + +Here \(P(\mathbf r)\) is defined by interpreting the committed coefficients as integers (via \(\mathrm{lift}_q\)) and reducing them into \(F_{q'}\); for bit/one-hot tables, this interpretation is canonical. + +##### Step 1 (Hachi Step B “split-and-fold”): reduce evaluation to a structured linear relation witness + +This step is unchanged in *spirit* from Hachi: we rewrite “\(P(\mathbf r)=v\)” into a small set of algebraic constraints involving intermediate witnesses (partial evaluations, decomposed digits, and folded combinations). + +The key point in the two-field setting is typing: + +- any time the original §4.2 constraints multiply “point-derived scalars” into witness quantities, those scalars are in \(F_{q'}\) and we interpret the witness quantities through their images in \(F_{q'}\) (via ring switching in Step 2), rather than trying to multiply \(F_{q'}\) scalars by \(R_q\) elements directly. + +Operationally, the prover still constructs the same witness objects (\(\hat w,\hat t,\hat z\), folded \(z\), etc.) and commits to the same ring elements (\(u,v\), etc.). The verifier will not check the ring equations directly; it will check their *ring-switched* images in Step 2. + +##### Step 2 (Hachi Step C generalized): ring switching + modulus switching to get field-native constraints + +The verifier samples a random \(\alpha\leftarrow F_{q'}\) and defines the evaluation map: + +\[ +\mathrm{ev}_\alpha: R_q \to F_{q'} +\] + +as “lift coefficients to integers (per chosen \(\mathrm{lift}_q\)), then evaluate the polynomial at \(X=\alpha\) in \(F_{q'}\)”. + +Now take each ring equation from the Step B constraint set (the six constraints in B.4.1) and convert it into a field equation at \(\alpha\) by: + +1. lifting it from \(R_q\) to \(Z[X]\) with a cyclotomic quotient witness \(r\) (as in the paper), and +2. adding a modulus quotient witness \(s\) so that “mod \(q\) equality” becomes an *integer* equality plus a \(q\cdot s\) slack (Section F.1). + +After evaluating at \(X=\alpha\), every check becomes an identity in \(F_{q'}\) involving: + +- public scalars derived from \(\mathbf r\) (equality weights / monomial vectors), +- the prover’s unknowns only through \(\mathrm{ev}_\alpha(\cdot)\), +- and the quotient witnesses \(r(\alpha), s(\alpha)\). + +##### Step 3 (Hachi §4.3 sumcheck): batch all constraints and reduce to one oracle evaluation + +Exactly as in Hachi, we encode the relevant coefficient tables into one committed multilinear object \(\tilde w\) (“witness table embedding”), except that in the cross-prime setting \(\tilde w\) must encode the extra modulus quotient witness \(s\) as well as \(r\). + +Then we define batched constraint polynomials (the analogs of \(H_\alpha\) and \(H_0\)), and run sumcheck over **\(F_{q'}\)** (not over \(R_q\)). The sumcheck transcript produces a final random point: + +\[ +\mathbf r^\* \in F_{q'}^{m'} +\] + +and reduces verification to a **single evaluation claim** of the committed witness-table multilinear: + +\[ +\tilde w(\mathbf r^\*) = v^\*\in F_{q'}. +\] + +##### Step 4 (recursion boundary): output a smaller opening claim of the same type + +This is the crucial “amenable to further folding” point: + +- The new claim “\(\tilde w(\mathbf r^\*)=v^\*\)” has the **same shape** as the original opening claim, just on a different committed multilinear and a different point. +- The commitment to \(\tilde w\) is again over the **small modulus domain** (it is built from digitized tables over \(F_q/R_q\)). +- The opening point \(\mathbf r^\*\) is again in the **large field** \(F_{q'}\), because it is derived from the sumcheck challenges (and thus from the Jolt transcript). + +Therefore, we can repeat the same 4-step pipeline on \(\tilde w\): + +\[ +P(\mathbf r)=v +\ \leadsto\ +\tilde w(\mathbf r^\*)=v^\* +\ \leadsto\ +\tilde w^{(2)}(\mathbf r^{\*\*})=v^{\*\*} +\ \leadsto\ \cdots +\] + +and eventually hand off to a base PCS / small-instance prover once the witness table is small enough (exactly the same “stop recursion when small” design choice as Hachi’s §5 composition discussion). + +#### F.0.6 Option B (for comparison): decompose the \(F_{q'}\) opening point/value and prove everything over \(F_q/R_q\) + +This subsection works out the alternative you asked for: **avoid doing any verifier arithmetic in \(F_{q'}\)** by representing all \(F_{q'}\) elements (the opening point \(\mathbf r\) and claimed value \(v\), plus all derived weights) as *digits/limbs* over the small field. + +This is not “ring switching”; it is **non-native (foreign-field) arithmetic**: we simulate mod-\(q'\) arithmetic inside a proof system whose native arithmetic is mod-\(q\). + +##### Why this is a different interface than Jolt Stage 8 + +In Jolt, Stage 8’s opening statement is parameterized by an actual \(\mathbf r\in F_{q'}^m\) sampled from the transcript (see `OpeningPoint<..., F>`). + +Option B instead treats \(\mathbf r\) and \(v\) as **public limb vectors over \(F_q\)**. To plug this into Jolt unchanged, you would need to either: + +- change Jolt’s transcript/challenges to live in the small field (not compatible with the stated “char \(> u64\cdot u64\)” requirement), or +- keep \(\mathbf r\in F_{q'}\) as usual but additionally provide a limb decomposition of \(\mathbf r\) and then **prove inside the PCS** that those limbs reconstruct the same \(\mathbf r\) (which reintroduces \(F_{q'}\) operations unless you also non-natively model the transcript). + +So Option B is best viewed as a “theoretical comparison point”, not a drop-in Stage-8 replacement. + +##### Representation: limbs for \(F_{q'}\) elements + +Fix a radix \(B=2^t\) (e.g. \(B=2^{16}\) so limbs fit comfortably under a 32-bit-ish \(q\)), and let + +\[ +L := \left\lceil \log_B(q') \right\rceil +\] + +be the limb count (for a 128-bit prime and \(B=2^{16}\), \(L\approx 8\)). + +We represent a field element \(a\in F_{q'}\) by an integer representative \(\tilde a\in[0,q')\) (canonical lift mod \(q'\)) and a limb decomposition: + +\[ +\tilde a = \sum_{j=0}^{L-1} a_j B^j,\quad a_j\in\{0,1,\dots,B-1\}. +\] + +All limbs \(a_j\) are then encoded as small elements of \(F_q\) (and range-checked to lie in \([0,B)\)). + +##### Non-native arithmetic constraints (core gadget) + +To simulate \(F_{q'}\) arithmetic, every operation becomes an integer identity plus a \(q'\)-multiple slack: + +- **Addition**: to enforce \(c \equiv a+b \pmod{q'}\), prove + + \[ + \tilde a + \tilde b - \tilde c = q'\cdot k + \] + + for some integer \(k\) (with \(k\in\{0,1\}\) if \(\tilde a,\tilde b,\tilde c\in[0,q')\)). + +- **Multiplication**: to enforce \(c \equiv a\cdot b \pmod{q'}\), prove + + \[ + \tilde a\cdot \tilde b - \tilde c = q'\cdot k + \] + + for some integer \(k\) with \(0 \le k < q'\). + +In practice, you do not materialize \(\tilde a\) as a single huge integer; you enforce these identities **in base \(B\) with carries**: + +- introduce carry variables \(u_t\) so that the schoolbook convolution of limbs matches the target limbs, +- and range-check carries so that “equality mod \(q\)” implies “equality over integers” (no wrap-around) at the limb level. + +Cost model: one non-native multiplication of two \(L\)-limb numbers costs + +\[ +\Theta(L^2)\ \text{small-field multiplications} \quad + \quad \Theta(L)\ \text{carry/range constraints}. +\] + +##### How to prove an opening claim \(P(\mathbf r)=v\) using Option B + +Let \(P\) be a multilinear polynomial committed over \(F_q\) (or \(R_q\)) whose coefficients are small integers (bits/one-hot is the cleanest case). + +Public input for Option B: + +- limb decompositions of each coordinate \(r_i\in F_{q'}\) (so \(\tilde r_i=\sum_j r_{i,j}B^j\)), +- limb decomposition of \(v\in F_{q'}\), +- and the modulus \(q'\) itself (public constant). + +Prover witness: + +- limb decompositions for all intermediate \(F_{q'}\) values needed to evaluate \(P\) at \(\mathbf r\), +- plus all carry/quotient witnesses for the modular reduction constraints above. + +Then enforce, inside the proof system over \(F_q/R_q\): + +1. **Range**: every limb is in \([0,B)\) (bitness is the special case \(B=2\)). +2. **Recomposition**: the limb vectors correspond to integers in \([0,q')\) (often implemented by providing a quotient \(t\) such that \(\sum_j a_jB^j = \tilde a + q' t\) and constraining \(\tilde a CE(b,L)^(K+k) -- + | + Π_RLC + | + CE(B,L) + | + Π_DEC + | + next accumulator: CE(b,L)^k +``` + +### What the prover sends in each sub-protocol + +Below is the communication “payload” you should have in mind (Fiat–Shamir makes verifier randomness transcript-derived, but sizes are the same). + +#### Π_CCS (paper §7.3) + +Prover payload: + +- **Sum-check transcript over \(K\)** for \(\log m\) rounds (one univariate polynomial per round; degree is small in typical CCS, e.g. R1CS-like \(u=2\)). +- **Oracle answers after sum-check**: for every claim \(i\in[K+k]\) and every matrix \(j\in[t]\), send + \[ + y'_{i,j} = \bar M_j \tilde z_i(r') \in R_K. + \] + +Rule of thumb: the \(y'_{i,j}\) payload dominates. + +#### Π_RLC (paper §7.4) + +This is just random linear combination using \(\rho_i\in C\): + +\[ +c := \sum_i \rho_i c_i,\quad y_j := \sum_i \rho_i y_{i,j},\quad z := \sum_i \rho_i z_i. +\] + +Prover payload: essentially **nothing new** (the reduction is “algebraic bookkeeping”; in FS the \(\rho_i\) come from the transcript). + +#### Π_DEC (paper §7.5) + +Prover payload: + +- Decompose \(z\) into digits \((z_1,\dots,z_k)=\mathrm{split}_b(z)\). +- Send **\(k\) new commitments** \(c_i = L(z_i)\). +- Send **\(k\cdot t\)** new lifted evaluations \(y_{i,j} = \bar M_j \tilde z_i(r)\in R_K\). + +### Proof size accounting (symbolic, per fold step) + +Let: + +- commitment output be \(c\in R_F^\kappa\) (Ajtai: \(\kappa\) ring elements over \(R_F\)), +- each lifted evaluation be one ring element in \(R_K\) (degree \(d\) over \(K\)), +- \(\ell := \log m\) be the sum-check arity. + +Then per fold step, prover sends approximately: + +- **Π\_CCS**: + - sum-check: \(O(\ell)\) field elements in \(K\), + - evaluations: \((K+k)\cdot t\) ring elements in \(R_K\). +- **Π\_DEC**: + - commitments: \(k\cdot \kappa\) ring elements in \(R_F\), + - evaluations: \(k\cdot t\) ring elements in \(R_K\). + +Total (dominant terms): + +\[ +\text{#}(R_K\text{-elements}) \approx (K+k)t + kt = (K+2k)t, +\] +\[ +\text{#}(R_F\text{-elements}) \approx k\kappa, +\] +plus a small \(\tilde O(\log m)\) number of \(K\)-field elements from sum-check. + +To convert to bytes for appendix parameter sets where \(K=\mathbb{F}_{q^2}\) and \(q\) is 61–64 bits: + +- 1 \(F\)-element ≈ 8 bytes +- 1 \(K\)-element ≈ 16 bytes +- 1 \(R_F\)-element ≈ \(d\cdot 8\) bytes +- 1 \(R_K\)-element ≈ \(d\cdot 16\) bytes +- 1 commitment \(c\in R_F^\kappa\) ≈ \(\kappa\cdot d\cdot 8\) bytes + +So a back-of-the-envelope fold proof byte size is: + +\[ +\approx (K+2k)t\cdot(d\cdot 16)\;+\;k\kappa\cdot(d\cdot 8)\;+\;\tilde O(\log m)\cdot 16. +\] + +## Exact counts + +- Total concrete field/cyclotomic tuples in SuperNeo appendix: **3** +- Unique base fields in those tuples: **3** +- Unique cyclotomic polynomials in those tuples: **2** + - `X^64 + 1` + - `X^54 + X^27 + 1` +- Strict Solinas-only tuples (`q = 2^x - 2^y + 1`): **2 / 3** + - Goldilocks (`2^64 - 2^32 + 1`) + - Mersenne61 (`2^61 - 1 = 2^61 - 2^1 + 1`) +- Non-Solinas tuple: **1 / 3** + - Almost-Goldilocks (`2^64 - 2^32 - 31`) + +## What SuperNeo suggests (for their purpose) + +- Keep sum-check and norm-check over a small field (or small extension), not over ring arithmetic. +- Do not restrict to only power-of-two cyclotomics (`X^d + 1`); using broader cyclotomics can unlock better field compatibility. +- For Goldilocks/M61, the paper uses a trinomial cyclotomic (`X^54 + X^27 + 1`) instead of `X^d + 1`. +- They explicitly discuss why power-of-two cyclotomics are problematic for some small fields (full splitting / security issues in prior designs). + +## Fit into Hachi design (easy vs harder) + +- Easy fit now: + - Almost-Goldilocks-like setup with `d = 64` and `X^64 + 1` is structurally close to Hachi's current ring form. +- Requires refactor: + - Goldilocks and M61 with `X^54 + X^27 + 1` do not match the current hardcoded negacyclic `X^D + 1` shape. + +## Minimal integration plan into current architecture + +1. Add a cyclotomic profile abstraction (`CyclotomicProfile`) that defines modulus polynomial behavior. +2. Keep current profile as `Negacyclic` (`X^D + 1`) for existing NTT/CRT path. +3. Add a `Trinomial54` profile (`X^54 + X^27 + 1`) for SuperNeo-style field/ring experiments. +4. Keep backend split: + - current CRT+NTT backend for `Negacyclic` only, + - coefficient-domain backend first for `Trinomial54` (no forced NTT dependency). +5. Add explicit domain aliases for profile-bound rings so APIs remain clear (`CoeffDomain`). + +## Practical implication if we insist on Solinas-only fields + +- Within SuperNeo's concrete options, we still have **2 viable Solinas choices**: + - Goldilocks + `X^54 + X^27 + 1` + - M61 + `X^54 + X^27 + 1` +- If we also insist on power-of-two cyclotomics only (`X^d + 1`), SuperNeo's concrete Solinas count drops to **0**. diff --git a/docs/TRANSCRIPT_COMMITMENT_COMPAT_SPEC.md b/docs/TRANSCRIPT_COMMITMENT_COMPAT_SPEC.md new file mode 100644 index 00000000..03fceb7a --- /dev/null +++ b/docs/TRANSCRIPT_COMMITMENT_COMPAT_SPEC.md @@ -0,0 +1,106 @@ +# Transcript and Commitment Compatibility Spec (Hachi Core) + +This document specifies Hachi's protocol-layer transcript and commitment interfaces. + +## Scope + +- Applies to Hachi core (`src/protocol/*`). +- Uses **Hachi-native** transcript labels and ordering. +- Does **not** wire Hachi into Jolt in this phase. +- Any future cross-system interop (for example Jolt-facing adaptation) must be handled by an adapter layer outside core label definitions. + +## Transcript Contract + +Hachi protocol transcripts implement: + +- `new(domain_label)` +- `append_bytes(label, bytes)` +- `append_field(label, x)` +- `append_serde(label, s)` +- `challenge_scalar(label)` +- `reset(domain_label)` + +Current core implementations: + +- `Blake2bTranscript` (Blake2b-512) +- `KeccakTranscript` (Keccak-256, matching Jolt's `sha3` crate usage) + +### Byte Framing + +All absorbed bytes use deterministic framing: + +- `label || len_le64 || bytes` + +This framing is applied uniformly for raw bytes, fields, and serializable protocol objects. + +### Field Encoding + +- Field elements are encoded through canonical representatives (little-endian `u128` bytes). +- Challenge derivation maps transcript digest bytes into field elements via canonical reduction. + +### Label Namespace + +All labels are defined in `src/protocol/transcript/labels.rs` and are Hachi-native. + +Reserved core labels include: + +- Domain label: `hachi/protocol` +- Commitment phase (§4.1): commitment +- Reduction phase (§4.2): evaluation-claims + linear-relation challenge +- Ring-switch phase (§4.3): ring-switch-message + ring-switch challenge +- Sumcheck phase (§4.3): sumcheck-round + sumcheck-round challenge +- Recursion stop phase (§4.5): stop-condition + stop-condition challenge + +Forbidden in Hachi core transcript constants: + +- Dory label literals (for example `vmv_c`, `beta`, `alpha`, `gamma`, `final_e1`, `final_e2`, `d`) + +## Commitment Contract + +Hachi protocol commitment interfaces include: + +- `CommitmentScheme` +- `StreamingCommitmentScheme` +- `AppendToTranscript` + +The commitment layer defines: + +- setup split (`setup_prover`, `setup_verifier`) +- commitment/opening APIs (`commit`, `prove`, `verify`) +- homomorphic combination APIs (`combine_commitments`, `combine_hints`) +- optional streaming/chunked path for large inputs +- label-directed transcript absorption (`AppendToTranscript` call sites choose event labels) + +## Determinism Requirements + +- Prover and verifier must absorb the same labeled byte sequence in the same order. +- Transcript challenges must be reproducible for identical input schedules. +- Commitment/proof objects absorbed via `append_serde` must use deterministic `HachiSerialize` encoding. + +## Test Requirements + +Tests should enforce: + +- Transcript replay determinism (same schedule => same challenges). +- Label/order sensitivity (different labels/order => diverging challenges). +- Framing stability. +- No Dory-label leakage in Hachi label constants/schedules. +- Commitment/hint combination algebraic sanity. + +## Deferred Integration Note + +Integration into Jolt is a separate, deferred phase tracked in `HACHI_PROGRESS.md`. +When started, an adapter should translate between external transcript conventions and Hachi core interfaces without changing Hachi-native core labels. + +## Deferred Adapter Contract (Design Only) + +`JoltToHachiTranscript` is deferred, but its expected behavior is fixed now: + +- Owns a mutable reference to a Jolt transcript object. +- Implements Hachi `Transcript` by forwarding absorption/challenge calls. +- Performs label translation at the boundary (Jolt-side naming to Hachi-side API events). +- Never mutates or extends Hachi core label constants. +- Maintains deterministic call ordering: prover and verifier adapter paths must replay identical absorb/challenge sequences. +- Supports domain initialization and explicit reset semantics. + +This adapter lives outside Hachi core protocol modules and is not part of this phase's implementation. diff --git a/examples/codegen_probe_special.rs b/examples/codegen_probe_special.rs new file mode 100644 index 00000000..bc6d5b50 --- /dev/null +++ b/examples/codegen_probe_special.rs @@ -0,0 +1,143 @@ +#![allow(missing_docs)] + +//! Codegen probe for packed/scalar Fp64 multiply kernels. +//! +//! Build with: +//! `cargo rustc --example codegen_probe_special --release -- --emit=asm` + +use hachi_pcs::algebra::fields::pseudo_mersenne::{POW2_OFFSET_MODULUS_40, POW2_OFFSET_MODULUS_64}; +use hachi_pcs::algebra::{Fp64, Fp64Packing, PackedValue}; +use hachi_pcs::CanonicalField; + +const MASK40: u64 = (1u64 << 40) - 1; +const P40: u64 = POW2_OFFSET_MODULUS_40; +const C40: u64 = (1u64 << 40) - P40; // 195 +const P64: u64 = POW2_OFFSET_MODULUS_64; +const C64: u64 = 0u64.wrapping_sub(P64); // 59 + +#[inline(always)] +fn mul_c40_split(x: u64) -> u64 { + let c = C40 as u32; + let x_lo = x as u32; + let x_hi = (x >> 32) as u32; + (c as u64 * x_lo as u64).wrapping_add((c as u64 * x_hi as u64) << 32) +} + +#[inline(always)] +fn mul_c40_shiftadd(x: u64) -> u64 { + // 195x = (128 + 64 + 2 + 1) * x + (x << 7) + .wrapping_add(x << 6) + .wrapping_add(x << 1) + .wrapping_add(x) +} + +#[inline(always)] +fn reduce40_with_mulc(lo: u64, hi: u64, mulc: fn(u64) -> u64) -> u64 { + let high = (lo >> 40) | (hi << 24); + let f1 = (lo & MASK40).wrapping_add(mulc(high)); + let f2 = (f1 & MASK40).wrapping_add(mulc(f1 >> 40)); + let reduced = f2.wrapping_sub(P40); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P40) +} + +#[inline(always)] +fn reduce64(lo: u64, hi: u64) -> u64 { + let f1 = (lo as u128) + (C64 as u128) * (hi as u128); + let f2 = (f1 as u64 as u128) + (C64 as u128) * ((f1 >> 64) as u64 as u128); + let reduced = f2.wrapping_sub(P64 as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P64 as u128)) as u64 +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_reduce40_split(lo: u64, hi: u64) -> u64 { + reduce40_with_mulc(lo, hi, mul_c40_split) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_reduce40_shiftadd(lo: u64, hi: u64) -> u64 { + reduce40_with_mulc(lo, hi, mul_c40_shiftadd) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_reduce64(lo: u64, hi: u64) -> u64 { + reduce64(lo, hi) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_packed_fp64_40_mul(a0: u64, a1: u64, b0: u64, b1: u64) -> u64 { + type F = Fp64<{ POW2_OFFSET_MODULUS_40 }>; + type PF = Fp64Packing<{ POW2_OFFSET_MODULUS_40 }>; + + let a = PF::from_fn(|i| { + if i == 0 { + F::from_canonical_u64(a0) + } else { + F::from_canonical_u64(a1) + } + }); + let b = PF::from_fn(|i| { + if i == 0 { + F::from_canonical_u64(b0) + } else { + F::from_canonical_u64(b1) + } + }); + let c = a * b; + (c.extract(0).to_canonical_u128() as u64) ^ (c.extract(1).to_canonical_u128() as u64) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_packed_fp64_64_mul(a0: u64, a1: u64, b0: u64, b1: u64) -> u64 { + type F = Fp64<{ POW2_OFFSET_MODULUS_64 }>; + type PF = Fp64Packing<{ POW2_OFFSET_MODULUS_64 }>; + + let a = PF::from_fn(|i| { + if i == 0 { + F::from_canonical_u64(a0) + } else { + F::from_canonical_u64(a1) + } + }); + let b = PF::from_fn(|i| { + if i == 0 { + F::from_canonical_u64(b0) + } else { + F::from_canonical_u64(b1) + } + }); + let c = a * b; + (c.extract(0).to_canonical_u128() as u64) ^ (c.extract(1).to_canonical_u128() as u64) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_scalar_fp64_40_mul(a: u64, b: u64) -> u64 { + type F = Fp64<{ POW2_OFFSET_MODULUS_40 }>; + (F::from_canonical_u64(a) * F::from_canonical_u64(b)).to_canonical_u128() as u64 +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_scalar_fp64_64_mul(a: u64, b: u64) -> u64 { + type F = Fp64<{ POW2_OFFSET_MODULUS_64 }>; + (F::from_canonical_u64(a) * F::from_canonical_u64(b)).to_canonical_u128() as u64 +} + +fn main() { + let x = probe_packed_fp64_40_mul(1, 2, 3, 4) + ^ probe_packed_fp64_64_mul(5, 6, 7, 8) + ^ probe_scalar_fp64_40_mul(9, 10) + ^ probe_scalar_fp64_64_mul(11, 12) + ^ probe_reduce40_split(13, 14) + ^ probe_reduce40_shiftadd(15, 16) + ^ probe_reduce64(17, 18); + std::hint::black_box(x); +} diff --git a/paper/fri-binius.pdf b/paper/fri-binius.pdf new file mode 100644 index 00000000..9f19074d Binary files /dev/null and b/paper/fri-binius.pdf differ diff --git a/paper/greyhound.pdf b/paper/greyhound.pdf new file mode 100644 index 00000000..75d13b41 Binary files /dev/null and b/paper/greyhound.pdf differ diff --git a/paper/labrador.pdf b/paper/labrador.pdf new file mode 100644 index 00000000..eb93a89a Binary files /dev/null and b/paper/labrador.pdf differ diff --git a/paper/standards/NIST.FIPS.203.pdf b/paper/standards/NIST.FIPS.203.pdf new file mode 100644 index 00000000..a97b548e Binary files /dev/null and b/paper/standards/NIST.FIPS.203.pdf differ diff --git a/paper/standards/NIST.FIPS.204.pdf b/paper/standards/NIST.FIPS.204.pdf new file mode 100644 index 00000000..33368b88 Binary files /dev/null and b/paper/standards/NIST.FIPS.204.pdf differ diff --git a/paper/superneo.pdf b/paper/superneo.pdf new file mode 100644 index 00000000..39cf953e Binary files /dev/null and b/paper/superneo.pdf differ diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 00000000..c5b9f7f3 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,4 @@ +[toolchain] +channel = "1.88" +profile = "minimal" +components = ["cargo", "rustc", "clippy", "rustfmt"] diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 00000000..541c50b7 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,5 @@ +edition = "2021" +tab_spaces = 4 +newline_style = "Unix" +use_try_shorthand = true +use_field_init_shorthand = true diff --git a/src/algebra/backend/mod.rs b/src/algebra/backend/mod.rs new file mode 100644 index 00000000..3b885ce1 --- /dev/null +++ b/src/algebra/backend/mod.rs @@ -0,0 +1,7 @@ +//! Backend contracts and concrete backend implementations. + +pub mod scalar; +pub mod traits; + +pub use scalar::ScalarBackend; +pub use traits::{CrtReconstruct, NttPrimeOps, NttTransform, RingBackend}; diff --git a/src/algebra/backend/scalar.rs b/src/algebra/backend/scalar.rs new file mode 100644 index 00000000..1ef4c7da --- /dev/null +++ b/src/algebra/backend/scalar.rs @@ -0,0 +1,94 @@ +//! Default scalar backend: delegates to NTT kernels and uses Garner's +//! algorithm for CRT reconstruction. + +use super::traits::{CrtReconstruct, NttPrimeOps, NttTransform}; +use crate::algebra::ntt::butterfly::{forward_ntt, inverse_ntt, NttTwiddles}; +use crate::algebra::ntt::crt::GarnerData; +use crate::algebra::ntt::prime::{MontCoeff, NttPrime, PrimeWidth}; +use crate::algebra::ring::CrtNttConvertibleField; + +/// Default scalar backend implementation. +#[derive(Debug, Clone, Copy, Default)] +pub struct ScalarBackend; + +impl NttPrimeOps for ScalarBackend { + #[inline] + fn from_canonical(prime: NttPrime, value: W) -> MontCoeff { + prime.from_canonical(value) + } + + #[inline] + fn to_canonical(prime: NttPrime, value: MontCoeff) -> W { + prime.to_canonical(value) + } + + #[inline] + fn reduce_range(prime: NttPrime, value: MontCoeff) -> MontCoeff { + prime.reduce_range(value) + } + + #[inline] + fn pointwise_mul( + prime: NttPrime, + out: &mut [MontCoeff; D], + lhs: &[MontCoeff; D], + rhs: &[MontCoeff; D], + ) { + prime.pointwise_mul(out, lhs, rhs); + } +} + +impl NttTransform for ScalarBackend { + #[inline] + fn forward_ntt(limb: &mut [MontCoeff; D], prime: NttPrime, twiddles: &NttTwiddles) { + forward_ntt(limb, prime, twiddles); + } + + #[inline] + fn inverse_ntt(limb: &mut [MontCoeff; D], prime: NttPrime, twiddles: &NttTwiddles) { + inverse_ntt(limb, prime, twiddles); + } +} + +impl CrtReconstruct for ScalarBackend { + fn reconstruct( + primes: &[NttPrime; K], + canonical: &[[W; D]; K], + garner: &GarnerData, + ) -> [F; D] { + let mut coeffs = [F::zero(); D]; + for (d, coeff) in coeffs.iter_mut().enumerate() { + // Garner mixed-radix decomposition (all arithmetic in i64, mod p_i). + let mut v = [0i64; K]; + v[0] = canonical[0][d].to_i64(); + for i in 1..K { + let pi = primes[i].p.to_i64(); + let mut temp = canonical[i][d].to_i64(); + #[allow(clippy::needless_range_loop)] + for j in 0..i { + temp -= v[j]; + temp = ((temp % pi) + pi) % pi; + temp = (temp * garner.gamma[i][j].to_i64()) % pi; + } + // Center the mixed-radix digit to keep the final reconstruction + // in a small signed range when inputs are centered. + if temp > pi / 2 { + temp -= pi; + } + v[i] = temp; + } + + // Horner accumulation in the target field F. + let mut result = F::from_i64(v[0]); + let mut partial_prod = F::from_i64(primes[0].p.to_i64()); + for i in 1..K { + result = result + F::from_i64(v[i]) * partial_prod; + if i + 1 < K { + partial_prod = partial_prod * F::from_i64(primes[i].p.to_i64()); + } + } + *coeff = result; + } + coeffs + } +} diff --git a/src/algebra/backend/traits.rs b/src/algebra/backend/traits.rs new file mode 100644 index 00000000..118a5f1f --- /dev/null +++ b/src/algebra/backend/traits.rs @@ -0,0 +1,59 @@ +//! Backend traits for CRT+NTT execution semantics. +//! +//! All traits are generic over `W: PrimeWidth` to support both +//! `i16` (primes < 2^14) and `i32` (primes < 2^30) NTT backends. + +use crate::algebra::ntt::butterfly::NttTwiddles; +use crate::algebra::ntt::crt::GarnerData; +use crate::algebra::ntt::prime::{MontCoeff, NttPrime, PrimeWidth}; +use crate::algebra::ring::CrtNttConvertibleField; + +/// Per-prime arithmetic primitives used by CRT+NTT domains. +pub trait NttPrimeOps { + /// Convert canonical coefficient to backend prime representation. + fn from_canonical(prime: NttPrime, value: W) -> MontCoeff; + + /// Convert backend prime representation back to canonical coefficient. + fn to_canonical(prime: NttPrime, value: MontCoeff) -> W; + + /// Range-reduce one coefficient from `(-2p, 2p)` to `(-p, p)`. + fn reduce_range(prime: NttPrime, value: MontCoeff) -> MontCoeff; + + /// Pointwise multiplication in backend prime representation. + fn pointwise_mul( + prime: NttPrime, + out: &mut [MontCoeff; D], + lhs: &[MontCoeff; D], + rhs: &[MontCoeff; D], + ); +} + +/// Forward/inverse transform kernels for one NTT limb. +pub trait NttTransform { + /// Forward transform from coefficient limb to NTT limb. + fn forward_ntt(limb: &mut [MontCoeff; D], prime: NttPrime, twiddles: &NttTwiddles); + + /// Inverse transform from NTT limb to coefficient limb. + fn inverse_ntt(limb: &mut [MontCoeff; D], prime: NttPrime, twiddles: &NttTwiddles); +} + +/// CRT reconstruction from per-prime canonical coefficients via Garner's algorithm. +pub trait CrtReconstruct { + /// Reconstruct coefficient-domain values from canonical CRT residues. + fn reconstruct( + primes: &[NttPrime; K], + canonical_limbs: &[[W; D]; K], + garner: &GarnerData, + ) -> [F; D]; +} + +/// Convenience composition trait for full ring backend capability. +pub trait RingBackend: + NttPrimeOps + NttTransform + CrtReconstruct +{ +} + +impl RingBackend for T where + T: NttPrimeOps + NttTransform + CrtReconstruct +{ +} diff --git a/src/algebra/fields/ext.rs b/src/algebra/fields/ext.rs new file mode 100644 index 00000000..c341b6fa --- /dev/null +++ b/src/algebra/fields/ext.rs @@ -0,0 +1,704 @@ +//! Quadratic and quartic extension fields. + +use crate::algebra::module::VectorModule; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{FieldCore, FieldSampling, FromSmallInt}; + +/// `Fp2Config` with non-residue = -1. +/// +/// Valid when `p ≡ 3 (mod 4)`, i.e. -1 is a quadratic non-residue. +pub struct NegOneNr; + +impl Fp2Config for NegOneNr { + const IS_NEG_ONE: bool = true; + + fn non_residue() -> F { + -F::one() + } +} + +/// `Fp2Config` with non-residue = 2. +/// +/// Valid when `p ≡ 5 (mod 8)`, i.e. 2 is a quadratic non-residue. +/// All Hachi pseudo-Mersenne primes (`2^k - c` with `c ≡ 3 mod 8`) +/// satisfy this. +pub struct TwoNr; + +impl Fp2Config for TwoNr { + fn non_residue() -> F { + F::from_u64(2) + } +} +use rand_core::RngCore; +use std::io::{Read, Write}; +use std::marker::PhantomData; +use std::ops::{Add, Mul, Neg, Sub}; + +/// Parameters for an `Fp2` quadratic extension over base field `F`. +pub trait Fp2Config { + /// Whether the non-residue is -1. + /// + /// When `true`, multiplication by the non-residue is a free negation and + /// the Karatsuba/squaring routines can avoid a base-field multiply. + const IS_NEG_ONE: bool = false; + + /// Non-residue `NR` such that `u^2 = NR`. + fn non_residue() -> F; +} + +/// Quadratic extension element `c0 + c1 * u` with `u^2 = NR`. +pub struct Fp2> { + /// Constant term. + pub c0: F, + /// Coefficient of `u`. + pub c1: F, + _cfg: PhantomData C>, +} + +impl> Fp2 { + /// Construct `c0 + c1 * u`. + #[inline] + pub fn new(c0: F, c1: F) -> Self { + Self { + c0, + c1, + _cfg: PhantomData, + } + } + + /// Multiply a base-field element by the non-residue. + /// + /// When `IS_NEG_ONE` is true this is just a negation (no multiply). + #[inline(always)] + fn mul_nr(x: F) -> F { + if C::IS_NEG_ONE { + -x + } else { + C::non_residue() * x + } + } + + /// Return the conjugate `c0 - c1 * u`. + #[inline] + pub fn conjugate(self) -> Self { + Self::new(self.c0, -self.c1) + } + + /// Return the norm in the base field: `c0^2 - NR * c1^2`. + #[inline] + pub fn norm(self) -> F { + (self.c0 * self.c0) - Self::mul_nr(self.c1 * self.c1) + } +} + +impl> std::fmt::Debug for Fp2 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Fp2") + .field("c0", &self.c0) + .field("c1", &self.c1) + .finish() + } +} + +impl> Clone for Fp2 { + fn clone(&self) -> Self { + *self + } +} + +impl> Copy for Fp2 {} + +impl> Default for Fp2 { + fn default() -> Self { + Self::new(F::zero(), F::zero()) + } +} + +impl> PartialEq for Fp2 { + fn eq(&self, other: &Self) -> bool { + self.c0 == other.c0 && self.c1 == other.c1 + } +} + +impl> Eq for Fp2 {} + +impl> Add for Fp2 { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + Self::new(self.c0 + rhs.c0, self.c1 + rhs.c1) + } +} +impl> Sub for Fp2 { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + Self::new(self.c0 - rhs.c0, self.c1 - rhs.c1) + } +} +impl> Neg for Fp2 { + type Output = Self; + fn neg(self) -> Self::Output { + Self::new(-self.c0, -self.c1) + } +} +impl> Mul for Fp2 { + type Output = Self; + fn mul(self, rhs: Self) -> Self::Output { + let v0 = self.c0 * rhs.c0; + let v1 = self.c1 * rhs.c1; + Self::new( + v0 + Self::mul_nr(v1), + (self.c0 + self.c1) * (rhs.c0 + rhs.c1) - v0 - v1, + ) + } +} + +impl<'a, F: FieldCore, C: Fp2Config> Add<&'a Self> for Fp2 { + type Output = Self; + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} +impl<'a, F: FieldCore, C: Fp2Config> Sub<&'a Self> for Fp2 { + type Output = Self; + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} +impl<'a, F: FieldCore, C: Fp2Config> Mul<&'a Self> for Fp2 { + type Output = Self; + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl> Valid for Fp2 { + fn check(&self) -> Result<(), SerializationError> { + self.c0.check()?; + self.c1.check()?; + Ok(()) + } +} + +impl> HachiSerialize for Fp2 { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.c0.serialize_with_mode(&mut writer, compress)?; + self.c1.serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.c0.serialized_size(compress) + self.c1.serialized_size(compress) + } +} + +impl> HachiDeserialize for Fp2 { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let c0 = F::deserialize_with_mode(&mut reader, compress, validate)?; + let c1 = F::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self::new(c0, c1); + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl> FieldCore for Fp2 { + fn zero() -> Self { + Self::new(F::zero(), F::zero()) + } + + fn one() -> Self { + Self::new(F::one(), F::zero()) + } + + fn is_zero(&self) -> bool { + self.c0.is_zero() && self.c1.is_zero() + } + + /// Specialized squaring: 2 base-field multiplications instead of 3. + /// + /// `(c0 + c1·u)^2 = (c0^2 + NR·c1^2) + (2·c0·c1)·u` + fn square(&self) -> Self { + let v0 = self.c0 * self.c0; + let v1 = self.c1 * self.c1; + Self::new(v0 + Self::mul_nr(v1), (self.c0 + self.c0) * self.c1) + } + + fn inv(self) -> Option { + if self.is_zero() { + return None; + } + let inv_n = self.norm().inv()?; + Some(Self::new(self.c0 * inv_n, (-self.c1) * inv_n)) + } +} + +impl> FieldSampling for Fp2 { + fn sample(rng: &mut R) -> Self { + Self::new(F::sample(rng), F::sample(rng)) + } +} + +impl> FromSmallInt for Fp2 { + fn from_u64(val: u64) -> Self { + Self::new(F::from_u64(val), F::zero()) + } + + fn from_i64(val: i64) -> Self { + Self::new(F::from_i64(val), F::zero()) + } +} + +/// Parameters for an `Fp4` quadratic extension over `Fp2`. +pub trait Fp4Config> { + /// Non-residue `NR2` in `Fp2` such that `v^2 = NR2`. + fn non_residue() -> Fp2; +} + +/// `Fp4Config` with non-residue `u ∈ Fp2` (the element `(0, 1)`). +/// +/// This is the standard tower choice: `Fp4 = Fp2[v] / (v^2 - u)`. +pub struct UnitNr; + +impl> Fp4Config for UnitNr { + fn non_residue() -> Fp2 { + Fp2::new(F::zero(), F::one()) + } +} + +/// Quartic extension element `c0 + c1 * v` over `Fp2`, where `v^2 = NR2`. +pub struct Fp4, C4: Fp4Config> { + /// Constant term. + pub c0: Fp2, + /// Coefficient of `v`. + pub c1: Fp2, + _cfg: PhantomData C4>, +} + +impl, C4: Fp4Config> Fp4 { + /// Construct `c0 + c1 * v`. + #[inline] + pub fn new(c0: Fp2, c1: Fp2) -> Self { + Self { + c0, + c1, + _cfg: PhantomData, + } + } + + /// Return the norm in `Fp2`: `c0^2 - NR2 * c1^2`. + #[inline] + pub fn norm(self) -> Fp2 { + let nr2 = C4::non_residue(); + (self.c0 * self.c0) - (nr2 * (self.c1 * self.c1)) + } +} + +impl, C4: Fp4Config> std::fmt::Debug + for Fp4 +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Fp4") + .field("c0", &self.c0) + .field("c1", &self.c1) + .finish() + } +} + +impl, C4: Fp4Config> Clone for Fp4 { + fn clone(&self) -> Self { + *self + } +} + +impl, C4: Fp4Config> Copy for Fp4 {} + +impl, C4: Fp4Config> Default for Fp4 { + fn default() -> Self { + Self::new( + Fp2::new(F::zero(), F::zero()), + Fp2::new(F::zero(), F::zero()), + ) + } +} + +impl, C4: Fp4Config> PartialEq for Fp4 { + fn eq(&self, other: &Self) -> bool { + self.c0 == other.c0 && self.c1 == other.c1 + } +} + +impl, C4: Fp4Config> Eq for Fp4 {} + +impl, C4: Fp4Config> Add for Fp4 { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + Self::new(self.c0 + rhs.c0, self.c1 + rhs.c1) + } +} +impl, C4: Fp4Config> Sub for Fp4 { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + Self::new(self.c0 - rhs.c0, self.c1 - rhs.c1) + } +} +impl, C4: Fp4Config> Neg for Fp4 { + type Output = Self; + fn neg(self) -> Self::Output { + Self::new(-self.c0, -self.c1) + } +} +impl, C4: Fp4Config> Mul for Fp4 { + type Output = Self; + fn mul(self, rhs: Self) -> Self::Output { + let nr2 = C4::non_residue(); + let v0 = self.c0 * rhs.c0; + let v1 = self.c1 * rhs.c1; + Self::new( + v0 + (nr2 * v1), + (self.c0 + self.c1) * (rhs.c0 + rhs.c1) - v0 - v1, + ) + } +} + +impl<'a, F: FieldCore, C2: Fp2Config, C4: Fp4Config> Add<&'a Self> for Fp4 { + type Output = Self; + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} +impl<'a, F: FieldCore, C2: Fp2Config, C4: Fp4Config> Sub<&'a Self> for Fp4 { + type Output = Self; + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} +impl<'a, F: FieldCore, C2: Fp2Config, C4: Fp4Config> Mul<&'a Self> for Fp4 { + type Output = Self; + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl, C4: Fp4Config> Valid for Fp4 { + fn check(&self) -> Result<(), SerializationError> { + self.c0.check()?; + self.c1.check()?; + Ok(()) + } +} + +impl, C4: Fp4Config> HachiSerialize for Fp4 { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.c0.serialize_with_mode(&mut writer, compress)?; + self.c1.serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.c0.serialized_size(compress) + self.c1.serialized_size(compress) + } +} + +impl, C4: Fp4Config> HachiDeserialize + for Fp4 +{ + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let c0 = Fp2::::deserialize_with_mode(&mut reader, compress, validate)?; + let c1 = Fp2::::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self::new(c0, c1); + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl, C4: Fp4Config> FieldCore for Fp4 { + fn zero() -> Self { + Self::new(Fp2::zero(), Fp2::zero()) + } + + fn one() -> Self { + Self::new(Fp2::one(), Fp2::zero()) + } + + fn is_zero(&self) -> bool { + self.c0.is_zero() && self.c1.is_zero() + } + + fn square(&self) -> Self { + let nr2 = C4::non_residue(); + let v0 = self.c0.square(); + let v1 = self.c1.square(); + Self::new(v0 + nr2 * v1, (self.c0 + self.c0) * self.c1) + } + + fn inv(self) -> Option { + if self.is_zero() { + return None; + } + let inv_n = self.norm().inv()?; + Some(Self::new(self.c0 * inv_n, (-self.c1) * inv_n)) + } +} + +impl, C4: Fp4Config> FieldSampling + for Fp4 +{ + fn sample(rng: &mut R) -> Self { + Self::new(Fp2::sample(rng), Fp2::sample(rng)) + } +} + +impl, C4: Fp4Config> FromSmallInt + for Fp4 +{ + fn from_u64(val: u64) -> Self { + Self::new(Fp2::from_u64(val), Fp2::zero()) + } + + fn from_i64(val: i64) -> Self { + Self::new(Fp2::from_i64(val), Fp2::zero()) + } +} + +// Scalar * VectorModule impls for extension scalars. + +impl Mul, N>> for Fp2 +where + F: FieldCore + Valid, + C: Fp2Config, +{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, F, C, const N: usize> Mul<&'a VectorModule, N>> for Fp2 +where + F: FieldCore + Valid, + C: Fp2Config, +{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} + +impl Mul, N>> for Fp4 +where + F: FieldCore + Valid, + C2: Fp2Config, + C4: Fp4Config, +{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, F, C2, C4, const N: usize> Mul<&'a VectorModule, N>> for Fp4 +where + F: FieldCore + Valid, + C2: Fp2Config, + C4: Fp4Config, +{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} + +// Convenience aliases for extension fields with NR = 2 (valid for all Hachi +// pseudo-Mersenne primes where p ≡ 5 mod 8). + +/// Quadratic extension over any `F` with non-residue 2. +pub type Ext2 = Fp2; + +/// Quartic extension as tower `Ext2[v]/(v^2 - u)`. +pub type Ext4 = Fp4; + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::lift::ExtField; + use crate::algebra::Fp64; + use crate::{FieldCore, FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp64<4294967197>; + type E2 = Ext2; + type E4 = Ext4; + + #[test] + fn fp2_add_sub_identity() { + let a = E2::new(F::from_u64(3), F::from_u64(5)); + let b = E2::new(F::from_u64(7), F::from_u64(11)); + let c = a + b; + assert_eq!(c - b, a); + assert_eq!(c - a, b); + } + + #[test] + fn fp2_mul_one() { + let a = E2::new(F::from_u64(42), F::from_u64(13)); + assert_eq!(a * E2::one(), a); + assert_eq!(E2::one() * a, a); + } + + #[test] + fn fp2_mul_commutativity() { + let mut rng = StdRng::seed_from_u64(1234); + let a = E2::sample(&mut rng); + let b = E2::sample(&mut rng); + assert_eq!(a * b, b * a); + } + + #[test] + fn fp2_karatsuba_matches_schoolbook() { + let mut rng = StdRng::seed_from_u64(5678); + for _ in 0..100 { + let a = E2::sample(&mut rng); + let b = E2::sample(&mut rng); + let nr = >::non_residue(); + let expected = E2::new( + (a.c0 * b.c0) + (nr * (a.c1 * b.c1)), + (a.c0 * b.c1) + (a.c1 * b.c0), + ); + assert_eq!(a * b, expected); + } + } + + #[test] + fn fp2_square_matches_mul() { + let mut rng = StdRng::seed_from_u64(9012); + for _ in 0..100 { + let a = E2::sample(&mut rng); + assert_eq!(a.square(), a * a, "square mismatch for {a:?}"); + } + } + + #[test] + fn fp2_inv() { + let mut rng = StdRng::seed_from_u64(3456); + for _ in 0..50 { + let a = E2::sample(&mut rng); + if !a.is_zero() { + let inv = a.inv().unwrap(); + assert_eq!(a * inv, E2::one()); + } + } + } + + #[test] + fn fp4_mul_commutativity() { + let mut rng = StdRng::seed_from_u64(7890); + let a = E4::sample(&mut rng); + let b = E4::sample(&mut rng); + assert_eq!(a * b, b * a); + } + + #[test] + fn fp4_square_matches_mul() { + let mut rng = StdRng::seed_from_u64(1111); + for _ in 0..50 { + let a = E4::sample(&mut rng); + assert_eq!(a.square(), a * a); + } + } + + #[test] + fn fp4_inv() { + let mut rng = StdRng::seed_from_u64(2222); + for _ in 0..50 { + let a = E4::sample(&mut rng); + if !a.is_zero() { + let inv = a.inv().unwrap(); + assert_eq!(a * inv, E4::one()); + } + } + } + + #[test] + fn from_small_int_fp2() { + let a = E2::from_u64(42); + assert_eq!(a, E2::new(F::from_u64(42), F::zero())); + + let b = E2::from_i64(-3); + assert_eq!(b, E2::new(F::from_i64(-3), F::zero())); + + let c = E2::from_u8(7); + assert_eq!(c, E2::from_u64(7)); + + let d = E2::from_u32(100_000); + assert_eq!(d, E2::from_u64(100_000)); + } + + #[test] + fn from_small_int_fp4() { + let a = E4::from_u64(42); + assert_eq!(a, E4::new(E2::from_u64(42), E2::zero())); + + let b = E4::from_i64(-7); + assert_eq!(b, E4::new(E2::from_i64(-7), E2::zero())); + } + + #[test] + fn ext_field_degree() { + assert_eq!(>::EXT_DEGREE, 1); + assert_eq!(>::EXT_DEGREE, 2); + assert_eq!(>::EXT_DEGREE, 4); + } + + #[test] + fn ext_field_from_base_slice() { + let c0 = F::from_u64(3); + let c1 = F::from_u64(5); + let e2 = E2::from_base_slice(&[c0, c1]); + assert_eq!(e2, E2::new(c0, c1)); + + let c2 = F::from_u64(7); + let c3 = F::from_u64(11); + let e4 = E4::from_base_slice(&[c0, c1, c2, c3]); + assert_eq!(e4, E4::new(E2::new(c0, c1), E2::new(c2, c3))); + } + + #[test] + fn eq_impl() { + let a = E2::new(F::from_u64(1), F::from_u64(2)); + let b = E2::new(F::from_u64(1), F::from_u64(2)); + let c = E2::new(F::from_u64(1), F::from_u64(3)); + assert_eq!(a, b); + assert_ne!(a, c); + } +} diff --git a/src/algebra/fields/fp128.rs b/src/algebra/fields/fp128.rs new file mode 100644 index 00000000..99f3fc5b --- /dev/null +++ b/src/algebra/fields/fp128.rs @@ -0,0 +1,987 @@ +//! 128-bit prime field for primes of the form `p = 2^128 − c` with `c < 2^64`. +//! +//! Uses Solinas-style two-fold reduction: no Montgomery form, ~23 cycles/mul +//! on both AArch64 and x86-64. The offset `c` is computed at compile time +//! from the const-generic modulus `P`. +//! +//! ## Naming convention for built-in primes +//! +//! The built-in type names encode the **signed terms as they appear in the +//! modulus `p`** (excluding the leading `+2^128` term). For example, +//! `Prime128M13M4P0` denotes `p = 2^128 − 2^13 − 2^4 + 2^0`. + +use std::io::{Read, Write}; +use std::ops::{Add, Mul, Neg, Sub}; + +use rand_core::RngCore; + +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{ + CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible, PseudoMersenneField, +}; + +/// Pack two u64 limbs into `[lo, hi]`. +#[inline(always)] +const fn pack(lo: u64, hi: u64) -> [u64; 2] { + [lo, hi] +} + +/// Convert `u128` → `[u64; 2]`. +#[inline(always)] +const fn from_u128(x: u128) -> [u64; 2] { + [x as u64, (x >> 64) as u64] +} + +/// Convert `[u64; 2]` → `u128`. +#[inline(always)] +const fn to_u128(x: [u64; 2]) -> u128 { + x[0] as u128 | (x[1] as u128) << 64 +} + +use super::util::{is_pow2_u64, log2_pow2_u64, mul64_wide}; + +/// 128-bit prime field element for primes `p = 2^128 − c` with `c < 2^64`. +/// +/// Stored as `[u64; 2]` (lo, hi) for 8-byte alignment and direct limb access. +/// +/// The offset `c = 2^128 − p` and all derived constants are computed at +/// compile time from the const-generic `P`. Instantiating `Fp128` with a +/// modulus that is not of this form is a compile-time error. +#[derive(Debug, Clone, Copy, Default)] +pub struct Fp128(pub(crate) [u64; 2]); + +impl PartialEq for Fp128

{ + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for Fp128

{} + +impl Fp128

{ + /// Offset `c = 2^128 − p`. Validated at compile time. + const C: u128 = { + let c = 0u128.wrapping_sub(P); + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!(c < (1u128 << 64), "P must be 2^128 - c with c < 2^64"); + // Fused overflow+canonicalize requires C(C+1) < P. + assert!( + c * (c + 1) < P, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + const C_LO: u64 = Self::C as u64; + /// +1 means `C = 2^a + 1`, -1 means `C = 2^a - 1`, 0 means generic. + const C_SHIFT_KIND: i8 = { + let c = Self::C_LO; + if c > 1 && is_pow2_u64(c - 1) { + 1 + } else if c == u64::MAX || is_pow2_u64(c + 1) { + -1 + } else { + 0 + } + }; + const C_SHIFT: u32 = { + let c = Self::C_LO; + if Self::C_SHIFT_KIND == 1 { + log2_pow2_u64(c - 1) + } else if Self::C_SHIFT_KIND == -1 { + if c == u64::MAX { + 64 + } else { + log2_pow2_u64(c + 1) + } + } else { + 0 + } + }; + + /// Multiply by `C = 2^128 - P`. For `C = 2^a ± 1`, this is shift/add or + /// shift/sub only; otherwise it falls back to generic widening multiply. + #[inline(always)] + fn mul_c_wide(x: u64) -> (u64, u64) { + if Self::C_SHIFT_KIND == 1 { + let v = ((x as u128) << Self::C_SHIFT) + x as u128; + (v as u64, (v >> 64) as u64) + } else if Self::C_SHIFT_KIND == -1 { + let v = ((x as u128) << Self::C_SHIFT) - x as u128; + (v as u64, (v >> 64) as u64) + } else { + mul64_wide(Self::C_LO, x) + } + } + + /// Create from a canonical representative in `[0, p)`. + #[inline] + pub fn from_canonical_u128(x: u128) -> Self { + debug_assert!(x < P); + Self(from_u128(x)) + } + + /// Return the canonical representative in `[0, p)`. + #[inline] + pub fn to_canonical_u128(self) -> u128 { + to_u128(self.0) + } + + #[inline(always)] + fn add_raw(a: [u64; 2], b: [u64; 2]) -> [u64; 2] { + let (s, carry) = to_u128(a).overflowing_add(to_u128(b)); + let (reduced, borrow) = s.overflowing_sub(P); + from_u128(if carry | !borrow { + reduced + } else { + reduced.wrapping_add(P) + }) + } + + #[inline(always)] + fn sub_raw(a: [u64; 2], b: [u64; 2]) -> [u64; 2] { + let (diff, borrow) = to_u128(a).overflowing_sub(to_u128(b)); + from_u128(if borrow { diff.wrapping_add(P) } else { diff }) + } + + /// Fold 2 + canonicalize: reduce `[t0, t1] + t2·2^128` into `[0, p)`. + /// + /// Correctness argument for the fused overflow+canonicalize: + /// + /// Let `v = base + C·t2` (mathematical, not mod 2^128). + /// From the fold-1 mac chain, `t2 ≤ C`, so `C·t2 ≤ C²`. + /// + /// - **No overflow** (`v < 2^128`): `s = v`, and the standard + /// canonicalize applies — `s + C` carries iff `s ≥ P`. + /// - **Overflow** (`v ≥ 2^128`): `s = v − 2^128`, so `s < C·t2 ≤ C²`. + /// The correct reduced value is `s + C` (since `2^128 ≡ C mod P`). + /// Because `s + C < C² + C = C(C+1)` and `C(C+1) < P` for all + /// `C < 2^64`, the value `s + C` is already in `[0, P)` — no + /// further canonicalization is needed, and `s + C < 2^128` so the + /// add does NOT carry. + /// + /// Therefore `if (overflow | carry) { s + C } else { s }` is correct + /// in both cases, fusing the overflow correction with canonicalization. + #[inline(always)] + fn fold2_canonicalize(t0: u64, t1: u64, t2: u64) -> [u64; 2] { + let (ct2_lo, ct2_hi) = Self::mul_c_wide(t2); + + let (s0, carry0) = t0.overflowing_add(ct2_lo); + let (s1a, carry1a) = t1.overflowing_add(ct2_hi); + let (s1, carry1b) = s1a.overflowing_add(carry0 as u64); + let overflow = carry1a | carry1b; + + let (r0, carry2) = s0.overflowing_add(Self::C_LO); + let (r1, carry3) = s1.overflowing_add(carry2 as u64); + + pack( + if overflow | carry3 { r0 } else { s0 }, + if overflow | carry3 { r1 } else { s1 }, + ) + } + + /// Solinas fold for exactly 4 limbs: `[r0,r1] + C·[r2,r3]` → 3 limbs, + /// then `fold2_canonicalize`. + #[inline(always)] + fn reduce_4(r0: u64, r1: u64, r2: u64, r3: u64) -> [u64; 2] { + let (cr2_lo, cr2_hi) = Self::mul_c_wide(r2); + let (cr3_lo, cr3_hi) = Self::mul_c_wide(r3); + + let t0_sum = r0 as u128 + cr2_lo as u128; + let t0 = t0_sum as u64; + let carryf = (t0_sum >> 64) as u64; + + let t1_sum = r1 as u128 + cr2_hi as u128 + cr3_lo as u128 + carryf as u128; + let t1 = t1_sum as u64; + + let t2_sum = cr3_hi as u128 + (t1_sum >> 64); + let t2 = t2_sum as u64; + debug_assert_eq!(t2_sum >> 64, 0); + + Self::fold2_canonicalize(t0, t1, t2) + } + + #[inline(always)] + fn mul_raw(a: [u64; 2], b: [u64; 2]) -> [u64; 2] { + let [r0, r1, r2, r3] = Self(a).mul_wide(Self(b)); + Self::reduce_4(r0, r1, r2, r3) + } + + #[inline(always)] + fn sqr_raw(a: [u64; 2]) -> [u64; 2] { + Self::mul_raw(a, a) + } + + /// Squaring, equivalent to `self * self`. + #[inline(always)] + pub fn square(self) -> Self { + Self(Self::sqr_raw(self.0)) + } + + fn pow_u128(self, mut exp: u128) -> Self { + let mut base = self; + let mut acc = Self::one(); + while exp > 0 { + if (exp & 1) == 1 { + acc = acc * base; + } + base = Self(Self::sqr_raw(base.0)); + exp >>= 1; + } + acc + } + + /// Extract the canonical `[lo, hi]` limb representation. + #[inline(always)] + pub fn to_limbs(self) -> [u64; 2] { + self.0 + } + + /// 128×64 → 192-bit widening multiply, **no reduction**. + /// + /// Returns `[lo, mid, hi]` representing `self · other` as a 192-bit + /// integer. Cost: 2 widening `mul64`. + #[inline(always)] + pub fn mul_wide_u64(self, other: u64) -> [u64; 3] { + let (a0, a1) = (self.0[0], self.0[1]); + let (p0_lo, p0_hi) = mul64_wide(a0, other); + let (p1_lo, p1_hi) = mul64_wide(a1, other); + let mid = p0_hi as u128 + p1_lo as u128; + let hi = p1_hi + (mid >> 64) as u64; + [p0_lo, mid as u64, hi] + } + + /// 128×128 → 256-bit widening multiply, **no reduction**. + /// + /// Returns `[r0, r1, r2, r3]` representing `self · other` as a 256-bit + /// integer. This is the schoolbook 2×2 portion of the Solinas multiply, + /// without the reduction fold. Cost: 4 widening `mul64`. + #[inline(always)] + pub fn mul_wide(self, other: Self) -> [u64; 4] { + let (a0, a1) = (self.0[0], self.0[1]); + let (b0, b1) = (other.0[0], other.0[1]); + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r0 = p00_lo; + let r1 = row1 as u64; + let carry1 = (row1 >> 64) as u64; + + let row2 = p01_hi as u128 + p10_hi as u128 + p11_lo as u128 + carry1 as u128; + let r2 = row2 as u64; + let carry2 = (row2 >> 64) as u64; + + let row3 = p11_hi as u128 + carry2 as u128; + let r3 = row3 as u64; + debug_assert_eq!(row3 >> 64, 0); + + [r0, r1, r2, r3] + } + + /// 128×128 → 256-bit widening multiply with a raw `u128` operand, + /// **no reduction**. + #[inline(always)] + pub fn mul_wide_u128(self, other: u128) -> [u64; 4] { + self.mul_wide(Self(from_u128(other))) + } + + /// 128×(64*M) → (64*OUT) widening multiply, **no reduction**. + /// + /// Multiplies a canonical Fp128 value (`[u64; 2]`) by an arbitrary + /// little-endian limb array and returns the little-endian product + /// truncated/extended to `OUT` limbs. + #[inline(always)] + pub fn mul_wide_limbs(self, other: [u64; M]) -> [u64; OUT] { + let (a0, a1) = (self.0[0], self.0[1]); + + // Hot-path specializations used by Jolt (M in {3,4}, OUT in {4,5}). + // These avoid loop/control-flow overhead in tight sumcheck FMAs. + if M == 3 && OUT == 5 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let (p12_lo, p12_hi) = mul64_wide(a1, b2); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + let carry3 = row3 >> 64; + + let row4 = p12_hi as u128 + carry3; + let r4 = row4 as u64; + debug_assert_eq!(row4 >> 64, 0); + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + out[4] = r4; + return out; + } + if M == 3 && OUT == 4 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let p12_lo = a1.wrapping_mul(b2); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + return out; + } + if M == 4 && OUT == 6 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + let b3 = other[3]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let (p03_lo, p03_hi) = mul64_wide(a0, b3); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let (p12_lo, p12_hi) = mul64_wide(a1, b2); + let (p13_lo, p13_hi) = mul64_wide(a1, b3); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p03_lo as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + let carry3 = row3 >> 64; + + let row4 = p03_hi as u128 + p12_hi as u128 + p13_lo as u128 + carry3; + let r4 = row4 as u64; + let carry4 = row4 >> 64; + + let row5 = p13_hi as u128 + carry4; + let r5 = row5 as u64; + debug_assert_eq!(row5 >> 64, 0); + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + out[4] = r4; + out[5] = r5; + return out; + } + if M == 4 && OUT == 5 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + let b3 = other[3]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let (p03_lo, p03_hi) = mul64_wide(a0, b3); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let (p12_lo, p12_hi) = mul64_wide(a1, b2); + let p13_lo = a1.wrapping_mul(b3); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p03_lo as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + let carry3 = row3 >> 64; + + let row4 = p03_hi as u128 + p12_hi as u128 + p13_lo as u128 + carry3; + let r4 = row4 as u64; + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + out[4] = r4; + return out; + } + if M == 4 && OUT == 4 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + let b3 = other[3]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let p03_lo = a0.wrapping_mul(b3); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let p12_lo = a1.wrapping_mul(b2); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p03_lo as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + return out; + } + + let mut out = [0u64; OUT]; + + for (i, &b) in other.iter().enumerate() { + if i >= OUT { + break; + } + + let (p0_lo, p0_hi) = mul64_wide(a0, b); + let (p1_lo, p1_hi) = mul64_wide(a1, b); + + let s0 = out[i] as u128 + p0_lo as u128; + out[i] = s0 as u64; + let mut carry = s0 >> 64; + + if i + 1 >= OUT { + continue; + } + let s1 = out[i + 1] as u128 + p0_hi as u128 + p1_lo as u128 + carry; + out[i + 1] = s1 as u64; + carry = s1 >> 64; + + if i + 2 >= OUT { + continue; + } + let s2 = out[i + 2] as u128 + p1_hi as u128 + carry; + out[i + 2] = s2 as u64; + + let mut carry_hi = s2 >> 64; + let mut j = i + 3; + while carry_hi != 0 && j < OUT { + let sj = out[j] as u128 + carry_hi; + out[j] = sj as u64; + carry_hi = sj >> 64; + j += 1; + } + } + + out + } + + /// Reduce an arbitrary-width little-endian limb array to a canonical + /// field element via iterated Solinas folding. + /// + /// Each fold splits at the 128-bit boundary and replaces + /// `hi · 2^128` with `hi · C`, reducing width by one limb per + /// iteration. Supports 0–10 input limbs (up to 640 bits). + /// + /// # Panics + /// + /// Panics if `limbs.len() > 10`. + #[inline(always)] + pub fn solinas_reduce(limbs: &[u64]) -> Self { + match limbs.len() { + 0 => Self::zero(), + 1 => Self(pack(limbs[0], 0)), + 2 => Self::from_canonical_u128_reduced(to_u128([limbs[0], limbs[1]])), + 3 => Self(Self::fold2_canonicalize(limbs[0], limbs[1], limbs[2])), + 4 => Self(Self::reduce_4(limbs[0], limbs[1], limbs[2], limbs[3])), + 5 => { + let (l0, l1, l2, l3, l4) = (limbs[0], limbs[1], limbs[2], limbs[3], limbs[4]); + let (c2_lo, c2_hi) = Self::mul_c_wide(l2); + let (c3_lo, c3_hi) = Self::mul_c_wide(l3); + let (c4_lo, c4_hi) = Self::mul_c_wide(l4); + + let s0 = l0 as u128 + c2_lo as u128; + let s1 = l1 as u128 + c2_hi as u128 + c3_lo as u128 + (s0 >> 64); + let s2 = c3_hi as u128 + c4_lo as u128 + (s1 >> 64); + let s3 = c4_hi as u128 + (s2 >> 64); + debug_assert_eq!(s3 >> 64, 0); + + Self(Self::reduce_4(s0 as u64, s1 as u64, s2 as u64, s3 as u64)) + } + n => { + assert!(n <= 10, "solinas_reduce supports at most 10 limbs"); + let mut buf = [0u64; 11]; + buf[..n].copy_from_slice(limbs); + let mut len = n; + let c = Self::C_LO; + + while len > 5 { + let high_len = len - 2; + let mut next = [0u64; 11]; + + let mut carry: u64 = 0; + for i in 0..high_len { + let wide = c as u128 * buf[i + 2] as u128 + carry as u128; + next[i] = wide as u64; + carry = (wide >> 64) as u64; + } + next[high_len] = carry; + + let s0 = next[0] as u128 + buf[0] as u128; + next[0] = s0 as u64; + let s1 = next[1] as u128 + buf[1] as u128 + (s0 >> 64); + next[1] = s1 as u64; + let mut c_out = (s1 >> 64) as u64; + for limb in &mut next[2..=high_len] { + if c_out == 0 { + break; + } + let s = *limb as u128 + c_out as u128; + *limb = s as u64; + c_out = (s >> 64) as u64; + } + debug_assert_eq!(c_out, 0); + + buf = next; + len -= 1; + while len > 5 && buf[len - 1] == 0 { + len -= 1; + } + } + + Self::solinas_reduce(&buf[..len]) + } + } + } +} + +impl Add for Fp128

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Self(Self::add_raw(self.0, rhs.0)) + } +} + +impl Sub for Fp128

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Self(Self::sub_raw(self.0, rhs.0)) + } +} + +impl Mul for Fp128

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Self(Self::mul_raw(self.0, rhs.0)) + } +} + +impl Neg for Fp128

{ + type Output = Self; + #[inline] + fn neg(self) -> Self::Output { + Self(Self::sub_raw(pack(0, 0), self.0)) + } +} + +impl<'a, const P: u128> Add<&'a Self> for Fp128

{ + type Output = Self; + #[inline] + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} + +impl<'a, const P: u128> Sub<&'a Self> for Fp128

{ + type Output = Self; + #[inline] + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} + +impl<'a, const P: u128> Mul<&'a Self> for Fp128

{ + type Output = Self; + #[inline] + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl Valid for Fp128

{ + fn check(&self) -> Result<(), SerializationError> { + if to_u128(self.0) < P { + Ok(()) + } else { + Err(SerializationError::InvalidData("Fp128 out of range".into())) + } + } +} + +impl HachiSerialize for Fp128

{ + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + to_u128(self.0).serialize_with_mode(&mut writer, Compress::No)?; + Ok(()) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 16 + } +} + +impl HachiDeserialize for Fp128

{ + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let x = u128::deserialize_with_mode(&mut reader, Compress::No, validate)?; + if matches!(validate, Validate::Yes) && x >= P { + return Err(SerializationError::InvalidData( + "Fp128 out of range".to_string(), + )); + } + + // Without validation, reduce without division. + // For `p = 2^128 − c` with `c < 2^64` we have `p > 2^127`, + // so any `u128` is in `[0, 2p)` and one conditional subtract suffices. + let out = if matches!(validate, Validate::Yes) { + x + } else { + let (sub, borrow) = x.overflowing_sub(P); + if borrow { + x + } else { + sub + } + }; + Ok(Self(from_u128(out))) + } +} + +impl FieldCore for Fp128

{ + fn zero() -> Self { + Self(pack(0, 0)) + } + + fn one() -> Self { + Self(pack(1, 0)) + } + + fn is_zero(&self) -> bool { + self.0 == [0, 0] + } + + fn inv(self) -> Option { + let inv = self.inv_or_zero(); + if self.is_zero() { + None + } else { + Some(inv) + } + } +} + +impl Invertible for Fp128

{ + fn inv_or_zero(self) -> Self { + let candidate = self.pow_u128(P.wrapping_sub(2)); + let v = to_u128(self.0); + let nz = ((v | v.wrapping_neg()) >> 127) & 1; + let mask = 0u128.wrapping_sub(nz); + let masked = to_u128(candidate.0) & mask; + Self(from_u128(masked)) + } +} + +impl FieldSampling for Fp128

{ + fn sample(rng: &mut R) -> Self { + loop { + let lo = rng.next_u64(); + let hi = rng.next_u64(); + let x = lo as u128 | (hi as u128) << 64; + if x < P { + return Self(pack(lo, hi)); + } + } + } +} + +impl FromSmallInt for Fp128

{ + fn from_u64(val: u64) -> Self { + // For Fp128 pseudo-Mersenne primes, p = 2^128 - c with c < 2^64. + // Therefore any u64 is always canonical (< p), so this can be a + // direct limb construction with no reduction path. + Self(from_u128(val as u128)) + } + + fn from_i64(val: i64) -> Self { + if val >= 0 { + Self::from_u64(val as u64) + } else { + // unsigned_abs avoids overflow for i64::MIN. + -Self::from_u64(val.unsigned_abs()) + } + } +} + +impl CanonicalField for Fp128

{ + fn to_canonical_u128(self) -> u128 { + to_u128(self.0) + } + + fn from_canonical_u128_checked(val: u128) -> Option { + if val < P { + Some(Self(from_u128(val))) + } else { + None + } + } + + fn from_canonical_u128_reduced(val: u128) -> Self { + let (sub, borrow) = val.overflowing_sub(P); + Self(from_u128(if borrow { val } else { sub })) + } +} + +impl PseudoMersenneField for Fp128

{ + const MODULUS_BITS: u32 = 128; + const MODULUS_OFFSET: u128 = Self::C; +} + +/// `p = 2^128 − 2^13 − 2^4 + 2^0` (C = 8207). +pub type Prime128M13M4P0 = Fp128<0xffffffffffffffffffffffffffffdff1>; +/// `p = 2^128 − 2^37 + 2^3 + 2^0` (C = 137438953463). +pub type Prime128M37P3P0 = Fp128<0xffffffffffffffffffffffe000000009>; +/// `p = 2^128 − 2^52 − 2^3 + 2^0` (C = 4503599627370487). +pub type Prime128M52M3P0 = Fp128<0xffffffffffffffffffeffffffffffff9>; +/// `p = 2^128 − 2^54 + 2^4 + 2^0` (C = 18014398509481967). +pub type Prime128M54P4P0 = Fp128<0xffffffffffffffffffc0000000000011>; +/// `p = 2^128 − 2^8 − 2^4 − 2^1 − 2^0` (C = 275). +pub type Prime128M8M4M1M0 = Fp128<0xfffffffffffffffffffffffffffffeed>; +/// `p = 2^128 − 2^18 − 2^0` (C = 2^18 + 1). +pub type Prime128M18M0 = Fp128<0xfffffffffffffffffffffffffffbffff>; +/// `p = 2^128 − 2^54 + 2^0` (C = 2^54 − 1). +pub type Prime128M54P0 = Fp128<0xffffffffffffffffffc0000000000001>; + +#[cfg(test)] +mod tests { + use super::*; + use crate::{FieldSampling, PseudoMersenneField}; + use rand::rngs::StdRng; + use rand::SeedableRng; + use rand_core::RngCore; + + type F = Prime128M8M4M1M0; + + #[test] + fn to_limbs_roundtrip() { + let mut rng = StdRng::seed_from_u64(0xdead_beef_cafe_1234); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + assert_eq!(Fp128(a.to_limbs()), a); + } + } + + #[test] + fn mul_wide_u64_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0x1122_3344_5566_7788); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b = rng.next_u64(); + let expected = a * F::from_u64(b); + let reduced = F::solinas_reduce(&a.mul_wide_u64(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0xaabb_ccdd_eeff_0011); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b: F = FieldSampling::sample(&mut rng); + let expected = a * b; + let reduced = F::solinas_reduce(&a.mul_wide(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_u128_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0x9988_7766_5544_3322); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b = rng.next_u64() as u128 | ((rng.next_u64() as u128) << 64); + let expected = a * F::from_canonical_u128_reduced(b); + let reduced = F::solinas_reduce(&a.mul_wide_u128(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_limbs_roundtrips_through_reduction() { + let mut rng = StdRng::seed_from_u64(0x1bad_f00d_0ddc_afe1); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b3 = [rng.next_u64(), rng.next_u64(), rng.next_u64()]; + let b4 = [ + rng.next_u64(), + rng.next_u64(), + rng.next_u64(), + rng.next_u64(), + ]; + + let got3_full = a.mul_wide_limbs::<3, 5>(b3); + let got3_trunc = a.mul_wide_limbs::<3, 4>(b3); + assert_eq!( + got3_trunc, + [got3_full[0], got3_full[1], got3_full[2], got3_full[3]] + ); + let exp3 = a * F::solinas_reduce(&b3); + assert_eq!(F::solinas_reduce(&got3_full), exp3); + + let got4_full = a.mul_wide_limbs::<4, 6>(b4); + let got4_trunc = a.mul_wide_limbs::<4, 4>(b4); + assert_eq!( + got4_trunc, + [got4_full[0], got4_full[1], got4_full[2], got4_full[3]] + ); + let exp4 = a * F::solinas_reduce(&b4); + assert_eq!(F::solinas_reduce(&got4_full), exp4); + } + } + + #[test] + fn solinas_reduce_small_inputs() { + assert_eq!(F::solinas_reduce(&[]), F::zero()); + assert_eq!(F::solinas_reduce(&[42]), F::from_u64(42)); + let one_shifted = F::from_canonical_u128_reduced(1u128 << 64); + assert_eq!(F::solinas_reduce(&[0, 1]), one_shifted); + } + + #[test] + fn solinas_reduce_4_limbs_max() { + // 2^256 - 1 ≡ C² - 1 (mod P), since 2^128 ≡ C + let c = F::from_canonical_u128_reduced(::MODULUS_OFFSET); + let expected = c * c - F::one(); + assert_eq!(F::solinas_reduce(&[u64::MAX; 4]), expected); + } + + #[test] + fn solinas_reduce_9_limbs() { + // 1 + 2^512 = 1 + (2^128)^4 ≡ 1 + C^4 + let c = F::from_canonical_u128_reduced(::MODULUS_OFFSET); + let expected = F::one() + c * c * c * c; + assert_eq!(F::solinas_reduce(&[1, 0, 0, 0, 0, 0, 0, 0, 1]), expected); + } + + #[test] + fn solinas_reduce_accumulated_products() { + let mut rng = StdRng::seed_from_u64(0xfeed_face_0bad_c0de); + let mut acc = [0u64; 5]; + let mut expected = F::zero(); + + for _ in 0..200 { + let a: F = FieldSampling::sample(&mut rng); + let b = rng.next_u64(); + let wide = a.mul_wide_u64(b); + + let mut carry: u64 = 0; + for j in 0..5 { + let addend = if j < 3 { wide[j] } else { 0 }; + let sum = acc[j] as u128 + addend as u128 + carry as u128; + acc[j] = sum as u64; + carry = (sum >> 64) as u64; + } + assert_eq!(carry, 0); + expected = expected + a * F::from_u64(b); + } + + assert_eq!(F::solinas_reduce(&acc), expected); + } + + #[test] + fn solinas_reduce_cross_prime() { + // Verify with Prime128M18M0 (C = 2^18 + 1, shift+add path) + type G = Prime128M18M0; + let c = G::from_canonical_u128_reduced(::MODULUS_OFFSET); + let expected = c * c - G::one(); + assert_eq!(G::solinas_reduce(&[u64::MAX; 4]), expected); + + // Verify with Prime128M54P0 (C = 2^54 - 1, shift-sub path) + type H = Prime128M54P0; + let c = H::from_canonical_u128_reduced(::MODULUS_OFFSET); + let expected = c * c - H::one(); + assert_eq!(H::solinas_reduce(&[u64::MAX; 4]), expected); + } + + #[test] + fn from_i64_handles_min_without_overflow() { + let x = F::from_i64(i64::MIN); + let y = F::from_u64(i64::MIN.unsigned_abs()); + assert_eq!(x + y, F::zero()); + } +} diff --git a/src/algebra/fields/fp32.rs b/src/algebra/fields/fp32.rs new file mode 100644 index 00000000..9be1b7b3 --- /dev/null +++ b/src/algebra/fields/fp32.rs @@ -0,0 +1,447 @@ +//! Prime field for primes of the form `p = 2^k − c` with `c` small, backed +//! by `u32` storage. +//! +//! Uses Solinas-style two-fold reduction: the offset `c` and fold point `k` +//! are computed at compile time from the const-generic modulus `P`. + +use std::ops::{Add, Mul, Neg, Sub}; + +use rand_core::RngCore; + +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{ + CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible, PseudoMersenneField, +}; +use std::io::{Read, Write}; + +/// Prime field element for primes `p = 2^k − c` stored as `u32`. +/// +/// The fold point `k` and offset `c = 2^k − p` are computed at compile time +/// from the const-generic `P`. Instantiating with a modulus that does not +/// satisfy the Solinas conditions is a compile-time error. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct Fp32(pub(crate) u32); + +impl Fp32

{ + /// Fold point: smallest `k` such that `P ≤ 2^k`. + const BITS: u32 = 32 - P.leading_zeros(); + + /// Offset `c = 2^k − P`. + const C: u32 = { + let c = if Self::BITS == 32 { + 0u32.wrapping_sub(P) + } else { + (1u32 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u64) * (c as u64 + 1) < P as u64, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + /// Mask for extracting the low `BITS` bits from a u64. + const MASK: u64 = if Self::BITS == 32 { + u32::MAX as u64 + } else { + (1u64 << Self::BITS) - 1 + }; + + /// Create from a canonical representative in `[0, P)`. + #[inline] + pub fn from_canonical_u32(x: u32) -> Self { + debug_assert!(x < P); + Self(x) + } + + /// Return the canonical representative in `[0, P)`. + #[inline] + pub fn to_canonical_u32(self) -> u32 { + self.0 + } + + /// Solinas reduction: fold a u64 at bit `BITS` until the value fits, + /// then conditionally subtract `P`. + /// + /// For multiplication products (< 2^{2·BITS}) exactly 2 folds suffice; + /// for arbitrary u64 inputs (e.g. `from_u64`) the loop runs at most + /// `ceil(64 / BITS)` iterations. + #[inline(always)] + fn reduce_u64(x: u64) -> u32 { + let c = Self::C as u64; + let mut v = x; + while v >> Self::BITS != 0 { + v = (v & Self::MASK) + c * (v >> Self::BITS); + } + let reduced = v.wrapping_sub(P as u64); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + /// Reduce a `u128` to canonical form (for `from_canonical_u128_reduced`). + #[inline(always)] + fn reduce_u128(x: u128) -> u32 { + let c = Self::C as u128; + let bits = Self::BITS; + let mask = if bits == 32 { + u32::MAX as u128 + } else { + (1u128 << bits) - 1 + }; + let mut v = x; + while v >> bits != 0 { + v = (v & mask) + c * (v >> bits); + } + let f = v as u64; + let reduced = f.wrapping_sub(P as u64); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + /// Two-fold Solinas reduction for multiplication products. + /// + /// Input must be < 2^{2·BITS} (guaranteed for `a*b` where `a,b < P`). + /// Exactly 2 folds + conditional subtract, no loop. + #[inline(always)] + fn reduce_product(x: u64) -> u32 { + let c = Self::C as u64; + let f1 = (x & Self::MASK) + c * (x >> Self::BITS); + let f2 = (f1 & Self::MASK) + c * (f1 >> Self::BITS); + let reduced = f2.wrapping_sub(P as u64); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + #[inline(always)] + fn add_raw(a: u32, b: u32) -> u32 { + let s = (a as u64) + (b as u64); + let reduced = s.wrapping_sub(P as u64); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + #[inline(always)] + fn sub_raw(a: u32, b: u32) -> u32 { + let diff = (a as u64).wrapping_sub(b as u64); + let borrow = diff >> 63; + diff.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + #[inline(always)] + fn mul_raw(a: u32, b: u32) -> u32 { + Self::reduce_product((a as u64) * (b as u64)) + } + + #[inline(always)] + fn sqr_raw(a: u32) -> u32 { + Self::mul_raw(a, a) + } + + /// Squaring, equivalent to `self * self`. + #[inline(always)] + pub fn square(self) -> Self { + Self(Self::sqr_raw(self.0)) + } + + fn pow(self, mut exp: u64) -> Self { + let mut base = self; + let mut acc = Self::one(); + while exp > 0 { + if (exp & 1) == 1 { + acc = acc * base; + } + base = base.square(); + exp >>= 1; + } + acc + } + + /// Extract the canonical value. + #[inline(always)] + pub fn to_limbs(self) -> u32 { + self.0 + } + + /// 32×32 → 64-bit widening multiply, **no reduction**. + #[inline(always)] + pub fn mul_wide(self, other: Self) -> u64 { + (self.0 as u64) * (other.0 as u64) + } + + /// 32×32 → 64-bit widening multiply with a raw `u32` operand, + /// **no reduction**. + #[inline(always)] + pub fn mul_wide_u32(self, other: u32) -> u64 { + (self.0 as u64) * (other as u64) + } + + /// Reduce a u64 value via Solinas folding to a canonical field element. + #[inline(always)] + pub fn solinas_reduce(x: u64) -> Self { + Self(Self::reduce_u64(x)) + } +} + +impl Add for Fp32

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Self(Self::add_raw(self.0, rhs.0)) + } +} + +impl Sub for Fp32

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Self(Self::sub_raw(self.0, rhs.0)) + } +} + +impl Mul for Fp32

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Self(Self::mul_raw(self.0, rhs.0)) + } +} + +impl Neg for Fp32

{ + type Output = Self; + #[inline] + fn neg(self) -> Self::Output { + Self(Self::sub_raw(0, self.0)) + } +} + +impl<'a, const P: u32> Add<&'a Self> for Fp32

{ + type Output = Self; + #[inline] + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} + +impl<'a, const P: u32> Sub<&'a Self> for Fp32

{ + type Output = Self; + #[inline] + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} + +impl<'a, const P: u32> Mul<&'a Self> for Fp32

{ + type Output = Self; + #[inline] + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl Valid for Fp32

{ + fn check(&self) -> Result<(), SerializationError> { + if self.0 < P { + Ok(()) + } else { + Err(SerializationError::InvalidData("Fp32 out of range".into())) + } + } +} + +impl HachiSerialize for Fp32

{ + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(&mut writer, Compress::No)?; + Ok(()) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 4 + } +} + +impl HachiDeserialize for Fp32

{ + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let x = u32::deserialize_with_mode(&mut reader, Compress::No, validate)?; + if matches!(validate, Validate::Yes) && x >= P { + return Err(SerializationError::InvalidData( + "Fp32 out of range".to_string(), + )); + } + let out = if matches!(validate, Validate::Yes) { + Self(x) + } else { + Self(Self::reduce_u64(x as u64)) + }; + Ok(out) + } +} + +impl FieldCore for Fp32

{ + fn zero() -> Self { + Self(0) + } + + fn one() -> Self { + Self(if P > 1 { 1 } else { 0 }) + } + + fn is_zero(&self) -> bool { + self.0 == 0 + } + + fn inv(self) -> Option { + let inv = self.inv_or_zero(); + if self.is_zero() { + None + } else { + Some(inv) + } + } +} + +impl Invertible for Fp32

{ + fn inv_or_zero(self) -> Self { + let candidate = self.pow((P as u64).wrapping_sub(2)); + let nz = ((self.0 | self.0.wrapping_neg()) >> 31) & 1; + let mask = 0u32.wrapping_sub(nz); + Self(candidate.0 & mask) + } +} + +impl FieldSampling for Fp32

{ + fn sample(rng: &mut R) -> Self { + Self(Self::reduce_u64(rng.next_u64())) + } +} + +impl FromSmallInt for Fp32

{ + fn from_u64(val: u64) -> Self { + Self(Self::reduce_u64(val)) + } + + fn from_i64(val: i64) -> Self { + if val >= 0 { + Self::from_u64(val as u64) + } else { + -Self::from_u64(val.unsigned_abs()) + } + } +} + +impl CanonicalField for Fp32

{ + fn to_canonical_u128(self) -> u128 { + self.0 as u128 + } + + fn from_canonical_u128_checked(val: u128) -> Option { + if val < P as u128 { + Some(Self(val as u32)) + } else { + None + } + } + + fn from_canonical_u128_reduced(val: u128) -> Self { + Self(Self::reduce_u128(val)) + } +} + +impl PseudoMersenneField for Fp32

{ + const MODULUS_BITS: u32 = Self::BITS; + const MODULUS_OFFSET: u128 = Self::C as u128; +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp32<251>; // 2^8 - 5 + + #[test] + fn solinas_constants() { + assert_eq!(F::BITS, 8); + assert_eq!(F::C, 5); + assert_eq!(F::MASK, 255); + + type G = Fp32<{ (1u32 << 24) - 3 }>; // 2^24 - 3 + assert_eq!(G::BITS, 24); + assert_eq!(G::C, 3); + } + + #[test] + fn basic_arithmetic() { + let a = F::from_u64(100); + let b = F::from_u64(200); + assert_eq!((a + b).to_canonical_u32(), (100 + 200) % 251); + assert_eq!((a * b).to_canonical_u32(), (100 * 200) % 251); + assert_eq!((b - a).to_canonical_u32(), 100); + assert_eq!((-a).to_canonical_u32(), 251 - 100); + } + + #[test] + fn mul_wide_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0x1234_5678); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b: F = FieldSampling::sample(&mut rng); + let expected = a * b; + let reduced = F::solinas_reduce(a.mul_wide(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_u32_matches() { + let mut rng = StdRng::seed_from_u64(0xabcd_ef01); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b = rng.next_u32() % 251; + let expected = a * F::from_canonical_u32(b); + let reduced = F::solinas_reduce(a.mul_wide_u32(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn reduce_large_values() { + assert_eq!( + F::from_u64(u64::MAX).to_canonical_u32(), + (u64::MAX % 251) as u32 + ); + assert_eq!(F::from_u64(0).to_canonical_u32(), 0); + assert_eq!(F::from_u64(251).to_canonical_u32(), 0); + assert_eq!(F::from_u64(252).to_canonical_u32(), 1); + } + + #[test] + fn pseudo_mersenne_trait() { + assert_eq!(::MODULUS_BITS, 8); + assert_eq!(::MODULUS_OFFSET, 5); + } + + #[test] + fn cross_prime_32bit() { + type G = Fp32<{ u32::MAX - 98 }>; // 2^32 - 99 + assert_eq!(G::BITS, 32); + assert_eq!(G::C, 99); + + let a = G::from_u64(1_000_000); + let b = G::from_u64(2_000_000); + let product = (1_000_000u64 * 2_000_000u64) % ((1u64 << 32) - 99); + assert_eq!((a * b).to_canonical_u32(), product as u32); + } +} diff --git a/src/algebra/fields/fp64.rs b/src/algebra/fields/fp64.rs new file mode 100644 index 00000000..8b81218e --- /dev/null +++ b/src/algebra/fields/fp64.rs @@ -0,0 +1,561 @@ +//! Prime field for primes of the form `p = 2^k − c` with `c` small, backed +//! by `u64` storage. +//! +//! Uses Solinas-style two-fold reduction. For `c = 2^a ± 1` the fold +//! multiply is replaced by shift+add/sub, saving a u128 widening multiply. + +use std::ops::{Add, Mul, Neg, Sub}; + +use rand_core::RngCore; + +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{ + CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible, PseudoMersenneField, +}; +use std::io::{Read, Write}; + +use super::util::{is_pow2_u64, log2_pow2_u64, mul64_wide}; + +/// Prime field element for primes `p = 2^k − c` stored as `u64`. +/// +/// The fold point `k` and offset `c = 2^k − p` are computed at compile time +/// from the const-generic `P`. For `c = 2^a ± 1`, the fold multiply is +/// replaced by shift+add/sub. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct Fp64(pub(crate) u64); + +impl Fp64

{ + /// Fold point: smallest `k` such that `P ≤ 2^k`. + const BITS: u32 = 64 - P.leading_zeros(); + + /// Offset `c = 2^k − P`. + const C: u64 = { + let c = if Self::BITS == 64 { + 0u64.wrapping_sub(P) + } else { + (1u64 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u128) * (c as u128 + 1) < P as u128, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + /// +1 means `C = 2^a + 1`, -1 means `C = 2^a - 1`, 0 means generic. + const C_SHIFT_KIND: i8 = { + let c = Self::C; + if c > 1 && is_pow2_u64(c - 1) { + 1 + } else if c == u64::MAX || is_pow2_u64(c + 1) { + -1 + } else { + 0 + } + }; + + const C_SHIFT: u32 = { + let c = Self::C; + if Self::C_SHIFT_KIND == 1 { + log2_pow2_u64(c - 1) + } else if Self::C_SHIFT_KIND == -1 { + if c == u64::MAX { + 64 + } else { + log2_pow2_u64(c + 1) + } + } else { + 0 + } + }; + + /// Mask for extracting the low `BITS` bits from a u128. + const MASK: u128 = if Self::BITS == 64 { + u64::MAX as u128 + } else { + (1u128 << Self::BITS) - 1 + }; + + /// u64-width mask (only valid when BITS < 64). + const MASK64: u64 = if Self::BITS < 64 { + (1u64 << Self::BITS) - 1 + } else { + u64::MAX + }; + + /// Whether Solinas folding of a multiplication product can stay + /// entirely in u64. True when BITS < 64 and C·2^BITS < 2^64. + const FOLD_IN_U64: bool = Self::BITS < 64 && (Self::C as u128) < (1u128 << (64 - Self::BITS)); + + /// u64 multiply by C, split into u32-wide halves so LLVM emits + /// `umull` (32×32→64) instead of promoting to u128. + /// Only valid when C fits in u32 (always true: C < sqrt(P) < 2^32). + #[inline(always)] + fn mul_c_narrow(x: u64) -> u64 { + #[cfg(target_arch = "x86_64")] + { + // x86_64 has fast scalar 64-bit multiply; use one multiply instead + // of two widened 32-bit multiplies in the fold hot path. + Self::C.wrapping_mul(x) + } + #[cfg(not(target_arch = "x86_64"))] + { + let c = Self::C as u32; + let x_lo = x as u32; + let x_hi = (x >> 32) as u32; + (c as u64 * x_lo as u64).wrapping_add((c as u64 * x_hi as u64) << 32) + } + } + + /// Multiply `x` by `C`. For `C = 2^a ± 1` uses shift+add/sub. + #[inline(always)] + fn mul_c(x: u64) -> u128 { + if Self::C_SHIFT_KIND == 1 { + ((x as u128) << Self::C_SHIFT) + x as u128 + } else if Self::C_SHIFT_KIND == -1 { + ((x as u128) << Self::C_SHIFT) - x as u128 + } else { + (Self::C as u128) * (x as u128) + } + } + + /// Create from a canonical representative in `[0, P)`. + #[inline] + pub fn from_canonical_u64(x: u64) -> Self { + debug_assert!(x < P); + Self(x) + } + + /// Return the canonical representative in `[0, P)`. + #[inline] + pub fn to_canonical_u64(self) -> u64 { + self.0 + } + + /// Solinas reduction: fold a u128 at bit `BITS` until the value fits, + /// then conditionally subtract `P`. + /// + /// For multiplication products (< 2^{2·BITS}) exactly 2 folds suffice; + /// for arbitrary u128 inputs the loop runs at most `ceil(128 / BITS)` + /// iterations. + #[inline(always)] + fn reduce_u128(x: u128) -> u64 { + let mut v = x; + while v >> Self::BITS != 0 { + v = (v & Self::MASK) + Self::mul_c((v >> Self::BITS) as u64); + } + let reduced = v.wrapping_sub(P as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + + /// Two-fold Solinas reduction for multiplication products. + /// + /// Input must be < 2^{2·BITS} (guaranteed for `a*b` where `a,b < P`). + /// Exactly 2 folds + conditional subtract, no loop. + /// + /// When `FOLD_IN_U64` is true the entire reduction stays in u64, + /// avoiding expensive u128 mask/shift on sub-word primes. + #[inline(always)] + fn reduce_product(x: u128) -> u64 { + if Self::FOLD_IN_U64 { + let lo = x as u64; + let hi = (x >> 64) as u64; + let high = (lo >> Self::BITS) | (hi << (64 - Self::BITS)); + let f1 = (lo & Self::MASK64) + Self::mul_c_narrow(high); + let f2 = (f1 & Self::MASK64) + Self::mul_c_narrow(f1 >> Self::BITS); + let reduced = f2.wrapping_sub(P); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P) + } else { + let f1 = (x & Self::MASK) + Self::mul_c((x >> Self::BITS) as u64); + let f2 = (f1 & Self::MASK) + Self::mul_c((f1 >> Self::BITS) as u64); + let reduced = f2.wrapping_sub(P as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + } + + /// BMI2 fast path: avoid re-materializing `u128` product in the common + /// sub-word configuration where reduction stays in `u64`. + #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] + #[inline(always)] + fn reduce_product_wide(lo: u64, hi: u64) -> u64 { + if Self::FOLD_IN_U64 { + let high = (lo >> Self::BITS) | (hi << (64 - Self::BITS)); + let f1 = (lo & Self::MASK64) + Self::mul_c_narrow(high); + let f2 = (f1 & Self::MASK64) + Self::mul_c_narrow(f1 >> Self::BITS); + let reduced = f2.wrapping_sub(P); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P) + } else { + Self::reduce_product(lo as u128 | ((hi as u128) << 64)) + } + } + + #[inline(always)] + fn add_raw(a: u64, b: u64) -> u64 { + if Self::BITS <= 62 { + let s = a + b; + let reduced = s.wrapping_sub(P); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P) + } else { + let s = (a as u128) + (b as u128); + let reduced = s.wrapping_sub(P as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + } + + #[inline(always)] + fn sub_raw(a: u64, b: u64) -> u64 { + if Self::BITS <= 62 { + let diff = a.wrapping_sub(b); + let borrow = diff >> 63; + diff.wrapping_add(borrow.wrapping_neg() & P) + } else { + let diff = (a as u128).wrapping_sub(b as u128); + let borrow = diff >> 127; + diff.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + } + + #[inline(always)] + fn mul_raw(a: u64, b: u64) -> u64 { + #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] + { + let (lo, hi) = mul64_wide(a, b); + Self::reduce_product_wide(lo, hi) + } + #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))] + { + Self::reduce_product((a as u128) * (b as u128)) + } + } + + #[inline(always)] + fn sqr_raw(a: u64) -> u64 { + Self::mul_raw(a, a) + } + + /// Squaring, equivalent to `self * self`. + #[inline(always)] + pub fn square(self) -> Self { + Self(Self::sqr_raw(self.0)) + } + + fn pow(self, mut exp: u64) -> Self { + let mut base = self; + let mut acc = Self::one(); + while exp > 0 { + if (exp & 1) == 1 { + acc = acc * base; + } + base = base.square(); + exp >>= 1; + } + acc + } + + /// Extract the canonical value. + #[inline(always)] + pub fn to_limbs(self) -> u64 { + self.0 + } + + /// 64×64 → 128-bit widening multiply, **no reduction**. + #[inline(always)] + pub fn mul_wide(self, other: Self) -> u128 { + let (lo, hi) = mul64_wide(self.0, other.0); + lo as u128 | ((hi as u128) << 64) + } + + /// 64×64 → 128-bit widening multiply with a raw `u64` operand, + /// **no reduction**. + #[inline(always)] + pub fn mul_wide_u64(self, other: u64) -> u128 { + let (lo, hi) = mul64_wide(self.0, other); + lo as u128 | ((hi as u128) << 64) + } + + /// Reduce a u128 value via Solinas folding to a canonical field element. + #[inline(always)] + pub fn solinas_reduce(x: u128) -> Self { + Self(Self::reduce_u128(x)) + } +} + +impl Add for Fp64

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Self(Self::add_raw(self.0, rhs.0)) + } +} + +impl Sub for Fp64

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Self(Self::sub_raw(self.0, rhs.0)) + } +} + +impl Mul for Fp64

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Self(Self::mul_raw(self.0, rhs.0)) + } +} + +impl Neg for Fp64

{ + type Output = Self; + #[inline] + fn neg(self) -> Self::Output { + Self(Self::sub_raw(0, self.0)) + } +} + +impl<'a, const P: u64> Add<&'a Self> for Fp64

{ + type Output = Self; + #[inline] + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} + +impl<'a, const P: u64> Sub<&'a Self> for Fp64

{ + type Output = Self; + #[inline] + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} + +impl<'a, const P: u64> Mul<&'a Self> for Fp64

{ + type Output = Self; + #[inline] + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl Valid for Fp64

{ + fn check(&self) -> Result<(), SerializationError> { + if self.0 < P { + Ok(()) + } else { + Err(SerializationError::InvalidData("Fp64 out of range".into())) + } + } +} + +impl HachiSerialize for Fp64

{ + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(&mut writer, Compress::No)?; + Ok(()) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 8 + } +} + +impl HachiDeserialize for Fp64

{ + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let x = u64::deserialize_with_mode(&mut reader, Compress::No, validate)?; + if matches!(validate, Validate::Yes) && x >= P { + return Err(SerializationError::InvalidData( + "Fp64 out of range".to_string(), + )); + } + let out = if matches!(validate, Validate::Yes) { + Self(x) + } else { + Self(Self::reduce_u128(x as u128)) + }; + Ok(out) + } +} + +impl FieldCore for Fp64

{ + fn zero() -> Self { + Self(0) + } + + fn one() -> Self { + Self(if P > 1 { 1 } else { 0 }) + } + + fn is_zero(&self) -> bool { + self.0 == 0 + } + + fn inv(self) -> Option { + let inv = self.inv_or_zero(); + if self.is_zero() { + None + } else { + Some(inv) + } + } +} + +impl Invertible for Fp64

{ + fn inv_or_zero(self) -> Self { + let candidate = self.pow(P.wrapping_sub(2)); + let nz = ((self.0 | self.0.wrapping_neg()) >> 63) & 1; + let mask = 0u64.wrapping_sub(nz); + Self(candidate.0 & mask) + } +} + +impl FieldSampling for Fp64

{ + fn sample(rng: &mut R) -> Self { + let lo = rng.next_u64() as u128; + let hi = rng.next_u64() as u128; + Self(Self::reduce_u128(lo | (hi << 64))) + } +} + +impl FromSmallInt for Fp64

{ + fn from_u64(val: u64) -> Self { + Self(Self::reduce_u128(val as u128)) + } + + fn from_i64(val: i64) -> Self { + if val >= 0 { + Self::from_u64(val as u64) + } else { + -Self::from_u64(val.unsigned_abs()) + } + } +} + +impl CanonicalField for Fp64

{ + fn to_canonical_u128(self) -> u128 { + self.0 as u128 + } + + fn from_canonical_u128_checked(val: u128) -> Option { + if val < P as u128 { + Some(Self(val as u64)) + } else { + None + } + } + + fn from_canonical_u128_reduced(val: u128) -> Self { + Self(Self::reduce_u128(val)) + } +} + +impl PseudoMersenneField for Fp64

{ + const MODULUS_BITS: u32 = Self::BITS; + const MODULUS_OFFSET: u128 = Self::C as u128; +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F40 = Fp64<{ (1u64 << 40) - 195 }>; // 2^40 - 195 + type F64 = Fp64<{ u64::MAX - 58 }>; // 2^64 - 59 + + #[test] + fn solinas_constants() { + assert_eq!(F40::BITS, 40); + assert_eq!(F40::C, 195); + + assert_eq!(F64::BITS, 64); + assert_eq!(F64::C, 59); + } + + #[test] + fn basic_arithmetic_sub_word() { + let a = F40::from_u64(1_000_000); + let b = F40::from_u64(2_000_000); + let p = (1u64 << 40) - 195; + assert_eq!((a + b).to_canonical_u64(), 3_000_000); + assert_eq!( + (a * b).to_canonical_u64(), + (1_000_000u128 * 2_000_000u128 % p as u128) as u64 + ); + } + + #[test] + fn basic_arithmetic_full_word() { + let a = F64::from_u64(1_000_000_000); + let b = F64::from_u64(2_000_000_000); + let p = u64::MAX - 58; + assert_eq!( + (a * b).to_canonical_u64(), + (1_000_000_000u128 * 2_000_000_000u128 % p as u128) as u64 + ); + } + + #[test] + fn mul_wide_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0xdead_beef); + for _ in 0..1000 { + let a: F40 = FieldSampling::sample(&mut rng); + let b: F40 = FieldSampling::sample(&mut rng); + let expected = a * b; + let reduced = F40::solinas_reduce(a.mul_wide(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_u64_matches() { + let mut rng = StdRng::seed_from_u64(0xcafe_d00d); + for _ in 0..1000 { + let a: F40 = FieldSampling::sample(&mut rng); + let b = rng.next_u64() % ((1u64 << 40) - 195); + let expected = a * F40::from_canonical_u64(b); + let reduced = F40::solinas_reduce(a.mul_wide_u64(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn pseudo_mersenne_trait() { + assert_eq!(::MODULUS_BITS, 40); + assert_eq!(::MODULUS_OFFSET, 195); + assert_eq!(::MODULUS_BITS, 64); + assert_eq!(::MODULUS_OFFSET, 59); + } + + #[test] + fn shift_optimization_detected() { + type G = Fp64<{ (1u64 << 56) - 27 }>; // C = 27, not 2^a±1 + assert_eq!(G::C_SHIFT_KIND, 0); + + type H = Fp64<{ u64::MAX - 58 }>; // C = 59, not 2^a±1 + assert_eq!(H::C_SHIFT_KIND, 0); + } + + #[test] + fn reduce_u128_large() { + assert_eq!(F64::from_canonical_u128_reduced(u128::MAX), { + let p = u64::MAX as u128 - 58; + F64::from_canonical_u64((u128::MAX % p) as u64) + }); + } +} diff --git a/src/algebra/fields/lift.rs b/src/algebra/fields/lift.rs new file mode 100644 index 00000000..e05a8aa7 --- /dev/null +++ b/src/algebra/fields/lift.rs @@ -0,0 +1,100 @@ +//! Helpers for embedding base fields into extension fields. + +use crate::algebra::fields::ext::{Fp2, Fp2Config, Fp4, Fp4Config}; +use crate::primitives::serialization::Valid; +use crate::{FieldCore, FromSmallInt}; + +/// Lift a base-field element into an extension field. +/// +/// This is intentionally small: for extension towers we embed into the constant term. +pub trait LiftBase: FieldCore { + /// Embed `x ∈ F` as a constant in `Self`. + fn lift_base(x: F) -> Self; +} + +/// An algebraic extension of base field `F`. +/// +/// Provides the extension degree and a constructor from a slice of base-field +/// coefficients (in the canonical basis `{1, u, u^2, ...}`). +pub trait ExtField: FieldCore + LiftBase + FromSmallInt { + /// Extension degree: `[Self : F]`. + const EXT_DEGREE: usize; + + /// Construct from a coefficient slice `[c0, c1, ..., c_{d-1}]`. + /// + /// # Panics + /// Panics if `coeffs.len() != Self::EXT_DEGREE`. + fn from_base_slice(coeffs: &[F]) -> Self; +} + +impl ExtField for F { + const EXT_DEGREE: usize = 1; + + #[inline] + fn from_base_slice(coeffs: &[F]) -> Self { + assert_eq!(coeffs.len(), 1); + coeffs[0] + } +} + +impl ExtField for Fp2 +where + F: FieldCore + FromSmallInt + Valid, + C: Fp2Config, +{ + const EXT_DEGREE: usize = 2; + + #[inline] + fn from_base_slice(coeffs: &[F]) -> Self { + assert_eq!(coeffs.len(), 2); + Self::new(coeffs[0], coeffs[1]) + } +} + +impl ExtField for Fp4 +where + F: FieldCore + FromSmallInt + Valid, + C2: Fp2Config, + C4: Fp4Config, +{ + const EXT_DEGREE: usize = 4; + + #[inline] + fn from_base_slice(coeffs: &[F]) -> Self { + assert_eq!(coeffs.len(), 4); + Self::new( + Fp2::new(coeffs[0], coeffs[1]), + Fp2::new(coeffs[2], coeffs[3]), + ) + } +} + +impl LiftBase for F { + #[inline] + fn lift_base(x: F) -> Self { + x + } +} + +impl LiftBase for Fp2 +where + F: FieldCore + Valid, + C: Fp2Config, +{ + #[inline] + fn lift_base(x: F) -> Self { + Self::new(x, F::zero()) + } +} + +impl LiftBase for Fp4 +where + F: FieldCore + Valid, + C2: Fp2Config, + C4: Fp4Config, +{ + #[inline] + fn lift_base(x: F) -> Self { + Self::new(Fp2::new(x, F::zero()), Fp2::new(F::zero(), F::zero())) + } +} diff --git a/src/algebra/fields/mod.rs b/src/algebra/fields/mod.rs new file mode 100644 index 00000000..85cf56d0 --- /dev/null +++ b/src/algebra/fields/mod.rs @@ -0,0 +1,43 @@ +//! Prime fields and extension field towers. + +pub mod ext; +pub mod fp128; +pub mod fp32; +pub mod fp64; +pub mod lift; +pub mod packed; +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(target_feature = "avx512f", target_feature = "avx512dq")) +))] +pub mod packed_avx2; +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx512f", + target_feature = "avx512dq" +))] +pub mod packed_avx512; +#[allow(missing_docs)] +pub mod packed_ext; +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +pub mod packed_neon; +pub mod pseudo_mersenne; +pub(crate) mod util; + +pub use ext::{Ext2, Ext4, Fp2, Fp2Config, Fp4, Fp4Config, NegOneNr, TwoNr, UnitNr}; +pub use fp128::{ + Fp128, Prime128M13M4P0, Prime128M37P3P0, Prime128M52M3P0, Prime128M54P4P0, Prime128M8M4M1M0, +}; +pub use fp32::Fp32; +pub use fp64::Fp64; +pub use lift::{ExtField, LiftBase}; +pub use packed::{ + Fp128Packing, Fp32Packing, Fp64Packing, HasPacking, NoPacking, PackedField, PackedValue, +}; +pub use pseudo_mersenne::{ + is_pow2_offset, pow2_offset, pseudo_mersenne_modulus, Pow2Offset128Field, Pow2Offset24Field, + Pow2Offset30Field, Pow2Offset31Field, Pow2Offset32Field, Pow2Offset40Field, Pow2Offset48Field, + Pow2Offset56Field, Pow2Offset64Field, Pow2OffsetPrimeSpec, POW2_OFFSET_IMPLEMENTED_MAX_BITS, + POW2_OFFSET_MAX, POW2_OFFSET_PRIMES, POW2_OFFSET_TABLE, +}; diff --git a/src/algebra/fields/packed.rs b/src/algebra/fields/packed.rs new file mode 100644 index 00000000..eaf15ab2 --- /dev/null +++ b/src/algebra/fields/packed.rs @@ -0,0 +1,434 @@ +//! Packed field abstractions and architecture-specific SIMD backends. + +use crate::algebra::fields::{Fp128, Fp32, Fp64}; +use crate::FieldCore; +use core::ops::{Add, Mul, Sub}; + +/// Array-like packed values over a scalar type. +pub trait PackedValue: 'static + Copy + Send + Sync { + /// Scalar value type carried by each lane. + type Value: 'static + Copy + Send + Sync; + + /// Number of scalar lanes. + const WIDTH: usize; + + /// Build from a lane generator. + fn from_fn(f: F) -> Self + where + F: FnMut(usize) -> Self::Value; + + /// Extract one lane. + fn extract(&self, lane: usize) -> Self::Value; + + /// Pack a scalar slice into packed values. + /// + /// # Panics + /// + /// Panics if the length is not divisible by `WIDTH`. + #[inline] + fn pack_slice(buf: &[Self::Value]) -> Vec { + assert!( + buf.len() % Self::WIDTH == 0, + "slice length {} must be divisible by WIDTH {}", + buf.len(), + Self::WIDTH + ); + buf.chunks_exact(Self::WIDTH) + .map(|chunk| Self::from_fn(|i| chunk[i])) + .collect() + } + + /// Packed prefix + scalar suffix split. + #[inline] + fn pack_slice_with_suffix(buf: &[Self::Value]) -> (Vec, &[Self::Value]) { + let split = buf.len() - (buf.len() % Self::WIDTH); + let (packed, suffix) = buf.split_at(split); + (Self::pack_slice(packed), suffix) + } + + /// Unpack packed values into a flat scalar vector. + #[inline] + fn unpack_slice(buf: &[Self]) -> Vec { + let mut out = Vec::with_capacity(buf.len() * Self::WIDTH); + for packed in buf { + for lane in 0..Self::WIDTH { + out.push(packed.extract(lane)); + } + } + out + } +} + +/// Packed arithmetic over a scalar field. +pub trait PackedField: + PackedValue + Add + Sub + Mul +{ + /// Scalar field type. + type Scalar: FieldCore; + + /// Broadcast one scalar across all lanes. + fn broadcast(value: Self::Scalar) -> Self; +} + +/// Scalar fallback packed type with one lane. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +#[repr(transparent)] +pub struct NoPacking(pub [T; 1]); + +impl PackedValue for NoPacking +where + T: 'static + Copy + Send + Sync, +{ + type Value = T; + const WIDTH: usize = 1; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([f(0)]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert_eq!(lane, 0); + self.0[0] + } +} + +impl Add for NoPacking { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([self.0[0] + rhs.0[0]]) + } +} + +impl Sub for NoPacking { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([self.0[0] - rhs.0[0]]) + } +} + +impl Mul for NoPacking { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + Self([self.0[0] * rhs.0[0]]) + } +} + +impl PackedField for NoPacking { + type Scalar = T; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value]) + } +} + +/// Scalar field -> packed field association. +pub trait HasPacking: FieldCore { + /// Packed representation for this scalar field. + type Packing: PackedField; +} + +/// Selected packed backend for `Fp128`. +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +pub type Fp128Packing = super::packed_neon::PackedFp128Neon

; + +/// Selected packed backend for `Fp128`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx512f", + target_feature = "avx512dq" +))] +pub type Fp128Packing = super::packed_avx512::PackedFp128Avx512

; + +/// Selected packed backend for `Fp128`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(target_feature = "avx512f", target_feature = "avx512dq")) +))] +pub type Fp128Packing = super::packed_avx2::PackedFp128Avx2

; + +/// Selected packed backend for `Fp128`. +#[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "x86_64", target_feature = "avx2") +)))] +pub type Fp128Packing = NoPacking>; + +impl HasPacking for Fp128

{ + type Packing = Fp128Packing

; +} + +/// Selected packed backend for `Fp32`. +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +pub type Fp32Packing = super::packed_neon::PackedFp32Neon

; + +/// Selected packed backend for `Fp32`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx512f", + target_feature = "avx512dq" +))] +pub type Fp32Packing = super::packed_avx512::PackedFp32Avx512

; + +/// Selected packed backend for `Fp32`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(target_feature = "avx512f", target_feature = "avx512dq")) +))] +pub type Fp32Packing = super::packed_avx2::PackedFp32Avx2

; + +/// Selected packed backend for `Fp32`. +#[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "x86_64", target_feature = "avx2") +)))] +pub type Fp32Packing = NoPacking>; + +impl HasPacking for Fp32

{ + type Packing = Fp32Packing

; +} + +/// Selected packed backend for `Fp64`. +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +pub type Fp64Packing = super::packed_neon::PackedFp64Neon

; + +/// Selected packed backend for `Fp64`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx512f", + target_feature = "avx512dq" +))] +pub type Fp64Packing = super::packed_avx512::PackedFp64Avx512

; + +/// Selected packed backend for `Fp64`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(target_feature = "avx512f", target_feature = "avx512dq")) +))] +pub type Fp64Packing = super::packed_avx2::PackedFp64Avx2

; + +/// Selected packed backend for `Fp64`. +#[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "x86_64", target_feature = "avx2") +)))] +pub type Fp64Packing = NoPacking>; + +impl HasPacking for Fp64

{ + type Packing = Fp64Packing

; +} + +#[cfg(test)] +mod tests { + use super::{HasPacking, PackedField, PackedValue}; + use crate::algebra::fields::{ + Pow2Offset24Field, Pow2Offset31Field, Pow2Offset32Field, Pow2Offset40Field, + Pow2Offset64Field, Prime128M13M4P0, + }; + use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt}; + use rand::{rngs::StdRng, RngCore, SeedableRng}; + + fn rand_u128(rng: &mut R) -> u128 { + let lo = rng.next_u64() as u128; + let hi = rng.next_u64() as u128; + lo | (hi << 64) + } + + fn check_packed_add_sub_mul(seed: u64) + where + F: FieldCore + FieldSampling + PartialEq + std::fmt::Debug, + PF: PackedField + PackedValue, + { + let mut rng = StdRng::seed_from_u64(seed); + let len = PF::WIDTH * 17 + 3; + let lhs: Vec = (0..len).map(|_| FieldSampling::sample(&mut rng)).collect(); + let rhs: Vec = (0..len).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let (lhs_p, lhs_s) = PF::pack_slice_with_suffix(&lhs); + let (rhs_p, rhs_s) = PF::pack_slice_with_suffix(&rhs); + + let add_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a + b) + .collect(); + let sub_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a - b) + .collect(); + let mul_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a * b) + .collect(); + + let mut add_out = PF::unpack_slice(&add_p); + let mut sub_out = PF::unpack_slice(&sub_p); + let mut mul_out = PF::unpack_slice(&mul_p); + + for (&a, &b) in lhs_s.iter().zip(rhs_s.iter()) { + add_out.push(a + b); + sub_out.push(a - b); + mul_out.push(a * b); + } + + for i in 0..len { + assert_eq!( + add_out[i], + lhs[i] + rhs[i], + "packed add mismatch at lane {i}" + ); + assert_eq!( + sub_out[i], + lhs[i] - rhs[i], + "packed sub mismatch at lane {i}" + ); + assert_eq!( + mul_out[i], + lhs[i] * rhs[i], + "packed mul mismatch at lane {i}" + ); + } + } + + fn check_broadcast_roundtrip(val: F) + where + F: FieldCore + PartialEq + std::fmt::Debug, + PF: PackedField + PackedValue, + { + let p = PF::broadcast(val); + for lane in 0..PF::WIDTH { + assert_eq!(p.extract(lane), val); + } + } + + #[test] + fn packed_fp128_add_sub_mul_match_scalar() { + type F = Prime128M13M4P0; + type PF = ::Packing; + + let mut rng = StdRng::seed_from_u64(0x55aa_4422_1177_0033); + let len = PF::WIDTH * 17 + 3; + let lhs: Vec = (0..len) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let rhs: Vec = (0..len) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let (lhs_p, lhs_s) = PF::pack_slice_with_suffix(&lhs); + let (rhs_p, rhs_s) = PF::pack_slice_with_suffix(&rhs); + + let add_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a + b) + .collect(); + let sub_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a - b) + .collect(); + let mul_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a * b) + .collect(); + + let mut add_out = PF::unpack_slice(&add_p); + let mut sub_out = PF::unpack_slice(&sub_p); + let mut mul_out = PF::unpack_slice(&mul_p); + + for (&a, &b) in lhs_s.iter().zip(rhs_s.iter()) { + add_out.push(a + b); + sub_out.push(a - b); + mul_out.push(a * b); + } + + for i in 0..len { + assert_eq!( + add_out[i], + lhs[i] + rhs[i], + "packed add mismatch at lane {i}" + ); + assert_eq!( + sub_out[i], + lhs[i] - rhs[i], + "packed sub mismatch at lane {i}" + ); + assert_eq!( + mul_out[i], + lhs[i] * rhs[i], + "packed mul mismatch at lane {i}" + ); + } + } + + #[test] + fn fp128_broadcast_and_extract_roundtrip() { + type F = Prime128M13M4P0; + type PF = ::Packing; + check_broadcast_roundtrip::(F::from_u64(42)); + } + + #[test] + fn packed_fp32_24b_add_sub_mul() { + type F = Pow2Offset24Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa24_bb24_cc24_dd24); + } + + #[test] + fn packed_fp32_31b_add_sub_mul() { + type F = Pow2Offset31Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa31_bb31_cc31_dd31); + } + + #[test] + fn packed_fp32_32b_add_sub_mul() { + type F = Pow2Offset32Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa32_bb32_cc32_dd32); + } + + #[test] + fn fp32_broadcast_and_extract_roundtrip() { + type F = Pow2Offset24Field; + type PF = ::Packing; + check_broadcast_roundtrip::(F::from_u64(42)); + } + + #[test] + fn packed_fp64_40b_add_sub_mul() { + type F = Pow2Offset40Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa40_bb40_cc40_dd40); + } + + #[test] + fn packed_fp64_64b_add_sub_mul() { + type F = Pow2Offset64Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa64_bb64_cc64_dd64); + } + + #[test] + fn fp64_broadcast_and_extract_roundtrip() { + type F = Pow2Offset40Field; + type PF = ::Packing; + check_broadcast_roundtrip::(F::from_u64(42)); + } +} diff --git a/src/algebra/fields/packed_avx2.rs b/src/algebra/fields/packed_avx2.rs new file mode 100644 index 00000000..95468959 --- /dev/null +++ b/src/algebra/fields/packed_avx2.rs @@ -0,0 +1,704 @@ +//! AVX2 packed backends for Fp32, Fp64, Fp128. +//! +//! Techniques adapted from plonky2 (Goldilocks) and plonky3 (Mersenne-31). + +use super::packed::{PackedField, PackedValue}; +use crate::algebra::fields::{Fp128, Fp32, Fp64}; +use core::arch::x86_64::*; +use core::fmt; +use core::mem::transmute; +use core::ops::{Add, Mul, Sub}; + +/// Duplicate high 32 bits of each 64-bit lane into the low 32 bits. +/// Uses the float `movehdup` instruction which runs on port 5 (doesn't compete +/// with multiply on ports 0/1). +#[inline(always)] +unsafe fn movehdup_epi32(x: __m256i) -> __m256i { + _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x))) +} + +#[inline(always)] +unsafe fn moveldup_epi32(x: __m256i) -> __m256i { + _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(x))) +} + +/// 64×64→128 schoolbook multiply using 32×32→64 partial products. +/// Returns (hi, lo) representing the 128-bit product. +#[inline] +unsafe fn mul64_64_256(x: __m256i, y: __m256i) -> (__m256i, __m256i) { + let x_hi = movehdup_epi32(x); + let y_hi = movehdup_epi32(y); + + let mul_ll = _mm256_mul_epu32(x, y); + let mul_lh = _mm256_mul_epu32(x, y_hi); + let mul_hl = _mm256_mul_epu32(x_hi, y); + let mul_hh = _mm256_mul_epu32(x_hi, y_hi); + + let mul_ll_hi = _mm256_srli_epi64::<32>(mul_ll); + let t0 = _mm256_add_epi64(mul_hl, mul_ll_hi); + let mask32 = _mm256_set1_epi64x(0xFFFF_FFFF_i64); + let t0_lo = _mm256_and_si256(t0, mask32); + let t0_hi = _mm256_srli_epi64::<32>(t0); + let t1 = _mm256_add_epi64(mul_lh, t0_lo); + let t2 = _mm256_add_epi64(mul_hh, t0_hi); + let t1_hi = _mm256_srli_epi64::<32>(t1); + let res_hi = _mm256_add_epi64(t2, t1_hi); + + let t1_lo = moveldup_epi32(t1); + let res_lo = _mm256_blend_epi32::<0b10101010>(mul_ll, t1_lo); + + (res_hi, res_lo) +} + +/// Number of `Fp32` lanes in an AVX2 packed vector. +pub const FP32_WIDTH: usize = 8; + +/// AVX2 packed arithmetic for `Fp32

`, processing 8 lanes. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PackedFp32Avx2(pub [Fp32

; FP32_WIDTH]); + +impl PackedFp32Avx2

{ + const BITS: u32 = 32 - P.leading_zeros(); + + const C: u32 = { + let c = if Self::BITS == 32 { + 0u32.wrapping_sub(P) + } else { + (1u32 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u64) * (c as u64 + 1) < P as u64, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + const MASK_U64: u64 = if Self::BITS == 32 { + u32::MAX as u64 + } else { + (1u64 << Self::BITS) - 1 + }; + + #[inline(always)] + fn to_vec(self) -> __m256i { + unsafe { transmute(self) } + } + + #[inline(always)] + unsafe fn from_vec(v: __m256i) -> Self { + unsafe { transmute(v) } + } +} + +impl Default for PackedFp32Avx2

{ + #[inline] + fn default() -> Self { + Self([Fp32(0); FP32_WIDTH]) + } +} + +impl fmt::Debug for PackedFp32Avx2

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp32Avx2").field(&self.0).finish() + } +} + +impl PartialEq for PackedFp32Avx2

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for PackedFp32Avx2

{} + +impl Add for PackedFp32Avx2

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm256_set1_epi32(P as i32); + + let result = if Self::BITS <= 31 { + let t = _mm256_add_epi32(a, b); + let u = _mm256_sub_epi32(t, p); + _mm256_min_epu32(t, u) + } else { + let c = _mm256_set1_epi32(Self::C as i32); + let t = _mm256_add_epi32(a, b); + // Emulate unsigned compare: XOR with 0x80000000 converts u32 compare to i32 + let sign32 = _mm256_set1_epi32(i32::MIN); + let overflow = + _mm256_cmpgt_epi32(_mm256_xor_si256(a, sign32), _mm256_xor_si256(t, sign32)); + // Step 1: correct overflow (2^32 ≡ C mod P) + let t2 = _mm256_add_epi32(t, _mm256_and_si256(overflow, c)); + // Step 2: subtract P if t2 >= P + let r = _mm256_sub_epi32(t2, p); + _mm256_min_epu32(t2, r) + }; + + Self::from_vec(result) + } + } +} + +impl Sub for PackedFp32Avx2

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm256_set1_epi32(P as i32); + + let result = if Self::BITS <= 31 { + let t = _mm256_sub_epi32(a, b); + let u = _mm256_add_epi32(t, p); + _mm256_min_epu32(t, u) + } else { + let t = _mm256_sub_epi32(a, b); + let sign32 = _mm256_set1_epi32(i32::MIN); + let underflow = + _mm256_cmpgt_epi32(_mm256_xor_si256(b, sign32), _mm256_xor_si256(a, sign32)); + let corrected = _mm256_add_epi32(t, p); + _mm256_blendv_epi8(t, corrected, underflow) + }; + + Self::from_vec(result) + } + } +} + +impl Mul for PackedFp32Avx2

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + + let prod_evn = _mm256_mul_epu32(a, b); + let a_odd = movehdup_epi32(a); + let b_odd = movehdup_epi32(b); + let prod_odd = _mm256_mul_epu32(a_odd, b_odd); + + let mask = _mm256_set1_epi64x(Self::MASK_U64 as i64); + let c_vec = _mm256_set1_epi64x(Self::C as i64); + let shift = _mm_set_epi64x(0, Self::BITS as i64); + + // Fold 1 + let evn_lo = _mm256_and_si256(prod_evn, mask); + let evn_hi = _mm256_srl_epi64(prod_evn, shift); + let evn_f1 = _mm256_add_epi64(evn_lo, _mm256_mul_epu32(evn_hi, c_vec)); + + let odd_lo = _mm256_and_si256(prod_odd, mask); + let odd_hi = _mm256_srl_epi64(prod_odd, shift); + let odd_f1 = _mm256_add_epi64(odd_lo, _mm256_mul_epu32(odd_hi, c_vec)); + + // Fold 2 + let evn_f1_lo = _mm256_and_si256(evn_f1, mask); + let evn_f1_hi = _mm256_srl_epi64(evn_f1, shift); + let evn_f2 = _mm256_add_epi64(evn_f1_lo, _mm256_mul_epu32(evn_f1_hi, c_vec)); + + let odd_f1_lo = _mm256_and_si256(odd_f1, mask); + let odd_f1_hi = _mm256_srl_epi64(odd_f1, shift); + let odd_f2 = _mm256_add_epi64(odd_f1_lo, _mm256_mul_epu32(odd_f1_hi, c_vec)); + + // Recombine even/odd: shift odd results into high 32-bit positions, blend. + let odd_shifted = _mm256_slli_epi64::<32>(odd_f2); + let combined = _mm256_blend_epi32::<0b10101010>(evn_f2, odd_shifted); + + // Conditional subtract P + let p_vec = _mm256_set1_epi32(P as i32); + let reduced = _mm256_sub_epi32(combined, p_vec); + Self::from_vec(_mm256_min_epu32(combined, reduced)) + } + } +} + +impl PackedValue for PackedFp32Avx2

{ + type Value = Fp32

; + const WIDTH: usize = FP32_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([f(0), f(1), f(2), f(3), f(4), f(5), f(6), f(7)]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP32_WIDTH); + self.0[lane] + } +} + +impl PackedField for PackedFp32Avx2

{ + type Scalar = Fp32

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value; FP32_WIDTH]) + } +} + +/// Number of `Fp64` lanes in an AVX2 packed vector. +pub const FP64_WIDTH: usize = 4; + +/// AVX2 packed arithmetic for `Fp64

`, processing 4 lanes. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PackedFp64Avx2(pub [Fp64

; FP64_WIDTH]); + +impl PackedFp64Avx2

{ + const BITS: u32 = 64 - P.leading_zeros(); + + const C_LO: u64 = { + let c = if Self::BITS == 64 { + 0u64.wrapping_sub(P) + } else { + (1u64 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + c + }; + + const MASK64: u64 = if Self::BITS < 64 { + (1u64 << Self::BITS) - 1 + } else { + u64::MAX + }; + + #[inline(always)] + fn to_vec(self) -> __m256i { + unsafe { transmute(self) } + } + + #[inline(always)] + unsafe fn from_vec(v: __m256i) -> Self { + unsafe { transmute(v) } + } + + #[inline] + unsafe fn reduce128_vec(hi: __m256i, lo: __m256i) -> __m256i { + if Self::BITS < 64 { + Self::reduce128_small_k(hi, lo) + } else { + Self::reduce128_full_k(hi, lo) + } + } + + /// Reduction for BITS < 64. All intermediates fit in u64 — no overflow. + #[inline] + unsafe fn reduce128_small_k(hi: __m256i, lo: __m256i) -> __m256i { + let mask_k = _mm256_set1_epi64x(Self::MASK64 as i64); + let c_vec = _mm256_set1_epi64x(Self::C_LO as i64); + let p_vec = _mm256_set1_epi64x(P as i64); + let shift_k = _mm_set_epi64x(0, Self::BITS as i64); + let shift_64mk = _mm_set_epi64x(0, (64 - Self::BITS) as i64); + + let lo_k = _mm256_and_si256(lo, mask_k); + let lo_upper = _mm256_srl_epi64(lo, shift_k); + let hi_shifted = _mm256_sll_epi64(hi, shift_64mk); + let hi_k = _mm256_or_si256(lo_upper, hi_shifted); + + let c_hi_lo = _mm256_mul_epu32(c_vec, hi_k); + let hi_k_top = _mm256_srli_epi64::<32>(hi_k); + let c_hi_top = _mm256_mul_epu32(c_vec, hi_k_top); + let c_hi_top_shifted = _mm256_slli_epi64::<32>(c_hi_top); + let c_hi_full = _mm256_add_epi64(c_hi_lo, c_hi_top_shifted); + + let fold1 = _mm256_add_epi64(lo_k, c_hi_full); + + let fold1_lo_k = _mm256_and_si256(fold1, mask_k); + let fold1_hi = _mm256_srl_epi64(fold1, shift_k); + let c_fold1_hi = _mm256_mul_epu32(c_vec, fold1_hi); + let fold2 = _mm256_add_epi64(fold1_lo_k, c_fold1_hi); + + let reduced = _mm256_sub_epi64(fold2, p_vec); + let sign = _mm256_set1_epi64x(i64::MIN); + let fold2_s = _mm256_xor_si256(fold2, sign); + let reduced_s = _mm256_xor_si256(reduced, sign); + let fold2_lt = _mm256_cmpgt_epi64(reduced_s, fold2_s); + _mm256_blendv_epi8(reduced, fold2, fold2_lt) + } + + /// Reduction for BITS == 64. Uses XOR-with-SIGN_BIT trick for unsigned + /// overflow detection. + #[inline] + unsafe fn reduce128_full_k(hi: __m256i, lo: __m256i) -> __m256i { + let c_vec = _mm256_set1_epi64x(Self::C_LO as i64); + let p_vec = _mm256_set1_epi64x(P as i64); + let sign = _mm256_set1_epi64x(i64::MIN); + let c_hi_lo = _mm256_mul_epu32(c_vec, hi); + let hi_hi = _mm256_srli_epi64::<32>(hi); + let c_hi_hi = _mm256_mul_epu32(c_vec, hi_hi); + + let c_hi_hi_lo32 = _mm256_slli_epi64::<32>(c_hi_hi); + let c_hi_carry = _mm256_srli_epi64::<32>(c_hi_hi); + + let sum_lo = _mm256_add_epi64(c_hi_lo, c_hi_hi_lo32); + let c_hi_lo_s = _mm256_xor_si256(c_hi_lo, sign); + let sum_lo_s = _mm256_xor_si256(sum_lo, sign); + let carry0 = _mm256_cmpgt_epi64(c_hi_lo_s, sum_lo_s); + let overflow = _mm256_sub_epi64(c_hi_carry, carry0); + + let s = _mm256_add_epi64(lo, sum_lo); + let lo_s = _mm256_xor_si256(lo, sign); + let s_s = _mm256_xor_si256(s, sign); + let carry1 = _mm256_cmpgt_epi64(lo_s, s_s); + let total_overflow = _mm256_sub_epi64(overflow, carry1); + + let final_corr = _mm256_mul_epu32(c_vec, total_overflow); + let result = _mm256_add_epi64(s, final_corr); + let s2_s = _mm256_xor_si256(s, sign); + let result_s = _mm256_xor_si256(result, sign); + let carry_f = _mm256_cmpgt_epi64(s2_s, result_s); + let corr_f = _mm256_and_si256(carry_f, c_vec); + let result = _mm256_add_epi64(result, corr_f); + + let result_s2 = _mm256_xor_si256(result, sign); + let p_s = _mm256_xor_si256(p_vec, sign); + let lt_p = _mm256_cmpgt_epi64(p_s, result_s2); + let sub_amt = _mm256_andnot_si256(lt_p, p_vec); + _mm256_sub_epi64(result, sub_amt) + } +} + +impl Default for PackedFp64Avx2

{ + #[inline] + fn default() -> Self { + Self([Fp64(0); FP64_WIDTH]) + } +} + +impl fmt::Debug for PackedFp64Avx2

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp64Avx2").field(&self.0).finish() + } +} + +impl PartialEq for PackedFp64Avx2

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for PackedFp64Avx2

{} + +impl Add for PackedFp64Avx2

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm256_set1_epi64x(P as i64); + + let result = if Self::BITS <= 62 { + // a + b < 2P < 2^63: no overflow. + let s = _mm256_add_epi64(a, b); + let r = _mm256_sub_epi64(s, p); + // s < P? Use signed compare after shift trick. + let sign = _mm256_set1_epi64x(i64::MIN); + let s_s = _mm256_xor_si256(s, sign); + let p_s = _mm256_xor_si256(p, sign); + let borrow = _mm256_cmpgt_epi64(p_s, s_s); + _mm256_blendv_epi8(r, s, borrow) + } else { + // a + b can overflow u64. + let s = _mm256_add_epi64(a, b); + let sign = _mm256_set1_epi64x(i64::MIN); + let a_s = _mm256_xor_si256(a, sign); + let s_s = _mm256_xor_si256(s, sign); + let overflow = _mm256_cmpgt_epi64(a_s, s_s); + let c = _mm256_set1_epi64x(Self::C_LO as i64); + let s_plus_c = _mm256_add_epi64(s, c); + let s_minus_p = _mm256_sub_epi64(s, p); + let p_s = _mm256_xor_si256(p, sign); + let lt_p = _mm256_cmpgt_epi64(p_s, s_s); + let no_of = _mm256_blendv_epi8(s_minus_p, s, lt_p); + _mm256_blendv_epi8(no_of, s_plus_c, overflow) + }; + + Self::from_vec(result) + } + } +} + +impl Sub for PackedFp64Avx2

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm256_set1_epi64x(P as i64); + let d = _mm256_sub_epi64(a, b); + + let sign = _mm256_set1_epi64x(i64::MIN); + let a_s = _mm256_xor_si256(a, sign); + let b_s = _mm256_xor_si256(b, sign); + let underflow = _mm256_cmpgt_epi64(b_s, a_s); + let corrected = _mm256_add_epi64(d, p); + Self::from_vec(_mm256_blendv_epi8(d, corrected, underflow)) + } + } +} + +impl Mul for PackedFp64Avx2

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + unsafe { + let (hi, lo) = mul64_64_256(self.to_vec(), rhs.to_vec()); + Self::from_vec(Self::reduce128_vec(hi, lo)) + } + } +} + +impl PackedValue for PackedFp64Avx2

{ + type Value = Fp64

; + const WIDTH: usize = FP64_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([f(0), f(1), f(2), f(3)]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP64_WIDTH); + self.0[lane] + } +} + +impl PackedField for PackedFp64Avx2

{ + type Scalar = Fp64

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value; FP64_WIDTH]) + } +} + +/// Number of `Fp128` lanes in an AVX2 packed vector. +pub const FP128_WIDTH: usize = 4; + +/// AVX2 packed arithmetic for `Fp128

`, 4 lanes in SoA layout. +/// +/// Stores 4 elements as separate `lo` and `hi` `u64` arrays, enabling +/// vectorized add/sub via `__m256i`. Mul remains scalar per-lane. +#[derive(Clone, Copy)] +pub struct PackedFp128Avx2 { + lo: [u64; FP128_WIDTH], + hi: [u64; FP128_WIDTH], +} + +impl PackedFp128Avx2

{ + const P_LO: u64 = P as u64; + const P_HI: u64 = (P >> 64) as u64; +} + +impl Default for PackedFp128Avx2

{ + #[inline] + fn default() -> Self { + Self { + lo: [0; FP128_WIDTH], + hi: [0; FP128_WIDTH], + } + } +} + +impl fmt::Debug for PackedFp128Avx2

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let elems: Vec<_> = (0..FP128_WIDTH).map(|i| self.extract(i)).collect(); + f.debug_tuple("PackedFp128Avx2").field(&elems).finish() + } +} + +impl PartialEq for PackedFp128Avx2

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.lo == other.lo && self.hi == other.hi + } +} + +impl Eq for PackedFp128Avx2

{} + +impl Add for PackedFp128Avx2

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a_lo = _mm256_loadu_si256(self.lo.as_ptr().cast()); + let a_hi = _mm256_loadu_si256(self.hi.as_ptr().cast()); + let b_lo = _mm256_loadu_si256(rhs.lo.as_ptr().cast()); + let b_hi = _mm256_loadu_si256(rhs.hi.as_ptr().cast()); + let p_lo = _mm256_set1_epi64x(Self::P_LO as i64); + let p_hi = _mm256_set1_epi64x(Self::P_HI as i64); + let sign = _mm256_set1_epi64x(i64::MIN); + let one = _mm256_set1_epi64x(1); + + // 128-bit add with unsigned compare emulation (XOR sign bit) + let sum_lo = _mm256_add_epi64(a_lo, b_lo); + let carry_lo = + _mm256_cmpgt_epi64(_mm256_xor_si256(a_lo, sign), _mm256_xor_si256(sum_lo, sign)); + let carry_lo_bit = _mm256_and_si256(carry_lo, one); + + let hi_tmp = _mm256_add_epi64(a_hi, b_hi); + let ov1 = + _mm256_cmpgt_epi64(_mm256_xor_si256(a_hi, sign), _mm256_xor_si256(hi_tmp, sign)); + let sum_hi = _mm256_add_epi64(hi_tmp, carry_lo_bit); + let ov2 = _mm256_cmpgt_epi64( + _mm256_xor_si256(hi_tmp, sign), + _mm256_xor_si256(sum_hi, sign), + ); + let carry_128 = _mm256_or_si256(ov1, ov2); + + // 128-bit subtract P + let red_lo = _mm256_sub_epi64(sum_lo, p_lo); + let borrow_lo = + _mm256_cmpgt_epi64(_mm256_xor_si256(p_lo, sign), _mm256_xor_si256(sum_lo, sign)); + let borrow_lo_bit = _mm256_and_si256(borrow_lo, one); + + let red_hi_tmp = _mm256_sub_epi64(sum_hi, p_hi); + let bw1 = + _mm256_cmpgt_epi64(_mm256_xor_si256(p_hi, sign), _mm256_xor_si256(sum_hi, sign)); + let red_hi = _mm256_sub_epi64(red_hi_tmp, borrow_lo_bit); + let bw2 = _mm256_cmpgt_epi64( + _mm256_xor_si256(borrow_lo_bit, sign), + _mm256_xor_si256(red_hi_tmp, sign), + ); + let borrow = _mm256_or_si256(bw1, bw2); + + // use_reduced = carry_128 | !borrow + let not_borrow = _mm256_xor_si256(borrow, _mm256_set1_epi64x(-1)); + let use_reduced = _mm256_or_si256(carry_128, not_borrow); + let out_lo = _mm256_blendv_epi8(sum_lo, red_lo, use_reduced); + let out_hi = _mm256_blendv_epi8(sum_hi, red_hi, use_reduced); + + let mut result = Self::default(); + _mm256_storeu_si256(result.lo.as_mut_ptr().cast(), out_lo); + _mm256_storeu_si256(result.hi.as_mut_ptr().cast(), out_hi); + result + } + } +} + +impl Sub for PackedFp128Avx2

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a_lo = _mm256_loadu_si256(self.lo.as_ptr().cast()); + let a_hi = _mm256_loadu_si256(self.hi.as_ptr().cast()); + let b_lo = _mm256_loadu_si256(rhs.lo.as_ptr().cast()); + let b_hi = _mm256_loadu_si256(rhs.hi.as_ptr().cast()); + let p_lo = _mm256_set1_epi64x(Self::P_LO as i64); + let p_hi = _mm256_set1_epi64x(Self::P_HI as i64); + let sign = _mm256_set1_epi64x(i64::MIN); + let one = _mm256_set1_epi64x(1); + + // 128-bit sub + let diff_lo = _mm256_sub_epi64(a_lo, b_lo); + let borrow_lo = + _mm256_cmpgt_epi64(_mm256_xor_si256(b_lo, sign), _mm256_xor_si256(a_lo, sign)); + let borrow_lo_bit = _mm256_and_si256(borrow_lo, one); + + let hi_tmp = _mm256_sub_epi64(a_hi, b_hi); + let bw1 = + _mm256_cmpgt_epi64(_mm256_xor_si256(b_hi, sign), _mm256_xor_si256(a_hi, sign)); + let diff_hi = _mm256_sub_epi64(hi_tmp, borrow_lo_bit); + let bw2 = _mm256_cmpgt_epi64( + _mm256_xor_si256(borrow_lo_bit, sign), + _mm256_xor_si256(hi_tmp, sign), + ); + let borrow_128 = _mm256_or_si256(bw1, bw2); + + // Correction: add P back where underflow occurred + let corr_lo = _mm256_add_epi64(diff_lo, p_lo); + let carry_lo = _mm256_cmpgt_epi64( + _mm256_xor_si256(diff_lo, sign), + _mm256_xor_si256(corr_lo, sign), + ); + let carry_lo_bit = _mm256_and_si256(carry_lo, one); + let corr_hi = _mm256_add_epi64(diff_hi, p_hi); + let corr_hi = _mm256_add_epi64(corr_hi, carry_lo_bit); + + let out_lo = _mm256_blendv_epi8(diff_lo, corr_lo, borrow_128); + let out_hi = _mm256_blendv_epi8(diff_hi, corr_hi, borrow_128); + + let mut result = Self::default(); + _mm256_storeu_si256(result.lo.as_mut_ptr().cast(), out_lo); + _mm256_storeu_si256(result.hi.as_mut_ptr().cast(), out_hi); + result + } + } +} + +impl Mul for PackedFp128Avx2

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let mut out = Self::default(); + for i in 0..FP128_WIDTH { + let a = Fp128::

([self.lo[i], self.hi[i]]); + let b = Fp128::

([rhs.lo[i], rhs.hi[i]]); + let r = a * b; + out.lo[i] = r.0[0]; + out.hi[i] = r.0[1]; + } + out + } +} + +impl PackedValue for PackedFp128Avx2

{ + type Value = Fp128

; + const WIDTH: usize = FP128_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + let mut lo = [0u64; FP128_WIDTH]; + let mut hi = [0u64; FP128_WIDTH]; + for i in 0..FP128_WIDTH { + let v = f(i); + lo[i] = v.0[0]; + hi[i] = v.0[1]; + } + Self { lo, hi } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP128_WIDTH); + Fp128([self.lo[lane], self.hi[lane]]) + } +} + +impl PackedField for PackedFp128Avx2

{ + type Scalar = Fp128

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self { + lo: [value.0[0]; FP128_WIDTH], + hi: [value.0[1]; FP128_WIDTH], + } + } +} diff --git a/src/algebra/fields/packed_avx512.rs b/src/algebra/fields/packed_avx512.rs new file mode 100644 index 00000000..1102ac95 --- /dev/null +++ b/src/algebra/fields/packed_avx512.rs @@ -0,0 +1,666 @@ +//! AVX-512 packed backends for Fp32, Fp64, Fp128. +//! +//! Requires AVX-512F + AVX-512DQ. Uses native unsigned comparisons and mask +//! registers for branchless conditionals. + +use super::packed::{PackedField, PackedValue}; +use crate::algebra::fields::{Fp128, Fp32, Fp64}; +use crate::FieldCore; +use core::arch::x86_64::*; +use core::fmt; +use core::mem::transmute; +use core::ops::{Add, Mul, Sub}; + +#[inline(always)] +unsafe fn movehdup_epi32_512(x: __m512i) -> __m512i { + _mm512_castps_si512(_mm512_movehdup_ps(_mm512_castsi512_ps(x))) +} + +#[inline(always)] +unsafe fn moveldup_epi32_512(x: __m512i) -> __m512i { + _mm512_castps_si512(_mm512_moveldup_ps(_mm512_castsi512_ps(x))) +} + +/// 64×64→128 schoolbook multiply using 32×32→64 partial products. +/// Returns (hi, lo) representing the 128-bit product. +/// Adapted from plonky3's Goldilocks AVX-512 backend. +#[inline] +unsafe fn mul64_64_512(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + let x_hi = movehdup_epi32_512(x); + let y_hi = movehdup_epi32_512(y); + + let mul_ll = _mm512_mul_epu32(x, y); + let mul_lh = _mm512_mul_epu32(x, y_hi); + let mul_hl = _mm512_mul_epu32(x_hi, y); + let mul_hh = _mm512_mul_epu32(x_hi, y_hi); + + let mul_ll_hi = _mm512_srli_epi64::<32>(mul_ll); + let t0 = _mm512_add_epi64(mul_hl, mul_ll_hi); + let mask32 = _mm512_set1_epi64(0xFFFF_FFFF_i64); + let t0_lo = _mm512_and_si512(t0, mask32); + let t0_hi = _mm512_srli_epi64::<32>(t0); + let t1 = _mm512_add_epi64(mul_lh, t0_lo); + let t2 = _mm512_add_epi64(mul_hh, t0_hi); + let t1_hi = _mm512_srli_epi64::<32>(t1); + let res_hi = _mm512_add_epi64(t2, t1_hi); + + let t1_lo = moveldup_epi32_512(t1); + let res_lo = _mm512_mask_blend_epi32(0b0101_0101_0101_0101, t1_lo, mul_ll); + + (res_hi, res_lo) +} + +/// Number of `Fp32` lanes in an AVX-512 packed vector. +pub const FP32_WIDTH: usize = 16; + +/// AVX-512 packed arithmetic for `Fp32

`, processing 16 lanes. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PackedFp32Avx512(pub [Fp32

; FP32_WIDTH]); + +impl PackedFp32Avx512

{ + const BITS: u32 = 32 - P.leading_zeros(); + + const C: u32 = { + let c = if Self::BITS == 32 { + 0u32.wrapping_sub(P) + } else { + (1u32 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u64) * (c as u64 + 1) < P as u64, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + const MASK_U64: u64 = if Self::BITS == 32 { + u32::MAX as u64 + } else { + (1u64 << Self::BITS) - 1 + }; + + #[inline(always)] + fn to_vec(self) -> __m512i { + unsafe { transmute(self) } + } + + #[inline(always)] + unsafe fn from_vec(v: __m512i) -> Self { + unsafe { transmute(v) } + } +} + +impl Default for PackedFp32Avx512

{ + #[inline] + fn default() -> Self { + Self([Fp32(0); FP32_WIDTH]) + } +} + +impl fmt::Debug for PackedFp32Avx512

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp32Avx512").field(&self.0).finish() + } +} + +impl PartialEq for PackedFp32Avx512

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for PackedFp32Avx512

{} + +impl Add for PackedFp32Avx512

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm512_set1_epi32(P as i32); + + let result = if Self::BITS <= 31 { + let t = _mm512_add_epi32(a, b); + let u = _mm512_sub_epi32(t, p); + _mm512_min_epu32(t, u) + } else { + let c = _mm512_set1_epi32(Self::C as i32); + let t = _mm512_add_epi32(a, b); + // Step 1: correct overflow (2^32 ≡ C mod P) + let overflow = _mm512_cmplt_epu32_mask(t, a); + let t2 = _mm512_mask_add_epi32(t, overflow, t, c); + // Step 2: subtract P if t2 >= P + let geq_p = _mm512_cmpge_epu32_mask(t2, p); + _mm512_mask_sub_epi32(t2, geq_p, t2, p) + }; + + Self::from_vec(result) + } + } +} + +impl Sub for PackedFp32Avx512

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm512_set1_epi32(P as i32); + + let result = if Self::BITS <= 31 { + let t = _mm512_sub_epi32(a, b); + let u = _mm512_add_epi32(t, p); + _mm512_min_epu32(t, u) + } else { + let t = _mm512_sub_epi32(a, b); + let underflow = _mm512_cmplt_epu32_mask(a, b); + _mm512_mask_add_epi32(t, underflow, t, p) + }; + + Self::from_vec(result) + } + } +} + +impl Mul for PackedFp32Avx512

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + + let prod_evn = _mm512_mul_epu32(a, b); + let a_odd = movehdup_epi32_512(a); + let b_odd = movehdup_epi32_512(b); + let prod_odd = _mm512_mul_epu32(a_odd, b_odd); + + let mask = _mm512_set1_epi64(Self::MASK_U64 as i64); + let c_vec = _mm512_set1_epi64(Self::C as i64); + let shift = _mm_set_epi64x(0, Self::BITS as i64); + + // Fold 1 + let evn_lo = _mm512_and_si512(prod_evn, mask); + let evn_hi = _mm512_srl_epi64(prod_evn, shift); + let evn_f1 = _mm512_add_epi64(evn_lo, _mm512_mul_epu32(evn_hi, c_vec)); + + let odd_lo = _mm512_and_si512(prod_odd, mask); + let odd_hi = _mm512_srl_epi64(prod_odd, shift); + let odd_f1 = _mm512_add_epi64(odd_lo, _mm512_mul_epu32(odd_hi, c_vec)); + + // Fold 2 + let evn_f1_lo = _mm512_and_si512(evn_f1, mask); + let evn_f1_hi = _mm512_srl_epi64(evn_f1, shift); + let evn_f2 = _mm512_add_epi64(evn_f1_lo, _mm512_mul_epu32(evn_f1_hi, c_vec)); + + let odd_f1_lo = _mm512_and_si512(odd_f1, mask); + let odd_f1_hi = _mm512_srl_epi64(odd_f1, shift); + let odd_f2 = _mm512_add_epi64(odd_f1_lo, _mm512_mul_epu32(odd_f1_hi, c_vec)); + + // Recombine even/odd + let odd_shifted = _mm512_slli_epi64::<32>(odd_f2); + let combined = _mm512_mask_blend_epi32(0b1010101010101010, evn_f2, odd_shifted); + + // Conditional subtract P + let p_vec = _mm512_set1_epi32(P as i32); + let reduced = _mm512_sub_epi32(combined, p_vec); + Self::from_vec(_mm512_min_epu32(combined, reduced)) + } + } +} + +impl PackedValue for PackedFp32Avx512

{ + type Value = Fp32

; + const WIDTH: usize = FP32_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([ + f(0), + f(1), + f(2), + f(3), + f(4), + f(5), + f(6), + f(7), + f(8), + f(9), + f(10), + f(11), + f(12), + f(13), + f(14), + f(15), + ]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP32_WIDTH); + self.0[lane] + } +} + +impl PackedField for PackedFp32Avx512

{ + type Scalar = Fp32

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value; FP32_WIDTH]) + } +} + +/// Number of `Fp64` lanes in an AVX-512 packed vector. +pub const FP64_WIDTH: usize = 8; + +/// AVX-512 packed arithmetic for `Fp64

`, processing 8 lanes. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PackedFp64Avx512(pub [Fp64

; FP64_WIDTH]); + +impl PackedFp64Avx512

{ + const BITS: u32 = 64 - P.leading_zeros(); + + const C_LO: u64 = { + let c = if Self::BITS == 64 { + 0u64.wrapping_sub(P) + } else { + (1u64 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + c + }; + + const MASK64: u64 = if Self::BITS < 64 { + (1u64 << Self::BITS) - 1 + } else { + u64::MAX + }; + + #[inline(always)] + fn to_vec(self) -> __m512i { + unsafe { transmute(self) } + } + + #[inline(always)] + unsafe fn from_vec(v: __m512i) -> Self { + unsafe { transmute(v) } + } + + /// Vectorized 128-bit Solinas reduction for p = 2^BITS - C. + /// Given (hi, lo) = 128-bit product, computes result ≡ (hi*2^64 + lo) mod p. + #[inline] + unsafe fn reduce128_vec(hi: __m512i, lo: __m512i) -> __m512i { + if Self::BITS < 64 { + Self::reduce128_small_k(hi, lo) + } else { + Self::reduce128_full_k(hi, lo) + } + } + + /// Reduction for BITS < 64 (e.g. 40-bit prime). No overflow issues: all + /// intermediates fit in u64. + #[inline] + unsafe fn reduce128_small_k(hi: __m512i, lo: __m512i) -> __m512i { + let mask_k = _mm512_set1_epi64(Self::MASK64 as i64); + let c_vec = _mm512_set1_epi64(Self::C_LO as i64); + let p_vec = _mm512_set1_epi64(P as i64); + let shift_k = _mm_set_epi64x(0, Self::BITS as i64); + let shift_64mk = _mm_set_epi64x(0, (64 - Self::BITS) as i64); + + let lo_k = _mm512_and_si512(lo, mask_k); + let lo_upper = _mm512_srl_epi64(lo, shift_k); + let hi_shifted = _mm512_sll_epi64(hi, shift_64mk); + let hi_k = _mm512_or_si512(lo_upper, hi_shifted); + + // c * hi_k: hi_k may exceed 32 bits, split into lo32 and top + let c_hi_lo = _mm512_mul_epu32(c_vec, hi_k); + let hi_k_top = _mm512_srli_epi64::<32>(hi_k); + let c_hi_top = _mm512_mul_epu32(c_vec, hi_k_top); + let c_hi_top_shifted = _mm512_slli_epi64::<32>(c_hi_top); + let c_hi_full = _mm512_add_epi64(c_hi_lo, c_hi_top_shifted); + + let fold1 = _mm512_add_epi64(lo_k, c_hi_full); + + let fold1_lo_k = _mm512_and_si512(fold1, mask_k); + let fold1_hi = _mm512_srl_epi64(fold1, shift_k); + let c_fold1_hi = _mm512_mul_epu32(c_vec, fold1_hi); + let fold2 = _mm512_add_epi64(fold1_lo_k, c_fold1_hi); + + let reduced = _mm512_sub_epi64(fold2, p_vec); + _mm512_min_epu64(fold2, reduced) + } + + /// Reduction for BITS == 64 (e.g. p = 2^64 - 87). Tracks overflow from + /// c*hi exceeding 64 bits, using native unsigned comparisons. + #[inline] + unsafe fn reduce128_full_k(hi: __m512i, lo: __m512i) -> __m512i { + let c_vec = _mm512_set1_epi64(Self::C_LO as i64); + let p_vec = _mm512_set1_epi64(P as i64); + let one = _mm512_set1_epi64(1); + + // c * hi_lo32 + let c_hi_lo = _mm512_mul_epu32(c_vec, hi); + // c * hi_hi32 + let hi_hi = _mm512_srli_epi64::<32>(hi); + let c_hi_hi = _mm512_mul_epu32(c_vec, hi_hi); + + let c_hi_hi_lo32 = _mm512_slli_epi64::<32>(c_hi_hi); + let c_hi_carry = _mm512_srli_epi64::<32>(c_hi_hi); + + // Lower 64 bits of c * hi + let sum_lo = _mm512_add_epi64(c_hi_lo, c_hi_hi_lo32); + let carry0 = _mm512_cmplt_epu64_mask(sum_lo, c_hi_lo); + let overflow = _mm512_mask_add_epi64(c_hi_carry, carry0, c_hi_carry, one); + + // lo + sum_lo + let s = _mm512_add_epi64(lo, sum_lo); + let carry1 = _mm512_cmplt_epu64_mask(s, lo); + let total_overflow = _mm512_mask_add_epi64(overflow, carry1, overflow, one); + + // Fold overflow: total_overflow * c (at most ~2^15) + let final_corr = _mm512_mul_epu32(c_vec, total_overflow); + let result = _mm512_add_epi64(s, final_corr); + let carry_f = _mm512_cmplt_epu64_mask(result, s); + let result = _mm512_mask_add_epi64(result, carry_f, result, c_vec); + + let ge_mask = _mm512_cmpge_epu64_mask(result, p_vec); + _mm512_mask_sub_epi64(result, ge_mask, result, p_vec) + } +} + +impl Default for PackedFp64Avx512

{ + #[inline] + fn default() -> Self { + Self([Fp64(0); FP64_WIDTH]) + } +} + +impl fmt::Debug for PackedFp64Avx512

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp64Avx512").field(&self.0).finish() + } +} + +impl PartialEq for PackedFp64Avx512

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for PackedFp64Avx512

{} + +impl Add for PackedFp64Avx512

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm512_set1_epi64(P as i64); + + let result = if Self::BITS <= 62 { + let s = _mm512_add_epi64(a, b); + let geq_p = _mm512_cmpge_epu64_mask(s, p); + _mm512_mask_sub_epi64(s, geq_p, s, p) + } else { + let s = _mm512_add_epi64(a, b); + let overflow = _mm512_cmplt_epu64_mask(s, a); + let c = _mm512_set1_epi64(Self::C_LO as i64); + let geq_p = _mm512_cmpge_epu64_mask(s, p); + let no_of = _mm512_mask_sub_epi64(s, geq_p, s, p); + let s_plus_c = _mm512_add_epi64(s, c); + _mm512_mask_blend_epi64(overflow, no_of, s_plus_c) + }; + + Self::from_vec(result) + } + } +} + +impl Sub for PackedFp64Avx512

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm512_set1_epi64(P as i64); + let d = _mm512_sub_epi64(a, b); + let underflow = _mm512_cmplt_epu64_mask(a, b); + Self::from_vec(_mm512_mask_add_epi64(d, underflow, d, p)) + } + } +} + +impl Mul for PackedFp64Avx512

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + unsafe { + let (hi, lo) = mul64_64_512(self.to_vec(), rhs.to_vec()); + Self::from_vec(Self::reduce128_vec(hi, lo)) + } + } +} + +impl PackedValue for PackedFp64Avx512

{ + type Value = Fp64

; + const WIDTH: usize = FP64_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([f(0), f(1), f(2), f(3), f(4), f(5), f(6), f(7)]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP64_WIDTH); + self.0[lane] + } +} + +impl PackedField for PackedFp64Avx512

{ + type Scalar = Fp64

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value; FP64_WIDTH]) + } +} + +/// Number of `Fp128` lanes in an AVX-512 packed vector. +pub const FP128_WIDTH: usize = 8; + +/// AVX-512 packed arithmetic for `Fp128

`, 8 lanes in SoA layout. +/// +/// Stores 8 elements as separate `lo` and `hi` `u64` arrays, enabling +/// vectorized add/sub via `__m512i`. Mul remains scalar per-lane. +#[derive(Clone, Copy)] +pub struct PackedFp128Avx512 { + lo: [u64; FP128_WIDTH], + hi: [u64; FP128_WIDTH], +} + +impl PackedFp128Avx512

{ + const P_LO: u64 = P as u64; + const P_HI: u64 = (P >> 64) as u64; +} + +impl Default for PackedFp128Avx512

{ + #[inline] + fn default() -> Self { + Self { + lo: [0; FP128_WIDTH], + hi: [0; FP128_WIDTH], + } + } +} + +impl fmt::Debug for PackedFp128Avx512

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let elems: Vec<_> = (0..FP128_WIDTH).map(|i| self.extract(i)).collect(); + f.debug_tuple("PackedFp128Avx512").field(&elems).finish() + } +} + +impl PartialEq for PackedFp128Avx512

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.lo == other.lo && self.hi == other.hi + } +} + +impl Eq for PackedFp128Avx512

{} + +impl Add for PackedFp128Avx512

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a_lo = _mm512_loadu_si512(self.lo.as_ptr().cast()); + let a_hi = _mm512_loadu_si512(self.hi.as_ptr().cast()); + let b_lo = _mm512_loadu_si512(rhs.lo.as_ptr().cast()); + let b_hi = _mm512_loadu_si512(rhs.hi.as_ptr().cast()); + let p_lo = _mm512_set1_epi64(Self::P_LO as i64); + let p_hi = _mm512_set1_epi64(Self::P_HI as i64); + let one = _mm512_set1_epi64(1); + + // 128-bit add: (sum_hi, sum_lo) = (a_hi, a_lo) + (b_hi, b_lo) + let sum_lo = _mm512_add_epi64(a_lo, b_lo); + let carry_lo = _mm512_cmplt_epu64_mask(sum_lo, a_lo); + let hi_tmp = _mm512_add_epi64(a_hi, b_hi); + let ov1 = _mm512_cmplt_epu64_mask(hi_tmp, a_hi); + let sum_hi = _mm512_mask_add_epi64(hi_tmp, carry_lo, hi_tmp, one); + let ov2 = _mm512_cmplt_epu64_mask(sum_hi, hi_tmp); + let carry_128 = ov1 | ov2; + + // 128-bit subtract P: (red_hi, red_lo) = (sum_hi, sum_lo) - P + let red_lo = _mm512_sub_epi64(sum_lo, p_lo); + let borrow_lo = _mm512_cmplt_epu64_mask(sum_lo, p_lo); + let red_hi_tmp = _mm512_sub_epi64(sum_hi, p_hi); + let bw1 = _mm512_cmplt_epu64_mask(sum_hi, p_hi); + let red_hi = _mm512_mask_sub_epi64(red_hi_tmp, borrow_lo, red_hi_tmp, one); + let bw2 = _mm512_cmplt_epu64_mask(red_hi_tmp, _mm512_maskz_mov_epi64(borrow_lo, one)); + let borrow = bw1 | bw2; + + // Use reduced if: overflow happened OR subtraction didn't borrow + let use_reduced = carry_128 | !borrow; + let out_lo = _mm512_mask_blend_epi64(use_reduced, sum_lo, red_lo); + let out_hi = _mm512_mask_blend_epi64(use_reduced, sum_hi, red_hi); + + let mut result = Self::default(); + _mm512_storeu_si512(result.lo.as_mut_ptr().cast(), out_lo); + _mm512_storeu_si512(result.hi.as_mut_ptr().cast(), out_hi); + result + } + } +} + +impl Sub for PackedFp128Avx512

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a_lo = _mm512_loadu_si512(self.lo.as_ptr().cast()); + let a_hi = _mm512_loadu_si512(self.hi.as_ptr().cast()); + let b_lo = _mm512_loadu_si512(rhs.lo.as_ptr().cast()); + let b_hi = _mm512_loadu_si512(rhs.hi.as_ptr().cast()); + let p_lo = _mm512_set1_epi64(Self::P_LO as i64); + let p_hi = _mm512_set1_epi64(Self::P_HI as i64); + let one = _mm512_set1_epi64(1); + + // 128-bit sub: (diff_hi, diff_lo) = (a_hi, a_lo) - (b_hi, b_lo) + let diff_lo = _mm512_sub_epi64(a_lo, b_lo); + let borrow_lo = _mm512_cmplt_epu64_mask(a_lo, b_lo); + let hi_tmp = _mm512_sub_epi64(a_hi, b_hi); + let bw1 = _mm512_cmplt_epu64_mask(a_hi, b_hi); + let diff_hi = _mm512_mask_sub_epi64(hi_tmp, borrow_lo, hi_tmp, one); + let bw2 = _mm512_cmplt_epu64_mask(hi_tmp, _mm512_maskz_mov_epi64(borrow_lo, one)); + let borrow_128 = bw1 | bw2; + + // Correction: add P back where underflow occurred + let corr_lo = _mm512_add_epi64(diff_lo, p_lo); + let carry_lo = _mm512_cmplt_epu64_mask(corr_lo, diff_lo); + let corr_hi = _mm512_add_epi64(diff_hi, p_hi); + let corr_hi = _mm512_mask_add_epi64(corr_hi, carry_lo, corr_hi, one); + + let out_lo = _mm512_mask_blend_epi64(borrow_128, diff_lo, corr_lo); + let out_hi = _mm512_mask_blend_epi64(borrow_128, diff_hi, corr_hi); + + let mut result = Self::default(); + _mm512_storeu_si512(result.lo.as_mut_ptr().cast(), out_lo); + _mm512_storeu_si512(result.hi.as_mut_ptr().cast(), out_hi); + result + } + } +} + +impl Mul for PackedFp128Avx512

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let mut out = Self::default(); + for i in 0..FP128_WIDTH { + let a = Fp128::

([self.lo[i], self.hi[i]]); + let b = Fp128::

([rhs.lo[i], rhs.hi[i]]); + let r = a * b; + out.lo[i] = r.0[0]; + out.hi[i] = r.0[1]; + } + out + } +} + +impl PackedValue for PackedFp128Avx512

{ + type Value = Fp128

; + const WIDTH: usize = FP128_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + let mut lo = [0u64; FP128_WIDTH]; + let mut hi = [0u64; FP128_WIDTH]; + for i in 0..FP128_WIDTH { + let v = f(i); + lo[i] = v.0[0]; + hi[i] = v.0[1]; + } + Self { lo, hi } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP128_WIDTH); + Fp128([self.lo[lane], self.hi[lane]]) + } +} + +impl PackedField for PackedFp128Avx512

{ + type Scalar = Fp128

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self { + lo: [value.0[0]; FP128_WIDTH], + hi: [value.0[1]; FP128_WIDTH], + } + } +} diff --git a/src/algebra/fields/packed_ext.rs b/src/algebra/fields/packed_ext.rs new file mode 100644 index 00000000..5bf061d9 --- /dev/null +++ b/src/algebra/fields/packed_ext.rs @@ -0,0 +1,415 @@ +//! Packed extension field types using transpose-based packing. +//! +//! A `PackedFp2` stores `[PF; 2]` where `PF` is the packed base field. +//! Each `PF` lane contains the corresponding coefficient of an `Fp2` element. +//! This enables WIDTH-fold parallel arithmetic over `Fp2` using existing SIMD +//! base-field operations. + +use crate::algebra::fields::ext::{Fp2, Fp2Config, Fp4, Fp4Config}; +use crate::algebra::fields::packed::{HasPacking, PackedField, PackedValue}; +use crate::primitives::serialization::Valid; +use crate::FieldCore; +use core::ops::{Add, Mul, Sub}; + +/// Packed `Fp2` elements stored in transpose layout: `[PF; 2]`. +/// +/// If `PF` has width `W`, this represents `W` parallel `Fp2` values. +pub struct PackedFp2, PF: PackedField> { + pub c0: PF, + pub c1: PF, + _marker: std::marker::PhantomData (F, C)>, +} + +impl, PF: PackedField> Clone for PackedFp2 { + fn clone(&self) -> Self { + *self + } +} + +impl, PF: PackedField> Copy for PackedFp2 {} + +impl, PF: PackedField> std::fmt::Debug + for PackedFp2 +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PackedFp2").finish_non_exhaustive() + } +} + +impl, PF: PackedField> PackedFp2 { + #[inline] + pub fn new(c0: PF, c1: PF) -> Self { + Self { + c0, + c1, + _marker: std::marker::PhantomData, + } + } + + #[inline] + fn mul_nr(x: PF) -> PF { + if C::IS_NEG_ONE { + let zero = PF::broadcast(F::zero()); + zero - x + } else { + PF::broadcast(C::non_residue()) * x + } + } +} + +impl PackedValue for PackedFp2 +where + F: FieldCore + Valid + 'static, + C: Fp2Config + 'static, + PF: PackedField, +{ + type Value = Fp2; + const WIDTH: usize = PF::WIDTH; + + fn from_fn(mut f: G) -> Self + where + G: FnMut(usize) -> Self::Value, + { + let mut c0s = Vec::with_capacity(PF::WIDTH); + let mut c1s = Vec::with_capacity(PF::WIDTH); + for i in 0..PF::WIDTH { + let val = f(i); + c0s.push(val.c0); + c1s.push(val.c1); + } + Self::new(PF::from_fn(|i| c0s[i]), PF::from_fn(|i| c1s[i])) + } + + fn extract(&self, lane: usize) -> Self::Value { + Fp2::new(self.c0.extract(lane), self.c1.extract(lane)) + } +} + +impl Add for PackedFp2 +where + F: FieldCore, + C: Fp2Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self::new(self.c0 + rhs.c0, self.c1 + rhs.c1) + } +} + +impl Sub for PackedFp2 +where + F: FieldCore, + C: Fp2Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::new(self.c0 - rhs.c0, self.c1 - rhs.c1) + } +} + +impl Mul for PackedFp2 +where + F: FieldCore, + C: Fp2Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let v0 = self.c0 * rhs.c0; + let v1 = self.c1 * rhs.c1; + Self::new( + v0 + Self::mul_nr(v1), + (self.c0 + self.c1) * (rhs.c0 + rhs.c1) - v0 - v1, + ) + } +} + +impl PackedField for PackedFp2 +where + F: FieldCore + Valid + 'static, + C: Fp2Config + 'static, + PF: PackedField, +{ + type Scalar = Fp2; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self::new(PF::broadcast(value.c0), PF::broadcast(value.c1)) + } +} + +impl HasPacking for Fp2 +where + F: FieldCore + Valid + HasPacking + 'static, + C: Fp2Config + 'static, +{ + type Packing = PackedFp2; +} + +/// Packed `Fp4` elements stored in transpose layout: `[PackedFp2; 2]`. +pub struct PackedFp4< + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +> { + pub c0: PackedFp2, + pub c1: PackedFp2, + _marker: std::marker::PhantomData C4>, +} + +impl Clone for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + fn clone(&self) -> Self { + *self + } +} + +impl Copy for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ +} + +impl std::fmt::Debug for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PackedFp4").finish_non_exhaustive() + } +} + +impl PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + #[inline] + pub fn new(c0: PackedFp2, c1: PackedFp2) -> Self { + Self { + c0, + c1, + _marker: std::marker::PhantomData, + } + } +} + +impl PackedValue for PackedFp4 +where + F: FieldCore + Valid + 'static, + C2: Fp2Config + 'static, + C4: Fp4Config + 'static, + PF: PackedField, +{ + type Value = Fp4; + const WIDTH: usize = PF::WIDTH; + + fn from_fn(mut f: G) -> Self + where + G: FnMut(usize) -> Self::Value, + { + let mut c0s: Vec> = Vec::with_capacity(PF::WIDTH); + let mut c1s: Vec> = Vec::with_capacity(PF::WIDTH); + for i in 0..PF::WIDTH { + let val = f(i); + c0s.push(val.c0); + c1s.push(val.c1); + } + Self::new( + PackedFp2::from_fn(|i| c0s[i]), + PackedFp2::from_fn(|i| c1s[i]), + ) + } + + fn extract(&self, lane: usize) -> Self::Value { + Fp4::new(self.c0.extract(lane), self.c1.extract(lane)) + } +} + +impl Add for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self::new(self.c0 + rhs.c0, self.c1 + rhs.c1) + } +} + +impl Sub for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::new(self.c0 - rhs.c0, self.c1 - rhs.c1) + } +} + +impl Mul for PackedFp4 +where + F: FieldCore + Valid + 'static, + C2: Fp2Config + 'static, + C4: Fp4Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let nr2 = PackedFp2::broadcast(C4::non_residue()); + let v0 = self.c0 * rhs.c0; + let v1 = self.c1 * rhs.c1; + Self::new( + v0 + nr2 * v1, + (self.c0 + self.c1) * (rhs.c0 + rhs.c1) - v0 - v1, + ) + } +} + +impl PackedField for PackedFp4 +where + F: FieldCore + Valid + 'static, + C2: Fp2Config + 'static, + C4: Fp4Config + 'static, + PF: PackedField, +{ + type Scalar = Fp4; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self::new( + PackedFp2::broadcast(value.c0), + PackedFp2::broadcast(value.c1), + ) + } +} + +impl HasPacking for Fp4 +where + F: FieldCore + Valid + HasPacking + 'static, + C2: Fp2Config + 'static, + C4: Fp4Config + 'static, +{ + type Packing = PackedFp4; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::ext::{Ext2, Ext4, TwoNr, UnitNr}; + use crate::algebra::Fp64; + use crate::{FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp64<4294967197>; + type E2 = Ext2; + type E4 = Ext4; + type PE2 = PackedFp2::Packing>; + type PE4 = PackedFp4::Packing>; + + #[test] + fn packed_fp2_add() { + let mut rng = StdRng::seed_from_u64(100); + let width = ::WIDTH; + let a_elems: Vec = (0..width).map(|_| E2::sample(&mut rng)).collect(); + let b_elems: Vec = (0..width).map(|_| E2::sample(&mut rng)).collect(); + + let pa = PE2::from_fn(|i| a_elems[i]); + let pb = PE2::from_fn(|i| b_elems[i]); + let pc = pa + pb; + + for i in 0..width { + assert_eq!(pc.extract(i), a_elems[i] + b_elems[i]); + } + } + + #[test] + fn packed_fp2_mul() { + let mut rng = StdRng::seed_from_u64(200); + let width = ::WIDTH; + let a_elems: Vec = (0..width).map(|_| E2::sample(&mut rng)).collect(); + let b_elems: Vec = (0..width).map(|_| E2::sample(&mut rng)).collect(); + + let pa = PE2::from_fn(|i| a_elems[i]); + let pb = PE2::from_fn(|i| b_elems[i]); + let pc = pa * pb; + + for i in 0..width { + assert_eq!( + pc.extract(i), + a_elems[i] * b_elems[i], + "packed Fp2 mul mismatch at lane {i}" + ); + } + } + + #[test] + fn packed_fp2_broadcast() { + let val = E2::new(F::from_u64(7), F::from_u64(11)); + let packed = PE2::broadcast(val); + let width = ::WIDTH; + for i in 0..width { + assert_eq!(packed.extract(i), val); + } + } + + #[test] + fn packed_fp4_mul() { + let mut rng = StdRng::seed_from_u64(300); + let width = ::WIDTH; + let a_elems: Vec = (0..width).map(|_| E4::sample(&mut rng)).collect(); + let b_elems: Vec = (0..width).map(|_| E4::sample(&mut rng)).collect(); + + let pa = PE4::from_fn(|i| a_elems[i]); + let pb = PE4::from_fn(|i| b_elems[i]); + let pc = pa * pb; + + for i in 0..width { + assert_eq!( + pc.extract(i), + a_elems[i] * b_elems[i], + "packed Fp4 mul mismatch at lane {i}" + ); + } + } + + #[test] + fn pack_unpack_roundtrip_fp2() { + let mut rng = StdRng::seed_from_u64(400); + let width = ::WIDTH; + let elems: Vec = (0..width * 3).map(|_| E2::sample(&mut rng)).collect(); + + let packed = PE2::pack_slice(&elems); + let unpacked = PE2::unpack_slice(&packed); + + assert_eq!(elems, unpacked); + } +} diff --git a/src/algebra/fields/packed_neon.rs b/src/algebra/fields/packed_neon.rs new file mode 100644 index 00000000..a50f73a0 --- /dev/null +++ b/src/algebra/fields/packed_neon.rs @@ -0,0 +1,718 @@ +//! AArch64 NEON packed backends for Fp32, Fp64, Fp128. + +use super::packed::{PackedField, PackedValue}; +use crate::algebra::fields::{Fp128, Fp32, Fp64}; +use crate::FieldCore; +use core::arch::aarch64::{ + uint32x2_t, uint32x4_t, uint64x2_t, vaddq_u32, vaddq_u64, vandq_u64, vbslq_u32, vbslq_u64, + vcgtq_u64, vcltq_u32, vcltq_u64, vcombine_u32, vdup_n_u32, vdupq_n_s64, vdupq_n_u32, + vdupq_n_u64, veorq_u64, vget_low_u32, vminq_u32, vmovn_u64, vmull_high_u32, vmull_u32, + vorrq_u64, vshlq_u64, vsubq_u32, vsubq_u64, +}; +use core::fmt; +use core::mem::transmute; +use core::ops::{Add, Mul, Sub}; + +/// Number of packed `Fp128` lanes in this backend. +pub const WIDTH: usize = 2; + +/// True SoA layout for two packed `Fp128` lanes. +/// +/// `lo = [lane0.lo, lane1.lo]` +/// `hi = [lane0.hi, lane1.hi]` +#[derive(Clone, Copy)] +pub struct PackedFp128Neon { + lo: [u64; 2], + hi: [u64; 2], +} + +#[inline(always)] +fn to_vec(x: [u64; 2]) -> uint64x2_t { + unsafe { transmute::<[u64; 2], uint64x2_t>(x) } +} + +#[inline(always)] +fn from_vec(v: uint64x2_t) -> [u64; 2] { + unsafe { transmute::(v) } +} + +#[inline(always)] +fn mask_to_bit(mask: uint64x2_t) -> uint64x2_t { + // SAFETY: NEON intrinsics are available under this cfg. + unsafe { vandq_u64(mask, vdupq_n_u64(1)) } +} + +#[inline(always)] +const fn modulus_lo() -> u64 { + P as u64 +} + +#[inline(always)] +const fn modulus_hi() -> u64 { + (P >> 64) as u64 +} + +use super::util::{is_pow2_u64, log2_pow2_u64}; + +impl PackedFp128Neon

{ + const C: u128 = { + let c = 0u128.wrapping_sub(P); + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!(c < (1u128 << 64), "P must be 2^128 - c with c < 2^64"); + assert!( + c * (c + 1) < P, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + const C_LO: u64 = Self::C as u64; + const C_SHIFT_KIND: i8 = { + let c = Self::C_LO; + if c > 1 && is_pow2_u64(c - 1) { + 1 + } else if c == u64::MAX || is_pow2_u64(c + 1) { + -1 + } else { + 0 + } + }; + const C_SHIFT: u32 = { + let c = Self::C_LO; + if Self::C_SHIFT_KIND == 1 { + log2_pow2_u64(c - 1) + } else if Self::C_SHIFT_KIND == -1 { + if c == u64::MAX { + 64 + } else { + log2_pow2_u64(c + 1) + } + } else { + 0 + } + }; + + #[inline(always)] + fn mul_wide_u64(a: u64, b: u64) -> (u64, u64) { + let prod = (a as u128) * (b as u128); + (prod as u64, (prod >> 64) as u64) + } + + #[inline(always)] + fn mul_c_wide(x: u64) -> (u64, u64) { + if Self::C_SHIFT_KIND == 1 { + let v = ((x as u128) << Self::C_SHIFT) + x as u128; + (v as u64, (v >> 64) as u64) + } else if Self::C_SHIFT_KIND == -1 { + let v = ((x as u128) << Self::C_SHIFT) - x as u128; + (v as u64, (v >> 64) as u64) + } else { + Self::mul_wide_u64(Self::C_LO, x) + } + } + + #[inline(always)] + fn fold2_canonicalize(t0: u64, t1: u64, t2: u64) -> (u64, u64) { + let (ct2_lo, ct2_hi) = Self::mul_c_wide(t2); + + let (s0, carry0) = t0.overflowing_add(ct2_lo); + let (s1a, carry1a) = t1.overflowing_add(ct2_hi); + let (s1, carry1b) = s1a.overflowing_add(carry0 as u64); + let overflow = carry1a | carry1b; + + let (r0, carry2) = s0.overflowing_add(Self::C_LO); + let (r1, carry3) = s1.overflowing_add(carry2 as u64); + + if overflow | carry3 { + (r0, r1) + } else { + (s0, s1) + } + } + + #[inline(always)] + fn mul_raw_lane(a0: u64, a1: u64, b0: u64, b1: u64) -> (u64, u64) { + let (p00_lo, p00_hi) = Self::mul_wide_u64(a0, b0); + let (p01_lo, p01_hi) = Self::mul_wide_u64(a0, b1); + let (p10_lo, p10_hi) = Self::mul_wide_u64(a1, b0); + let (p11_lo, p11_hi) = Self::mul_wide_u64(a1, b1); + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r0 = p00_lo; + let r1 = row1 as u64; + let carry1 = (row1 >> 64) as u64; + + let row2 = p01_hi as u128 + p10_hi as u128 + p11_lo as u128 + carry1 as u128; + let r2 = row2 as u64; + let carry2 = (row2 >> 64) as u64; + + let row3 = p11_hi as u128 + carry2 as u128; + let r3 = row3 as u64; + debug_assert_eq!(row3 >> 64, 0); + + let (cr2_lo, cr2_hi) = Self::mul_c_wide(r2); + let (cr3_lo, cr3_hi) = Self::mul_c_wide(r3); + + let t0_sum = r0 as u128 + cr2_lo as u128; + let t0 = t0_sum as u64; + let carryf = (t0_sum >> 64) as u64; + + let t1_sum = r1 as u128 + cr2_hi as u128 + cr3_lo as u128 + carryf as u128; + let t1 = t1_sum as u64; + + let t2_sum = cr3_hi as u128 + (t1_sum >> 64); + let t2 = t2_sum as u64; + debug_assert_eq!(t2_sum >> 64, 0); + + Self::fold2_canonicalize(t0, t1, t2) + } +} + +impl Default for PackedFp128Neon

{ + #[inline] + fn default() -> Self { + Self::broadcast(Fp128::zero()) + } +} + +impl fmt::Debug for PackedFp128Neon

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp128Neon") + .field(&[self.extract(0), self.extract(1)]) + .finish() + } +} + +impl PartialEq for PackedFp128Neon

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.extract(0) == other.extract(0) && self.extract(1) == other.extract(1) + } +} + +impl Eq for PackedFp128Neon

{} + +impl PackedValue for PackedFp128Neon

{ + type Value = Fp128

; + const WIDTH: usize = WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + let x0 = f(0); + let x1 = f(1); + Self { + lo: [x0.0[0], x1.0[0]], + hi: [x0.0[1], x1.0[1]], + } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < WIDTH); + Fp128([self.lo[lane], self.hi[lane]]) + } +} + +impl Add for PackedFp128Neon

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + let lo_a = to_vec(self.lo); + let hi_a = to_vec(self.hi); + let lo_b = to_vec(rhs.lo); + let hi_b = to_vec(rhs.hi); + + let (out_lo, out_hi) = unsafe { + let p_lo = vdupq_n_u64(modulus_lo::

()); + let p_hi = vdupq_n_u64(modulus_hi::

()); + + let sum_lo = vaddq_u64(lo_a, lo_b); + let carry_lo = mask_to_bit(vcltq_u64(sum_lo, lo_a)); + + let hi_tmp = vaddq_u64(hi_a, hi_b); + let carry_hi1 = vcltq_u64(hi_tmp, hi_a); + let sum_hi = vaddq_u64(hi_tmp, carry_lo); + let carry_hi2 = vcltq_u64(sum_hi, hi_tmp); + let carry_128 = vorrq_u64(carry_hi1, carry_hi2); + + let red_lo = vsubq_u64(sum_lo, p_lo); + let borrow_lo = mask_to_bit(vcgtq_u64(p_lo, sum_lo)); + + let red_hi_tmp = vsubq_u64(sum_hi, p_hi); + let borrow_hi1 = vcgtq_u64(p_hi, sum_hi); + let red_hi = vsubq_u64(red_hi_tmp, borrow_lo); + let borrow_hi2 = vcltq_u64(red_hi_tmp, borrow_lo); + let borrow = vorrq_u64(borrow_hi1, borrow_hi2); + + let not_borrow = veorq_u64(borrow, vdupq_n_u64(u64::MAX)); + let use_reduced = vorrq_u64(carry_128, not_borrow); + let out_lo = vbslq_u64(use_reduced, red_lo, sum_lo); + let out_hi = vbslq_u64(use_reduced, red_hi, sum_hi); + (out_lo, out_hi) + }; + + Self { + lo: from_vec(out_lo), + hi: from_vec(out_hi), + } + } +} + +impl Sub for PackedFp128Neon

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + let lo_a = to_vec(self.lo); + let hi_a = to_vec(self.hi); + let lo_b = to_vec(rhs.lo); + let hi_b = to_vec(rhs.hi); + + let (out_lo, out_hi) = unsafe { + let p_lo = vdupq_n_u64(modulus_lo::

()); + let p_hi = vdupq_n_u64(modulus_hi::

()); + + let diff_lo = vsubq_u64(lo_a, lo_b); + let borrow_lo = mask_to_bit(vcltq_u64(lo_a, lo_b)); + + let diff_hi_tmp = vsubq_u64(hi_a, hi_b); + let borrow_hi1 = vcltq_u64(hi_a, hi_b); + let diff_hi = vsubq_u64(diff_hi_tmp, borrow_lo); + let borrow_hi2 = vcltq_u64(diff_hi_tmp, borrow_lo); + let borrow_128 = vorrq_u64(borrow_hi1, borrow_hi2); + + let corr_lo = vaddq_u64(diff_lo, p_lo); + let carry_lo = mask_to_bit(vcltq_u64(corr_lo, diff_lo)); + + let corr_hi_tmp = vaddq_u64(diff_hi, p_hi); + let corr_hi = vaddq_u64(corr_hi_tmp, carry_lo); + + let out_lo = vbslq_u64(borrow_128, corr_lo, diff_lo); + let out_hi = vbslq_u64(borrow_128, corr_hi, diff_hi); + (out_lo, out_hi) + }; + + Self { + lo: from_vec(out_lo), + hi: from_vec(out_hi), + } + } +} + +impl Mul for PackedFp128Neon

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let (o0_lo, o0_hi) = Self::mul_raw_lane(self.lo[0], self.hi[0], rhs.lo[0], rhs.hi[0]); + let (o1_lo, o1_hi) = Self::mul_raw_lane(self.lo[1], self.hi[1], rhs.lo[1], rhs.hi[1]); + + Self { + lo: [o0_lo, o1_lo], + hi: [o0_hi, o1_hi], + } + } +} + +impl PackedField for PackedFp128Neon

{ + type Scalar = Fp128

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self::from_fn(|_| value) + } +} + +/// Number of packed `Fp32` lanes. +pub const FP32_WIDTH: usize = 4; + +/// NEON packed `Fp32` backend: 4 lanes in `uint32x4_t`. +#[derive(Clone, Copy)] +pub struct PackedFp32Neon { + vals: [u32; 4], +} + +#[inline(always)] +fn to_vec32(x: [u32; 4]) -> uint32x4_t { + unsafe { transmute::<[u32; 4], uint32x4_t>(x) } +} + +#[inline(always)] +fn from_vec32(v: uint32x4_t) -> [u32; 4] { + unsafe { transmute::(v) } +} + +impl PackedFp32Neon

{ + const BITS: u32 = 32 - P.leading_zeros(); + + const C: u32 = { + let c = if Self::BITS == 32 { + 0u32.wrapping_sub(P) + } else { + (1u32 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u64) * (c as u64 + 1) < P as u64, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + const MASK_U64: u64 = if Self::BITS == 32 { + u32::MAX as u64 + } else { + (1u64 << Self::BITS) - 1 + }; + + #[inline(always)] + fn mul_c_u64(hi: uint64x2_t, c: uint32x2_t) -> uint64x2_t { + unsafe { + let hi_narrow = vmovn_u64(hi); + vmull_u32(hi_narrow, c) + } + } + + #[inline(always)] + fn solinas_reduce(prod_lo: uint64x2_t, prod_hi: uint64x2_t) -> uint32x4_t { + unsafe { + let mask = vdupq_n_u64(Self::MASK_U64); + let neg_bits = vdupq_n_s64(-(Self::BITS as i64)); + let c = vdup_n_u32(Self::C); + + let f1_lo = vaddq_u64( + vandq_u64(prod_lo, mask), + Self::mul_c_u64(vshlq_u64(prod_lo, neg_bits), c), + ); + let f1_hi = vaddq_u64( + vandq_u64(prod_hi, mask), + Self::mul_c_u64(vshlq_u64(prod_hi, neg_bits), c), + ); + + let f2_lo = vaddq_u64( + vandq_u64(f1_lo, mask), + Self::mul_c_u64(vshlq_u64(f1_lo, neg_bits), c), + ); + let f2_hi = vaddq_u64( + vandq_u64(f1_hi, mask), + Self::mul_c_u64(vshlq_u64(f1_hi, neg_bits), c), + ); + + if Self::BITS < 32 { + let result = vcombine_u32(vmovn_u64(f2_lo), vmovn_u64(f2_hi)); + let p = vdupq_n_u32(P); + vminq_u32(result, vsubq_u32(result, p)) + } else { + let p_u64 = vdupq_n_u64(P as u64); + + let red_lo = vsubq_u64(f2_lo, p_u64); + let keep_lo = vcltq_u64(f2_lo, p_u64); + let out_lo = vbslq_u64(keep_lo, f2_lo, red_lo); + + let red_hi = vsubq_u64(f2_hi, p_u64); + let keep_hi = vcltq_u64(f2_hi, p_u64); + let out_hi = vbslq_u64(keep_hi, f2_hi, red_hi); + + vcombine_u32(vmovn_u64(out_lo), vmovn_u64(out_hi)) + } + } + } +} + +impl Default for PackedFp32Neon

{ + #[inline] + fn default() -> Self { + Self { vals: [0; 4] } + } +} + +impl fmt::Debug for PackedFp32Neon

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp32Neon").field(&self.vals).finish() + } +} + +impl PartialEq for PackedFp32Neon

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.vals == other.vals + } +} + +impl Eq for PackedFp32Neon

{} + +impl Add for PackedFp32Neon

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + let a = to_vec32(self.vals); + let b = to_vec32(rhs.vals); + let result = unsafe { + let p = vdupq_n_u32(P); + if Self::BITS <= 31 { + let t = vaddq_u32(a, b); + vminq_u32(t, vsubq_u32(t, p)) + } else { + let c = vdupq_n_u32(Self::C); + let t = vaddq_u32(a, b); + let overflow = vcltq_u32(t, a); + let t_plus_c = vaddq_u32(t, c); + let no_of = vminq_u32(t, vsubq_u32(t, p)); + vbslq_u32(overflow, t_plus_c, no_of) + } + }; + Self { + vals: from_vec32(result), + } + } +} + +impl Sub for PackedFp32Neon

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + let a = to_vec32(self.vals); + let b = to_vec32(rhs.vals); + let result = unsafe { + let p = vdupq_n_u32(P); + if Self::BITS <= 31 { + let t = vsubq_u32(a, b); + vminq_u32(t, vaddq_u32(t, p)) + } else { + let t = vsubq_u32(a, b); + let underflow = vcltq_u32(a, b); + vbslq_u32(underflow, vaddq_u32(t, p), t) + } + }; + Self { + vals: from_vec32(result), + } + } +} + +impl Mul for PackedFp32Neon

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let a = to_vec32(self.vals); + let b = to_vec32(rhs.vals); + let result = unsafe { + let prod_lo = vmull_u32(vget_low_u32(a), vget_low_u32(b)); + let prod_hi = vmull_high_u32(a, b); + Self::solinas_reduce(prod_lo, prod_hi) + }; + Self { + vals: from_vec32(result), + } + } +} + +impl PackedValue for PackedFp32Neon

{ + type Value = Fp32

; + const WIDTH: usize = FP32_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self { + vals: [f(0).0, f(1).0, f(2).0, f(3).0], + } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP32_WIDTH); + Fp32(self.vals[lane]) + } +} + +impl PackedField for PackedFp32Neon

{ + type Scalar = Fp32

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self { vals: [value.0; 4] } + } +} + +/// Number of packed `Fp64` lanes. +pub const FP64_WIDTH: usize = 2; + +/// NEON packed `Fp64` backend: 2 lanes in `uint64x2_t`. +#[derive(Clone, Copy)] +pub struct PackedFp64Neon { + vals: [u64; 2], +} + +impl PackedFp64Neon

{ + const BITS: u32 = 64 - P.leading_zeros(); + + const C_LO: u64 = { + let c = if Self::BITS == 64 { + 0u64.wrapping_sub(P) + } else { + (1u64 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + c + }; + + const MASK64: u64 = if Self::BITS < 64 { + (1u64 << Self::BITS) - 1 + } else { + u64::MAX + }; + + const MASK_U128: u128 = if Self::BITS == 64 { + u64::MAX as u128 + } else { + (1u128 << Self::BITS) - 1 + }; + + const FOLD_IN_U64: bool = + Self::BITS < 64 && (Self::C_LO as u128) < (1u128 << (64 - Self::BITS)); + + #[inline(always)] + fn mul_c_narrow(x: u64) -> u64 { + Self::C_LO.wrapping_mul(x) + } + + #[inline(always)] + fn reduce_product(x: u128) -> u64 { + if Self::FOLD_IN_U64 { + let lo = x as u64; + let hi = (x >> 64) as u64; + let high = (lo >> Self::BITS) | (hi << (64 - Self::BITS)); + let f1 = (lo & Self::MASK64).wrapping_add(Self::mul_c_narrow(high)); + let f2 = (f1 & Self::MASK64).wrapping_add(Self::mul_c_narrow(f1 >> Self::BITS)); + let reduced = f2.wrapping_sub(P); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P) + } else { + let f1 = + (x & Self::MASK_U128) + (Self::C_LO as u128) * ((x >> Self::BITS) as u64 as u128); + let f2 = + (f1 & Self::MASK_U128) + (Self::C_LO as u128) * ((f1 >> Self::BITS) as u64 as u128); + let reduced = f2.wrapping_sub(P as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + } +} + +impl Default for PackedFp64Neon

{ + #[inline] + fn default() -> Self { + Self { vals: [0; 2] } + } +} + +impl fmt::Debug for PackedFp64Neon

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp64Neon").field(&self.vals).finish() + } +} + +impl PartialEq for PackedFp64Neon

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.vals == other.vals + } +} + +impl Eq for PackedFp64Neon

{} + +impl Add for PackedFp64Neon

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + let a = to_vec(self.vals); + let b = to_vec(rhs.vals); + let result = unsafe { + let p = vdupq_n_u64(P); + if Self::BITS <= 62 { + let s = vaddq_u64(a, b); + let r = vsubq_u64(s, p); + let borrow = vcltq_u64(s, p); + vbslq_u64(borrow, s, r) + } else { + let s = vaddq_u64(a, b); + let overflow = vcltq_u64(s, a); + let c = vdupq_n_u64(Self::C_LO); + let s_plus_c = vaddq_u64(s, c); + let s_minus_p = vsubq_u64(s, p); + let borrow = vcltq_u64(s, p); + let no_of = vbslq_u64(borrow, s, s_minus_p); + vbslq_u64(overflow, s_plus_c, no_of) + } + }; + Self { + vals: from_vec(result), + } + } +} + +impl Sub for PackedFp64Neon

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + let a = to_vec(self.vals); + let b = to_vec(rhs.vals); + let result = unsafe { + let p = vdupq_n_u64(P); + let d = vsubq_u64(a, b); + let underflow = vcltq_u64(a, b); + vbslq_u64(underflow, vaddq_u64(d, p), d) + }; + Self { + vals: from_vec(result), + } + } +} + +impl Mul for PackedFp64Neon

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let x0 = (self.vals[0] as u128) * (rhs.vals[0] as u128); + let x1 = (self.vals[1] as u128) * (rhs.vals[1] as u128); + let r0 = Self::reduce_product(x0); + let r1 = Self::reduce_product(x1); + Self { vals: [r0, r1] } + } +} + +impl PackedValue for PackedFp64Neon

{ + type Value = Fp64

; + const WIDTH: usize = FP64_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self { + vals: [f(0).0, f(1).0], + } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP64_WIDTH); + Fp64(self.vals[lane]) + } +} + +impl PackedField for PackedFp64Neon

{ + type Scalar = Fp64

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self { vals: [value.0; 2] } + } +} diff --git a/src/algebra/fields/pseudo_mersenne.rs b/src/algebra/fields/pseudo_mersenne.rs new file mode 100644 index 00000000..17920899 --- /dev/null +++ b/src/algebra/fields/pseudo_mersenne.rs @@ -0,0 +1,177 @@ +//! `2^k - offset` pseudo-Mersenne registry and field aliases. +//! +//! This module models the specific flavor where each modulus is the smallest +//! prime below `2^k` with `q % 8 == 5`, written as `q = 2^k - offset`. + +use super::{Fp128, Fp32, Fp64}; + +/// Offset table (`q = 2^k - offset[k]`) imported from `labrador/data.py`. +pub const POW2_OFFSET_TABLE: [i16; 256] = [ + -1, -1, -1, 3, 3, 3, 3, 19, 27, 3, 3, 19, 3, 75, 3, 19, 99, 91, 11, 19, 3, 19, 3, 27, 3, 91, + 27, 115, 299, 3, 35, 19, 99, 355, 131, 451, 243, 123, 107, 19, 195, 75, 11, 67, 539, 139, 635, + 115, 59, 123, 27, 139, 395, 315, 131, 67, 27, 195, 27, 99, 107, 259, 171, 259, 59, 115, 203, + 19, 83, 19, 35, 411, 107, 475, 35, 427, 123, 43, 11, 67, 1307, 51, 315, 139, 35, 19, 35, 67, + 299, 99, 75, 315, 83, 51, 3, 211, 147, 595, 51, 115, 99, 99, 483, 339, 395, 139, 1187, 171, 59, + 91, 195, 835, 75, 211, 11, 67, 3, 451, 563, 867, 395, 531, 3, 67, 59, 579, 203, 507, 275, 315, + 27, 315, 347, 99, 603, 795, 243, 339, 203, 187, 27, 171, 1491, 355, 83, 355, 1371, 387, 347, + 99, 3, 195, 539, 171, 243, 499, 195, 19, 155, 91, 75, 1011, 627, 867, 155, 115, 1811, 771, + 1467, 643, 195, 19, 155, 531, 3, 267, 563, 339, 563, 507, 107, 283, 267, 147, 59, 339, 371, + 1411, 363, 819, 11, 19, 915, 123, 75, 915, 459, 75, 627, 459, 75, 1035, 195, 187, 1515, 1219, + 1443, 91, 299, 451, 171, 1099, 99, 3, 395, 1147, 683, 675, 243, 355, 395, 3, 875, 235, 363, + 1131, 155, 835, 723, 91, 27, 235, 875, 3, 83, 259, 875, 1515, 731, 531, 467, 819, 267, 475, + 1923, 163, 107, 411, 387, 75, 2331, 355, 1515, 1723, 1427, 19, +]; + +/// Maximum supported offset in this `2^k - offset` specialization. +pub const POW2_OFFSET_MAX: u128 = 1u128 << 16; + +/// Current active bit-size bound for concrete field aliases in this phase. +pub const POW2_OFFSET_IMPLEMENTED_MAX_BITS: u32 = 128; + +/// Metadata describing a `2^k - offset` pseudo-Mersenne modulus. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Pow2OffsetPrimeSpec { + /// `k` in `2^k - offset`. + pub bits: u32, + /// `offset` in `2^k - offset`. + pub offset: u16, + /// Modulus value. + pub modulus: u128, +} + +/// Return table offset for `q = 2^k - offset` when available and positive. +pub const fn pow2_offset(bits: u32) -> Option { + if bits as usize >= POW2_OFFSET_TABLE.len() { + return None; + } + let offset = POW2_OFFSET_TABLE[bits as usize]; + if offset <= 0 { + None + } else { + Some(offset as u16) + } +} + +/// Compute `2^k - offset` for `k <= 128`. +pub const fn pseudo_mersenne_modulus(bits: u32, offset: u128) -> Option { + if bits == 0 || bits > 128 || offset == 0 { + return None; + } + if bits == 128 { + Some(u128::MAX - (offset - 1)) + } else { + Some((1u128 << bits) - offset) + } +} + +/// Check whether `(k, offset)` is accepted by the `2^k - offset` policy. +pub const fn is_pow2_offset(bits: u32, offset: u128) -> bool { + if bits > POW2_OFFSET_IMPLEMENTED_MAX_BITS || offset > POW2_OFFSET_MAX { + return false; + } + match pow2_offset(bits) { + Some(qoff) => (qoff as u128) == offset, + None => false, + } +} + +/// `offset` for `k = 24`. +pub const POW2_OFFSET_24: u16 = 3; +/// `offset` for `k = 30`. +pub const POW2_OFFSET_30: u16 = 35; +/// `offset` for `k = 31`. +pub const POW2_OFFSET_31: u16 = 19; +/// `offset` for `k = 32`. +pub const POW2_OFFSET_32: u16 = 99; +/// `offset` for `k = 40`. +pub const POW2_OFFSET_40: u16 = 195; +/// `offset` for `k = 48`. +pub const POW2_OFFSET_48: u16 = 59; +/// `offset` for `k = 56`. +pub const POW2_OFFSET_56: u16 = 27; +/// `offset` for `k = 64`. +pub const POW2_OFFSET_64: u16 = 59; +/// `offset` for `k = 128`. +pub const POW2_OFFSET_128: u16 = 275; + +/// `2^24 - 3`. +pub const POW2_OFFSET_MODULUS_24: u32 = ((1u128 << 24) - (POW2_OFFSET_24 as u128)) as u32; +/// `2^30 - 35`. +pub const POW2_OFFSET_MODULUS_30: u32 = ((1u128 << 30) - (POW2_OFFSET_30 as u128)) as u32; +/// `2^31 - 19`. +pub const POW2_OFFSET_MODULUS_31: u32 = ((1u128 << 31) - (POW2_OFFSET_31 as u128)) as u32; +/// `2^32 - 99`. +pub const POW2_OFFSET_MODULUS_32: u32 = ((1u128 << 32) - (POW2_OFFSET_32 as u128)) as u32; +/// `2^40 - 195`. +pub const POW2_OFFSET_MODULUS_40: u64 = ((1u128 << 40) - (POW2_OFFSET_40 as u128)) as u64; +/// `2^48 - 59`. +pub const POW2_OFFSET_MODULUS_48: u64 = ((1u128 << 48) - (POW2_OFFSET_48 as u128)) as u64; +/// `2^56 - 27`. +pub const POW2_OFFSET_MODULUS_56: u64 = ((1u128 << 56) - (POW2_OFFSET_56 as u128)) as u64; +/// `2^64 - 59`. +pub const POW2_OFFSET_MODULUS_64: u64 = u64::MAX - ((POW2_OFFSET_64 as u64) - 1); +/// `2^128 - 275`. +pub const POW2_OFFSET_MODULUS_128: u128 = u128::MAX - (POW2_OFFSET_128 as u128 - 1); + +/// Alias for `2^24 - offset`. +pub type Pow2Offset24Field = Fp32; +/// Alias for `2^30 - offset`. +pub type Pow2Offset30Field = Fp32; +/// Alias for `2^31 - offset`. +pub type Pow2Offset31Field = Fp32; +/// Alias for `2^32 - offset`. +pub type Pow2Offset32Field = Fp32; +/// Alias for `2^40 - offset`. +pub type Pow2Offset40Field = Fp64; +/// Alias for `2^48 - offset`. +pub type Pow2Offset48Field = Fp64; +/// Alias for `2^56 - offset`. +pub type Pow2Offset56Field = Fp64; +/// Alias for `2^64 - offset`. +pub type Pow2Offset64Field = Fp64; +/// Alias for `2^128 - offset`. +pub type Pow2Offset128Field = Fp128; + +/// `2^k - offset` profiles currently enabled in-code. +/// +/// Each listed modulus satisfies `q % 8 == 5`. +pub const POW2_OFFSET_PRIMES: [Pow2OffsetPrimeSpec; 7] = [ + Pow2OffsetPrimeSpec { + bits: 24, + offset: POW2_OFFSET_24, + modulus: POW2_OFFSET_MODULUS_24 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 32, + offset: POW2_OFFSET_32, + modulus: POW2_OFFSET_MODULUS_32 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 40, + offset: POW2_OFFSET_40, + modulus: POW2_OFFSET_MODULUS_40 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 48, + offset: POW2_OFFSET_48, + modulus: POW2_OFFSET_MODULUS_48 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 56, + offset: POW2_OFFSET_56, + modulus: POW2_OFFSET_MODULUS_56 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 64, + offset: POW2_OFFSET_64, + modulus: POW2_OFFSET_MODULUS_64 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 128, + offset: POW2_OFFSET_128, + modulus: POW2_OFFSET_MODULUS_128, + }, +]; + +// All PseudoMersenneField impls for Fp32/Fp64/Fp128 are blanket impls in +// their respective modules (fp32.rs, fp64.rs, fp128.rs). diff --git a/src/algebra/fields/util.rs b/src/algebra/fields/util.rs new file mode 100644 index 00000000..9496f1ce --- /dev/null +++ b/src/algebra/fields/util.rs @@ -0,0 +1,38 @@ +//! Shared helpers for field arithmetic backends. + +#[inline(always)] +pub(crate) const fn is_pow2_u64(x: u64) -> bool { + x != 0 && (x & (x - 1)) == 0 +} + +#[inline(always)] +pub(crate) const fn log2_pow2_u64(mut x: u64) -> u32 { + let mut k = 0u32; + while x > 1 { + x >>= 1; + k += 1; + } + k +} + +/// `a * b` widening to 128 bits; returns `(lo64, hi64)`. +#[inline(always)] +pub(crate) fn mul64_wide(a: u64, b: u64) -> (u64, u64) { + #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] + { + unsafe { mul64_wide_bmi2(a, b) } + } + #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))] + { + let prod = (a as u128) * (b as u128); + (prod as u64, (prod >> 64) as u64) + } +} + +#[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] +#[inline(always)] +unsafe fn mul64_wide_bmi2(a: u64, b: u64) -> (u64, u64) { + let mut hi = 0; + let lo = unsafe { std::arch::x86_64::_mulx_u64(a, b, &mut hi) }; + (lo, hi) +} diff --git a/src/algebra/mod.rs b/src/algebra/mod.rs new file mode 100644 index 00000000..21aea776 --- /dev/null +++ b/src/algebra/mod.rs @@ -0,0 +1,32 @@ +//! Concrete algebra backends and arithmetic building blocks. +//! +//! This module includes: +//! - Generic prime fields and extensions (`fields`) +//! - Module and polynomial containers (`module`, `poly`) +//! - Low-level NTT and CRT+NTT arithmetic scaffolding (`ntt`) + +pub mod backend; +pub mod fields; +pub mod module; +pub mod ntt; +pub mod poly; +pub mod ring; + +// Flat re-exports for convenience. +pub use backend::{CrtReconstruct, NttPrimeOps, NttTransform, RingBackend, ScalarBackend}; +pub use fields::{ + is_pow2_offset, pow2_offset, pseudo_mersenne_modulus, ExtField, Fp128, Fp128Packing, Fp2, + Fp2Config, Fp32, Fp32Packing, Fp4, Fp4Config, Fp64, Fp64Packing, HasPacking, LiftBase, + NoPacking, PackedField, PackedValue, Pow2Offset128Field, Pow2Offset24Field, Pow2Offset30Field, + Pow2Offset31Field, Pow2Offset32Field, Pow2Offset40Field, Pow2Offset48Field, Pow2Offset56Field, + Pow2Offset64Field, Pow2OffsetPrimeSpec, Prime128M13M4P0, Prime128M37P3P0, Prime128M52M3P0, + Prime128M54P4P0, Prime128M8M4M1M0, POW2_OFFSET_IMPLEMENTED_MAX_BITS, POW2_OFFSET_MAX, + POW2_OFFSET_PRIMES, POW2_OFFSET_TABLE, +}; +pub use module::VectorModule; +pub use ntt::tables; +pub use ntt::{GarnerData, LimbQ, MontCoeff, NttPrime, PrimeWidth, RADIX_BITS}; +pub use ring::{ + CrtNttConvertibleField, CrtNttParamSet, CyclotomicCrtNtt, CyclotomicRing, SparseChallenge, + SparseChallengeConfig, +}; diff --git a/src/algebra/module.rs b/src/algebra/module.rs new file mode 100644 index 00000000..b924a0ae --- /dev/null +++ b/src/algebra/module.rs @@ -0,0 +1,194 @@ +//! Simple module implementations. + +use super::fields::{Fp128, Fp32, Fp64}; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{CanonicalField, FieldCore, FieldSampling, Module}; +use rand_core::RngCore; +use std::io::{Read, Write}; +use std::ops::{Add, Mul, Neg, Sub}; + +/// Fixed-length vector module over a scalar type `F`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct VectorModule(pub [F; N]); + +impl VectorModule { + /// Construct the zero vector. + #[inline] + pub fn zero_vec() -> Self { + Self([F::zero(); N]) + } +} + +impl Add for VectorModule { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + let mut out = self.0; + for (dst, src) in out.iter_mut().zip(rhs.0.iter()) { + *dst = *dst + *src; + } + Self(out) + } +} + +impl Sub for VectorModule { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + let mut out = self.0; + for (dst, src) in out.iter_mut().zip(rhs.0.iter()) { + *dst = *dst - *src; + } + Self(out) + } +} + +impl Neg for VectorModule { + type Output = Self; + fn neg(self) -> Self::Output { + let mut out = self.0; + for coeff in &mut out { + *coeff = -*coeff; + } + Self(out) + } +} + +impl<'a, F: FieldCore, const N: usize> Add<&'a Self> for VectorModule { + type Output = Self; + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} + +impl<'a, F: FieldCore, const N: usize> Sub<&'a Self> for VectorModule { + type Output = Self; + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} + +impl Valid for VectorModule { + fn check(&self) -> Result<(), SerializationError> { + for x in self.0.iter() { + x.check()?; + } + Ok(()) + } +} + +impl HachiSerialize for VectorModule { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + for x in self.0.iter() { + x.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.0.iter().map(|x| x.serialized_size(compress)).sum() + } +} + +impl HachiDeserialize for VectorModule { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let mut arr = [F::zero(); N]; + for coeff in &mut arr { + *coeff = F::deserialize_with_mode(&mut reader, compress, validate)?; + } + let out = Self(arr); + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl Module for VectorModule +where + F: FieldCore + + CanonicalField + + FieldSampling + + Valid + + Mul, Output = VectorModule> + + for<'a> Mul<&'a VectorModule, Output = VectorModule>, +{ + type Scalar = F; + + fn zero() -> Self { + Self::zero_vec() + } + + fn scale(&self, k: &Self::Scalar) -> Self { + // Delegate to Scalar * VectorModule to satisfy the Module trait’s scalar bounds. + *k * *self + } + + fn random(rng: &mut R) -> Self { + Self(std::array::from_fn(|_| F::sample(rng))) + } +} + +// Scalar * VectorModule impls for our local prime field types. + +impl Mul, N>> for Fp32

{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, const P: u32, const N: usize> Mul<&'a VectorModule, N>> for Fp32

{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} + +impl Mul, N>> for Fp64

{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, const P: u64, const N: usize> Mul<&'a VectorModule, N>> for Fp64

{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} + +impl Mul, N>> for Fp128

{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, const P: u128, const N: usize> Mul<&'a VectorModule, N>> for Fp128

{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} diff --git a/src/algebra/ntt/butterfly.rs b/src/algebra/ntt/butterfly.rs new file mode 100644 index 00000000..7380fdf1 --- /dev/null +++ b/src/algebra/ntt/butterfly.rs @@ -0,0 +1,205 @@ +//! NTT butterfly transforms for negacyclic rings `Z_p[X]/(X^D + 1)`. +//! +//! Implements a negacyclic NTT via the standard **twist + cyclic NTT** method. +//! +//! Let `psi` be a primitive `2D`-th root of unity (`psi^D = -1 mod p`) and +//! `omega = psi^2`, a primitive `D`-th root of unity. For polynomials modulo +//! `X^D + 1`, we: +//! - pre-twist coefficients by `psi^i` +//! - run a cyclic size-`D` NTT using `omega` +//! - inverse-cyclic NTT using `omega^{-1}` +//! - post-untwist by `psi^{-i}` + +use super::prime::{MontCoeff, NttPrime, PrimeWidth}; + +/// Precomputed twiddle factors for a specific prime and degree `D`. +/// +/// `D` must be a power of two. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NttTwiddles { + /// Stage roots for iterative forward cyclic NTT in Montgomery form. + pub(crate) fwd_wlen: [MontCoeff; D], + /// Stage roots for iterative inverse cyclic NTT in Montgomery form. + pub(crate) inv_wlen: [MontCoeff; D], + /// Number of active stages in the twiddle arrays (`log2(D)`). + pub(crate) num_stages: usize, + /// Twist factors `psi^i` for negacyclic embedding, in Montgomery form. + pub(crate) psi_pows: [MontCoeff; D], + /// Untwist factors `psi^{-i}`, in Montgomery form. + pub(crate) psi_inv_pows: [MontCoeff; D], + /// `D^{-1} mod p` in Montgomery form, used for inverse NTT final scaling. + pub(crate) d_inv: MontCoeff, +} + +impl NttTwiddles { + /// Compute twiddle factors for the given prime. + /// + /// Finds a primitive `2D`-th root `psi` and derives `omega = psi^2`. + /// Fills cyclic forward/inverse twiddles for `omega` and twist/untwist + /// tables for `psi`. All values are stored in Montgomery form. + /// + /// # Panics + /// + /// Panics if `D` is not a power of two, or if `2D` does not divide `p - 1`. + pub fn compute(prime: NttPrime) -> Self { + assert!(D.is_power_of_two(), "D must be a power of two"); + let p = prime.p.to_i64(); + assert!( + (p - 1) % (2 * D as i64) == 0, + "2D must divide p - 1 for NTT roots to exist" + ); + + let psi = find_primitive_root_2d(p, D); + let omega = (psi * psi) % p; + let omega_inv = pow_mod(omega, p - 2, p); + + let psi_inv = pow_mod(psi, p - 2, p); + let mut psi_pows = [MontCoeff::from_raw(W::default()); D]; + let mut psi_inv_pows = [MontCoeff::from_raw(W::default()); D]; + let mut cur = 1i64; + let mut cur_inv = 1i64; + for i in 0..D { + psi_pows[i] = prime.from_canonical(W::from_i64(cur)); + psi_inv_pows[i] = prime.from_canonical(W::from_i64(cur_inv)); + cur = (cur * psi) % p; + cur_inv = (cur_inv * psi_inv) % p; + } + + let mut fwd_wlen = [MontCoeff::from_raw(W::default()); D]; + let mut inv_wlen = [MontCoeff::from_raw(W::default()); D]; + let mut len = 1usize; + let mut stage = 0usize; + while len < D { + let exp = (D / (2 * len)) as i64; + fwd_wlen[stage] = prime.from_canonical(W::from_i64(pow_mod(omega, exp, p))); + inv_wlen[stage] = prime.from_canonical(W::from_i64(pow_mod(omega_inv, exp, p))); + len *= 2; + stage += 1; + } + + let d_inv_canonical = pow_mod(D as i64, p - 2, p); + let d_inv = prime.from_canonical(W::from_i64(d_inv_canonical)); + + Self { + fwd_wlen, + inv_wlen, + num_stages: stage, + psi_pows, + psi_inv_pows, + d_inv, + } + } +} + +/// Forward negacyclic NTT (twist + cyclic Gentleman-Sande DIF). +/// +/// Transforms `D` coefficients in-place from coefficient form to NTT +/// evaluation form. Both outputs of each butterfly are range-reduced +/// to prevent overflow. +pub fn forward_ntt( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + for (ai, psi) in a.iter_mut().zip(tw.psi_pows.iter()) { + *ai = prime.mul(*ai, *psi); + } + + let one = prime.from_canonical(W::from_i64(1)); + + let mut len = D / 2; + let mut stage = tw.num_stages; + while len > 0 { + stage -= 1; + let wlen = tw.fwd_wlen[stage]; + let mut start = 0usize; + while start < D { + let mut w = one; + for j in 0..len { + let u = a[start + j]; + let v = a[start + j + len]; + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.mul(MontCoeff::from_raw(diff), w); + w = prime.mul(w, wlen); + } + start += 2 * len; + } + len /= 2; + } + + // Keep exported NTT-domain coefficients in the same reduced range expected + // by add/sub reduced operations and equality checks. + prime.reduce_range_in_place(a); +} + +/// Inverse negacyclic NTT (cyclic Cooley-Tukey DIT + untwist). +/// +/// Transforms `D` evaluations in-place back to coefficient form. +/// Includes the final `D^{-1}` scaling. +pub fn inverse_ntt( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let one = prime.from_canonical(W::from_i64(1)); + + let mut len = 1usize; + let mut stage = 0usize; + while len < D { + let wlen = tw.inv_wlen[stage]; + let mut start = 0usize; + while start < D { + let mut w = one; + for j in 0..len { + let u = a[start + j]; + let v = prime.mul(a[start + j + len], w); + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.reduce_range(MontCoeff::from_raw(diff)); + w = prime.mul(w, wlen); + } + start += 2 * len; + } + len *= 2; + stage += 1; + } + + for c in a.iter_mut() { + *c = prime.mul(*c, tw.d_inv); + } + + for (ai, psi_inv) in a.iter_mut().zip(tw.psi_inv_pows.iter()) { + *ai = prime.mul(*ai, *psi_inv); + } +} + +/// Find a primitive `2D`-th root of unity mod `p`. +fn find_primitive_root_2d(p: i64, d: usize) -> i64 { + let half = (p - 1) / 2; + let exp = (p - 1) / (2 * d as i64); + for a in 2..p { + if pow_mod(a, half, p) == p - 1 { + let psi = pow_mod(a, exp, p); + debug_assert_eq!(pow_mod(psi, d as i64, p), p - 1, "psi^D != -1"); + return psi; + } + } + panic!("no primitive root found for p={p}"); +} + +/// Modular exponentiation: `base^exp mod modulus`. +fn pow_mod(mut base: i64, mut exp: i64, modulus: i64) -> i64 { + let mut result = 1i64; + base %= modulus; + while exp > 0 { + if exp & 1 == 1 { + result = result * base % modulus; + } + base = base * base % modulus; + exp >>= 1; + } + result +} diff --git a/src/algebra/ntt/crt.rs b/src/algebra/ntt/crt.rs new file mode 100644 index 00000000..e0897937 --- /dev/null +++ b/src/algebra/ntt/crt.rs @@ -0,0 +1,202 @@ +//! CRT helpers: Garner reconstruction and limb-based modular arithmetic. + +use std::cmp::Ordering; +use std::fmt; +use std::ops::{Add, Sub}; + +use super::prime::{NttPrime, PrimeWidth}; + +/// Limb radix bit-width (`2^14`). +pub const RADIX_BITS: u32 = 14; +const RADIX: i32 = 1 << RADIX_BITS; +const RADIX_MASK: i32 = RADIX - 1; + +/// Precomputed Garner inverse table for CRT reconstruction. +/// +/// `gamma[i][j]` = `p_j^{-1} mod p_i` for `j < i`. Upper triangle and +/// diagonal entries are zero (unused). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GarnerData { + /// `gamma[i][j]` = `p_j^{-1} mod p_i` for `j < i`. + pub gamma: [[W; K]; K], +} + +impl GarnerData { + /// Compute Garner constants from a set of NTT primes. + pub fn compute(primes: &[NttPrime; K]) -> Self { + let mut gamma = [[W::default(); K]; K]; + for i in 1..K { + let pi = primes[i].p.to_i64(); + #[allow(clippy::needless_range_loop)] + for j in 0..i { + let pj = primes[j].p.to_i64(); + let inv = mod_inverse_i64(pj, pi); + gamma[i][j] = W::from_i64(inv); + } + } + Self { gamma } + } +} + +/// Modular inverse via extended GCD, operating in `i64`. +fn mod_inverse_i64(a: i64, modulus: i64) -> i64 { + let (mut t, mut new_t) = (0i64, 1i64); + let (mut r, mut new_r) = (modulus, ((a % modulus) + modulus) % modulus); + while new_r != 0 { + let q = r / new_r; + (t, new_t) = (new_t, t - q * new_t); + (r, new_r) = (new_r, r - q * new_r); + } + assert_eq!(r, 1, "modular inverse does not exist"); + ((t % modulus) + modulus) % modulus +} + +/// Fixed-width radix-`2^14` integer. +/// +/// Limbs are little-endian: `limbs[0]` is least significant. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct LimbQ { + /// Little-endian limbs. + pub limbs: [u16; L], +} + +impl Default for LimbQ { + #[inline] + fn default() -> Self { + Self::zero() + } +} + +impl LimbQ { + /// Zero value. + #[inline] + pub const fn zero() -> Self { + Self { limbs: [0; L] } + } + + /// Construct directly from limbs. + #[inline] + pub const fn from_limbs(limbs: [u16; L]) -> Self { + Self { limbs } + } + + /// Conditional subtraction: if `self >= modulus`, return `self - modulus` (branchless). + #[inline] + pub fn csub_mod(self, modulus: Self) -> Self { + let mut diff = [0u16; L]; + let mut borrow = 0i32; + for (i, df) in diff.iter_mut().enumerate() { + let d = self.limbs[i] as i32 - modulus.limbs[i] as i32 + borrow; + borrow = d >> 31; + if i + 1 < L { + *df = (d - borrow * RADIX) as u16; + } else { + *df = d as u16; + } + } + let mask = borrow as u16; + let mut result = [0u16; L]; + for (i, r) in result.iter_mut().enumerate() { + *r = (self.limbs[i] & mask) | (diff[i] & !mask); + } + Self { limbs: result } + } +} + +impl From for LimbQ { + fn from(mut x: u128) -> Self { + let mut out = [0u16; L]; + for (i, limb) in out.iter_mut().enumerate() { + if i + 1 < L { + *limb = (x & (RADIX_MASK as u128)) as u16; + x >>= RADIX_BITS; + } else { + *limb = x as u16; + } + } + Self { limbs: out } + } +} + +impl TryFrom> for u128 { + type Error = &'static str; + + fn try_from(limb: LimbQ) -> Result { + if (L as u32) * RADIX_BITS > 128 { + return Err("LimbQ too wide for u128"); + } + let mut acc = 0u128; + for i in (0..L).rev() { + acc <<= RADIX_BITS; + acc |= limb.limbs[i] as u128; + } + Ok(acc) + } +} + +impl PartialOrd for LimbQ { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for LimbQ { + fn cmp(&self, other: &Self) -> Ordering { + for i in (0..L).rev() { + match self.limbs[i].cmp(&other.limbs[i]) { + Ordering::Equal => continue, + ord => return ord, + } + } + Ordering::Equal + } +} + +impl Add for LimbQ { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + let mut out = [0u16; L]; + let mut carry = 0i32; + for (i, out_limb) in out.iter_mut().enumerate() { + let s = self.limbs[i] as i32 + rhs.limbs[i] as i32 + carry; + if i + 1 < L { + carry = s >> RADIX_BITS; + *out_limb = (s & RADIX_MASK) as u16; + } else { + *out_limb = s as u16; + } + } + Self { limbs: out } + } +} + +impl Sub for LimbQ { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + let mut out = [0u16; L]; + let mut borrow = 0i32; + for (i, out_limb) in out.iter_mut().enumerate() { + let d = self.limbs[i] as i32 - rhs.limbs[i] as i32 + borrow; + if i + 1 < L { + borrow = d >> 31; + *out_limb = (d - borrow * RADIX) as u16; + } else { + *out_limb = d as u16; + } + } + Self { limbs: out } + } +} + +impl fmt::Display for LimbQ { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Ok(val) = u128::try_from(*self) { + write!(f, "{val}") + } else { + write!(f, "LimbQ{:?}", self.limbs) + } + } +} diff --git a/src/algebra/ntt/mod.rs b/src/algebra/ntt/mod.rs new file mode 100644 index 00000000..f2536363 --- /dev/null +++ b/src/algebra/ntt/mod.rs @@ -0,0 +1,10 @@ +//! NTT-friendly small-prime arithmetic and CRT helpers. + +pub mod butterfly; +pub mod crt; +pub mod prime; +pub mod tables; + +pub use butterfly::NttTwiddles; +pub use crt::{GarnerData, LimbQ, RADIX_BITS}; +pub use prime::{MontCoeff, NttPrime, PrimeWidth}; diff --git a/src/algebra/ntt/prime.rs b/src/algebra/ntt/prime.rs new file mode 100644 index 00000000..a5646234 --- /dev/null +++ b/src/algebra/ntt/prime.rs @@ -0,0 +1,382 @@ +//! NTT prime arithmetic kernels generic over coefficient width. +//! +//! Per-prime scalar operations: +//! - Montgomery multiplication ([`NttPrime::mul`]) +//! - Branchless conditional add/sub and range reduction +//! +//! Coefficients in Montgomery domain are wrapped in [`MontCoeff`] to prevent +//! accidental mixing with canonical values. +//! +//! The [`PrimeWidth`] trait abstracts over `i16` (R = 2^16, for primes < 2^14) +//! and `i32` (R = 2^32, for primes < 2^30). All NTT types are generic over +//! `W: PrimeWidth`; monomorphization produces optimal code for each width. + +use std::fmt; + +mod sealed { + pub trait Sealed {} + impl Sealed for i16 {} + impl Sealed for i32 {} +} + +/// Integer width abstraction for NTT prime arithmetic. +/// +/// Sealed with exactly two implementations: `i16` and `i32`. +pub trait PrimeWidth: + sealed::Sealed + Copy + Clone + Eq + Default + fmt::Debug + Send + Sync + 'static +{ + /// Double-width type for intermediate Montgomery products. + type Wide: Copy + Clone; + + /// log2(R) for Montgomery reduction: 16 for `i16`, 32 for `i32`. + const R_LOG: u32; + + /// Widening multiply: `a * b` as `Wide`. + fn wide_mul(a: Self, b: Self) -> Self::Wide; + + /// Truncate wide value to narrow (low half, i.e. mod R). + fn truncate(w: Self::Wide) -> Self; + + /// Arithmetic right shift of wide value by `R_LOG` bits. + fn wide_shift(w: Self::Wide) -> Self; + + /// Wide subtraction (wrapping). + fn wide_sub(a: Self::Wide, b: Self::Wide) -> Self::Wide; + + /// Wrapping addition. + fn wrapping_add(self, rhs: Self) -> Self; + /// Wrapping subtraction. + fn wrapping_sub(self, rhs: Self) -> Self; + /// Wrapping multiplication. + fn wrapping_mul(self, rhs: Self) -> Self; + /// Wrapping negation. + fn wrapping_neg(self) -> Self; + + /// Arithmetic right shift by `BITS - 1`: all-1s if negative, all-0s otherwise. + fn sign_mask(self) -> Self; + + /// Bitwise AND. + fn bitand(self, rhs: Self) -> Self; + + /// Convert from `i64` (truncating). + fn from_i64(v: i64) -> Self; + /// Convert to `i64` (sign-extending). + fn to_i64(self) -> i64; +} + +impl PrimeWidth for i16 { + type Wide = i32; + const R_LOG: u32 = 16; + + #[inline] + fn wide_mul(a: Self, b: Self) -> i32 { + (a as i32) * (b as i32) + } + #[inline] + fn truncate(w: i32) -> Self { + w as i16 + } + #[inline] + fn wide_shift(w: i32) -> Self { + (w >> 16) as i16 + } + #[inline] + fn wide_sub(a: i32, b: i32) -> i32 { + a.wrapping_sub(b) + } + #[inline] + fn wrapping_add(self, rhs: Self) -> Self { + i16::wrapping_add(self, rhs) + } + #[inline] + fn wrapping_sub(self, rhs: Self) -> Self { + i16::wrapping_sub(self, rhs) + } + #[inline] + fn wrapping_mul(self, rhs: Self) -> Self { + i16::wrapping_mul(self, rhs) + } + #[inline] + fn wrapping_neg(self) -> Self { + i16::wrapping_neg(self) + } + #[inline] + fn sign_mask(self) -> Self { + self >> 15 + } + #[inline] + fn bitand(self, rhs: Self) -> Self { + self & rhs + } + #[inline] + fn from_i64(v: i64) -> Self { + v as i16 + } + #[inline] + fn to_i64(self) -> i64 { + self as i64 + } +} + +impl PrimeWidth for i32 { + type Wide = i64; + const R_LOG: u32 = 32; + + #[inline] + fn wide_mul(a: Self, b: Self) -> i64 { + (a as i64) * (b as i64) + } + #[inline] + fn truncate(w: i64) -> Self { + w as i32 + } + #[inline] + fn wide_shift(w: i64) -> Self { + (w >> 32) as i32 + } + #[inline] + fn wide_sub(a: i64, b: i64) -> i64 { + a.wrapping_sub(b) + } + #[inline] + fn wrapping_add(self, rhs: Self) -> Self { + i32::wrapping_add(self, rhs) + } + #[inline] + fn wrapping_sub(self, rhs: Self) -> Self { + i32::wrapping_sub(self, rhs) + } + #[inline] + fn wrapping_mul(self, rhs: Self) -> Self { + i32::wrapping_mul(self, rhs) + } + #[inline] + fn wrapping_neg(self) -> Self { + i32::wrapping_neg(self) + } + #[inline] + fn sign_mask(self) -> Self { + self >> 31 + } + #[inline] + fn bitand(self, rhs: Self) -> Self { + self & rhs + } + #[inline] + fn from_i64(v: i64) -> Self { + v as i32 + } + #[inline] + fn to_i64(self) -> i64 { + self as i64 + } +} + +/// A coefficient in Montgomery domain for an NTT prime. +/// +/// Wraps a `W` representing `a * R mod p` (where `R = 2^{W::R_LOG}`). +/// Use [`NttPrime::from_canonical`] to enter and [`NttPrime::to_canonical`] +/// to leave Montgomery domain. +#[derive(Clone, Copy, PartialEq, Eq, Default)] +#[repr(transparent)] +pub struct MontCoeff(W); + +impl MontCoeff { + /// Wrap a raw Montgomery-domain value. + #[inline] + pub fn from_raw(val: W) -> Self { + Self(val) + } + + /// Extract the raw value (still in Montgomery domain). + #[inline] + pub fn raw(self) -> W { + self.0 + } +} + +impl fmt::Debug for MontCoeff { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Mont({:?})", self.0) + } +} + +/// Per-prime constants for NTT arithmetic. +/// +/// Generic over `W: PrimeWidth` — use `i16` for primes below 2^14 (R = 2^16), +/// or `i32` for primes below 2^30 (R = 2^32). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NttPrime { + /// Prime modulus. + pub p: W, + /// `p^{-1} mod R` (centered signed). Used in Montgomery reduction. + pub pinv: W, + /// `R mod p` (centered signed). Montgomery form of 1. + pub mont: W, + /// `R^2 mod p` (centered signed). Used for canonical → Montgomery conversion. + pub montsq: W, +} + +impl NttPrime { + /// Derive all Montgomery constants from a raw prime value. + pub fn compute(p: W) -> Self { + let p_i64 = p.to_i64(); + debug_assert!(p_i64 > 1 && p_i64 % 2 == 1, "NTT prime must be odd and > 1"); + + // pinv via Newton's method: x_{n+1} = x_n * (2 - p * x_n). + // 5 iterations gives correctness mod 2^32 (sufficient for both i16 and i32). + let mut pinv: i64 = 1; + for _ in 0..5 { + pinv = pinv.wrapping_mul(2i64.wrapping_sub(p_i64.wrapping_mul(pinv))); + } + let pinv = W::from_i64(pinv); + + let half = p_i64 / 2; + let center = |x: i64| -> W { W::from_i64(if x > half { x - p_i64 } else { x }) }; + + let r_mod_p = ((1i128 << W::R_LOG) % (p_i64 as i128)) as i64; + let mont = center(r_mod_p); + + let rsq_mod_p = ((1i128 << (2 * W::R_LOG)) % (p_i64 as i128)) as i64; + let montsq = center(rsq_mod_p); + + Self { + p, + pinv, + mont, + montsq, + } + } + + /// Montgomery product: `a * b * R^{-1} mod p`. + #[inline] + pub fn mul(self, a: MontCoeff, b: MontCoeff) -> MontCoeff { + MontCoeff(self.mont_mul_raw(a.0, b.0)) + } + + /// Raw Montgomery multiply on bare `W` values. + #[inline] + pub(crate) fn mont_mul_raw(self, a: W, b: W) -> W { + let c = W::wide_mul(a, b); + let t = W::truncate(c).wrapping_mul(self.pinv); + let tp = W::wide_mul(t, self.p); + W::wide_shift(W::wide_sub(c, tp)) + } + + /// Conditionally subtract `p` if `a >= p` (branchless). + /// + /// Input may be in `(-2p, 2p)` during butterflies. Both i16 and i32 paths + /// widen to avoid signed overflow: i16→i32, i32→i64. + #[inline] + pub fn csubp(self, a: MontCoeff) -> MontCoeff { + if W::R_LOG == 16 { + let ai = a.0.to_i64() as i32; + let pi = self.p.to_i64() as i32; + let diff = ai - pi; + let mask = diff >> 31; + MontCoeff(W::from_i64((diff + (mask & pi)) as i64)) + } else { + let ai = a.0.to_i64(); + let pi = self.p.to_i64(); + let diff = ai - pi; + let mask = diff >> 63; + MontCoeff(W::from_i64(diff + (mask & pi))) + } + } + + /// Conditionally add `p` if `a < 0` (branchless). + /// + /// Widened arithmetic mirrors `csubp` to avoid i16 overflow edge cases. + #[inline] + pub fn caddp(self, a: MontCoeff) -> MontCoeff { + if W::R_LOG == 16 { + let ai = a.0.to_i64() as i32; + let pi = self.p.to_i64() as i32; + let mask = ai >> 31; + MontCoeff(W::from_i64((ai + (mask & pi)) as i64)) + } else { + let ai = a.0.to_i64(); + let pi = self.p.to_i64(); + let mask = ai >> 63; + MontCoeff(W::from_i64(ai + (mask & pi))) + } + } + + /// Range-reduce from `(-2p, 2p)` to `(-p, p)`. + #[inline] + pub fn reduce_range(self, a: MontCoeff) -> MontCoeff { + self.caddp(self.csubp(a)) + } + + /// Fully normalize a Montgomery coefficient to `[0, p)`. + #[inline] + pub fn normalize(self, a: MontCoeff) -> MontCoeff { + self.csubp(self.caddp(a)) + } + + /// Convert a canonical value into Montgomery domain: `a ↦ a * R mod p`. + #[inline] + pub fn from_canonical(self, a: W) -> MontCoeff { + MontCoeff(self.mont_mul_raw(a, self.montsq)) + } + + /// Convert from Montgomery domain to canonical `[0, p)`. + #[inline] + pub fn to_canonical(self, a: MontCoeff) -> W { + let raw = MontCoeff(self.mont_mul_raw(a.0, W::from_i64(1))); + self.normalize(raw).0 + } + + /// Center a canonical value from approximately `(-p, p)` into `[-p/2, p/2)`. + #[inline] + pub fn center(self, a: W) -> W { + let mask_neg = a.sign_mask(); + let canonical = a.wrapping_add(mask_neg.bitand(self.p)); + let half = W::from_i64(self.p.to_i64() / 2); + let needs_sub = half.wrapping_sub(canonical).sign_mask(); + canonical.wrapping_add(needs_sub.bitand(self.p.wrapping_neg())) + } + + /// Pointwise Montgomery multiplication of two coefficient slices. + /// + /// # Panics + /// + /// Panics if slices have different lengths. + #[inline] + pub fn pointwise_mul( + self, + out: &mut [MontCoeff], + lhs: &[MontCoeff], + rhs: &[MontCoeff], + ) { + assert_eq!(out.len(), lhs.len()); + assert_eq!(lhs.len(), rhs.len()); + for ((o, a), b) in out.iter_mut().zip(lhs.iter()).zip(rhs.iter()) { + *o = self.mul(*a, *b); + } + } + + /// In-place Montgomery scaling by a constant. + #[inline] + pub fn scale_in_place(self, coeffs: &mut [MontCoeff], scalar: MontCoeff) { + for c in coeffs { + *c = self.mul(*c, scalar); + } + } + + /// In-place range reduction on a coefficient slice. + #[inline] + pub fn reduce_range_in_place(self, coeffs: &mut [MontCoeff]) { + for c in coeffs { + *c = self.reduce_range(*c); + } + } + + /// In-place centering of canonical values to `[-p/2, p/2)`. + #[inline] + pub fn center_slice(self, coeffs: &mut [W]) { + for c in coeffs { + *c = self.center(*c); + } + } +} diff --git a/src/algebra/ntt/tables.rs b/src/algebra/ntt/tables.rs new file mode 100644 index 00000000..70e35bd1 --- /dev/null +++ b/src/algebra/ntt/tables.rs @@ -0,0 +1,215 @@ +//! Deterministic parameter presets for small-prime CRT arithmetic. +//! +//! Q32: `logq = 32` with six `i16` NTT-friendly primes (D ≤ 64). +//! Q64: `logq = 64` with `i32` NTT-friendly primes (D ≤ 1024). +//! Q128: `logq = 128` with five `i32` NTT-friendly primes (D ≤ 1024). + +use super::crt::GarnerData; +use super::prime::NttPrime; + +/// Polynomial degree for the base ring `Z_q[X]/(X^d + 1)`. +pub const RING_DEGREE: usize = 64; +/// Maximum ring degree covered by the i32 CRT parameter sets. +pub const MAX_CRT_RING_DEGREE: usize = 1024; + +/// Number of CRT primes for the `logq = 32` parameter set. +pub const Q32_NUM_PRIMES: usize = 6; + +/// The modulus `q = 2^32 - 99`. +pub const Q32_MODULUS: u64 = (1u64 << 32) - 99; + +/// CRT primes and per-prime Montgomery constants for `logq = 32`. +/// +/// All constants are for `R = 2^16` (i16 width). +pub const Q32_PRIMES: [NttPrime; Q32_NUM_PRIMES] = [ + NttPrime { + p: 13697, + pinv: 2689, + mont: -2949, + montsq: -994, + }, + NttPrime { + p: 13441, + pinv: 2945, + mont: -1669, + montsq: 3274, + }, + NttPrime { + p: 13313, + pinv: -13311, + mont: -1029, + montsq: -6199, + }, + NttPrime { + p: 12289, + pinv: -12287, + mont: 4091, + montsq: -1337, + }, + NttPrime { + p: 12161, + pinv: 4225, + mont: 4731, + montsq: -6040, + }, + NttPrime { + p: 11777, + pinv: -11775, + mont: -5126, + montsq: 1389, + }, +]; + +/// Garner CRT reconstruction constants for Q32. +pub fn q32_garner() -> GarnerData { + GarnerData::compute(&Q32_PRIMES) +} + +/// Number of CRT primes for the `logq = 64` fast profile (`P > q`). +pub const Q64_NUM_PRIMES_FAST: usize = 3; +/// Number of CRT primes for the `logq = 64` conservative profile (`P > 128*q^2`). +pub const Q64_NUM_PRIMES: usize = 5; + +/// The modulus `q = 2^64 - 59`. +pub const Q64_MODULUS: u64 = u64::MAX - 58; + +/// Number of CRT primes for the `logq = 128` parameter set. +pub const Q128_NUM_PRIMES: usize = 5; + +/// The modulus `q = 2^128 - 275`. +pub const Q128_MODULUS: u128 = u128::MAX - 274; + +/// Raw 30-bit primes for D≤1024, each satisfying `2048 | (p - 1)`. +/// +/// They are ordered descending by value. +pub const D1024_RAW_PRIMES: [i32; Q128_NUM_PRIMES] = + [1073707009, 1073698817, 1073692673, 1073682433, 1073668097]; + +/// Raw 30-bit primes for Q64 fast profile (`K=3`, `P > q`). +pub const Q64_RAW_PRIMES_FAST: [i32; Q64_NUM_PRIMES_FAST] = [ + D1024_RAW_PRIMES[0], + D1024_RAW_PRIMES[1], + D1024_RAW_PRIMES[2], +]; + +/// Raw 30-bit primes for Q64 conservative profile (`K=5`, `P > 128*q^2`). +pub const Q64_RAW_PRIMES: [i32; Q64_NUM_PRIMES] = D1024_RAW_PRIMES; + +/// Raw 30-bit primes for Q128, each satisfying `2048 | (p - 1)`. +pub const Q128_RAW_PRIMES: [i32; Q128_NUM_PRIMES] = D1024_RAW_PRIMES; + +/// CRT primes and per-prime Montgomery constants for `logq = 64` fast profile. +pub fn q64_primes_fast() -> [NttPrime; Q64_NUM_PRIMES_FAST] { + std::array::from_fn(|k| NttPrime::compute(Q64_RAW_PRIMES_FAST[k])) +} + +/// Garner CRT reconstruction constants for Q64 fast profile. +pub fn q64_garner_fast() -> GarnerData { + let primes = q64_primes_fast(); + GarnerData::compute(&primes) +} + +/// CRT primes and per-prime Montgomery constants for `logq = 64` conservative profile. +pub fn q64_primes() -> [NttPrime; Q64_NUM_PRIMES] { + std::array::from_fn(|k| NttPrime::compute(Q64_RAW_PRIMES[k])) +} + +/// Garner CRT reconstruction constants for Q64 conservative profile. +pub fn q64_garner() -> GarnerData { + let primes = q64_primes(); + GarnerData::compute(&primes) +} + +/// CRT primes and per-prime Montgomery constants for `logq = 128`. +pub fn q128_primes() -> [NttPrime; Q128_NUM_PRIMES] { + std::array::from_fn(|k| NttPrime::compute(Q128_RAW_PRIMES[k])) +} + +/// Garner CRT reconstruction constants for Q128. +pub fn q128_garner() -> GarnerData { + let primes = q128_primes(); + GarnerData::compute(&primes) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn verify_q32_prime_derived_constants() { + for prime in &Q32_PRIMES { + let recomputed = NttPrime::compute(prime.p); + assert_eq!( + prime.pinv, recomputed.pinv, + "pinv mismatch for p={}", + prime.p + ); + assert_eq!( + prime.mont, recomputed.mont, + "mont mismatch for p={}", + prime.p + ); + assert_eq!( + prime.montsq, recomputed.montsq, + "montsq mismatch for p={}", + prime.p + ); + } + } + + #[test] + fn verify_q128_primes_are_valid() { + let primes = q128_primes(); + for np in &primes { + let p = np.p as i64; + assert!(p > 1 && p % 2 == 1, "prime must be odd and > 1"); + assert_eq!( + (p - 1) % 2048, + 0, + "2048 must divide p-1 for D=1024 NTT (p={p})" + ); + // Verify pinv: p * pinv ≡ 1 (mod 2^32) + assert_eq!( + np.p.wrapping_mul(np.pinv), + 1, + "pinv verification failed for p={p}" + ); + } + } + + #[test] + fn verify_q64_primes_are_valid() { + let primes = q64_primes(); + for np in &primes { + let p = np.p as i64; + assert!(p > 1 && p % 2 == 1, "prime must be odd and > 1"); + assert_eq!( + (p - 1) % 2048, + 0, + "2048 must divide p-1 for D=1024 NTT (p={p})" + ); + assert_eq!( + np.p.wrapping_mul(np.pinv), + 1, + "pinv verification failed for p={p}" + ); + } + } + + #[test] + fn garner_data_is_consistent() { + let garner = q32_garner(); + for (i, &prime_i) in Q32_PRIMES.iter().enumerate().skip(1) { + let pi = prime_i.p as i64; + for (j, &prime_j) in Q32_PRIMES[..i].iter().enumerate() { + let pj = prime_j.p as i64; + let g = garner.gamma[i][j] as i64; + assert_eq!( + (pj * g) % pi, + 1, + "garner gamma[{i}][{j}] not inverse of p_{j} mod p_{i}" + ); + } + } + } +} diff --git a/src/algebra/poly.rs b/src/algebra/poly.rs new file mode 100644 index 00000000..cc69e42e --- /dev/null +++ b/src/algebra/poly.rs @@ -0,0 +1,163 @@ +//! Polynomial containers and evaluation utilities. + +use crate::error::HachiError; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::FieldCore; +use crate::FromSmallInt; +use std::io::{Read, Write}; + +/// A degree-(pub [F; D]); + +impl Poly { + /// Construct the zero polynomial. + pub fn zero() -> Self { + Self([F::zero(); D]) + } +} + +impl std::ops::Add for Poly { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + let mut out = self.0; + for (dst, src) in out.iter_mut().zip(rhs.0.iter()) { + *dst = *dst + *src; + } + Self(out) + } +} + +impl std::ops::Sub for Poly { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + let mut out = self.0; + for (dst, src) in out.iter_mut().zip(rhs.0.iter()) { + *dst = *dst - *src; + } + Self(out) + } +} + +impl std::ops::Neg for Poly { + type Output = Self; + fn neg(self) -> Self::Output { + let mut out = self.0; + for coeff in &mut out { + *coeff = -*coeff; + } + Self(out) + } +} + +impl Valid for Poly { + fn check(&self) -> Result<(), SerializationError> { + for x in self.0.iter() { + x.check()?; + } + Ok(()) + } +} + +impl HachiSerialize for Poly { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + for x in self.0.iter() { + x.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.0.iter().map(|x| x.serialized_size(compress)).sum() + } +} + +impl HachiDeserialize for Poly { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let mut arr = [F::zero(); D]; + for coeff in &mut arr { + *coeff = F::deserialize_with_mode(&mut reader, compress, validate)?; + } + let out = Self(arr); + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Evaluate the range-check polynomial `w · Π_{k=1}^{b−1} (w − k)(w + k)`. +/// +/// This polynomial vanishes exactly when `w ∈ {−(b−1), …, b−1}`. +/// Total degree in `w` is `2b − 1`. +pub fn range_check_eval(w: E, b: usize) -> E { + let mut acc = w; + for k in 1..b { + let k_e = E::from_u64(k as u64); + acc = acc * (w - k_e) * (w + k_e); + } + acc +} + +/// Evaluate a multilinear polynomial (given by boolean-hypercube evaluations in +/// little-endian bit order) at an arbitrary point via iterated folding. +/// +/// # Errors +/// +/// Returns an error if the evaluation table length is not a power of two or +/// does not match `2^point.len()`. +pub fn multilinear_eval(evals: &[E], point: &[E]) -> Result { + if !evals.len().is_power_of_two() { + return Err(HachiError::InvalidSize { + expected: 1 << point.len(), + actual: evals.len(), + }); + } + if evals.len() != 1 << point.len() { + return Err(HachiError::InvalidSize { + expected: 1 << point.len(), + actual: evals.len(), + }); + } + let mut current = evals.to_vec(); + for &r in point { + let half = current.len() / 2; + let mut next = Vec::with_capacity(half); + for i in 0..half { + next.push(current[2 * i] + r * (current[2 * i + 1] - current[2 * i])); + } + current = next; + } + Ok(current[0]) +} + +/// Fold an evaluation table in place by binding its first variable to `r`, +/// halving the table size. +/// +/// # Panics +/// +/// Panics if the evaluation table length is not a power of two or has fewer +/// than 2 elements. This is a prover-only helper where the caller guarantees +/// well-formed input. +pub fn fold_evals_in_place(evals: &mut Vec, r: E) { + assert!( + evals.len().is_power_of_two(), + "evals length must be a power of two" + ); + assert!(evals.len() >= 2, "evals must have at least 2 elements"); + let half = evals.len() / 2; + for i in 0..half { + evals[i] = evals[2 * i] + r * (evals[2 * i + 1] - evals[2 * i]); + } + evals.truncate(half); +} diff --git a/src/algebra/ring/crt_ntt_repr.rs b/src/algebra/ring/crt_ntt_repr.rs new file mode 100644 index 00000000..ad856ac7 --- /dev/null +++ b/src/algebra/ring/crt_ntt_repr.rs @@ -0,0 +1,299 @@ +//! CRT+NTT-domain representation of cyclotomic ring elements. + +use std::array::from_fn; + +use crate::algebra::backend::{CrtReconstruct, NttPrimeOps, NttTransform, ScalarBackend}; +use crate::algebra::ntt::butterfly::NttTwiddles; +use crate::algebra::ntt::crt::GarnerData; +use crate::algebra::ntt::prime::{MontCoeff, NttPrime, PrimeWidth}; +use crate::{CanonicalField, FieldCore}; + +use super::cyclotomic::CyclotomicRing; + +/// CRT+NTT-domain representation of a cyclotomic ring element. +/// +/// Stores `K` arrays of `D` [`MontCoeff`] values, one per CRT prime. +/// Multiplication is pointwise per prime — O(K*D) vs O(D^2) for coefficient form. +/// +/// Generic over: +/// - `W: PrimeWidth` — integer width (`i16` or `i32`) +/// - `K` — number of CRT primes +/// - `D` — polynomial degree +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CyclotomicCrtNtt { + pub(crate) limbs: [[MontCoeff; D]; K], +} + +/// Field types that can convert to/from the CRT+NTT representation. +/// +/// Blanket-implemented for all `FieldCore + CanonicalField` types. +pub trait CrtNttConvertibleField: FieldCore + CanonicalField {} + +impl CrtNttConvertibleField for F {} + +/// Bundled CRT+NTT parameters for a fixed width/prime-count/degree tuple. +/// +/// Keeps primes/twiddles/Garner constants consistent and avoids passing them +/// independently at every call site. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CrtNttParamSet { + /// CRT primes with Montgomery constants. + pub primes: [NttPrime; K], + /// Per-prime twiddle tables for forward/inverse NTT. + pub twiddles: [NttTwiddles; K], + /// Garner reconstruction constants for CRT lift-back. + pub garner: GarnerData, +} + +impl CrtNttParamSet { + /// Build a full parameter set from CRT primes. + /// + /// Computes per-prime twiddles and Garner reconstruction constants. + pub fn new(primes: [NttPrime; K]) -> Self { + let twiddles = from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = GarnerData::compute(&primes); + Self { + primes, + twiddles, + garner, + } + } +} + +impl CyclotomicCrtNtt { + /// The additive identity (all zeros in every CRT limb). + pub fn zero() -> Self { + Self { + limbs: [[MontCoeff::from_raw(W::default()); D]; K], + } + } + + /// Convert a coefficient-form ring element into CRT+NTT domain + /// using the default scalar backend. + pub fn from_ring( + ring: &CyclotomicRing, + primes: &[NttPrime; K], + twiddles: &[NttTwiddles; K], + ) -> Self { + Self::from_ring_with_backend::(ring, primes, twiddles) + } + + /// Convert a coefficient-form ring element into CRT+NTT domain + /// using a bundled parameter set and the scalar backend. + pub fn from_ring_with_params( + ring: &CyclotomicRing, + params: &CrtNttParamSet, + ) -> Self { + Self::from_ring(ring, ¶ms.primes, ¶ms.twiddles) + } + + /// Convert a coefficient-form ring element into CRT+NTT domain + /// through an explicit backend implementation. + pub fn from_ring_with_backend< + F: CrtNttConvertibleField, + B: NttPrimeOps + NttTransform, + >( + ring: &CyclotomicRing, + primes: &[NttPrime; K], + twiddles: &[NttTwiddles; K], + ) -> Self { + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + let centered_coeffs: [i128; D] = from_fn(|i| { + let canonical = ring.coeffs[i].to_canonical_u128(); + if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + } + }); + + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for ((limb, prime), tw) in limbs.iter_mut().zip(primes.iter()).zip(twiddles.iter()) { + // Interpret coefficients in centered form (-q/2, q/2] before reducing + // into the CRT primes. This makes the reduction map consistent with + // negacyclic subtraction (which naturally produces negative values). + let p = prime.p.to_i64() as i128; + let half_p = p / 2; + for (dst, centered) in limb.iter_mut().zip(centered_coeffs.iter()) { + let mut r = *centered % p; + if r < 0 { + r += p; + } + // Center residues into [-p/2, p/2) for stable signed arithmetic. + if r >= half_p { + r -= p; + } + *dst = B::from_canonical(*prime, W::from_i64(r as i64)); + } + B::forward_ntt(limb, *prime, tw); + } + Self { limbs } + } + + /// Convert from CRT+NTT domain back to coefficient form + /// using the default scalar backend. + pub fn to_ring( + &self, + primes: &[NttPrime; K], + twiddles: &[NttTwiddles; K], + garner: &GarnerData, + ) -> CyclotomicRing { + self.to_ring_with_backend::(primes, twiddles, garner) + } + + /// Convert from CRT+NTT domain back to coefficient form + /// using a bundled parameter set and the scalar backend. + pub fn to_ring_with_params( + &self, + params: &CrtNttParamSet, + ) -> CyclotomicRing { + self.to_ring(¶ms.primes, ¶ms.twiddles, ¶ms.garner) + } + + /// Convert from CRT+NTT domain back to coefficient form + /// through an explicit backend implementation. + pub fn to_ring_with_backend< + F: CrtNttConvertibleField, + B: NttPrimeOps + NttTransform + CrtReconstruct, + >( + &self, + primes: &[NttPrime; K], + twiddles: &[NttTwiddles; K], + garner: &GarnerData, + ) -> CyclotomicRing { + let mut canonical = [[W::default(); D]; K]; + for (k, ((can, prime), tw)) in canonical + .iter_mut() + .zip(primes.iter()) + .zip(twiddles.iter()) + .enumerate() + { + let mut limb = self.limbs[k]; + B::inverse_ntt(&mut limb, *prime, tw); + for (dst, src) in can.iter_mut().zip(limb.iter()) { + let canon = B::to_canonical(*prime, *src); + *dst = prime.center(canon); + } + } + + let coeffs = B::reconstruct::(primes, &canonical, garner); + + CyclotomicRing::from_coefficients(coeffs) + } + + /// Add another CRT+NTT element and reduce each coefficient with the matching + /// prime to maintain valid Montgomery ranges using the scalar backend. + pub fn add_reduced(&self, rhs: &Self, primes: &[NttPrime; K]) -> Self { + self.add_reduced_with_backend::(rhs, primes) + } + + /// Add another CRT+NTT element and reduce using a bundled parameter set. + pub fn add_reduced_with_params(&self, rhs: &Self, params: &CrtNttParamSet) -> Self { + self.add_reduced(rhs, ¶ms.primes) + } + + /// Add another CRT+NTT element and reduce each coefficient with the matching + /// prime through an explicit backend implementation. + pub fn add_reduced_with_backend>( + &self, + rhs: &Self, + primes: &[NttPrime; K], + ) -> Self { + let mut out = self.clone(); + for (k, (limb, rhs_limb)) in out.limbs.iter_mut().zip(rhs.limbs.iter()).enumerate() { + let prime = primes[k]; + for (a, b) in limb.iter_mut().zip(rhs_limb.iter()) { + let sum = MontCoeff::from_raw(a.raw().wrapping_add(b.raw())); + *a = B::reduce_range(prime, sum); + } + } + out + } + + /// Subtract another CRT+NTT element and reduce using the scalar backend. + pub fn sub_reduced(&self, rhs: &Self, primes: &[NttPrime; K]) -> Self { + self.sub_reduced_with_backend::(rhs, primes) + } + + /// Subtract another CRT+NTT element and reduce using a bundled parameter set. + pub fn sub_reduced_with_params(&self, rhs: &Self, params: &CrtNttParamSet) -> Self { + self.sub_reduced(rhs, ¶ms.primes) + } + + /// Subtract another CRT+NTT element and reduce through an explicit backend. + pub fn sub_reduced_with_backend>( + &self, + rhs: &Self, + primes: &[NttPrime; K], + ) -> Self { + let mut out = self.clone(); + for (k, (limb, rhs_limb)) in out.limbs.iter_mut().zip(rhs.limbs.iter()).enumerate() { + let prime = primes[k]; + for (a, b) in limb.iter_mut().zip(rhs_limb.iter()) { + let diff = MontCoeff::from_raw(a.raw().wrapping_sub(b.raw())); + *a = B::reduce_range(prime, diff); + } + } + out + } + + /// Negate each CRT+NTT coefficient and reduce using the scalar backend. + pub fn neg_reduced(&self, primes: &[NttPrime; K]) -> Self { + self.neg_reduced_with_backend::(primes) + } + + /// Negate each CRT+NTT coefficient and reduce using a bundled parameter set. + pub fn neg_reduced_with_params(&self, params: &CrtNttParamSet) -> Self { + self.neg_reduced(¶ms.primes) + } + + /// Negate each CRT+NTT coefficient and reduce through an explicit backend. + pub fn neg_reduced_with_backend>( + &self, + primes: &[NttPrime; K], + ) -> Self { + let mut out = self.clone(); + for (k, limb) in out.limbs.iter_mut().enumerate() { + let prime = primes[k]; + for a in limb.iter_mut() { + let neg = MontCoeff::from_raw(a.raw().wrapping_neg()); + *a = B::reduce_range(prime, neg); + } + } + out + } + + /// Pointwise multiplication in CRT+NTT domain using the scalar backend. + pub fn pointwise_mul(&self, rhs: &Self, primes: &[NttPrime; K]) -> Self { + self.pointwise_mul_with_backend::(rhs, primes) + } + + /// Pointwise multiplication in CRT+NTT domain using a bundled parameter set. + pub fn pointwise_mul_with_params(&self, rhs: &Self, params: &CrtNttParamSet) -> Self { + self.pointwise_mul(rhs, ¶ms.primes) + } + + /// Pointwise multiplication in CRT+NTT domain through an explicit backend. + pub fn pointwise_mul_with_backend>( + &self, + rhs: &Self, + primes: &[NttPrime; K], + ) -> Self { + let mut out = [[MontCoeff::from_raw(W::default()); D]; K]; + for (k, ((o, a), b)) in out + .iter_mut() + .zip(self.limbs.iter()) + .zip(rhs.limbs.iter()) + .enumerate() + { + let prime = primes[k]; + B::pointwise_mul(prime, o, a, b); + // Keep coefficients in a bounded range for subsequent inverse NTT. + for c in o.iter_mut() { + *c = B::reduce_range(prime, *c); + } + } + Self { limbs: out } + } +} diff --git a/src/algebra/ring/cyclotomic.rs b/src/algebra/ring/cyclotomic.rs new file mode 100644 index 00000000..a39b49af --- /dev/null +++ b/src/algebra/ring/cyclotomic.rs @@ -0,0 +1,486 @@ +//! Cyclotomic ring `Z_q[X]/(X^D + 1)` in coefficient form. + +use super::sparse_challenge::SparseChallenge; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{CanonicalField, FieldCore, FieldSampling}; +use rand_core::RngCore; +use std::array::from_fn; +use std::io::{Read, Write}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +/// Element of the cyclotomic ring `Z_q[X]/(X^D + 1)`. +/// +/// Stored as `D` coefficients in the base field `F`, representing +/// `a_0 + a_1*X + ... + a_{D-1}*X^{D-1}`. +/// +/// Multiplication is negacyclic convolution: `X^D = -1`, so a product +/// term at index `i + j >= D` wraps to index `(i + j) - D` with a sign flip. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct CyclotomicRing { + pub(crate) coeffs: [F; D], +} + +impl CyclotomicRing { + /// Construct from a coefficient array. + #[inline] + pub fn from_coefficients(coeffs: [F; D]) -> Self { + Self { coeffs } + } + + /// Borrow the coefficient array. + #[inline] + pub fn coefficients(&self) -> &[F; D] { + &self.coeffs + } + + /// The additive identity (all-zero polynomial). + #[inline] + pub fn zero() -> Self { + Self { + coeffs: [F::zero(); D], + } + } + + /// The multiplicative identity (`1 + 0*X + ... + 0*X^{D-1}`). + #[inline] + pub fn one() -> Self { + let mut coeffs = [F::zero(); D]; + coeffs[0] = F::one(); + Self { coeffs } + } + + /// The monomial `X` (i.e., `[0, 1, 0, ..., 0]`). + /// + /// # Panics + /// + /// Panics if `D < 2`. + #[inline] + pub fn x() -> Self { + assert!(D >= 2, "ring degree must be at least 2"); + let mut coeffs = [F::zero(); D]; + coeffs[1] = F::one(); + Self { coeffs } + } + + /// Scalar multiplication: multiply every coefficient by `k`. + #[inline] + pub fn scale(&self, k: &F) -> Self { + let mut out = self.coeffs; + for c in &mut out { + *c = *c * *k; + } + Self { coeffs: out } + } + + /// Apply the cyclotomic automorphism `sigma_k: X -> X^k` for odd `k`. + /// + /// In `Z_q[X]/(X^D + 1)`, this permutes/sign-flips coefficients using + /// exponent reduction modulo `2D`. + /// + /// # Panics + /// + /// Panics if `D == 0` or `k` is not odd modulo `2D`. + pub fn sigma(&self, k: usize) -> Self { + assert!(D > 0, "ring degree must be non-zero"); + let two_d = 2 * D; + let k_mod = k % two_d; + assert!(k_mod % 2 == 1, "sigma_k requires odd k in Z_q[X]/(X^D + 1)"); + + let mut out = [F::zero(); D]; + for (j, coeff) in self.coeffs.iter().copied().enumerate() { + let idx = (j * k_mod) % two_d; + if idx < D { + out[idx] = out[idx] + coeff; + } else { + out[idx - D] = out[idx - D] - coeff; + } + } + Self { coeffs: out } + } + + /// Apply `sigma_{-1}` (`X -> X^{-1} = X^{2D-1}` in this ring). + /// + /// # Panics + /// + /// Panics if `D == 0`. + pub fn sigma_m1(&self) -> Self { + assert!(D > 0, "ring degree must be non-zero"); + self.sigma(2 * D - 1) + } + + /// Multiply by `X^k` in `Z_q[X]/(X^D + 1)` via O(D) coefficient rotation. + /// + /// Since `X^D = -1`, coefficients that wrap past index `D` get negated. + #[inline] + pub fn negacyclic_shift(&self, k: usize) -> Self { + let k = k % D; + if k == 0 { + return *self; + } + let mut out = [F::zero(); D]; + for i in 0..D { + let target = i + k; + if target < D { + out[target] = self.coeffs[i]; + } else { + out[target - D] = -self.coeffs[i]; + } + } + Self { coeffs: out } + } + + /// Multiply `self` by a sum of monomials `X^{k_1} + X^{k_2} + ...` + /// + /// Each term is a negacyclic shift, so the total cost is + /// `O(positions.len() * D)` field additions with zero multiplications. + pub fn mul_by_monomial_sum(&self, nonzero_positions: &[usize]) -> Self { + let mut result = Self::zero(); + for &k in nonzero_positions { + result += self.negacyclic_shift(k); + } + result + } + + /// Multiply `self` by a sparse challenge element. + /// + /// Cost: `O(omega * D)` field additions instead of `O(D^2)` multiplications. + /// For `omega=23, D=256` this is 5,888 adds vs 65,536 muls. + pub fn mul_by_sparse(&self, challenge: &SparseChallenge) -> Self + where + F: CanonicalField, + { + let mut result = Self::zero(); + for (&pos, &coeff) in challenge.positions.iter().zip(challenge.coeffs.iter()) { + let shifted = self.negacyclic_shift(pos as usize); + match coeff { + 1 => result += shifted, + -1 => result -= shifted, + c => result += shifted.scale(&F::from_i64(c as i64)), + } + } + result + } + + /// Check whether all coefficients are zero. + #[inline] + pub fn is_zero(&self) -> bool { + self.coeffs.iter().all(|c| c.is_zero()) + } + + /// Count non-zero coefficients. + #[inline] + pub fn hamming_weight(&self) -> usize { + self.coeffs.iter().filter(|c| !c.is_zero()).count() + } + + /// Sample a sparse challenge with exactly `omega` non-zeros in `{+1, -1}`. + /// + /// # Panics + /// + /// Panics if `omega > D` or `D == 0` with non-zero `omega`. + pub fn sample_sparse_pm1(rng: &mut R, omega: usize) -> Self { + assert!(omega <= D, "omega must be <= ring degree"); + assert!(D > 0 || omega == 0, "ring degree must be non-zero"); + + let mut coeffs = [F::zero(); D]; + let mut placed = 0usize; + while placed < omega { + let idx = (rng.next_u64() % (D as u64)) as usize; + if coeffs[idx].is_zero() { + coeffs[idx] = if (rng.next_u32() & 1) == 0 { + F::one() + } else { + -F::one() + }; + placed += 1; + } + } + Self { coeffs } + } +} + +impl CyclotomicRing { + /// Balanced decomposition writing directly into a pre-allocated output slice. + /// + /// `out` must have length exactly `levels`. Each element receives one digit plane. + /// + /// # Panics + /// + /// Panics if `log_basis == 0`, `log_basis >= 128`, or `out.len() * log_basis > 128 + log_basis`. + pub fn balanced_decompose_pow2_into(&self, out: &mut [Self], log_basis: u32) { + let levels = out.len(); + assert!(log_basis > 0 && log_basis < 128, "invalid log_basis"); + assert!( + (levels as u32).saturating_mul(log_basis) <= 128 + log_basis, + "levels * log_basis must be <= 128 + log_basis" + ); + + let b = 1i128 << log_basis; + let half_b = b / 2; + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + + for plane in out.iter_mut() { + *plane = Self::zero(); + } + + for i in 0..D { + let canonical = self.coeffs[i].to_canonical_u128(); + let mut c: i128 = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + + for plane in out.iter_mut() { + let d = c.rem_euclid(b); + let balanced = if d >= half_b { d - b } else { d }; + c = (c - balanced) / b; + + plane.coeffs[i] = if balanced >= 0 { + F::from_canonical_u128_reduced(balanced as u128) + } else { + F::from_canonical_u128_reduced(q - ((-balanced) as u128)) + }; + } + } + } + + /// Functional gadget recomposition (`G * digits`) for base `2^log_basis`. + /// + /// Coefficients from each part are interpreted as one digit plane and + /// recombined back into canonical integers (then reduced into the field). + /// + /// # Panics + /// + /// Panics if `log_basis == 0`, `log_basis >= 128`, or `parts.len() * log_basis > 128`. + pub fn gadget_recompose_pow2(parts: &[Self], log_basis: u32) -> Self { + if parts.is_empty() { + return Self::zero(); + } + + assert!(log_basis > 0 && log_basis < 128, "invalid log_basis"); + + let b = F::from_canonical_u128_reduced(1u128 << log_basis); + let coeffs = from_fn(|i| { + let mut acc = F::zero(); + let mut power = F::one(); + for part in parts.iter() { + acc = acc + part.coeffs[i] * power; + power = power * b; + } + acc + }); + Self { coeffs } + } + + /// Balanced (centered) base-`2^log_basis` gadget decomposition: `G^{-1}`. + /// + /// Each coefficient `c` (centered into `(-q/2, q/2]`) is decomposed into + /// `levels` balanced digits `d_k ∈ [-b/2, b/2)` satisfying + /// `c ≡ Σ_k d_k · b^k (mod q)`. + /// + /// Negative digits are stored as their field representation (`q + d`). + /// + /// # Panics + /// + /// Panics if `log_basis == 0`, `log_basis >= 128`, or `levels * log_basis > 128`. + pub fn balanced_decompose_pow2(&self, levels: usize, log_basis: u32) -> Vec { + assert!(log_basis > 0 && log_basis < 128, "invalid log_basis"); + assert!( + (levels as u32).saturating_mul(log_basis) <= 128 + log_basis, + "levels * log_basis must be <= 128 + log_basis" + ); + + let b = 1i128 << log_basis; + let half_b = b / 2; + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + + let mut digit_planes: Vec<[F; D]> = (0..levels).map(|_| [F::zero(); D]).collect(); + + for i in 0..D { + let canonical = self.coeffs[i].to_canonical_u128(); + let mut c: i128 = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + + for plane in digit_planes.iter_mut() { + let d = c.rem_euclid(b); + let balanced = if d >= half_b { d - b } else { d }; + c = (c - balanced) / b; + + plane[i] = if balanced >= 0 { + F::from_canonical_u128_reduced(balanced as u128) + } else { + F::from_canonical_u128_reduced(q - ((-balanced) as u128)) + }; + } + } + + digit_planes + .into_iter() + .map(Self::from_coefficients) + .collect() + } +} + +impl CyclotomicRing { + /// Generate a random ring element. + pub fn random(rng: &mut R) -> Self { + Self { + coeffs: from_fn(|_| F::sample(rng)), + } + } +} + +impl AddAssign for CyclotomicRing { + fn add_assign(&mut self, rhs: Self) { + for (dst, src) in self.coeffs.iter_mut().zip(rhs.coeffs.iter()) { + *dst = *dst + *src; + } + } +} + +impl SubAssign for CyclotomicRing { + fn sub_assign(&mut self, rhs: Self) { + for (dst, src) in self.coeffs.iter_mut().zip(rhs.coeffs.iter()) { + *dst = *dst - *src; + } + } +} + +impl Add for CyclotomicRing { + type Output = Self; + fn add(mut self, rhs: Self) -> Self { + self += rhs; + self + } +} + +impl Sub for CyclotomicRing { + type Output = Self; + fn sub(mut self, rhs: Self) -> Self { + self -= rhs; + self + } +} + +impl Neg for CyclotomicRing { + type Output = Self; + fn neg(self) -> Self { + let mut out = self.coeffs; + for c in &mut out { + *c = -*c; + } + Self { coeffs: out } + } +} + +impl MulAssign for CyclotomicRing { + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl<'a, F: FieldCore, const D: usize> Add<&'a Self> for CyclotomicRing { + type Output = Self; + fn add(self, rhs: &'a Self) -> Self { + self + *rhs + } +} + +impl<'a, F: FieldCore, const D: usize> Sub<&'a Self> for CyclotomicRing { + type Output = Self; + fn sub(self, rhs: &'a Self) -> Self { + self - *rhs + } +} + +impl<'a, F: FieldCore, const D: usize> Mul<&'a Self> for CyclotomicRing { + type Output = Self; + fn mul(self, rhs: &'a Self) -> Self { + self * *rhs + } +} + +/// Schoolbook negacyclic convolution: O(D^2). +/// +/// For each pair `(i, j)`: +/// - If `i + j < D`: accumulate `a_i * b_j` at index `i + j`. +/// - If `i + j >= D`: accumulate `-(a_i * b_j)` at index `(i + j) - D`. +impl Mul for CyclotomicRing { + type Output = Self; + fn mul(self, rhs: Self) -> Self { + let mut out = [F::zero(); D]; + for i in 0..D { + for j in 0..D { + let product = self.coeffs[i] * rhs.coeffs[j]; + let idx = i + j; + if idx < D { + out[idx] = out[idx] + product; + } else { + out[idx - D] = out[idx - D] - product; + } + } + } + Self { coeffs: out } + } +} + +impl Valid for CyclotomicRing { + fn check(&self) -> Result<(), SerializationError> { + for x in self.coeffs.iter() { + x.check()?; + } + Ok(()) + } +} + +impl HachiSerialize for CyclotomicRing { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + for x in self.coeffs.iter() { + x.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.coeffs + .iter() + .map(|x| x.serialized_size(compress)) + .sum() + } +} + +impl HachiDeserialize for CyclotomicRing { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let mut coeffs = [F::zero(); D]; + for c in &mut coeffs { + *c = F::deserialize_with_mode(&mut reader, compress, validate)?; + } + let out = Self { coeffs }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl Default for CyclotomicRing { + fn default() -> Self { + Self::zero() + } +} diff --git a/src/algebra/ring/mod.rs b/src/algebra/ring/mod.rs new file mode 100644 index 00000000..756854b3 --- /dev/null +++ b/src/algebra/ring/mod.rs @@ -0,0 +1,9 @@ +//! Cyclotomic ring types and NTT representations. + +pub mod crt_ntt_repr; +pub mod cyclotomic; +pub mod sparse_challenge; + +pub use crt_ntt_repr::{CrtNttConvertibleField, CrtNttParamSet, CyclotomicCrtNtt}; +pub use cyclotomic::CyclotomicRing; +pub use sparse_challenge::{SparseChallenge, SparseChallengeConfig}; diff --git a/src/algebra/ring/sparse_challenge.rs b/src/algebra/ring/sparse_challenge.rs new file mode 100644 index 00000000..4f55b06a --- /dev/null +++ b/src/algebra/ring/sparse_challenge.rs @@ -0,0 +1,164 @@ +//! Sparse ring challenges for cyclotomic protocols. +//! +//! Many lattice protocols sample "short/sparse" ring challenges whose coefficients +//! are mostly zero and whose non-zero coefficients come from a tiny integer alphabet +//! (e.g. `{±1}` or `{±1,±2}`), with a fixed Hamming weight `ω`. +//! +//! This module provides a minimal representation that is: +//! - independent of any specific protocol (Hachi/Greyhound/SuperNeo, etc.), +//! - easy to sample deterministically from Fiat–Shamir at the protocol layer, +//! - and efficient to evaluate at a point `α` using precomputed powers. + +use super::CyclotomicRing; +use crate::{CanonicalField, FieldCore}; + +/// Configuration for sampling a sparse challenge. +/// +/// This intentionally avoids redundant knobs: the distribution is determined by: +/// - exact `weight` (Hamming weight), +/// - and a list of allowed **non-zero** integer coefficients. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SparseChallengeConfig { + /// Exact Hamming weight ω. + pub weight: usize, + /// Allowed non-zero coefficients (small signed integers). + /// + /// Examples: + /// - `{±1}`: `vec![-1, 1]` + /// - `{±1,±2}`: `vec![-2, -1, 1, 2]` + pub nonzero_coeffs: Vec, +} + +impl SparseChallengeConfig { + /// Validate basic invariants for a given ring degree `D`. + /// + /// # Errors + /// + /// Returns an error if `weight > D`, if `nonzero_coeffs` is empty, or if it + /// contains `0`. + pub fn validate(&self) -> Result<(), &'static str> { + if self.weight > D { + return Err("weight must be <= ring degree D"); + } + if self.nonzero_coeffs.is_empty() { + return Err("nonzero_coeffs must be non-empty"); + } + if self.nonzero_coeffs.contains(&0) { + return Err("nonzero_coeffs must not contain 0"); + } + Ok(()) + } +} + +/// Sparse polynomial in `F[X]/(X^D+1)` represented by its non-zero terms. +/// +/// Invariants: +/// - `positions.len() == coeffs.len()` +/// - all positions are `< D` +/// - positions are unique +/// - all coeffs are non-zero +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SparseChallenge { + /// Coefficient indices (powers of `X`) where the polynomial is non-zero. + pub positions: Vec, + /// Small integer coefficients at the corresponding positions. + pub coeffs: Vec, +} + +impl SparseChallenge { + /// Construct an empty (all-zero) challenge. + #[inline] + pub fn zero() -> Self { + Self { + positions: Vec::new(), + coeffs: Vec::new(), + } + } + + /// Number of non-zero coefficients (Hamming weight). + #[inline] + pub fn hamming_weight(&self) -> usize { + debug_assert_eq!(self.positions.len(), self.coeffs.len()); + self.positions.len() + } + + /// ℓ₁ norm over integers: `Σ |coeff_i|`. + #[inline] + pub fn l1_norm(&self) -> u64 { + self.coeffs + .iter() + .map(|&c| (c as i32).unsigned_abs() as u64) + .sum() + } + + /// Validate structural invariants for a ring degree `D`. + /// + /// # Errors + /// + /// Returns an error if lengths mismatch, if any coefficient is zero, if any + /// position is out of range, or if positions contain duplicates. + pub fn validate(&self) -> Result<(), &'static str> { + if self.positions.len() != self.coeffs.len() { + return Err("positions and coeffs must have same length"); + } + // Check coeffs are non-zero and positions are in range + unique. + let mut seen = vec![false; D]; + for (&pos, &c) in self.positions.iter().zip(self.coeffs.iter()) { + if c == 0 { + return Err("coeffs must not contain 0"); + } + let p = pos as usize; + if p >= D { + return Err("position out of range"); + } + if seen[p] { + return Err("positions must be unique"); + } + seen[p] = true; + } + Ok(()) + } + + /// Convert to a dense ring element by placing coefficients in the canonical + /// coefficient basis. + /// + /// # Errors + /// + /// Returns an error if the sparse representation violates structural invariants. + pub fn to_dense( + &self, + ) -> Result, &'static str> { + self.validate::()?; + let mut out = [F::zero(); D]; + for (&pos, &c) in self.positions.iter().zip(self.coeffs.iter()) { + out[pos as usize] = out[pos as usize] + F::from_i64(c as i64); + } + Ok(CyclotomicRing::from_coefficients(out)) + } + + /// Evaluate this sparse polynomial at `α` in `E`, given precomputed powers + /// `[α^0, α^1, ..., α^{D-1}]`. + /// + /// This is `O(weight)` and is intended to be used for verifier-side oracles + /// where `D` may be large but `weight` is small. + /// + /// # Errors + /// + /// Returns an error if structural invariants fail or if `alpha_pows.len() != D`. + pub fn eval_at_alpha(&self, alpha_pows: &[E]) -> Result + where + F: FieldCore + CanonicalField, + E: FieldCore + crate::algebra::fields::LiftBase, + { + self.validate::()?; + if alpha_pows.len() != D { + return Err("alpha_pows length mismatch"); + } + let mut acc = E::zero(); + for (&pos, &c) in self.positions.iter().zip(self.coeffs.iter()) { + let coeff_f = F::from_i64(c as i64); + acc = acc + (E::lift_base(coeff_f) * alpha_pows[pos as usize]); + } + Ok(acc) + } +} diff --git a/src/lib.rs b/src/lib.rs index e3859c54..4d6635ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,7 @@ +#![cfg_attr( + all(target_arch = "x86_64", target_feature = "avx512f"), + feature(stdarch_x86_avx512) +)] //! # hachi //! //! A high performance and modular implementation of the Hachi polynomial commitment scheme. @@ -35,7 +39,25 @@ pub mod error; /// Primitive traits and operations pub mod primitives; +/// Concrete algebra backends (prime fields, extensions, rings) +pub mod algebra; + +/// Protocol-layer transcript and commitment abstractions +pub mod protocol; + +/// Conditional parallelism utilities (`cfg_iter!`, `cfg_into_iter!`, etc.) +#[doc(hidden)] +#[allow(missing_docs)] +pub mod parallel; + +#[doc(hidden)] +#[allow(missing_docs)] +pub mod test_utils; + pub use error::HachiError; -pub use primitives::arithmetic::{Field, HachiRoutines, Module}; +pub use primitives::arithmetic::{ + CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible, Module, PseudoMersenneField, +}; pub use primitives::poly::{MultilinearLagrange, Polynomial}; pub use primitives::serialization::{HachiDeserialize, HachiSerialize}; +pub use protocol::{CommitmentScheme, StreamingCommitmentScheme, Transcript}; diff --git a/src/parallel.rs b/src/parallel.rs new file mode 100644 index 00000000..d73c9381 --- /dev/null +++ b/src/parallel.rs @@ -0,0 +1,56 @@ +//! Conditional parallelism utilities. +//! +//! When the `parallel` feature is enabled, the `cfg_iter!` family of macros +//! expand to rayon's parallel iterators. Otherwise they fall back to standard +//! sequential iterators. + +#[cfg(feature = "parallel")] +pub use rayon::prelude::*; + +/// Returns `.par_iter()` when `parallel` is enabled, `.iter()` otherwise. +#[macro_export] +macro_rules! cfg_iter { + ($e:expr) => {{ + #[cfg(feature = "parallel")] + let it = $e.par_iter(); + #[cfg(not(feature = "parallel"))] + let it = $e.iter(); + it + }}; +} + +/// Returns `.par_iter_mut()` when `parallel` is enabled, `.iter_mut()` otherwise. +#[macro_export] +macro_rules! cfg_iter_mut { + ($e:expr) => {{ + #[cfg(feature = "parallel")] + let it = $e.par_iter_mut(); + #[cfg(not(feature = "parallel"))] + let it = $e.iter_mut(); + it + }}; +} + +/// Returns `.into_par_iter()` when `parallel` is enabled, `.into_iter()` otherwise. +#[macro_export] +macro_rules! cfg_into_iter { + ($e:expr) => {{ + #[cfg(feature = "parallel")] + let it = $e.into_par_iter(); + #[cfg(not(feature = "parallel"))] + let it = $e.into_iter(); + it + }}; +} + +/// Returns `.par_chunks(n)` when `parallel` is enabled, `.chunks(n)` otherwise. +#[macro_export] +macro_rules! cfg_chunks { + ($e:expr, $n:expr) => {{ + #[cfg(feature = "parallel")] + let it = $e.par_chunks($n); + #[cfg(not(feature = "parallel"))] + let it = $e.chunks($n); + it + }}; +} diff --git a/src/primitives/arithmetic.rs b/src/primitives/arithmetic.rs index ccaaa4ae..f6f1d91a 100644 --- a/src/primitives/arithmetic.rs +++ b/src/primitives/arithmetic.rs @@ -3,8 +3,8 @@ use super::{HachiDeserialize, HachiSerialize}; use rand_core::RngCore; -/// Field trait for lattice-based arithmetic -pub trait Field: +/// Core field operations required across algebra backends. +pub trait FieldCore: Sized + Clone + Copy @@ -30,28 +30,109 @@ pub trait Field: /// Check if element is zero fn is_zero(&self) -> bool; - /// Field addition - fn add(&self, rhs: &Self) -> Self; - - /// Field subtraction - fn sub(&self, rhs: &Self) -> Self; - - /// Field multiplication - fn mul(&self, rhs: &Self) -> Self; - - /// Field inversion + /// Field squaring. + /// + /// Default is `self * self`; extension fields override with specialized + /// formulas that use fewer base-field multiplications. + fn square(&self) -> Self { + *self * *self + } + + /// Field inversion. + /// + /// This API may branch on zero-check and is intended for public/non-secret + /// values. For secret-bearing paths, use [`Invertible::inv_or_zero`]. fn inv(self) -> Option; +} - /// Generate random field element - fn random(rng: &mut R) -> Self; +/// Constant-time inversion helper for secret-bearing code paths. +/// +/// Implementations return `0` when the input is `0`, and `x^{-1}` otherwise, +/// without branching on the input value. +pub trait Invertible: FieldCore { + /// Constant-time inversion with zero-mapping behavior. + fn inv_or_zero(self) -> Self; +} - /// Convert from u64 +/// Embed small integers into a field. +/// +/// Every field contains a copy of its prime subfield, and small integers embed +/// into it canonically via reduction modulo the characteristic. This trait is +/// implementable for ALL fields — base and extension alike. +/// +/// Only `from_u64` and `from_i64` need concrete implementations; the narrower +/// widths have default impls via lossless widening. +pub trait FromSmallInt: FieldCore { + /// Embed a `u8` into the field. + fn from_u8(val: u8) -> Self { + Self::from_u64(val as u64) + } + + /// Embed an `i8` into the field. + fn from_i8(val: i8) -> Self { + Self::from_i64(val as i64) + } + + /// Embed a `u16` into the field. + fn from_u16(val: u16) -> Self { + Self::from_u64(val as u64) + } + + /// Embed an `i16` into the field. + fn from_i16(val: i16) -> Self { + Self::from_i64(val as i64) + } + + /// Embed a `u32` into the field. + fn from_u32(val: u32) -> Self { + Self::from_u64(val as u64) + } + + /// Embed an `i32` into the field. + fn from_i32(val: i32) -> Self { + Self::from_i64(val as i64) + } + + /// Embed a `u64` into the field (reduce mod characteristic). fn from_u64(val: u64) -> Self; - /// Convert from i64 + /// Embed an `i64` into the field (reduce mod characteristic). fn from_i64(val: i64) -> Self; } +/// Canonical integer representation for prime (base) field elements. +/// +/// Provides a bijection between field elements and integers in `[0, p)`. +/// Only meaningful for base prime fields where elements ARE residues mod p. +/// Extension fields should NOT implement this trait. +pub trait CanonicalField: FromSmallInt { + /// Return canonical integer representation as `u128`. + fn to_canonical_u128(self) -> u128; + + /// Construct from canonical value if it is in range. + fn from_canonical_u128_checked(val: u128) -> Option; + + /// Construct from canonical value reduced modulo the field modulus. + fn from_canonical_u128_reduced(val: u128) -> Self; +} + +/// Optional sampling support for field elements. +/// +/// This is intentionally separate from core field algebra and may evolve. +pub trait FieldSampling: FieldCore { + /// Generate a sampled field element. + fn sample(rng: &mut R) -> Self; +} + +/// Metadata for pseudo-Mersenne style moduli (`2^k - c`). +pub trait PseudoMersenneField: CanonicalField { + /// Exponent `k` in `2^k - c`. + const MODULUS_BITS: u32; + + /// Offset `c` in `2^k - c`. + const MODULUS_OFFSET: u128; +} + /// Module trait for lattice-based algebraic structures /// /// This trait represents a module over a ring/field, which is fundamental @@ -73,24 +154,18 @@ pub trait Module: + for<'a> std::ops::Sub<&'a Self, Output = Self> { /// Scalar type (field/ring elements) - type Scalar: Field + type Scalar: FieldCore + + CanonicalField + + FieldSampling + std::ops::Mul + for<'a> std::ops::Mul<&'a Self, Output = Self>; /// Zero element fn zero() -> Self; - /// Addition - fn add(&self, rhs: &Self) -> Self; - - /// Negation - fn neg(&self) -> Self; - /// Scalar multiplication fn scale(&self, k: &Self::Scalar) -> Self; /// Generate random module element fn random(rng: &mut R) -> Self; } - -pub trait HachiRoutines {} diff --git a/src/primitives/mod.rs b/src/primitives/mod.rs index 9f53c2f1..21376422 100644 --- a/src/primitives/mod.rs +++ b/src/primitives/mod.rs @@ -2,6 +2,7 @@ //! This submodule defines the basic lattice arithmetic and cryptographic tools that Hachi is built upon pub mod arithmetic; +pub mod multilinear_evals; pub mod poly; pub mod serialization; pub mod transcript; diff --git a/src/primitives/multilinear_evals.rs b/src/primitives/multilinear_evals.rs new file mode 100644 index 00000000..2ba3b95a --- /dev/null +++ b/src/primitives/multilinear_evals.rs @@ -0,0 +1,155 @@ +//! Dense multilinear polynomials in evaluation form. +//! +//! This module intentionally follows the same high-level representation style as +//! Jolt's `DensePolynomial` for multilinear extensions (MLEs): store the values +//! of the multilinear polynomial on the Boolean hypercube `{0,1}^n` and provide +//! binding/evaluation by iterative folding. +//! +//! The key convention for this repo (used by the ring-switch witness table) is: +//! +//! - An evaluation index `idx` is interpreted in binary. +//! - The **lowest** index bit is the **first** variable bound under +//! [`BindingOrder::LowToHigh`]. +//! +//! This matches the row-major flattening `idx = row * d + col` when `d` is a +//! power of two: the low `log2(d)` bits correspond to the `col` coordinate. + +use crate::primitives::arithmetic::FieldCore; +use crate::primitives::poly::Polynomial; + +/// The order in which variables are bound/evaluated. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum BindingOrder { + /// Bind the lowest index bit first (LSB → MSB). + #[default] + LowToHigh, + /// Bind the highest index bit first (MSB → LSB). + HighToLow, +} + +/// Dense multilinear polynomial in evaluation form over `{0,1}^num_vars`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DenseMultilinearEvals { + /// Number of variables in the multilinear extension. + pub num_vars: usize, + /// Active length (decreases as variables are bound). + pub len: usize, + /// Evaluations on the hypercube, length is a power of two. + pub evals: Vec, +} + +impl Default for DenseMultilinearEvals { + fn default() -> Self { + Self { + num_vars: 0, + len: 1, + evals: vec![F::zero()], + } + } +} + +impl DenseMultilinearEvals { + /// Construct from evaluations, padding with zeros to a power of two. + /// + /// The variable count is derived from the padded length. + pub fn new_padded(mut evals: Vec) -> Self { + if evals.is_empty() { + evals.push(F::zero()); + } + while !evals.len().is_power_of_two() { + evals.push(F::zero()); + } + let num_vars = evals.len().trailing_zeros() as usize; + let len = evals.len(); + Self { + num_vars, + len, + evals, + } + } + + /// Return the original (backing) evaluation length. + pub fn original_len(&self) -> usize { + self.evals.len() + } + + /// Bind one variable in-place, reducing `len` by a factor of 2. + /// + /// After binding, the polynomial has one fewer variable. + /// + /// # Panics + /// + /// Panics if the current backing length is not a power of two, or if attempting + /// to bind a constant (length-1) polynomial. + pub fn bind_in_place(&mut self, r: F, order: BindingOrder) { + assert!(self.len.is_power_of_two()); + assert!(self.len >= 2, "cannot bind variable of constant polynomial"); + match order { + BindingOrder::LowToHigh => self.bind_lsb_in_place(r), + BindingOrder::HighToLow => self.bind_msb_in_place(r), + } + } + + #[inline] + fn bind_lsb_in_place(&mut self, r: F) { + let next_len = self.len / 2; + for i in 0..next_len { + let v0 = self.evals[i << 1]; + let v1 = self.evals[(i << 1) | 1]; + // (1-r)*v0 + r*v1 = v0 + r*(v1-v0) + self.evals[i] = v0 + r * (v1 - v0); + } + self.len = next_len; + self.num_vars = self.num_vars.saturating_sub(1); + } + + #[inline] + fn bind_msb_in_place(&mut self, r: F) { + let next_len = self.len / 2; + let (left, right) = self.evals.split_at_mut(next_len); + for i in 0..next_len { + let v0 = left[i]; + let v1 = right[i]; + left[i] = v0 + r * (v1 - v0); + } + self.len = next_len; + self.num_vars = self.num_vars.saturating_sub(1); + } + + /// Evaluate without mutating `self`. + /// + /// # Panics + /// + /// Panics if `point.len() != self.num_vars`. + pub fn evaluate_with_order(&self, point: &[F], order: BindingOrder) -> F { + if point.is_empty() { + return self.evals[0]; + } + assert_eq!( + point.len(), + self.num_vars, + "point dimension mismatch: expected {}, got {}", + self.num_vars, + point.len() + ); + let mut tmp = self.clone(); + for r in point.iter().copied() { + tmp.bind_in_place(r, order); + } + tmp.evals[0] + } +} + +impl Polynomial for DenseMultilinearEvals { + fn num_vars(&self) -> usize { + self.num_vars + } + + fn evaluate(&self, point: &[F]) -> F { + self.evaluate_with_order(point, BindingOrder::LowToHigh) + } + + fn coeffs(&self) -> Vec { + self.evals[..self.len].to_vec() + } +} diff --git a/src/primitives/poly.rs b/src/primitives/poly.rs index 1e9a8b63..16f034cd 100644 --- a/src/primitives/poly.rs +++ b/src/primitives/poly.rs @@ -1,9 +1,9 @@ -//! Polynomial trait for multilinear polynomials +//! Polynomial trait for multilinear polynomials. -use super::arithmetic::Field; +use super::arithmetic::FieldCore; /// Trait for multilinear Lagrange polynomial operations -pub trait MultilinearLagrange: Polynomial { +pub trait MultilinearLagrange: Polynomial { /// Compute multilinear Lagrange basis evaluations at a point /// /// For variables (r₀, r₁, ..., r_{n-1}), computes all 2^n basis polynomial evaluations. @@ -30,7 +30,7 @@ pub trait MultilinearLagrange: Polynomial { /// Trait for multilinear polynomials /// /// Represents a polynomial in evaluation form (coefficients at hypercube points). -pub trait Polynomial { +pub trait Polynomial { /// Number of variables fn num_vars(&self) -> usize; @@ -52,6 +52,9 @@ pub trait Polynomial { /// # Returns /// Polynomial evaluation result fn evaluate(&self, point: &[F]) -> F; + + /// Return the coefficient/evaluation table on `{0,1}^n` in LSB-first order. + fn coeffs(&self) -> Vec; } /// Compute multilinear Lagrange basis evaluations at a point @@ -62,7 +65,7 @@ pub trait Polynomial { /// Uses an iterative doubling approach: /// - Start with [1-r₀, r₀] /// - For each variable rᵢ, split each value v into [v*(1-rᵢ), v*rᵢ] -pub(crate) fn multilinear_lagrange_basis(output: &mut [F], point: &[F]) { +pub(crate) fn multilinear_lagrange_basis(output: &mut [F], point: &[F]) { assert!( output.len() <= (1 << point.len()), "Output length must be at most 2^point.len()" @@ -115,7 +118,7 @@ pub(crate) fn multilinear_lagrange_basis(output: &mut [F], point: &[F] /// polynomial_evaluation(point) = L^T × M × R /// /// Splits variables between rows and columns based on sigma and nu. -pub fn compute_left_right_vectors( +pub fn compute_left_right_vectors( point: &[F], nu: usize, sigma: usize, diff --git a/src/primitives/serialization.rs b/src/primitives/serialization.rs index f4a0bf1e..ea14bdc3 100644 --- a/src/primitives/serialization.rs +++ b/src/primitives/serialization.rs @@ -182,11 +182,43 @@ mod primitive_impls { impl_primitive_serialization!(u16, 2); impl_primitive_serialization!(u32, 4); impl_primitive_serialization!(u64, 8); - impl_primitive_serialization!(usize, std::mem::size_of::()); + impl_primitive_serialization!(u128, 16); impl_primitive_serialization!(i8, 1); impl_primitive_serialization!(i16, 2); impl_primitive_serialization!(i32, 4); impl_primitive_serialization!(i64, 8); + impl_primitive_serialization!(i128, 16); + + impl Valid for usize { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } + } + + impl HachiSerialize for usize { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (*self as u64).serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 8 + } + } + + impl HachiDeserialize for usize { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let val = u64::deserialize_with_mode(reader, compress, validate)?; + Ok(val as usize) + } + } impl Valid for bool { fn check(&self) -> Result<(), SerializationError> { diff --git a/src/primitives/transcript.rs b/src/primitives/transcript.rs index 89ebc29a..ce2e8f53 100644 --- a/src/primitives/transcript.rs +++ b/src/primitives/transcript.rs @@ -2,13 +2,13 @@ #![allow(missing_docs)] -use crate::primitives::arithmetic::{Field, Module}; +use crate::primitives::arithmetic::{CanonicalField, FieldCore, Module}; use crate::primitives::HachiSerialize; /// Transcript for Fiat-Shamir transformations pub trait Transcript { /// Field type for challenges - type Field: Field; + type Field: FieldCore + CanonicalField; /// Append raw bytes to the transcript fn append_bytes(&mut self, label: &[u8], bytes: &[u8]); diff --git a/src/protocol/challenges/mod.rs b/src/protocol/challenges/mod.rs new file mode 100644 index 00000000..524941ed --- /dev/null +++ b/src/protocol/challenges/mod.rs @@ -0,0 +1,6 @@ +//! Protocol-level Fiat–Shamir challenge samplers. +//! +//! These utilities derive structured challenges (e.g. sparse ring elements) from +//! the transcript while keeping the low-level representations in the algebra layer. + +pub mod sparse; diff --git a/src/protocol/challenges/sparse.rs b/src/protocol/challenges/sparse.rs new file mode 100644 index 00000000..aca54aaf --- /dev/null +++ b/src/protocol/challenges/sparse.rs @@ -0,0 +1,121 @@ +//! Sparse challenge sampling via Fiat–Shamir. + +use crate::algebra::ring::{CyclotomicRing, SparseChallenge, SparseChallengeConfig}; +use crate::error::HachiError; +use crate::protocol::transcript::labels::{ABSORB_SPARSE_CHALLENGE, CHALLENGE_SPARSE_CHALLENGE}; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; + +/// Sample a sparse ring challenge (exact weight ω) from a transcript. +/// +/// This is intentionally deterministic and label-aware: +/// - first we absorb the sampling context under `ABSORB_SPARSE_CHALLENGE`, +/// - then we derive as many `CHALLENGE_SPARSE_CHALLENGE` scalars as needed. +/// +/// Notes: +/// - Indices are sampled with a simple `mod D` reduction. For the intended +/// regimes (small `D`, cryptographic transcript), any bias is negligible. +/// - Duplicate indices are rejected to enforce exact Hamming weight. +/// +/// # Errors +/// +/// Returns an error if the provided config is invalid for degree `D`. +pub fn sparse_challenge_from_transcript( + transcript: &mut T, + label: &[u8], + instance_idx: u64, + cfg: &SparseChallengeConfig, +) -> Result +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + cfg.validate::() + .map_err(|e| HachiError::InvalidInput(format!("invalid sparse challenge config: {e}")))?; + + // Absorb domain-separating context so different call sites can't collide. + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, label); + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, &instance_idx.to_le_bytes()); + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, &(D as u64).to_le_bytes()); + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, &(cfg.weight as u64).to_le_bytes()); + // Include the coefficient alphabet (as little-endian i16 stream). + let mut coeff_bytes = Vec::with_capacity(cfg.nonzero_coeffs.len() * 2); + for &c in cfg.nonzero_coeffs.iter() { + coeff_bytes.extend_from_slice(&c.to_le_bytes()); + } + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, &coeff_bytes); + + let mut seen = vec![false; D]; + let mut positions = Vec::with_capacity(cfg.weight); + let mut coeffs = Vec::with_capacity(cfg.weight); + + while positions.len() < cfg.weight { + let r = transcript + .challenge_scalar(CHALLENGE_SPARSE_CHALLENGE) + .to_canonical_u128(); + let lo = r as u64; + let hi = (r >> 64) as u64; + + let pos = (lo % (D as u64)) as usize; + if seen[pos] { + continue; + } + seen[pos] = true; + positions.push(pos as u32); + + let coeff_idx = (hi % (cfg.nonzero_coeffs.len() as u64)) as usize; + let c = cfg.nonzero_coeffs[coeff_idx]; + debug_assert_ne!(c, 0); + coeffs.push(c); + } + + Ok(SparseChallenge { positions, coeffs }) +} + +/// Sample `n` sparse challenges from a transcript, returning the sparse +/// representation directly. +/// +/// # Errors +/// +/// Returns an error if challenge sampling fails. +pub fn sample_sparse_challenges( + transcript: &mut T, + label: &[u8], + n: usize, + cfg: &SparseChallengeConfig, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + (0..n) + .map(|i| sparse_challenge_from_transcript::(transcript, label, i as u64, cfg)) + .collect() +} + +/// Sample `n` sparse challenges from a transcript and convert them to dense +/// `CyclotomicRing` elements. +/// +/// # Errors +/// +/// Returns an error if challenge sampling or dense conversion fails. +pub fn sample_dense_challenges( + transcript: &mut T, + label: &[u8], + n: usize, + cfg: &SparseChallengeConfig, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + (0..n) + .map(|i| { + let sparse = + sparse_challenge_from_transcript::(transcript, label, i as u64, cfg)?; + sparse + .to_dense::() + .map_err(|e| HachiError::InvalidInput(e.to_string())) + }) + .collect() +} diff --git a/src/protocol/commitment/commit.rs b/src/protocol/commitment/commit.rs new file mode 100644 index 00000000..fbbf80d0 --- /dev/null +++ b/src/protocol/commitment/commit.rs @@ -0,0 +1,693 @@ +//! Ring-native §4.1 commitment core implementation. + +use super::config::{ + ensure_block_layout, ensure_matrix_shape, ensure_supported_num_vars, + validate_and_derive_layout, HachiCommitmentLayout, +}; +use super::onehot::{inner_ajtai_onehot_t_only, map_onehot_to_sparse_blocks, SparseBlockEntry}; +use super::scheme::{CommitWitness, RingCommitmentScheme}; +use super::types::RingCommitment; +use super::utils::crt_ntt::{build_ntt_cache, NttMatrixCache}; +use super::utils::linear::{decompose_block, decompose_rows, mat_vec_mul_ntt_cached, MatrixSlot}; +use super::utils::matrix::{derive_public_matrix, sample_public_matrix_seed, PublicMatrixSeed}; +use super::CommitmentConfig; +use crate::algebra::CyclotomicRing; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{cfg_into_iter, cfg_iter, CanonicalField, FieldCore, FieldSampling}; +use std::io::{Read, Write}; + +/// Seed-only stage for deterministic setup expansion. +#[allow(non_snake_case)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiSetupSeed { + /// Maximum supported variable count. + pub max_num_vars: usize, + /// Runtime commitment layout. + pub layout: HachiCommitmentLayout, + /// Public seed used to derive commitment matrices. + pub public_matrix_seed: PublicMatrixSeed, +} + +/// Expanded setup stage containing coefficient-form matrices. +#[allow(non_snake_case)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiExpandedSetup { + /// Setup seed and runtime layout metadata. + pub seed: HachiSetupSeed, + /// Inner matrix `A`. + pub A: Vec>>, + /// Outer matrix `B`. + pub B: Vec>>, + /// Prover matrix `D ∈ R_q^{n_D × δ·2^R}` (§4.2). + pub D: Vec>>, +} + +/// Optional prepared setup stage for accelerated matrix-vector products. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiPreparedSetup { + /// Pre-converted CRT+NTT matrices for dense mat-vec paths. + pub(crate) ntt_cache: NttMatrixCache, +} + +/// Prover setup artifact (expanded setup + optional runtime cache). +#[allow(non_snake_case)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiProverSetup { + /// Expanded matrix stage used by both prover and verifier. + pub expanded: HachiExpandedSetup, + /// Optional runtime-prepared acceleration cache. + pub prepared: Option>, +} + +/// Verifier setup artifact derived from prover setup. +#[allow(non_snake_case)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiVerifierSetup { + /// Expanded matrix stage used for verification. + pub expanded: HachiExpandedSetup, +} + +impl HachiProverSetup { + /// Runtime layout carried by this setup. + pub fn layout(&self) -> HachiCommitmentLayout { + self.expanded.seed.layout + } + + pub(crate) fn ntt_cache(&self) -> Result<&NttMatrixCache, HachiError> { + self.prepared + .as_ref() + .map(|p| &p.ntt_cache) + .ok_or_else(|| HachiError::InvalidSetup("missing prepared NTT cache".to_string())) + } +} + +impl Valid for HachiSetupSeed { + fn check(&self) -> Result<(), SerializationError> { + self.layout.check() + } +} + +impl HachiSerialize for HachiSetupSeed { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.max_num_vars + .serialize_with_mode(&mut writer, compress)?; + self.layout.serialize_with_mode(&mut writer, compress)?; + writer.write_all(&self.public_matrix_seed)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.max_num_vars.serialized_size(compress) + self.layout.serialized_size(compress) + 32 + } +} + +impl HachiDeserialize for HachiSetupSeed { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let max_num_vars = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let layout = HachiCommitmentLayout::deserialize_with_mode(&mut reader, compress, validate)?; + let mut public_matrix_seed = [0u8; 32]; + reader.read_exact(&mut public_matrix_seed)?; + let out = Self { + max_num_vars, + layout, + public_matrix_seed, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl Valid for HachiExpandedSetup { + fn check(&self) -> Result<(), SerializationError> { + self.seed.check()?; + self.A.check()?; + self.B.check()?; + self.D.check()?; + Ok(()) + } +} + +impl HachiSerialize for HachiExpandedSetup { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.seed.serialize_with_mode(&mut writer, compress)?; + self.A.serialize_with_mode(&mut writer, compress)?; + self.B.serialize_with_mode(&mut writer, compress)?; + self.D.serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.seed.serialized_size(compress) + + self.A.serialized_size(compress) + + self.B.serialized_size(compress) + + self.D.serialized_size(compress) + } +} + +impl HachiDeserialize for HachiExpandedSetup { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let out = Self { + seed: HachiSetupSeed::deserialize_with_mode(&mut reader, compress, validate)?, + A: Vec::>>::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + B: Vec::>>::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + D: Vec::>>::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl Valid for HachiProverSetup { + fn check(&self) -> Result<(), SerializationError> { + self.expanded.check() + } +} + +impl HachiSerialize for HachiProverSetup { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + // Prepared cache is runtime-only and intentionally excluded. + self.expanded.serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.expanded.serialized_size(compress) + } +} + +impl HachiDeserialize for HachiProverSetup { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + Ok(Self { + expanded: HachiExpandedSetup::deserialize_with_mode(reader, compress, validate)?, + prepared: None, + }) + } +} + +impl Valid for HachiVerifierSetup { + fn check(&self) -> Result<(), SerializationError> { + self.expanded.check() + } +} + +impl HachiSerialize for HachiVerifierSetup { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.expanded.serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.expanded.serialized_size(compress) + } +} + +impl HachiDeserialize for HachiVerifierSetup { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + Ok(Self { + expanded: HachiExpandedSetup::deserialize_with_mode(reader, compress, validate)?, + }) + } +} + +/// Concrete §4.1 commitment core. +#[derive(Clone, Copy, Default)] +pub struct HachiCommitmentCore; + +impl RingCommitmentScheme for HachiCommitmentCore +where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, +{ + type ProverSetup = HachiProverSetup; + type VerifierSetup = HachiVerifierSetup; + type Commitment = RingCommitment; + + fn setup(max_num_vars: usize) -> Result<(Self::ProverSetup, Self::VerifierSetup), HachiError> { + let layout = validate_and_derive_layout::(max_num_vars)?; + ensure_supported_num_vars(max_num_vars, layout.required_num_vars::()?)?; + let public_matrix_seed = sample_public_matrix_seed(); + let a_matrix = + derive_public_matrix::(Cfg::N_A, layout.inner_width, &public_matrix_seed, b"A"); + let b_matrix = + derive_public_matrix::(Cfg::N_B, layout.outer_width, &public_matrix_seed, b"B"); + let d_matrix = derive_public_matrix::( + Cfg::N_D, + layout.d_matrix_width, + &public_matrix_seed, + b"D", + ); + + let ntt_cache = build_ntt_cache::(&a_matrix, &b_matrix, &d_matrix)?; + let expanded = HachiExpandedSetup { + seed: HachiSetupSeed { + max_num_vars, + layout, + public_matrix_seed, + }, + A: a_matrix, + B: b_matrix, + D: d_matrix, + }; + let prover_setup = HachiProverSetup { + expanded: expanded.clone(), + prepared: Some(HachiPreparedSetup { ntt_cache }), + }; + let verifier_setup = HachiVerifierSetup { expanded }; + ensure_matrix_shape(&prover_setup.expanded.A, Cfg::N_A, layout.inner_width, "A")?; + ensure_matrix_shape(&prover_setup.expanded.B, Cfg::N_B, layout.outer_width, "B")?; + ensure_matrix_shape( + &prover_setup.expanded.D, + Cfg::N_D, + layout.d_matrix_width, + "D", + )?; + Ok((prover_setup, verifier_setup)) + } + + fn layout(setup: &Self::ProverSetup) -> Result { + Ok(setup.layout()) + } + + fn commit_ring_blocks( + f_blocks: &[Vec>], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let layout = setup.layout(); + ensure_supported_num_vars( + setup.expanded.seed.max_num_vars, + layout.required_num_vars::()?, + )?; + ensure_block_layout(f_blocks, layout)?; + ensure_matrix_shape(&setup.expanded.A, Cfg::N_A, layout.inner_width, "A")?; + ensure_matrix_shape(&setup.expanded.B, Cfg::N_B, layout.outer_width, "B")?; + + let cache = setup.ntt_cache()?; + let t_hat_all: Vec>> = cfg_iter!(f_blocks) + .map(|block| { + let s_i = decompose_block(block, Cfg::DELTA, Cfg::LOG_BASIS); + let t_i = + mat_vec_mul_ntt_cached(cache, MatrixSlot::A, &s_i).expect("inner Ajtai failed"); + decompose_rows(&t_i, Cfg::DELTA, Cfg::LOG_BASIS) + }) + .collect(); + + let t_hat_flat: Vec> = + t_hat_all.iter().flat_map(|v| v.iter().copied()).collect(); + + let u = mat_vec_mul_ntt_cached(cache, MatrixSlot::B, &t_hat_flat)?; + Ok(CommitWitness { + commitment: RingCommitment { u }, + t_hat: t_hat_all, + }) + } + + fn commit_coeffs( + f_coeffs: &[CyclotomicRing], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let layout = setup.layout(); + let num_blocks = layout.num_blocks; + let block_len = layout.block_len; + let max_len = num_blocks + .checked_mul(block_len) + .ok_or_else(|| HachiError::InvalidSetup("coefficient length overflow".to_string()))?; + if f_coeffs.len() > max_len { + return Err(HachiError::InvalidSize { + expected: max_len, + actual: f_coeffs.len(), + }); + } + + let zero_t_hat = + vec![CyclotomicRing::::zero(); Cfg::N_A.checked_mul(Cfg::DELTA).unwrap()]; + let cache = setup.ntt_cache()?; + let coeff_len = f_coeffs.len(); + + let t_hat_all: Vec>> = cfg_into_iter!(0..num_blocks) + .map(|i| { + let start = i * block_len; + if start >= coeff_len { + zero_t_hat.clone() + } else { + let end = (start + block_len).min(coeff_len); + let block = &f_coeffs[start..end]; + let s_i = decompose_block(block, Cfg::DELTA, Cfg::LOG_BASIS); + let t_i = mat_vec_mul_ntt_cached(cache, MatrixSlot::A, &s_i) + .expect("inner Ajtai failed"); + decompose_rows(&t_i, Cfg::DELTA, Cfg::LOG_BASIS) + } + }) + .collect(); + + let t_hat_flat: Vec> = + t_hat_all.iter().flat_map(|v| v.iter().copied()).collect(); + + let u = mat_vec_mul_ntt_cached(cache, MatrixSlot::B, &t_hat_flat)?; + Ok(CommitWitness { + commitment: RingCommitment { u }, + t_hat: t_hat_all, + }) + } + + fn commit_onehot( + onehot_k: usize, + indices: &[Option], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let layout = setup.layout(); + ensure_supported_num_vars( + setup.expanded.seed.max_num_vars, + layout.required_num_vars::()?, + )?; + ensure_matrix_shape(&setup.expanded.A, Cfg::N_A, layout.inner_width, "A")?; + ensure_matrix_shape(&setup.expanded.B, Cfg::N_B, layout.outer_width, "B")?; + + let sparse_blocks = + map_onehot_to_sparse_blocks(onehot_k, indices, layout.r_vars, layout.m_vars, D)?; + + let zero_t_hat = + vec![CyclotomicRing::::zero(); Cfg::N_A.checked_mul(Cfg::DELTA).unwrap()]; + let cache = setup.ntt_cache()?; + let a_matrix = &setup.expanded.A; + let block_len = layout.block_len; + + let t_hat_all: Vec>> = cfg_iter!(sparse_blocks) + .map(|block_entries| { + if block_entries.is_empty() { + zero_t_hat.clone() + } else { + let t_i = + inner_ajtai_onehot_t_only(a_matrix, block_entries, block_len, Cfg::DELTA); + decompose_rows(&t_i, Cfg::DELTA, Cfg::LOG_BASIS) + } + }) + .collect(); + + let t_hat_flat: Vec> = + t_hat_all.iter().flat_map(|v| v.iter().copied()).collect(); + + let u = mat_vec_mul_ntt_cached(cache, MatrixSlot::B, &t_hat_flat)?; + Ok(CommitWitness { + commitment: RingCommitment { u }, + t_hat: t_hat_all, + }) + } +} + +impl HachiCommitmentCore { + /// Create a setup with a caller-specified layout, bypassing + /// `CommitmentConfig::commitment_layout`. + /// + /// Use this when the desired `(m_vars, r_vars)` split differs from what + /// the config's heuristic would choose (e.g. mega-polynomial commitments + /// where each sub-polynomial occupies one block). + /// + /// # Errors + /// + /// Returns `HachiError` on invalid layout or matrix generation failures. + #[allow(non_snake_case)] + pub fn setup_with_layout( + layout: HachiCommitmentLayout, + ) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + let max_num_vars = layout.required_num_vars::()?; + let public_matrix_seed = sample_public_matrix_seed(); + Self::setup_with_layout_and_seed::(layout, max_num_vars, public_matrix_seed) + } + + /// Like `setup_with_layout` but reuses an existing setup's random seed and + /// A matrix (which depends only on `m_vars`). Only regenerates B and D + /// matrices for the new `r_vars`. + /// + /// Use this when creating a mega-polynomial setup that shares `m_vars` with + /// an individual polynomial setup — avoids re-deriving and NTT-transforming + /// the A matrix. + /// + /// # Errors + /// + /// Returns `HachiError` if the new layout is incompatible with the existing + /// setup or matrix shapes are inconsistent. + #[allow(non_snake_case)] + pub fn setup_from_existing( + existing: &HachiExpandedSetup, + new_r_vars: usize, + ) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + let old_layout = existing.seed.layout; + let new_layout = HachiCommitmentLayout::new::(old_layout.m_vars, new_r_vars)?; + + if new_layout.inner_width != old_layout.inner_width { + return Err(HachiError::InvalidSetup( + "setup_from_existing requires matching m_vars/inner_width".to_string(), + )); + } + + let max_num_vars = new_layout.required_num_vars::()?; + let seed = existing.seed.public_matrix_seed; + + let b_matrix = derive_public_matrix::(Cfg::N_B, new_layout.outer_width, &seed, b"B"); + let d_matrix = + derive_public_matrix::(Cfg::N_D, new_layout.d_matrix_width, &seed, b"D"); + + let ntt_cache = build_ntt_cache::(&existing.A, &b_matrix, &d_matrix)?; + let expanded = HachiExpandedSetup { + seed: HachiSetupSeed { + max_num_vars, + layout: new_layout, + public_matrix_seed: seed, + }, + A: existing.A.clone(), + B: b_matrix, + D: d_matrix, + }; + let prover_setup = HachiProverSetup { + expanded: expanded.clone(), + prepared: Some(HachiPreparedSetup { ntt_cache }), + }; + let verifier_setup = HachiVerifierSetup { expanded }; + Ok((prover_setup, verifier_setup)) + } + + #[allow(non_snake_case)] + fn setup_with_layout_and_seed( + layout: HachiCommitmentLayout, + max_num_vars: usize, + public_matrix_seed: PublicMatrixSeed, + ) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + let a_matrix = + derive_public_matrix::(Cfg::N_A, layout.inner_width, &public_matrix_seed, b"A"); + let b_matrix = + derive_public_matrix::(Cfg::N_B, layout.outer_width, &public_matrix_seed, b"B"); + let d_matrix = derive_public_matrix::( + Cfg::N_D, + layout.d_matrix_width, + &public_matrix_seed, + b"D", + ); + + let ntt_cache = build_ntt_cache::(&a_matrix, &b_matrix, &d_matrix)?; + let expanded = HachiExpandedSetup { + seed: HachiSetupSeed { + max_num_vars, + layout, + public_matrix_seed, + }, + A: a_matrix, + B: b_matrix, + D: d_matrix, + }; + let prover_setup = HachiProverSetup { + expanded: expanded.clone(), + prepared: Some(HachiPreparedSetup { ntt_cache }), + }; + let verifier_setup = HachiVerifierSetup { expanded }; + ensure_matrix_shape(&prover_setup.expanded.A, Cfg::N_A, layout.inner_width, "A")?; + ensure_matrix_shape(&prover_setup.expanded.B, Cfg::N_B, layout.outer_width, "B")?; + ensure_matrix_shape( + &prover_setup.expanded.D, + Cfg::N_D, + layout.d_matrix_width, + "D", + )?; + Ok((prover_setup, verifier_setup)) + } +} + +/// Describes one block of a mega-polynomial commitment. +/// +/// A mega-polynomial packs multiple heterogeneous polynomials into a single +/// Hachi commitment by assigning each polynomial to its own block. Blocks +/// can be dense (arbitrary ring coefficients), sparse one-hot, or zero. +pub enum MegaPolyBlock<'a, F: FieldCore, const D: usize> { + /// Dense block: full ring coefficients (length ≤ block_len). + Dense(&'a [CyclotomicRing]), + /// One-hot block: sparse entries within this block. + OneHot(&'a [SparseBlockEntry]), + /// Empty block: all coefficients are zero (no allocation or computation). + Zero, +} + +impl HachiCommitmentCore { + /// Commit a mega-polynomial composed of heterogeneous blocks. + /// + /// Each block occupies `block_len` ring elements. Dense blocks are + /// decomposed via `balanced_decompose_pow2`; one-hot blocks use sparse + /// inner Ajtai; zero blocks are free. + /// + /// The number of blocks must equal `layout.num_blocks` (power of 2). + /// + /// # Errors + /// + /// Returns `HachiError` if the number of blocks doesn't match the layout + /// or if matrix shapes are inconsistent. + /// + /// # Panics + /// + /// Panics if `Cfg::N_A * Cfg::DELTA` overflows. + #[allow(non_snake_case)] + pub fn commit_mixed( + blocks: &[MegaPolyBlock<'_, F, D>], + setup: &HachiProverSetup, + ) -> Result, F, D>, HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + let layout = setup.layout(); + if blocks.len() != layout.num_blocks { + return Err(HachiError::InvalidSize { + expected: layout.num_blocks, + actual: blocks.len(), + }); + } + ensure_matrix_shape(&setup.expanded.A, Cfg::N_A, layout.inner_width, "A")?; + ensure_matrix_shape(&setup.expanded.B, Cfg::N_B, layout.outer_width, "B")?; + + let zero_t_hat = + vec![CyclotomicRing::::zero(); Cfg::N_A.checked_mul(Cfg::DELTA).unwrap()]; + let cache = setup.ntt_cache()?; + let a_matrix = &setup.expanded.A; + let block_len = layout.block_len; + + let t_hat_all: Vec>> = cfg_iter!(blocks) + .map(|block| match block { + MegaPolyBlock::Zero => zero_t_hat.clone(), + MegaPolyBlock::Dense(coeffs) => { + let s_i = decompose_block(coeffs, Cfg::DELTA, Cfg::LOG_BASIS); + let t_i = mat_vec_mul_ntt_cached(cache, MatrixSlot::A, &s_i) + .expect("inner Ajtai failed"); + decompose_rows(&t_i, Cfg::DELTA, Cfg::LOG_BASIS) + } + MegaPolyBlock::OneHot(sparse_entries) => { + if sparse_entries.is_empty() { + zero_t_hat.clone() + } else { + let t_i = inner_ajtai_onehot_t_only( + a_matrix, + sparse_entries, + block_len, + Cfg::DELTA, + ); + decompose_rows(&t_i, Cfg::DELTA, Cfg::LOG_BASIS) + } + } + }) + .collect(); + + let t_hat_flat: Vec> = + t_hat_all.iter().flat_map(|v| v.iter().copied()).collect(); + + let u = mat_vec_mul_ntt_cached(cache, MatrixSlot::B, &t_hat_flat)?; + Ok(CommitWitness { + commitment: RingCommitment { u }, + t_hat: t_hat_all, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::primitives::{HachiDeserialize, HachiSerialize}; + use crate::test_utils::{TinyConfig, D as TestD, F as TestF}; + + #[test] + fn prover_setup_roundtrips_and_derives_same_verifier() { + let (prover_setup, verifier_setup) = + >::setup(16) + .unwrap(); + + let mut bytes = Vec::new(); + prover_setup.serialize_compressed(&mut bytes).unwrap(); + let decoded = HachiProverSetup::::deserialize_compressed(&bytes[..]).unwrap(); + + assert_eq!(decoded.expanded, prover_setup.expanded); + assert_eq!(decoded.prepared, None); + + let derived_verifier = HachiVerifierSetup { + expanded: decoded.expanded.clone(), + }; + assert_eq!(derived_verifier, verifier_setup); + } +} diff --git a/src/protocol/commitment/config.rs b/src/protocol/commitment/config.rs new file mode 100644 index 00000000..3eeba2eb --- /dev/null +++ b/src/protocol/commitment/config.rs @@ -0,0 +1,412 @@ +//! Configuration presets for ring-native commitment construction. + +use super::utils::math::checked_pow2; +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::FieldCore; +use std::io::{Read, Write}; + +/// Runtime commitment layout authority for ring-native commitments. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct HachiCommitmentLayout { + /// Number of variables inside each committed block (`2^m_vars` entries). + pub m_vars: usize, + /// Number of block-select variables (`2^r_vars` blocks). + pub r_vars: usize, + /// Number of committed blocks (`2^r_vars`). + pub num_blocks: usize, + /// Number of ring elements per block (`2^m_vars`). + pub block_len: usize, + /// Width of inner matrix `A`. + pub inner_width: usize, + /// Width of outer matrix `B`. + pub outer_width: usize, + /// Width of prover matrix `D` (`delta * 2^r_vars`). + pub d_matrix_width: usize, +} + +impl HachiCommitmentLayout { + /// Build a layout from `(m_vars, r_vars)` and static config constants. + /// + /// # Errors + /// + /// Returns an error when powers or derived widths overflow. + pub fn new(m_vars: usize, r_vars: usize) -> Result { + let num_blocks = checked_pow2(r_vars)?; + let block_len = checked_pow2(m_vars)?; + let inner_width = block_len + .checked_mul(Cfg::DELTA) + .ok_or_else(|| HachiError::InvalidSetup("inner width overflow".to_string()))?; + let outer_width = Cfg::N_A + .checked_mul(Cfg::DELTA) + .and_then(|x| x.checked_mul(num_blocks)) + .ok_or_else(|| HachiError::InvalidSetup("outer width overflow".to_string()))?; + let d_matrix_width = Cfg::DELTA + .checked_mul(num_blocks) + .ok_or_else(|| HachiError::InvalidSetup("D-matrix width overflow".to_string()))?; + Ok(Self { + m_vars, + r_vars, + num_blocks, + block_len, + inner_width, + outer_width, + d_matrix_width, + }) + } + + /// Total number of outer variables consumed by ring coefficients. + /// + /// # Errors + /// + /// Returns an error if the variable count overflows. + pub fn outer_vars(&self) -> Result { + self.m_vars + .checked_add(self.r_vars) + .ok_or_else(|| HachiError::InvalidSetup("variable count overflow".to_string())) + } + + /// Required polynomial variable count for this layout (`outer + alpha`). + /// + /// # Errors + /// + /// Returns an error if the variable count overflows. + pub fn required_num_vars(&self) -> Result { + let alpha = D.trailing_zeros() as usize; + self.outer_vars()? + .checked_add(alpha) + .ok_or_else(|| HachiError::InvalidSetup("variable count overflow".to_string())) + } +} + +impl Valid for HachiCommitmentLayout { + fn check(&self) -> Result<(), SerializationError> { + if self.num_blocks == 0 || self.block_len == 0 { + return Err(SerializationError::InvalidData( + "invalid zero block layout".to_string(), + )); + } + Ok(()) + } +} + +impl HachiSerialize for HachiCommitmentLayout { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.m_vars.serialize_with_mode(&mut writer, compress)?; + self.r_vars.serialize_with_mode(&mut writer, compress)?; + self.num_blocks.serialize_with_mode(&mut writer, compress)?; + self.block_len.serialize_with_mode(&mut writer, compress)?; + self.inner_width + .serialize_with_mode(&mut writer, compress)?; + self.outer_width + .serialize_with_mode(&mut writer, compress)?; + self.d_matrix_width + .serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.m_vars.serialized_size(compress) + + self.r_vars.serialized_size(compress) + + self.num_blocks.serialized_size(compress) + + self.block_len.serialized_size(compress) + + self.inner_width.serialized_size(compress) + + self.outer_width.serialized_size(compress) + + self.d_matrix_width.serialized_size(compress) + } +} + +impl HachiDeserialize for HachiCommitmentLayout { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let out = Self { + m_vars: usize::deserialize_with_mode(&mut reader, compress, validate)?, + r_vars: usize::deserialize_with_mode(&mut reader, compress, validate)?, + num_blocks: usize::deserialize_with_mode(&mut reader, compress, validate)?, + block_len: usize::deserialize_with_mode(&mut reader, compress, validate)?, + inner_width: usize::deserialize_with_mode(&mut reader, compress, validate)?, + outer_width: usize::deserialize_with_mode(&mut reader, compress, validate)?, + d_matrix_width: usize::deserialize_with_mode(&mut reader, compress, validate)?, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Parameter bundle for the ring-native commitment core (§4.1–§4.2). +pub trait CommitmentConfig: Clone + Send + Sync + 'static { + /// Ring degree used by `CyclotomicRing`. + const D: usize; + /// Inner Ajtai matrix row count. + const N_A: usize; + /// Outer commitment matrix row count. + const N_B: usize; + /// Prover commitment matrix `D` row count (§4.2). + const N_D: usize; + /// Base-2 logarithm of gadget decomposition base. + const LOG_BASIS: u32; + /// Decomposition levels `delta`. + const DELTA: usize; + /// Decomposition levels for the folded witness `z` (`τ` in the paper). + const TAU: usize; + /// Hamming weight of sparse challenges (`ω` in the paper). + const CHALLENGE_WEIGHT: usize; + + /// Choose the runtime commitment layout for `max_num_vars`. + /// + /// # Errors + /// + /// Returns an error if `max_num_vars` does not admit a valid layout. + fn commitment_layout(max_num_vars: usize) -> Result; + + /// Runtime L∞ bound for `z` (`β`) used by stage-1 folding checks. + /// + /// # Errors + /// + /// Returns an error on invalid parameters or arithmetic overflow. + fn beta_bound(layout: HachiCommitmentLayout) -> Result { + beta_linf_fold_bound(layout.r_vars, Self::CHALLENGE_WEIGHT, Self::LOG_BASIS) + } +} + +/// Deterministic upper bound for the stage-1 folded-witness infinity norm. +/// +/// This encodes the bound used in `QuadraticEquation::compute_z_hat`: +/// `||z||_inf <= 2^R * ω * (b/2)` where `b = 2^LOG_BASIS`. +/// +/// # Errors +/// +/// Returns an error when parameters are out of range or intermediate products +/// overflow `u128`. +pub(crate) fn beta_linf_fold_bound( + r: usize, + challenge_weight: usize, + log_basis: u32, +) -> Result { + if !(1..128).contains(&log_basis) { + return Err(HachiError::InvalidSetup("invalid LOG_BASIS".to_string())); + } + if r >= 128 { + return Err(HachiError::InvalidSetup("r_vars must be < 128".to_string())); + } + + let blocks = 1u128 << r; + let b = 1u128 << log_basis; + let half_b = b / 2; + + let term = blocks + .checked_mul(challenge_weight as u128) + .ok_or_else(|| HachiError::InvalidSetup("beta bound overflow".to_string()))?; + term.checked_mul(half_b) + .ok_or_else(|| HachiError::InvalidSetup("beta bound overflow".to_string())) +} + +/// Validate static config invariants and derive runtime dimensions. +/// +/// # Errors +/// +/// Returns an error when config constants are inconsistent or overflow. +pub(super) fn validate_and_derive_layout( + max_num_vars: usize, +) -> Result { + if D != Cfg::D { + return Err(HachiError::InvalidSetup(format!( + "const D={D} mismatches config D={}", + Cfg::D + ))); + } + if Cfg::LOG_BASIS == 0 || Cfg::LOG_BASIS >= 128 { + return Err(HachiError::InvalidSetup("invalid LOG_BASIS".to_string())); + } + if (Cfg::DELTA as u32).saturating_mul(Cfg::LOG_BASIS) > 128 { + return Err(HachiError::InvalidSetup( + "DELTA * LOG_BASIS must be <= 128".to_string(), + )); + } + Cfg::commitment_layout(max_num_vars) +} + +/// Ensure `max_num_vars` is sufficient for config dimensions. +/// +/// # Errors +/// +/// Returns an error when `max_num_vars < required_vars`. +pub(super) fn ensure_supported_num_vars( + max_num_vars: usize, + required_vars: usize, +) -> Result<(), HachiError> { + if max_num_vars < required_vars { + return Err(HachiError::InvalidSetup(format!( + "max_num_vars {max_num_vars} is smaller than required {required_vars}" + ))); + } + Ok(()) +} + +/// Ensure input blocks match the expected config-derived layout. +/// +/// # Errors +/// +/// Returns an error when block count or per-block size mismatch. +pub(super) fn ensure_block_layout( + f_blocks: &[Vec>], + layout: HachiCommitmentLayout, +) -> Result<(), HachiError> { + if f_blocks.len() != layout.num_blocks { + return Err(HachiError::InvalidSize { + expected: layout.num_blocks, + actual: f_blocks.len(), + }); + } + for block in f_blocks { + if block.len() != layout.block_len { + return Err(HachiError::InvalidSize { + expected: layout.block_len, + actual: block.len(), + }); + } + } + Ok(()) +} + +/// Ensure matrix shape matches expected dimensions. +/// +/// # Errors +/// +/// Returns an error if row count or row width mismatch. +pub(super) fn ensure_matrix_shape( + mat: &[Vec], + expected_rows: usize, + expected_cols: usize, + name: &str, +) -> Result<(), HachiError> { + if mat.len() != expected_rows { + return Err(HachiError::InvalidSize { + expected: expected_rows, + actual: mat.len(), + }); + } + for (row_idx, row) in mat.iter().enumerate() { + if row.len() != expected_cols { + return Err(HachiError::InvalidSetup(format!( + "{name} row {row_idx} has width {}, expected {expected_cols}", + row.len() + ))); + } + } + Ok(()) +} + +/// Small correctness-first config for tests and local benchmarks. +/// +/// Fixed layout (m_vars=4, r_vars=2) for fast test iteration. For larger +/// polynomials, use [`DynamicSmallTestCommitmentConfig`] instead. +#[derive(Clone, Copy, Debug, Default)] +pub struct SmallTestCommitmentConfig; + +impl CommitmentConfig for SmallTestCommitmentConfig { + const D: usize = 16; + const N_A: usize = 8; + const N_B: usize = 4; + const N_D: usize = 4; + const LOG_BASIS: u32 = 4; + const DELTA: usize = 9; + const TAU: usize = 4; + const CHALLENGE_WEIGHT: usize = 3; + + fn commitment_layout(_max_num_vars: usize) -> Result { + HachiCommitmentLayout::new::(4, 2) + } +} + +/// D=16 config with dynamic layout that adapts to polynomial size. +/// +/// Uses the same D=16 ring dimension as [`SmallTestCommitmentConfig`] but +/// derives `m_vars`/`r_vars` from `max_num_vars`, so it can commit +/// polynomials with an arbitrary number of variables. +#[derive(Clone, Copy, Debug, Default)] +pub struct DynamicSmallTestCommitmentConfig; + +impl CommitmentConfig for DynamicSmallTestCommitmentConfig { + const D: usize = 16; + const N_A: usize = 8; + const N_B: usize = 4; + const N_D: usize = 4; + const LOG_BASIS: u32 = 4; + const DELTA: usize = 9; + const TAU: usize = 4; + const CHALLENGE_WEIGHT: usize = 3; + + fn commitment_layout(max_num_vars: usize) -> Result { + let alpha = Self::D.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + if reduced_vars == 0 { + return Err(HachiError::InvalidSetup( + "max_num_vars must leave at least one outer variable".to_string(), + )); + } + let m_vars = reduced_vars.min(11); + let r_vars = reduced_vars - m_vars; + HachiCommitmentLayout::new::(m_vars, r_vars) + } +} + +/// Production-oriented profile for 128-bit base fields (`Fp128

`). +/// +/// This profile targets the `D = 512`, `n_A = n_B = n_D = 1` regime with +/// base-16 decomposition over ~128-bit moduli. +/// +/// Rigorous β derivation for the stage-1 folded witness `z`: +/// - In `compute_z_hat`, each coordinate is `z[j] = Σ_i s_i[j].mul_by_sparse(c_i)`. +/// - `balanced_decompose_pow2` yields per-coefficient digits in `[-b/2, b/2)` where +/// `b = 2^LOG_BASIS`, so each input coefficient has `|·| <= b/2`. +/// - Challenges use exactly `ω = CHALLENGE_WEIGHT` nonzeros in `{±1}`. +/// - Therefore each `mul_by_sparse` output coefficient is a signed sum of `ω` +/// shifted digits, hence bounded by `ω * (b/2)`. +/// - Summing over `2^R` blocks gives: +/// `||z||_inf <= 2^R * ω * (b/2)`. +/// +/// For this profile: `R=11`, `ω=19`, `b=16`, so +/// `β = 2^11 * 19 * 8 = 311_296`. +#[derive(Clone, Copy, Debug, Default)] +pub struct ProductionFp128CommitmentConfig; + +impl CommitmentConfig for ProductionFp128CommitmentConfig { + const D: usize = 512; + const N_A: usize = 1; + const N_B: usize = 1; + const N_D: usize = 1; + const LOG_BASIS: u32 = 4; + const DELTA: usize = 32; + const TAU: usize = 5; + const CHALLENGE_WEIGHT: usize = 19; + + fn commitment_layout(max_num_vars: usize) -> Result { + let alpha = Self::D.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + if reduced_vars == 0 { + return Err(HachiError::InvalidSetup( + "max_num_vars must leave at least one outer variable".to_string(), + )); + } + let m_vars = reduced_vars.min(11); + let r_vars = reduced_vars - m_vars; + HachiCommitmentLayout::new::(m_vars, r_vars) + } +} diff --git a/src/protocol/commitment/mod.rs b/src/protocol/commitment/mod.rs new file mode 100644 index 00000000..2801d510 --- /dev/null +++ b/src/protocol/commitment/mod.rs @@ -0,0 +1,26 @@ +//! Protocol commitment abstraction layer. + +mod commit; +mod config; +pub mod onehot; +mod scheme; +mod transcript_append; +mod types; +pub mod utils; + +pub use commit::{ + HachiCommitmentCore, HachiExpandedSetup, HachiPreparedSetup, HachiProverSetup, HachiSetupSeed, + HachiVerifierSetup, MegaPolyBlock, +}; +pub use config::{ + CommitmentConfig, DynamicSmallTestCommitmentConfig, HachiCommitmentLayout, + ProductionFp128CommitmentConfig, SmallTestCommitmentConfig, +}; +pub use onehot::{map_onehot_to_sparse_blocks, SparseBlockEntry}; +pub use scheme::{ + CommitWitness, CommitmentScheme, RingCommitmentScheme, StreamingCommitmentScheme, +}; +pub use transcript_append::AppendToTranscript; +pub use types::{ + DummyProof, HachiCommitment, HachiOpeningClaim, HachiOpeningPoint, RingCommitment, +}; diff --git a/src/protocol/commitment/onehot.rs b/src/protocol/commitment/onehot.rs new file mode 100644 index 00000000..3fe8d3fb --- /dev/null +++ b/src/protocol/commitment/onehot.rs @@ -0,0 +1,299 @@ +//! One-hot commitment path for regular one-hot ring elements. +//! +//! Exploits the sparsity of one-hot witnesses (coefficients in {0,1}) to +//! eliminate all inner ring multiplications. The inner Ajtai `t = A * s` +//! reduces to summing selected columns of `A` with negacyclic rotations. + +use std::collections::BTreeMap; + +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +use crate::{CanonicalField, FieldCore}; + +/// Describes a nonzero ring element within one block of the commitment layout. +#[derive(Debug, Clone)] +pub struct SparseBlockEntry { + /// Position within the block (0..2^M). + pub pos_in_block: usize, + /// Coefficient indices that are 1 within this ring element. + pub nonzero_coeffs: Vec, +} + +/// Map a regular one-hot witness to sparse ring block entries. +/// +/// - `onehot_k`: chunk size K. The witness has T chunks of K field elements, +/// each chunk containing exactly one 1. +/// - `indices`: length-T slice where `indices[c]` is the hot position in +/// chunk `c` (must be in `[0, K)`). +/// - `r`, `m`: commitment config parameters (2^R blocks of 2^M ring elements). +/// - `D`: ring degree (const generic on caller side, passed as runtime here). +/// +/// Returns one `Vec` per block (outer len = 2^R). +/// +/// # Errors +/// +/// Returns an error if K and D are not "nicely matched" (one must divide +/// the other), if any index is out of range, or if the dimensions don't +/// fill the commitment layout. +pub fn map_onehot_to_sparse_blocks( + onehot_k: usize, + indices: &[Option], + r: usize, + m: usize, + d: usize, +) -> Result>, HachiError> { + if onehot_k == 0 || d == 0 { + return Err(HachiError::InvalidInput( + "onehot_k and D must be nonzero".into(), + )); + } + if !(onehot_k % d == 0 || d % onehot_k == 0) { + return Err(HachiError::InvalidInput(format!( + "K={onehot_k} and D={d} must be nicely matched (one divides the other)" + ))); + } + + let num_chunks = indices.len(); + let total_field_elems = num_chunks + .checked_mul(onehot_k) + .ok_or_else(|| HachiError::InvalidInput("T*K overflow".into()))?; + if total_field_elems % d != 0 { + return Err(HachiError::InvalidInput(format!( + "T*K={total_field_elems} is not divisible by D={d}" + ))); + } + let total_ring_elems = total_field_elems / d; + let num_blocks = 1usize << r; + let block_len = 1usize << m; + if total_ring_elems != num_blocks * block_len { + return Err(HachiError::InvalidSize { + expected: num_blocks * block_len, + actual: total_ring_elems, + }); + } + + // Accumulate nonzero coefficients per ring element index. + let mut ring_elem_map: BTreeMap> = BTreeMap::new(); + for (c, opt) in indices.iter().enumerate() { + let Some(&idx) = opt.as_ref() else { continue }; + if idx >= onehot_k { + return Err(HachiError::InvalidInput(format!( + "index {idx} out of range for chunk size K={onehot_k} at position {c}" + ))); + } + let field_pos = c * onehot_k + idx; + let ring_elem_idx = field_pos / d; + let coeff_idx = field_pos % d; + ring_elem_map + .entry(ring_elem_idx) + .or_default() + .push(coeff_idx); + } + + // Sequential block layout matching commit_coeffs: block i = ring elements + // [i*block_len, (i+1)*block_len). + let mut blocks: Vec> = vec![Vec::new(); num_blocks]; + for (ring_elem_idx, nonzero_coeffs) in ring_elem_map { + let block_idx = ring_elem_idx / block_len; + let pos_in_block = ring_elem_idx % block_len; + blocks[block_idx].push(SparseBlockEntry { + pos_in_block, + nonzero_coeffs, + }); + } + + Ok(blocks) +} + +/// Sparse inner Ajtai: compute `t = A * s` for a one-hot block. +/// +/// Instead of materializing the full decomposed vector `s` and doing a dense +/// matvec, we accumulate only the nonzero contributions: +/// +/// ```text +/// t[a] = sum_{entry} A[a][entry.pos * delta].mul_by_monomial_sum(entry.nonzero_coeffs) +/// ``` +/// +/// Also returns `s` (densely materialized) for the opening proof hint. +#[allow(non_snake_case)] +pub(crate) fn inner_ajtai_onehot( + A: &[Vec>], + sparse_entries: &[SparseBlockEntry], + block_len: usize, + delta: usize, +) -> (Vec>, Vec>) { + let n_a = A.len(); + let inner_width = block_len * delta; + + // Build s: mostly zeros, with level-0 entries for nonzero ring elements. + let mut s = vec![CyclotomicRing::::zero(); inner_width]; + for entry in sparse_entries { + let mut coeffs = [F::zero(); D]; + for &ci in &entry.nonzero_coeffs { + coeffs[ci] = F::one(); + } + s[entry.pos_in_block * delta] = CyclotomicRing::from_coefficients(coeffs); + } + + // Compute t[a] = sum over nonzero entries of A[a][pos*delta] * f_j, + // where f_j is the monomial sum at that position. + let mut t = vec![CyclotomicRing::::zero(); n_a]; + for entry in sparse_entries { + let col = entry.pos_in_block * delta; + for a in 0..n_a { + t[a] += A[a][col].mul_by_monomial_sum(&entry.nonzero_coeffs); + } + } + + (t, s) +} + +/// Like `inner_ajtai_onehot` but only returns `t`, skipping the `s` allocation. +#[allow(non_snake_case)] +pub(crate) fn inner_ajtai_onehot_t_only( + A: &[Vec>], + sparse_entries: &[SparseBlockEntry], + _block_len: usize, + _delta: usize, +) -> Vec> { + let n_a = A.len(); + + let mut t = vec![CyclotomicRing::::zero(); n_a]; + for entry in sparse_entries { + let col = entry.pos_in_block * _delta; + for a in 0..n_a { + t[a] += A[a][col].mul_by_monomial_sum(&entry.nonzero_coeffs); + } + } + + t +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::F; + use crate::FromSmallInt; + use std::array::from_fn; + + #[test] + fn map_onehot_k_gt_d() { + // K=16, D=4, T=2 chunks => 32 field elements => 8 ring elements + // R=1 (2 blocks), M=2 (4 per block) => 8 ring elements total + let k = 16; + let d = 4; + let indices = vec![Some(3), Some(10)]; + let blocks = map_onehot_to_sparse_blocks(k, &indices, 1, 2, d).unwrap(); + + assert_eq!(blocks.len(), 2); + let total_entries: usize = blocks.iter().map(|b| b.len()).sum(); + assert_eq!(total_entries, 2, "T=2 nonzero ring elements"); + + for block in &blocks { + for entry in block { + assert_eq!(entry.nonzero_coeffs.len(), 1, "K>D => single monomial"); + } + } + } + + #[test] + fn map_onehot_k_eq_d() { + // K=4, D=4, T=4 chunks => 16 field elements => 4 ring elements + // R=1 (2 blocks), M=1 (2 per block) + let k = 4; + let d = 4; + let indices = vec![Some(0), Some(2), Some(3), Some(1)]; + let blocks = map_onehot_to_sparse_blocks(k, &indices, 1, 1, d).unwrap(); + + assert_eq!(blocks.len(), 2); + let total_entries: usize = blocks.iter().map(|b| b.len()).sum(); + assert_eq!(total_entries, 4, "K=D => every ring element is nonzero"); + + for block in &blocks { + for entry in block { + assert_eq!(entry.nonzero_coeffs.len(), 1, "K=D => single monomial"); + } + } + } + + #[test] + fn map_onehot_k_lt_d() { + // K=4, D=8, T=8 chunks => 32 field elements => 4 ring elements + // R=1 (2 blocks), M=1 (2 per block) + let k = 4; + let d = 8; + let indices = vec![ + Some(0), + Some(2), + Some(3), + Some(1), + Some(0), + Some(0), + Some(3), + Some(3), + ]; + let blocks = map_onehot_to_sparse_blocks(k, &indices, 1, 1, d).unwrap(); + + assert_eq!(blocks.len(), 2); + let total_entries: usize = blocks.iter().map(|b| b.len()).sum(); + assert_eq!(total_entries, 4, "D>K => all ring elements nonzero"); + + for block in &blocks { + for entry in block { + assert_eq!( + entry.nonzero_coeffs.len(), + 2, + "D=2K => 2 nonzero coeffs per ring element" + ); + } + } + } + + #[test] + fn map_onehot_rejects_non_divisible() { + let result = map_onehot_to_sparse_blocks(3, &[Some(0), Some(1)], 0, 1, 4); + assert!(result.is_err()); + } + + #[test] + fn inner_ajtai_onehot_single_monomial() { + const D: usize = 4; + type R = CyclotomicRing; + + // A is 2x4 (N_A=2, inner_width = block_len * delta = 2 * 2 = 4) + let a: Vec> = vec![ + vec![ + R::from_coefficients(from_fn(|i| F::from_u64((i + 1) as u64))), + R::from_coefficients(from_fn(|i| F::from_u64((i + 10) as u64))), + R::from_coefficients(from_fn(|i| F::from_u64((i + 20) as u64))), + R::from_coefficients(from_fn(|i| F::from_u64((i + 30) as u64))), + ], + vec![ + R::from_coefficients(from_fn(|i| F::from_u64((i + 5) as u64))), + R::from_coefficients(from_fn(|i| F::from_u64((i + 15) as u64))), + R::from_coefficients(from_fn(|i| F::from_u64((i + 25) as u64))), + R::from_coefficients(from_fn(|i| F::from_u64((i + 35) as u64))), + ], + ]; + + // One nonzero entry at pos=1, coefficient index 2 => monomial X^2 + let entries = vec![SparseBlockEntry { + pos_in_block: 1, + nonzero_coeffs: vec![2], + }]; + + let (t, s) = inner_ajtai_onehot(&a, &entries, 2, 2); + + // t[row] should equal A[row][1*2] * X^2 = A[row][2].negacyclic_shift(2) + for row in 0..2 { + let expected = a[row][2].negacyclic_shift(2); + assert_eq!(t[row], expected); + } + + // s should have a nonzero entry at position 1*2 = 2 + assert_eq!(s[2].coefficients()[2], F::one()); + assert!(s[0] == R::zero()); + assert!(s[1] == R::zero()); + assert!(s[3] == R::zero()); + } +} diff --git a/src/protocol/commitment/scheme.rs b/src/protocol/commitment/scheme.rs new file mode 100644 index 00000000..256ce197 --- /dev/null +++ b/src/protocol/commitment/scheme.rs @@ -0,0 +1,266 @@ +//! Commitment-scheme trait surface for Hachi protocol code. + +use super::config::{CommitmentConfig, HachiCommitmentLayout}; +use super::transcript_append::AppendToTranscript; +use crate::algebra::CyclotomicRing; +use crate::error::HachiError; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, Polynomial}; + +/// Output type for batched commitments. +pub(crate) type BatchCommitOutput = Result, HachiError>; + +/// Witness data produced alongside a ring-native commitment. +/// +/// Contains the commitment itself plus `t_hat` (basis-decomposed inner Ajtai +/// output) from the two-layer Ajtai construction (§4.1). The decomposed input +/// vectors `s` are NOT stored; they are recomputed from `ring_coeffs` during +/// proving to avoid multi-GB memory usage at production parameters. +pub struct CommitWitness { + /// The ring commitment (outer Ajtai output `u = B · t̂`). + pub commitment: C, + /// Per-block basis-decomposed inner Ajtai output vectors. + pub t_hat: Vec>>, +} + +/// Generic commitment-scheme interface used by Hachi protocol code. +pub trait CommitmentScheme: Clone + Send + Sync + 'static +where + F: FieldCore + CanonicalField, +{ + /// Prover setup parameters. + type ProverSetup: Clone + Send + Sync; + /// Verifier setup parameters. + type VerifierSetup: Clone + Send + Sync; + /// Commitment object. + type Commitment: Clone + PartialEq + Send + Sync + AppendToTranscript; + /// Evaluation/opening proof object. + type Proof: Clone + Send + Sync; + /// Optional prover-side hint produced at commitment time. + type OpeningProofHint: Clone + Send + Sync; + + /// Build prover setup for maximum polynomial dimension. + /// + /// # Panics + /// + /// Panics if internal setup fails (programming error, not adversarial input). + fn setup_prover(max_num_vars: usize) -> Self::ProverSetup; + + /// Derive verifier setup from prover setup. + fn setup_verifier(setup: &Self::ProverSetup) -> Self::VerifierSetup; + + /// Commit to one polynomial. + /// + /// # Errors + /// + /// Returns an error when setup/parameter constraints are not satisfied. + fn commit>( + poly: &P, + setup: &Self::ProverSetup, + ) -> Result<(Self::Commitment, Self::OpeningProofHint), HachiError>; + + /// Commit to many polynomials. + /// + /// # Errors + /// + /// Returns an error if any per-polynomial commitment fails. + fn batch_commit>( + polys: &[P], + setup: &Self::ProverSetup, + ) -> BatchCommitOutput { + polys.iter().map(|p| Self::commit(p, setup)).collect() + } + + /// Produce an opening proof at `opening_point`. + /// + /// # Errors + /// + /// Returns an error if the opening point is invalid or proof generation fails. + fn prove, P: Polynomial>( + setup: &Self::ProverSetup, + poly: &P, + opening_point: &[F], + hint: Option, + transcript: &mut T, + commitment: &Self::Commitment, + ) -> Result; + + /// Verify an opening proof. + /// + /// # Errors + /// + /// Returns an error when verification fails. + fn verify>( + proof: &Self::Proof, + setup: &Self::VerifierSetup, + transcript: &mut T, + opening_point: &[F], + opening: &F, + commitment: &Self::Commitment, + ) -> Result<(), HachiError>; + + /// Homomorphic commitment combination. + fn combine_commitments(commitments: &[Self::Commitment], coeffs: &[F]) -> Self::Commitment; + + /// Homomorphic hint combination. + fn combine_hints(hints: Vec, coeffs: &[F]) -> Self::OpeningProofHint; + + /// Protocol identifier. + fn protocol_name() -> &'static [u8]; +} + +/// Streaming extension for chunked commitment workflows. +pub trait StreamingCommitmentScheme: CommitmentScheme +where + F: FieldCore + CanonicalField, +{ + /// Intermediate chunk state. + type ChunkState: Clone + Send + Sync + PartialEq + std::fmt::Debug; + + /// Process one chunk of field elements. + fn process_chunk(setup: &Self::ProverSetup, chunk: &[F]) -> Self::ChunkState; + + /// Process one chunk of one-hot values. + fn process_chunk_onehot( + setup: &Self::ProverSetup, + onehot_k: usize, + chunk: &[Option], + ) -> Self::ChunkState; + + /// Aggregate chunk states into one commitment + hint. + fn aggregate_chunks( + setup: &Self::ProverSetup, + onehot_k: Option, + chunks: &[Self::ChunkState], + ) -> (Self::Commitment, Self::OpeningProofHint); +} + +/// Ring-native commitment interface for §4.1 implementation work. +pub trait RingCommitmentScheme: Clone + Send + Sync + 'static +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, +{ + /// Prover setup parameters. + type ProverSetup: Clone + Send + Sync; + /// Verifier setup parameters. + type VerifierSetup: Clone + Send + Sync; + /// Ring-native commitment type. + type Commitment: Clone + PartialEq + Send + Sync; + + /// Construct commitment setup for at most `max_num_vars` variables. + /// + /// # Errors + /// + /// Returns an error if dimensions are inconsistent with `Cfg`. + fn setup(max_num_vars: usize) -> Result<(Self::ProverSetup, Self::VerifierSetup), HachiError>; + + /// Read the runtime layout carried by `setup`. + /// + /// # Errors + /// + /// Returns an error when setup metadata is inconsistent. + fn layout(setup: &Self::ProverSetup) -> Result; + + /// Commit to ring blocks arranged as `2^R` vectors of length `2^M`. + /// + /// Returns `(commitment, s, t_hat)` where `s` and `t_hat` are the + /// decomposed witness vectors from §4.1. + /// + /// # Errors + /// + /// Returns an error if block layout mismatches config or commitment fails. + fn commit_ring_blocks( + f_blocks: &[Vec>], + setup: &Self::ProverSetup, + ) -> Result, HachiError>; + + /// Commit to a flat coefficient table `(f_i)_{i∈{0,1}^ℓ}` in ring form. + /// + /// The input uses sequential block layout: ring elements + /// `[0, block_len)` form block 0, `[block_len, 2*block_len)` form + /// block 1, and so on. This matches the sequential variable ordering + /// where M variables (position in block) are lower-order and R variables + /// (block selection) are higher-order. + /// + /// # Errors + /// + /// Returns an error if `f_coeffs.len()` does not match the configured block + /// layout or if the underlying commitment routine fails. + fn commit_coeffs( + f_coeffs: &[CyclotomicRing], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let layout = Self::layout(setup)?; + let num_blocks = layout.num_blocks; + let block_len = layout.block_len; + let expected_len = num_blocks + .checked_mul(block_len) + .ok_or_else(|| HachiError::InvalidSetup("coefficient length overflow".to_string()))?; + if f_coeffs.len() != expected_len { + return Err(HachiError::InvalidSize { + expected: expected_len, + actual: f_coeffs.len(), + }); + } + + let blocks: Vec>> = f_coeffs + .chunks_exact(block_len) + .map(|chunk| chunk.to_vec()) + .collect(); + + Self::commit_ring_blocks(&blocks, setup) + } + + /// Commit to a regular one-hot witness. + /// + /// The witness represents `T` chunks of `onehot_k` field elements, each + /// chunk containing exactly one 1 and all other entries 0. `indices[c]` + /// gives the hot position in chunk `c` (must be in `[0, onehot_k)`). + /// + /// Requires `D` and `onehot_k` to be "nicely matched": one must divide + /// the other. + /// + /// The default implementation materializes the full one-hot field vector, + /// packs it into ring elements via coefficient embedding, and delegates + /// to `commit_coeffs`. Implementations may override this with a + /// sparse-aware path that avoids all inner ring multiplications. + /// + /// # Errors + /// + /// Returns an error if dimensions are inconsistent or any index is out + /// of range. + fn commit_onehot( + onehot_k: usize, + indices: &[Option], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let num_chunks = indices.len(); + let total_field_elems = num_chunks + .checked_mul(onehot_k) + .ok_or_else(|| HachiError::InvalidInput("T*K overflow".into()))?; + if total_field_elems % D != 0 { + return Err(HachiError::InvalidInput(format!( + "T*K={total_field_elems} is not divisible by D={D}" + ))); + } + + // Materialize the full one-hot vector as ring elements. + let total_ring_elems = total_field_elems / D; + let mut ring_coeffs = vec![CyclotomicRing::::zero(); total_ring_elems]; + for (c, opt) in indices.iter().enumerate() { + let Some(&idx) = opt.as_ref() else { continue }; + if idx >= onehot_k { + return Err(HachiError::InvalidInput(format!( + "index {idx} out of range for chunk size K={onehot_k} at position {c}" + ))); + } + let field_pos = c * onehot_k + idx; + let ring_idx = field_pos / D; + let coeff_idx = field_pos % D; + ring_coeffs[ring_idx].coeffs[coeff_idx] = F::one(); + } + + Self::commit_coeffs(&ring_coeffs, setup) + } +} diff --git a/src/protocol/commitment/transcript_append.rs b/src/protocol/commitment/transcript_append.rs new file mode 100644 index 00000000..6510ccb5 --- /dev/null +++ b/src/protocol/commitment/transcript_append.rs @@ -0,0 +1,13 @@ +//! Traits for appending commitment objects to protocol transcripts. + +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; + +/// Protocol object that can be absorbed into a transcript. +pub trait AppendToTranscript +where + F: FieldCore + CanonicalField, +{ + /// Append this object to a transcript using the provided event label. + fn append_to_transcript>(&self, label: &[u8], transcript: &mut T); +} diff --git a/src/protocol/commitment/types.rs b/src/protocol/commitment/types.rs new file mode 100644 index 00000000..339a1a03 --- /dev/null +++ b/src/protocol/commitment/types.rs @@ -0,0 +1,157 @@ +//! Protocol commitment/opening wrapper types. + +use super::transcript_append::AppendToTranscript; +use crate::algebra::ring::CyclotomicRing; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; +use std::io::{Read, Write}; + +/// A Hachi opening point represented as field coordinates. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiOpeningPoint { + /// Point coordinates used for multilinear opening. + pub r: Vec, +} + +/// A Hachi opening claim `(point, value)`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiOpeningClaim { + /// Opening point. + pub point: HachiOpeningPoint, + /// Claimed value at `point`. + pub value: F, +} + +/// Minimal commitment wrapper used by protocol traits/tests. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct HachiCommitment(pub u128); + +/// Minimal proof wrapper used by protocol trait stubs and tests. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct DummyProof(pub u128); + +impl Valid for HachiCommitment { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl HachiSerialize for HachiCommitment { + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(&mut writer, Compress::No) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 16 + } +} + +impl HachiDeserialize for HachiCommitment { + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let value = u128::deserialize_with_mode(&mut reader, Compress::No, validate)?; + Ok(Self(value)) + } +} + +impl Valid for DummyProof { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl HachiSerialize for DummyProof { + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(&mut writer, Compress::No) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 16 + } +} + +impl HachiDeserialize for DummyProof { + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let value = u128::deserialize_with_mode(&mut reader, Compress::No, validate)?; + Ok(Self(value)) + } +} + +impl AppendToTranscript for HachiCommitment +where + F: FieldCore + CanonicalField, +{ + fn append_to_transcript>(&self, label: &[u8], transcript: &mut T) { + transcript.append_serde(label, self); + } +} + +/// Ring-native commitment object `u in R_q^{n_B}` used by §4.1. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct RingCommitment { + /// Outer commitment vector. + pub u: Vec>, +} + +impl Valid for RingCommitment { + fn check(&self) -> Result<(), SerializationError> { + self.u.check() + } +} + +impl HachiSerialize for RingCommitment { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.u.serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.u.serialized_size(compress) + } +} + +impl HachiDeserialize for RingCommitment { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let u = + Vec::>::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { u }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl AppendToTranscript for RingCommitment +where + F: FieldCore + CanonicalField, +{ + fn append_to_transcript>(&self, label: &[u8], transcript: &mut T) { + transcript.append_serde(label, self); + } +} diff --git a/src/protocol/commitment/utils/crt_ntt.rs b/src/protocol/commitment/utils/crt_ntt.rs new file mode 100644 index 00000000..33bd79c4 --- /dev/null +++ b/src/protocol/commitment/utils/crt_ntt.rs @@ -0,0 +1,139 @@ +//! Protocol-facing CRT+NTT parameter dispatch and matrix caching. + +use crate::algebra::ntt::prime::PrimeWidth; +use crate::algebra::ntt::tables::{ + q128_primes, q64_primes, MAX_CRT_RING_DEGREE, Q128_MODULUS, Q128_NUM_PRIMES, Q32_MODULUS, + Q32_NUM_PRIMES, Q32_PRIMES, Q64_MODULUS, Q64_NUM_PRIMES, RING_DEGREE, +}; +use crate::algebra::ring::{CrtNttParamSet, CyclotomicCrtNtt, CyclotomicRing}; +use crate::cfg_iter; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::{CanonicalField, FieldCore}; + +use super::norm::detect_field_modulus; + +/// Supported protocol CRT+NTT parameter families. +pub(crate) enum ProtocolCrtNttParams { + Q32(CrtNttParamSet), + Q64(CrtNttParamSet), + Q128(CrtNttParamSet), +} + +/// Select a CRT+NTT parameter set from field modulus and ring degree. +/// +/// Dispatch policy: +/// - `q <= 2^32-99` and `D <= 64`: Q32 (`i16`) +/// - `q <= 2^64-59` and `D <= 1024`: Q64 (`i32`, conservative K=5) +/// - `q == 2^128-275` and `D <= 1024`: Q128 (`i32`, K=5) +/// - otherwise: explicit setup error +pub(crate) fn select_crt_ntt_params( +) -> Result, HachiError> { + if !D.is_power_of_two() { + return Err(HachiError::InvalidSetup(format!( + "CRT+NTT requires power-of-two ring degree, got D={D}" + ))); + } + if D > MAX_CRT_RING_DEGREE { + return Err(HachiError::InvalidSetup(format!( + "CRT+NTT supports D <= {MAX_CRT_RING_DEGREE}, got D={D}" + ))); + } + + let modulus = detect_field_modulus::(); + + if modulus == Q128_MODULUS { + return Ok(ProtocolCrtNttParams::Q128(CrtNttParamSet::new( + q128_primes(), + ))); + } + + if modulus <= Q32_MODULUS as u128 { + if D <= RING_DEGREE { + return Ok(ProtocolCrtNttParams::Q32(CrtNttParamSet::new(Q32_PRIMES))); + } + return Ok(ProtocolCrtNttParams::Q64(CrtNttParamSet::new(q64_primes()))); + } + + if modulus <= Q64_MODULUS as u128 { + return Ok(ProtocolCrtNttParams::Q64(CrtNttParamSet::new(q64_primes()))); + } + + Err(HachiError::InvalidSetup(format!( + "no CRT+NTT parameter set for modulus {modulus} and D={D}; supported ranges: <= {Q64_MODULUS} (with Q32/Q64 dispatch) or exactly {Q128_MODULUS}" + ))) +} + +/// Pre-converted CRT+NTT matrices, keyed by parameter family. +/// +/// Avoids repeated coefficient-to-NTT conversion on every dense mat-vec. +#[derive(Debug, Clone, PartialEq, Eq)] +#[allow(non_snake_case)] +pub(crate) enum NttMatrixCache { + Q32 { + A: Vec>>, + B: Vec>>, + D: Vec>>, + params: CrtNttParamSet, + }, + Q64 { + A: Vec>>, + B: Vec>>, + D: Vec>>, + params: CrtNttParamSet, + }, + Q128 { + A: Vec>>, + B: Vec>>, + D: Vec>>, + params: CrtNttParamSet, + }, +} + +fn convert_mat( + mat: &[Vec>], + params: &CrtNttParamSet, +) -> Vec>> +where + F: FieldCore + CanonicalField, + W: PrimeWidth, +{ + cfg_iter!(mat) + .map(|row| { + row.iter() + .map(|a| CyclotomicCrtNtt::from_ring_with_params(a, params)) + .collect() + }) + .collect() +} + +#[allow(non_snake_case)] +pub(crate) fn build_ntt_cache( + a: &[Vec>], + b: &[Vec>], + d: &[Vec>], +) -> Result, HachiError> { + let params = select_crt_ntt_params::()?; + let cache = match params { + ProtocolCrtNttParams::Q32(p) => NttMatrixCache::Q32 { + A: convert_mat(a, &p), + B: convert_mat(b, &p), + D: convert_mat(d, &p), + params: p, + }, + ProtocolCrtNttParams::Q64(p) => NttMatrixCache::Q64 { + A: convert_mat(a, &p), + B: convert_mat(b, &p), + D: convert_mat(d, &p), + params: p, + }, + ProtocolCrtNttParams::Q128(p) => NttMatrixCache::Q128 { + A: convert_mat(a, &p), + B: convert_mat(b, &p), + D: convert_mat(d, &p), + params: p, + }, + }; + Ok(cache) +} diff --git a/src/protocol/commitment/utils/linear.rs b/src/protocol/commitment/utils/linear.rs new file mode 100644 index 00000000..636e7b3a --- /dev/null +++ b/src/protocol/commitment/utils/linear.rs @@ -0,0 +1,358 @@ +//! Linear algebra helpers for ring commitment. + +use crate::algebra::ntt::{MontCoeff, PrimeWidth}; +use crate::algebra::{CrtNttParamSet, CyclotomicCrtNtt, CyclotomicRing}; +use crate::cfg_iter; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::{CanonicalField, FieldCore}; + +use super::crt_ntt::NttMatrixCache; +#[cfg(test)] +use super::crt_ntt::{select_crt_ntt_params, ProtocolCrtNttParams}; + +#[cfg(test)] +pub(crate) fn mat_vec_mul_unchecked( + mat: &[Vec>], + vec: &[CyclotomicRing], +) -> Vec> { + let mut out = Vec::with_capacity(mat.len()); + for row in mat { + debug_assert_eq!(row.len(), vec.len()); + let mut acc = CyclotomicRing::::zero(); + for (a, x) in row.iter().zip(vec.iter()) { + acc += *a * *x; + } + out.push(acc); + } + out +} + +#[inline] +fn accumulate_pointwise_product_into( + acc: &mut CyclotomicCrtNtt, + lhs: &CyclotomicCrtNtt, + rhs: &CyclotomicCrtNtt, + params: &CrtNttParamSet, +) { + for k in 0..K { + let prime = params.primes[k]; + let acc_limb = &mut acc.limbs[k]; + let lhs_limb = &lhs.limbs[k]; + let rhs_limb = &rhs.limbs[k]; + for ((acc_coeff, lhs_coeff), rhs_coeff) in acc_limb + .iter_mut() + .zip(lhs_limb.iter()) + .zip(rhs_limb.iter()) + { + let prod = prime.mul(*lhs_coeff, *rhs_coeff); + let sum = MontCoeff::from_raw(acc_coeff.raw().wrapping_add(prod.raw())); + *acc_coeff = prime.reduce_range(sum); + } + } +} + +#[cfg(test)] +fn precompute_dense_mat_ntt_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + mat: &[Vec>], + params: &CrtNttParamSet, +) -> Vec>> { + mat.iter() + .map(|row| { + row.iter() + .map(|a| CyclotomicCrtNtt::from_ring_with_params(a, params)) + .collect() + }) + .collect() +} + +#[cfg(test)] +fn mat_vec_mul_dense_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + mat: &[Vec>], + vec: &[CyclotomicRing], + params: &CrtNttParamSet, +) -> Vec> { + let ntt_vec: Vec> = vec + .iter() + .map(|v| CyclotomicCrtNtt::from_ring_with_params(v, params)) + .collect(); + + mat.iter() + .map(|row| { + debug_assert_eq!(row.len(), ntt_vec.len()); + let mut acc = CyclotomicCrtNtt::::zero(); + for (a, x_ntt) in row.iter().zip(ntt_vec.iter()) { + let a_ntt = CyclotomicCrtNtt::from_ring_with_params(a, params); + accumulate_pointwise_product_into(&mut acc, &a_ntt, x_ntt, params); + } + acc.to_ring_with_params(params) + }) + .collect() +} + +#[cfg(test)] +fn mat_vec_mul_dense_many_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + mat: &[Vec>], + vecs: &[Vec>], + params: &CrtNttParamSet, +) -> Vec>> { + let ntt_mat = precompute_dense_mat_ntt_with_params(mat, params); + vecs.iter() + .map(|vec| { + let ntt_vec: Vec> = vec + .iter() + .map(|v| CyclotomicCrtNtt::from_ring_with_params(v, params)) + .collect(); + + ntt_mat + .iter() + .map(|row_ntt| { + debug_assert_eq!(row_ntt.len(), ntt_vec.len()); + let mut acc = CyclotomicCrtNtt::::zero(); + for (a_ntt, x_ntt) in row_ntt.iter().zip(ntt_vec.iter()) { + accumulate_pointwise_product_into(&mut acc, a_ntt, x_ntt, params); + } + acc.to_ring_with_params(params) + }) + .collect() + }) + .collect() +} + +#[cfg(test)] +pub(crate) fn mat_vec_mul_crt_ntt( + mat: &[Vec>], + vec: &[CyclotomicRing], +) -> Result>, HachiError> { + let params = select_crt_ntt_params::()?; + let out = match ¶ms { + ProtocolCrtNttParams::Q32(p) => mat_vec_mul_dense_with_params(mat, vec, p), + ProtocolCrtNttParams::Q64(p) => mat_vec_mul_dense_with_params(mat, vec, p), + ProtocolCrtNttParams::Q128(p) => mat_vec_mul_dense_with_params(mat, vec, p), + }; + Ok(out) +} + +#[cfg(test)] +pub(crate) fn mat_vec_mul_crt_ntt_many( + mat: &[Vec>], + vecs: &[Vec>], +) -> Result>>, HachiError> { + let params = select_crt_ntt_params::()?; + let out = match ¶ms { + ProtocolCrtNttParams::Q32(p) => mat_vec_mul_dense_many_with_params(mat, vecs, p), + ProtocolCrtNttParams::Q64(p) => mat_vec_mul_dense_many_with_params(mat, vecs, p), + ProtocolCrtNttParams::Q128(p) => mat_vec_mul_dense_many_with_params(mat, vecs, p), + }; + Ok(out) +} + +/// Selector for which cached matrix to use. +#[derive(Debug, Clone, Copy)] +pub(crate) enum MatrixSlot { + A, + B, + D, +} + +fn mat_vec_mul_precomputed_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + ntt_mat: &[Vec>], + vec: &[CyclotomicRing], + params: &CrtNttParamSet, +) -> Vec> { + let ntt_vec: Vec> = cfg_iter!(vec) + .map(|v| CyclotomicCrtNtt::from_ring_with_params(v, params)) + .collect(); + + cfg_iter!(ntt_mat) + .map(|row_ntt| { + debug_assert!(row_ntt.len() >= ntt_vec.len()); + let mut acc = CyclotomicCrtNtt::::zero(); + for (a_ntt, x_ntt) in row_ntt.iter().zip(ntt_vec.iter()) { + accumulate_pointwise_product_into(&mut acc, a_ntt, x_ntt, params); + } + acc.to_ring_with_params(params) + }) + .collect() +} + +macro_rules! dispatch_cached { + ($cache:expr, $which:expr, $func:ident $(, $arg:expr)*) => {{ + #[allow(non_snake_case)] + match $cache { + NttMatrixCache::Q32 { A, B, D: Dm, params: p } => { + let m = match $which { MatrixSlot::A => A, MatrixSlot::B => B, MatrixSlot::D => Dm }; + $func(m, $($arg,)* p) + } + NttMatrixCache::Q64 { A, B, D: Dm, params: p } => { + let m = match $which { MatrixSlot::A => A, MatrixSlot::B => B, MatrixSlot::D => Dm }; + $func(m, $($arg,)* p) + } + NttMatrixCache::Q128 { A, B, D: Dm, params: p } => { + let m = match $which { MatrixSlot::A => A, MatrixSlot::B => B, MatrixSlot::D => Dm }; + $func(m, $($arg,)* p) + } + } + }}; +} + +/// Dense mat-vec using a pre-converted NTT matrix from the cache. +pub(crate) fn mat_vec_mul_ntt_cached( + cache: &NttMatrixCache, + which: MatrixSlot, + vec: &[CyclotomicRing], +) -> Result>, HachiError> { + let out = dispatch_cached!(cache, which, mat_vec_mul_precomputed_with_params, vec); + Ok(out) +} + +/// Basis-decompose a block of ring elements into `block.len() * delta` gadget components. +pub fn decompose_block( + block: &[CyclotomicRing], + delta: usize, + log_basis: u32, +) -> Vec> { + let mut out = vec![CyclotomicRing::::zero(); block.len() * delta]; + for (i, coeff_vec) in block.iter().enumerate() { + coeff_vec.balanced_decompose_pow2_into(&mut out[i * delta..(i + 1) * delta], log_basis); + } + out +} + +pub(crate) fn decompose_rows( + rows: &[CyclotomicRing], + delta: usize, + log_basis: u32, +) -> Vec> { + let mut out = vec![CyclotomicRing::::zero(); rows.len() * delta]; + for (i, row) in rows.iter().enumerate() { + row.balanced_decompose_pow2_into(&mut out[i * delta..(i + 1) * delta], log_basis); + } + out +} + +#[cfg(test)] +mod tests { + use super::{mat_vec_mul_crt_ntt, mat_vec_mul_crt_ntt_many, mat_vec_mul_unchecked}; + use crate::algebra::{CyclotomicRing, Fp64}; + use crate::FromSmallInt; + + #[test] + fn dense_mat_vec_matches_schoolbook_q32_d64() { + type F = Fp64<4294967197>; + const D: usize = 64; + let mat: Vec>> = (0..3) + .map(|i| { + (0..4) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + F::from_u64((i as u64 * 10_000 + j as u64 * 100 + k as u64 + 1) % 97) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + let vec: Vec> = (0..4) + .map(|j| { + let coeffs = + std::array::from_fn(|k| F::from_u64((j as u64 * 50 + k as u64 + 3) % 89)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + + let schoolbook = mat_vec_mul_unchecked(&mat, &vec); + let crt_ntt = mat_vec_mul_crt_ntt(&mat, &vec).expect("Q32 dispatch should succeed"); + assert_eq!(schoolbook, crt_ntt); + } + + #[test] + fn dense_mat_vec_matches_schoolbook_q64_dispatch_for_large_d() { + type F = Fp64<4294967197>; + const D: usize = 128; + let mat: Vec>> = (0..2) + .map(|i| { + (0..2) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + F::from_u64((i as u64 * 20_000 + j as u64 * 300 + k as u64 + 7) % 113) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + let vec: Vec> = (0..2) + .map(|j| { + let coeffs = + std::array::from_fn(|k| F::from_u64((j as u64 * 70 + k as u64 + 11) % 101)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + + let schoolbook = mat_vec_mul_unchecked(&mat, &vec); + let crt_ntt = mat_vec_mul_crt_ntt(&mat, &vec).expect("Q64 dispatch should succeed"); + assert_eq!(schoolbook, crt_ntt); + } + + #[test] + fn dense_mat_vec_many_matches_individual_crt_ntt_q32_d64() { + type F = Fp64<4294967197>; + const D: usize = 64; + let mat: Vec>> = (0..3) + .map(|i| { + (0..4) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + F::from_u64((i as u64 * 10_000 + j as u64 * 100 + k as u64 + 1) % 97) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + + let vecs: Vec>> = (0..3) + .map(|seed| { + (0..4) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + F::from_u64((seed as u64 * 700 + j as u64 * 50 + k as u64 + 3) % 89) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + + let expected: Vec>> = vecs + .iter() + .map(|v| mat_vec_mul_crt_ntt(&mat, v).expect("single CRT+NTT mat-vec should succeed")) + .collect(); + + let got = + mat_vec_mul_crt_ntt_many(&mat, &vecs).expect("batched CRT+NTT mat-vec should succeed"); + assert_eq!(expected, got); + } +} diff --git a/src/protocol/commitment/utils/math.rs b/src/protocol/commitment/utils/math.rs new file mode 100644 index 00000000..16cf8618 --- /dev/null +++ b/src/protocol/commitment/utils/math.rs @@ -0,0 +1,14 @@ +//! Small math helpers for commitment internals. + +use crate::error::HachiError; + +/// Compute `2^exp` with overflow checks. +/// +/// # Errors +/// +/// Returns `InvalidSetup` if `2^exp` does not fit in `usize`. +pub(in crate::protocol::commitment) fn checked_pow2(exp: usize) -> Result { + 1usize + .checked_shl(exp as u32) + .ok_or_else(|| HachiError::InvalidSetup(format!("2^{exp} does not fit usize"))) +} diff --git a/src/protocol/commitment/utils/matrix.rs b/src/protocol/commitment/utils/matrix.rs new file mode 100644 index 00000000..fa0b80ce --- /dev/null +++ b/src/protocol/commitment/utils/matrix.rs @@ -0,0 +1,131 @@ +//! Matrix sampling helpers for setup. + +use crate::algebra::ring::CyclotomicRing; +use crate::{FieldCore, FieldSampling}; +use rand_core::{CryptoRng, RngCore}; +use sha3::digest::{ExtendableOutput, Update, XofReader}; +use sha3::Shake256; + +/// Public seed used to derive commitment matrices. +pub(crate) type PublicMatrixSeed = [u8; 32]; + +const PUBLIC_MATRIX_DOMAIN: &[u8] = b"hachi/commitment/public-matrix"; + +/// Fixed public seed for deterministic, reproducible setup. +pub(crate) fn sample_public_matrix_seed() -> PublicMatrixSeed { + let mut seed = [0u8; 32]; + seed[..8].copy_from_slice(&0xDEAD_BEEF_CAFE_BABEu64.to_le_bytes()); + seed +} + +/// Derive a public matrix from a seed using domain-separated SHAKE expansion. +/// +/// This follows the same high-level pattern used in NIST lattice specs: +/// derive deterministic public structure from a seed + indices, then sample +/// coefficients via rejection-sampling at the field layer. +/// +/// NOTE: Potential future hardening: +/// move toward stricter ML-KEM/ML-DSA-style byte layout and parsing rules +/// (fixed-format seed/index encoding and scheme-specific expansion details) +/// if we decide to maximize standards-shape interoperability. +pub(crate) fn derive_public_matrix( + rows: usize, + cols: usize, + seed: &PublicMatrixSeed, + matrix_label: &[u8], +) -> Vec>> { + (0..rows) + .map(|r| { + (0..cols) + .map(|c| { + let mut entry_rng = ShakeXofRng::new(seed, matrix_label, rows, cols, r, c); + CyclotomicRing::random(&mut entry_rng) + }) + .collect() + }) + .collect() +} + +struct ShakeXofRng { + reader: Box, +} + +impl ShakeXofRng { + fn new( + seed: &PublicMatrixSeed, + matrix_label: &[u8], + rows: usize, + cols: usize, + row: usize, + col: usize, + ) -> Self { + let mut xof = Shake256::default(); + absorb_len_prefixed(&mut xof, b"domain", PUBLIC_MATRIX_DOMAIN); + absorb_len_prefixed(&mut xof, b"seed", seed); + absorb_len_prefixed(&mut xof, b"matrix", matrix_label); + absorb_len_prefixed(&mut xof, b"rows", &(rows as u64).to_le_bytes()); + absorb_len_prefixed(&mut xof, b"cols", &(cols as u64).to_le_bytes()); + absorb_len_prefixed(&mut xof, b"row", &(row as u64).to_le_bytes()); + absorb_len_prefixed(&mut xof, b"col", &(col as u64).to_le_bytes()); + Self { + reader: Box::new(xof.finalize_xof()), + } + } +} + +impl RngCore for ShakeXofRng { + fn next_u32(&mut self) -> u32 { + let mut buf = [0u8; 4]; + self.fill_bytes(&mut buf); + u32::from_le_bytes(buf) + } + + fn next_u64(&mut self) -> u64 { + let mut buf = [0u8; 8]; + self.fill_bytes(&mut buf); + u64::from_le_bytes(buf) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.reader.read(dest); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +impl CryptoRng for ShakeXofRng {} + +fn absorb_len_prefixed(xof: &mut Shake256, label: &[u8], data: &[u8]) { + xof.update(&(label.len() as u64).to_le_bytes()); + xof.update(label); + xof.update(&(data.len() as u64).to_le_bytes()); + xof.update(data); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Fp64; + + type F = Fp64<4294967197>; + const D: usize = 64; + + #[test] + fn matrix_derivation_is_deterministic_for_same_seed() { + let seed = [42u8; 32]; + let m1 = derive_public_matrix::(3, 5, &seed, b"A"); + let m2 = derive_public_matrix::(3, 5, &seed, b"A"); + assert_eq!(m1, m2); + } + + #[test] + fn matrix_derivation_domain_separates_labels() { + let seed = [7u8; 32]; + let a = derive_public_matrix::(2, 3, &seed, b"A"); + let b = derive_public_matrix::(2, 3, &seed, b"B"); + assert_ne!(a, b); + } +} diff --git a/src/protocol/commitment/utils/mod.rs b/src/protocol/commitment/utils/mod.rs new file mode 100644 index 00000000..f7035759 --- /dev/null +++ b/src/protocol/commitment/utils/mod.rs @@ -0,0 +1,7 @@ +//! Utility helpers for commitment internals. + +pub(crate) mod crt_ntt; +pub mod linear; +pub(crate) mod math; +pub(crate) mod matrix; +pub(crate) mod norm; diff --git a/src/protocol/commitment/utils/norm.rs b/src/protocol/commitment/utils/norm.rs new file mode 100644 index 00000000..cc01f9a7 --- /dev/null +++ b/src/protocol/commitment/utils/norm.rs @@ -0,0 +1,48 @@ +//! Infinity norm utilities for ring elements over Z_q. + +use crate::algebra::ring::CyclotomicRing; +use crate::CanonicalField; + +/// Detect the field modulus from the canonical representation. +/// +/// Uses the identity: the canonical form of `−1` in `Z_q` is `q − 1`. +pub(crate) fn detect_field_modulus() -> u128 { + (-F::one()).to_canonical_u128() + 1 +} + +/// Centered absolute value of a field element. +/// +/// Maps canonical representation `v ∈ [0, q)` to `min(v, q − v)`. +#[inline] +pub(crate) fn centered_abs(x: F, modulus: u128) -> u128 { + let v = x.to_canonical_u128(); + let half = modulus / 2; + if v <= half { + v + } else { + modulus - v + } +} + +/// L∞ norm of a single ring element (maximum centered coefficient magnitude). +pub(crate) fn ring_inf_norm( + r: &CyclotomicRing, + modulus: u128, +) -> u128 { + r.coefficients() + .iter() + .map(|c| centered_abs(*c, modulus)) + .max() + .unwrap_or(0) +} + +/// L∞ norm of a vector of ring elements. +pub(crate) fn vec_inf_norm( + v: &[CyclotomicRing], + modulus: u128, +) -> u128 { + v.iter() + .map(|r| ring_inf_norm(r, modulus)) + .max() + .unwrap_or(0) +} diff --git a/src/protocol/commitment_scheme.rs b/src/protocol/commitment_scheme.rs new file mode 100644 index 00000000..d332265d --- /dev/null +++ b/src/protocol/commitment_scheme.rs @@ -0,0 +1,824 @@ +//! Commitment scheme trait implementation. + +use crate::algebra::CyclotomicRing; +use crate::cfg_into_iter; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::primitives::poly::multilinear_lagrange_basis; +use crate::protocol::commitment::onehot::{inner_ajtai_onehot, SparseBlockEntry}; +use crate::protocol::commitment::utils::linear::{ + decompose_block, decompose_rows, mat_vec_mul_ntt_cached, MatrixSlot, +}; +use crate::protocol::commitment::{ + AppendToTranscript, CommitmentConfig, CommitmentScheme, HachiCommitmentCore, HachiProverSetup, + HachiVerifierSetup, RingCommitment, RingCommitmentScheme, StreamingCommitmentScheme, +}; +use crate::protocol::opening_point::RingOpeningPoint; +use crate::protocol::proof::{HachiCommitmentHint, HachiProof, SumcheckAux}; +use crate::protocol::quadratic_equation::QuadraticEquation; +use crate::protocol::ring_switch::{build_w_evals, ring_switch_prover, ring_switch_verifier}; +use crate::protocol::sumcheck::hachi_sumcheck::{HachiSumcheckProver, HachiSumcheckVerifier}; +use crate::protocol::sumcheck::{prove_sumcheck, verify_sumcheck}; +use crate::protocol::transcript::labels::{ + ABSORB_COMMITMENT, ABSORB_EVALUATION_CLAIMS, CHALLENGE_SUMCHECK_BATCH, CHALLENGE_SUMCHECK_ROUND, +}; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt, Polynomial}; + +#[cfg(test)] +use crate::protocol::quadratic_equation::compute_m_a_streaming; +#[cfg(test)] +use crate::protocol::ring_switch::expand_m_a; +#[cfg(test)] +use crate::protocol::transcript::labels::{ + ABSORB_SUMCHECK_W, CHALLENGE_RING_SWITCH, DOMAIN_HACHI_PROTOCOL, +}; +#[cfg(test)] +use crate::protocol::transcript::Blake2bTranscript; +#[cfg(test)] +use crate::protocol::SmallTestCommitmentConfig; + +/// End-to-end PCS wrapper, generic over ring degree `D` and config `Cfg`. +#[derive(Clone, Copy, Debug, Default)] +pub struct HachiCommitmentScheme { + _cfg: std::marker::PhantomData, +} + +impl CommitmentScheme for HachiCommitmentScheme +where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, +{ + type ProverSetup = HachiProverSetup; + type VerifierSetup = HachiVerifierSetup; + type Commitment = RingCommitment; + type Proof = HachiProof; + type OpeningProofHint = HachiCommitmentHint; + + fn setup_prover(max_num_vars: usize) -> Self::ProverSetup { + let (setup, _) = + >::setup(max_num_vars) + .expect("commitment setup failed"); + setup + } + + fn setup_verifier(setup: &Self::ProverSetup) -> Self::VerifierSetup { + HachiVerifierSetup { + expanded: setup.expanded.clone(), + } + } + + fn commit>( + poly: &P, + setup: &Self::ProverSetup, + ) -> Result<(Self::Commitment, Self::OpeningProofHint), HachiError> { + let ring_coeffs = + reduce_coeffs_to_ring_elements::(poly.num_vars(), &poly.coeffs())?; + let w = >::commit_coeffs( + &ring_coeffs, + setup, + )?; + let hint = HachiCommitmentHint { + t_hat: w.t_hat, + ring_coeffs, + }; + Ok((w.commitment, hint)) + } + + fn prove, P: Polynomial>( + setup: &Self::ProverSetup, + poly: &P, + opening_point: &[F], + hint: Option, + transcript: &mut T, + commitment: &Self::Commitment, + ) -> Result { + let hint = hint.ok_or_else(|| { + HachiError::InvalidInput("missing commitment hint for proving".to_string()) + })?; + let _num_vars = poly.num_vars(); + let alpha = Cfg::D.trailing_zeros() as usize; + if opening_point.len() < alpha { + return Err(HachiError::InvalidPointDimension { + expected: alpha, + actual: opening_point.len(), + }); + } + + let layout = >::layout(setup)?; + let target_num_vars = layout.m_vars + layout.r_vars + alpha; + if opening_point.len() > target_num_vars { + return Err(HachiError::InvalidPointDimension { + expected: target_num_vars, + actual: opening_point.len(), + }); + } + let mut padded_point = opening_point.to_vec(); + padded_point.resize(target_num_vars, F::zero()); + let outer_point = &padded_point[alpha..]; + + let ring_opening_point = + ring_opening_point_from_field::(outer_point, layout.r_vars, layout.m_vars)?; + + let y_ring = evaluate_packed_ring_poly::(&hint.ring_coeffs, outer_point); + + // Fiat-Shamir: bind commitment, opening point, and y_ring before any challenges. + commitment.append_to_transcript(ABSORB_COMMITMENT, transcript); + for pt in &padded_point { + transcript.append_field(ABSORB_EVALUATION_CLAIMS, pt); + } + transcript.append_serde(ABSORB_EVALUATION_CLAIMS, &y_ring); + + // §4.2 Quadratic equation + let mut quad_eq = QuadraticEquation::::new_prover( + setup, + &ring_opening_point, + &hint, + transcript, + commitment, + &y_ring, + )?; + + // §4.3 Ring switch + let rs = ring_switch_prover::(&mut quad_eq, &setup.expanded, transcript)?; + + // Sample batching coefficient for fused sumcheck + let batching_coeff: F = transcript.challenge_scalar(CHALLENGE_SUMCHECK_BATCH); + + // Fused sumcheck: norm + relation with shared w_table + let mut fused_prover = HachiSumcheckProver::new( + batching_coeff, + rs.w_evals, + &rs.tau0, + rs.b, + &rs.alpha_evals_y, + &rs.m_evals_x, + rs.num_u, + rs.num_l, + ); + + let (sumcheck_proof, ..) = + prove_sumcheck::(&mut fused_prover, transcript, |tr| { + tr.challenge_scalar(CHALLENGE_SUMCHECK_ROUND) + })?; + + Ok(HachiProof { + v: quad_eq.v, + y_ring, + sumcheck_proof, + sumcheck_aux: SumcheckAux { w: rs.w }, + w_commitment: rs.w_commitment, + }) + } + + fn verify>( + proof: &Self::Proof, + setup: &Self::VerifierSetup, + transcript: &mut T, + opening_point: &[F], + opening: &F, + commitment: &Self::Commitment, + ) -> Result<(), HachiError> { + let alpha_bits = Cfg::D.trailing_zeros() as usize; + if opening_point.len() < alpha_bits { + return Err(HachiError::InvalidSetup( + "opening point length underflow".to_string(), + )); + } + let layout = setup.expanded.seed.layout; + let target_num_vars = layout.m_vars + layout.r_vars + alpha_bits; + let mut padded_point = opening_point.to_vec(); + padded_point.resize(target_num_vars, F::zero()); + let inner_point = &padded_point[..alpha_bits]; + let reduced_opening_point = &padded_point[alpha_bits..]; + + // Fiat-Shamir: bind commitment, opening point, and y_ring before any challenges. + commitment.append_to_transcript(ABSORB_COMMITMENT, transcript); + for pt in &padded_point { + transcript.append_field(ABSORB_EVALUATION_CLAIMS, pt); + } + transcript.append_serde(ABSORB_EVALUATION_CLAIMS, &proof.y_ring); + + // §3.1 trace check + let v = reduce_inner_openings_to_ring_elements::(inner_point)?; + let d = F::from_u64(Cfg::D as u64); + let trace_lhs = trace::(&(proof.y_ring * v.sigma_m1())); + let trace_rhs = d * *opening; + if trace_lhs != trace_rhs { + return Err(HachiError::InvalidProof); + } + + // §4.2 Quadratic equation + let ring_opening_point = ring_opening_point_from_field::( + reduced_opening_point, + layout.r_vars, + layout.m_vars, + )?; + let quad_eq = QuadraticEquation::::new_verifier( + setup, + &ring_opening_point, + &proof.v, + transcript, + commitment, + &proof.y_ring, + )?; + + // §4.3 Ring switch (verifier side) + let rs = ring_switch_verifier::( + &quad_eq, + &setup.expanded, + &proof.sumcheck_aux.w, + &proof.w_commitment, + transcript, + )?; + + // Sample batching coefficient for fused sumcheck (must match prover) + let batching_coeff: F = transcript.challenge_scalar(CHALLENGE_SUMCHECK_BATCH); + + // Build full w_evals for verifier from the witness vector w. + let (w_evals_full, _, _) = build_w_evals(&proof.sumcheck_aux.w, Cfg::D)?; + + // Fused sumcheck verification: norm (F_0) + relation (F_α) + let fused_verifier = HachiSumcheckVerifier::new( + batching_coeff, + w_evals_full, + rs.tau0, + rs.b, + rs.alpha_evals_y, + rs.m_evals_x, + rs.tau1, + proof.v.clone(), + commitment.u.clone(), + proof.y_ring, + rs.alpha, + rs.num_u, + rs.num_l, + ); + + verify_sumcheck::( + &proof.sumcheck_proof, + &fused_verifier, + transcript, + |tr| tr.challenge_scalar(CHALLENGE_SUMCHECK_ROUND), + )?; + + Ok(()) + } + + fn combine_commitments(_commitments: &[Self::Commitment], _coeffs: &[F]) -> Self::Commitment { + unimplemented!() + } + + fn combine_hints(_hints: Vec, _coeffs: &[F]) -> Self::OpeningProofHint { + unimplemented!() + } + + fn protocol_name() -> &'static [u8] { + unimplemented!() + } +} + +/// Commit to a one-hot polynomial, returning both the commitment and a +/// complete `HachiCommitmentHint` (including `ring_coeffs` needed by `prove`). +/// +/// # Errors +/// +/// Returns an error if dimensions are inconsistent, any index is out of +/// range, or the underlying commitment routine fails. +pub fn commit_onehot( + onehot_k: usize, + indices: &[Option], + setup: &HachiProverSetup, +) -> Result<(RingCommitment, HachiCommitmentHint), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, +{ + let num_chunks = indices.len(); + let total_field_elems = num_chunks + .checked_mul(onehot_k) + .ok_or_else(|| HachiError::InvalidInput("T*K overflow".into()))?; + if total_field_elems % D != 0 { + return Err(HachiError::InvalidInput(format!( + "T*K={total_field_elems} is not divisible by D={D}" + ))); + } + + // Build ring_coeffs (needed for prove) from the sparse one-hot indices. + let total_ring_elems = total_field_elems / D; + let mut ring_coeffs = vec![CyclotomicRing::::zero(); total_ring_elems]; + for (c, opt) in indices.iter().enumerate() { + let Some(&idx) = opt.as_ref() else { continue }; + let field_pos = c * onehot_k + idx; + let ring_idx = field_pos / D; + let coeff_idx = field_pos % D; + ring_coeffs[ring_idx].coeffs[coeff_idx] = F::one(); + } + + let w = >::commit_onehot( + onehot_k, indices, setup, + )?; + + let hint = HachiCommitmentHint { + t_hat: w.t_hat, + ring_coeffs, + }; + Ok((w.commitment, hint)) +} + +/// Per-block intermediate state for streaming Hachi commitment. +/// +/// Each chunk corresponds to one Ajtai inner block: `D * block_len` field +/// elements packed into `block_len` ring elements, decomposed, and multiplied +/// by the inner matrix A. +#[derive(Clone, PartialEq, Eq)] +pub struct HachiChunkState { + /// Original ring elements for this block (needed for `ring_coeffs` hint). + pub block: Vec>, + /// Basis-decomposed input vector `s_i = G^{-1}(block)`. + pub s_i: Vec>, + /// Basis-decomposed inner Ajtai output `t̂_i = G^{-1}(A · s_i)`. + pub t_hat_i: Vec>, +} + +impl std::fmt::Debug for HachiChunkState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HachiChunkState") + .field("block_len", &self.block.len()) + .field("s_i_len", &self.s_i.len()) + .field("t_hat_i_len", &self.t_hat_i.len()) + .finish() + } +} + +impl StreamingCommitmentScheme for HachiCommitmentScheme +where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, +{ + type ChunkState = HachiChunkState; + + fn process_chunk(setup: &Self::ProverSetup, chunk: &[F]) -> Self::ChunkState { + assert!( + chunk.len() % D == 0, + "chunk length {} is not divisible by D={}", + chunk.len(), + D + ); + + let block: Vec> = chunk + .chunks_exact(D) + .map(|c| CyclotomicRing::from_coefficients(std::array::from_fn(|j| c[j]))) + .collect(); + + let s_i = decompose_block(&block, Cfg::DELTA, Cfg::LOG_BASIS); + let t_i = + mat_vec_mul_ntt_cached(setup.ntt_cache().expect("NTT cache"), MatrixSlot::A, &s_i) + .expect("inner Ajtai"); + let t_hat_i = decompose_rows(&t_i, Cfg::DELTA, Cfg::LOG_BASIS); + + HachiChunkState { + block, + s_i, + t_hat_i, + } + } + + fn process_chunk_onehot( + setup: &Self::ProverSetup, + onehot_k: usize, + chunk: &[Option], + ) -> Self::ChunkState { + let layout = >::layout(setup) + .expect("layout"); + let block_len = layout.block_len; + + let num_field_elems = chunk.len() * onehot_k; + assert!( + num_field_elems % D == 0, + "chunk cycles * K = {num_field_elems} is not divisible by D={D}", + ); + + // Build sparse entries and original block ring elements. + let num_ring_elems = num_field_elems / D; + let mut ring_block = vec![CyclotomicRing::::zero(); num_ring_elems]; + let mut ring_elem_map: std::collections::BTreeMap> = + std::collections::BTreeMap::new(); + for (c, opt) in chunk.iter().enumerate() { + if let Some(k) = opt { + let field_pos = c * onehot_k + k; + let ring_elem_idx = field_pos / D; + let coeff_idx = field_pos % D; + ring_block[ring_elem_idx].coeffs[coeff_idx] = F::one(); + ring_elem_map + .entry(ring_elem_idx) + .or_default() + .push(coeff_idx); + } + } + + let sparse_entries: Vec = ring_elem_map + .into_iter() + .map(|(ring_elem_idx, nonzero_coeffs)| SparseBlockEntry { + pos_in_block: ring_elem_idx, + nonzero_coeffs, + }) + .collect(); + + let (t_i, s_i) = + inner_ajtai_onehot(&setup.expanded.A, &sparse_entries, block_len, Cfg::DELTA); + let t_hat_i = decompose_rows(&t_i, Cfg::DELTA, Cfg::LOG_BASIS); + + HachiChunkState { + block: ring_block, + s_i, + t_hat_i, + } + } + + fn aggregate_chunks( + setup: &Self::ProverSetup, + _onehot_k: Option, + chunks: &[Self::ChunkState], + ) -> (Self::Commitment, Self::OpeningProofHint) { + let t_hat_flat: Vec> = chunks + .iter() + .flat_map(|c| c.t_hat_i.iter().copied()) + .collect(); + + let u = mat_vec_mul_ntt_cached( + setup.ntt_cache().expect("NTT cache"), + MatrixSlot::B, + &t_hat_flat, + ) + .expect("outer Ajtai"); + + let t_hat_all: Vec>> = + chunks.iter().map(|c| c.t_hat_i.clone()).collect(); + let ring_coeffs: Vec> = chunks + .iter() + .flat_map(|c| c.block.iter().copied()) + .collect(); + + let commitment = RingCommitment { u }; + let hint = HachiCommitmentHint { + t_hat: t_hat_all, + ring_coeffs, + }; + (commitment, hint) + } +} + +/// Re-derive the ring-switch challenge `alpha` and the expanded `M_a` vector +/// by replaying the transcript from the proof data and setup, exactly as the +/// verifier does. +#[cfg(test)] +pub(crate) fn rederive_alpha_and_m_a( + proof: &HachiProof, + setup: &HachiVerifierSetup, + opening_point: &[F], + commitment: &RingCommitment, +) -> Result<(F, Vec), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + 'static, + Cfg: CommitmentConfig, +{ + let alpha_bits = Cfg::D.trailing_zeros() as usize; + if opening_point.len() < alpha_bits { + return Err(HachiError::InvalidSetup( + "opening point length underflow".to_string(), + )); + } + let layout = setup.expanded.seed.layout; + let ring_opening_point = ring_opening_point_from_field::( + &opening_point[alpha_bits..], + layout.r_vars, + layout.m_vars, + )?; + let mut transcript = Blake2bTranscript::::new(DOMAIN_HACHI_PROTOCOL); + + // Replay the same Fiat-Shamir absorptions the real verifier performs. + commitment.append_to_transcript(ABSORB_COMMITMENT, &mut transcript); + for pt in opening_point { + transcript.append_field(ABSORB_EVALUATION_CLAIMS, pt); + } + transcript.append_serde(ABSORB_EVALUATION_CLAIMS, &proof.y_ring); + + let quad_eq = QuadraticEquation::::new_verifier( + setup, + &ring_opening_point, + &proof.v, + &mut transcript, + commitment, + &proof.y_ring, + )?; + transcript.append_serde(ABSORB_SUMCHECK_W, &proof.w_commitment); + let alpha: F = transcript.challenge_scalar(CHALLENGE_RING_SWITCH); + let m_a = compute_m_a_streaming::( + &setup.expanded, + quad_eq.opening_point(), + &quad_eq.challenges, + &alpha, + )?; + let m_a_vec = expand_m_a::(&m_a, alpha)?; + Ok((alpha, m_a_vec)) +} + +fn lagrange_weights(point: &[F]) -> Vec { + let len = 1usize << point.len(); + let mut weights = vec![F::zero(); len]; + multilinear_lagrange_basis(&mut weights, point); + weights +} + +fn ring_opening_point_from_field( + opening_point: &[F], + r_vars: usize, + m_vars: usize, +) -> Result, HachiError> { + let expected_len = r_vars + .checked_add(m_vars) + .ok_or_else(|| HachiError::InvalidSetup("opening point length overflow".to_string()))?; + if opening_point.len() != expected_len { + return Err(HachiError::InvalidPointDimension { + expected: expected_len, + actual: opening_point.len(), + }); + } + + // Sequential ordering: M variables (position in block) come first, + // R variables (block selection) come second. + let a = lagrange_weights(&opening_point[..m_vars]); + let b = lagrange_weights(&opening_point[m_vars..]); + Ok(RingOpeningPoint { a, b }) +} + +fn reduce_coeffs_to_ring_elements( + num_vars: usize, + coeffs: &[F], +) -> Result>, HachiError> { + if D == 0 || !D.is_power_of_two() { + return Err(HachiError::InvalidInput(format!( + "ring degree D={D} is not a power of two" + ))); + } + let alpha = D.trailing_zeros() as usize; + if num_vars < alpha { + return Err(HachiError::InvalidInput(format!( + "num_vars {num_vars} is smaller than alpha {alpha}" + ))); + } + + let expected_len = 1usize + .checked_shl(num_vars as u32) + .ok_or_else(|| HachiError::InvalidInput(format!("2^{num_vars} does not fit usize")))?; + if coeffs.len() != expected_len { + return Err(HachiError::InvalidSize { + expected: expected_len, + actual: coeffs.len(), + }); + } + + // Sequential packing: ring element i = coeffs[i*D .. (i+1)*D]. + // The first alpha variables (LSBs) become coefficient slots within each + // ring element; the remaining outer_vars variables index ring elements. + let outer_len = expected_len / D; + let out: Vec> = cfg_into_iter!(0..outer_len) + .map(|i| { + let ring_coeffs = std::array::from_fn(|j| coeffs[i * D + j]); + CyclotomicRing::from_coefficients(ring_coeffs) + }) + .collect(); + Ok(out) +} + +fn reduce_inner_openings_to_ring_elements( + inner_point: &[F], +) -> Result, HachiError> { + let weights = lagrange_weights(inner_point); + if weights.len() != D { + return Err(HachiError::InvalidInput(format!( + "inner basis length {} does not match D={D}", + weights.len() + ))); + } + let coeffs = std::array::from_fn(|i| weights[i]); + Ok(CyclotomicRing::from_coefficients(coeffs)) +} + +fn evaluate_packed_ring_poly( + packed_coeffs: &[CyclotomicRing], + outer_point: &[F], +) -> CyclotomicRing { + let weights = lagrange_weights(outer_point); + debug_assert!(weights.len() >= packed_coeffs.len()); + #[cfg(feature = "parallel")] + { + packed_coeffs + .par_iter() + .zip(weights.par_iter()) + .fold( + || CyclotomicRing::::zero(), + |acc, (f_i, w_i)| acc + f_i.scale(w_i), + ) + .reduce(|| CyclotomicRing::::zero(), |a, b| a + b) + } + #[cfg(not(feature = "parallel"))] + { + packed_coeffs + .iter() + .zip(weights.iter()) + .fold(CyclotomicRing::::zero(), |acc, (f_i, w_i)| { + acc + f_i.scale(w_i) + }) + } +} + +fn trace(u: &CyclotomicRing) -> F { + let d = F::from_u64(D as u64); + u.coefficients()[0] * d +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::primitives::multilinear_evals::DenseMultilinearEvals; + use crate::protocol::commitment::CommitmentConfig; + use crate::protocol::transcript::Blake2bTranscript; + use crate::test_utils::F; + use crate::{CommitmentScheme, FromSmallInt, Polynomial}; + + type Cfg = SmallTestCommitmentConfig; + type Scheme = HachiCommitmentScheme<{ Cfg::D }, Cfg>; + + #[test] + fn verify_passes_for_consistent_opening() { + let alpha = Cfg::D.trailing_zeros() as usize; + let layout = Cfg::commitment_layout(16).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha; + let len = 1usize << num_vars; + + let evals: Vec = (0..len).map(|i| F::from_u64(i as u64)).collect(); + let poly = DenseMultilinearEvals::new_padded(evals); + + let setup = >::setup_prover(num_vars); + let verifier_setup = >::setup_verifier(&setup); + + let (commitment, hint) = >::commit(&poly, &setup).unwrap(); + + let opening_point: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + let opening = poly.evaluate(&opening_point); + + let mut prover_transcript = Blake2bTranscript::::new(b"test/prove"); + let proof = >::prove( + &setup, + &poly, + &opening_point, + Some(hint), + &mut prover_transcript, + &commitment, + ) + .unwrap(); + + let mut verifier_transcript = Blake2bTranscript::::new(b"test/prove"); + let result = >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + ); + + assert!(result.is_ok()); + } + + #[test] + fn verify_rejects_wrong_opening() { + let alpha = Cfg::D.trailing_zeros() as usize; + let layout = Cfg::commitment_layout(16).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha; + let len = 1usize << num_vars; + + let evals: Vec = (0..len).map(|i| F::from_u64(i as u64)).collect(); + let poly = DenseMultilinearEvals::new_padded(evals); + + let setup = >::setup_prover(num_vars); + let verifier_setup = >::setup_verifier(&setup); + + let (commitment, hint) = >::commit(&poly, &setup).unwrap(); + + let opening_point: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + let opening = poly.evaluate(&opening_point); + + let mut prover_transcript = Blake2bTranscript::::new(b"test/prove"); + let proof = >::prove( + &setup, + &poly, + &opening_point, + Some(hint), + &mut prover_transcript, + &commitment, + ) + .unwrap(); + + let wrong_opening = opening + F::one(); + let mut verifier_transcript = Blake2bTranscript::::new(b"test/prove"); + let result = >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &wrong_opening, + &commitment, + ); + + assert!( + result.is_err(), + "verify must reject an incorrect opening value" + ); + } + + #[test] + fn streaming_commit_matches_non_streaming() { + let alpha = Cfg::D.trailing_zeros() as usize; + let layout = Cfg::commitment_layout(16).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha; + let len = 1usize << num_vars; + + let evals: Vec = (0..len).map(|i| F::from_u64(i as u64)).collect(); + let poly = DenseMultilinearEvals::new_padded(evals.clone()); + + let setup = >::setup_prover(num_vars); + + // Non-streaming commit + let (non_streaming_commitment, non_streaming_hint) = + >::commit(&poly, &setup).unwrap(); + + // Streaming commit: split field elements into chunks of D * block_len + let chunk_size = Cfg::D * layout.block_len; + let chunks: Vec> = evals + .chunks_exact(chunk_size) + .map(|chunk| >::process_chunk(&setup, chunk)) + .collect(); + + let (streaming_commitment, streaming_hint) = + >::aggregate_chunks(&setup, None, &chunks); + + assert_eq!(non_streaming_commitment, streaming_commitment); + assert_eq!(non_streaming_hint, streaming_hint); + } + + #[test] + fn streaming_commit_then_prove_verify() { + let alpha = Cfg::D.trailing_zeros() as usize; + let layout = Cfg::commitment_layout(16).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha; + let len = 1usize << num_vars; + + let evals: Vec = (0..len).map(|i| F::from_u64(i as u64)).collect(); + let poly = DenseMultilinearEvals::new_padded(evals.clone()); + + let setup = >::setup_prover(num_vars); + let verifier_setup = >::setup_verifier(&setup); + + // Streaming commit + let chunk_size = Cfg::D * layout.block_len; + let chunks: Vec> = evals + .chunks_exact(chunk_size) + .map(|chunk| >::process_chunk(&setup, chunk)) + .collect(); + let (commitment, hint) = + >::aggregate_chunks(&setup, None, &chunks); + + // Prove and verify with streaming-produced commitment + hint + let opening_point: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + let opening = poly.evaluate(&opening_point); + + let mut prover_transcript = Blake2bTranscript::::new(b"test/stream"); + let proof = >::prove( + &setup, + &poly, + &opening_point, + Some(hint), + &mut prover_transcript, + &commitment, + ) + .unwrap(); + + let mut verifier_transcript = Blake2bTranscript::::new(b"test/stream"); + let result = >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + ); + assert!( + result.is_ok(), + "streaming commit should produce valid proofs" + ); + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs new file mode 100644 index 00000000..27b5ac70 --- /dev/null +++ b/src/protocol/mod.rs @@ -0,0 +1,33 @@ +//! Protocol-layer transcript and commitment abstractions. +//! +//! This module defines the Hachi-native protocol interfaces used by higher-level +//! proof logic. It intentionally stays independent from external integration +//! details (for example, Jolt wiring). + +pub mod challenges; +pub mod commitment; +pub mod commitment_scheme; +pub mod opening_point; +pub mod proof; +pub mod quadratic_equation; +pub mod ring_switch; +pub mod sumcheck; +pub mod transcript; + +pub use commitment::{ + AppendToTranscript, CommitmentConfig, CommitmentScheme, DummyProof, + DynamicSmallTestCommitmentConfig, HachiCommitment, HachiCommitmentCore, HachiCommitmentLayout, + HachiExpandedSetup, HachiOpeningClaim, HachiOpeningPoint, HachiPreparedSetup, HachiProverSetup, + HachiSetupSeed, HachiVerifierSetup, ProductionFp128CommitmentConfig, RingCommitment, + RingCommitmentScheme, SmallTestCommitmentConfig, StreamingCommitmentScheme, +}; +pub use commitment_scheme::{commit_onehot, HachiChunkState, HachiCommitmentScheme}; +pub use opening_point::RingOpeningPoint; +pub use proof::{HachiProof, SumcheckAux}; +pub use quadratic_equation::QuadraticEquation; +pub use sumcheck::batched_sumcheck::{prove_batched_sumcheck, verify_batched_sumcheck}; +pub use sumcheck::{ + prove_sumcheck, verify_sumcheck, CompressedUniPoly, SumcheckInstanceProver, + SumcheckInstanceVerifier, SumcheckProof, UniPoly, +}; +pub use transcript::{sample_ext_challenge, Blake2bTranscript, KeccakTranscript, Transcript}; diff --git a/src/protocol/opening_point.rs b/src/protocol/opening_point.rs new file mode 100644 index 00000000..154402a9 --- /dev/null +++ b/src/protocol/opening_point.rs @@ -0,0 +1,20 @@ +//! Ring-native opening point for the Hachi protocol. + +use crate::FieldCore; + +/// Ring-native opening point storing field scalars (Lagrange weights). +/// +/// Contains the two vectors used by the §4.2 prover: +/// - `a`: evaluation vector of length `2^m` (inner-block coordinates). +/// - `b`: block-select vector of length `2^r` (outer coordinates). +/// +/// These are raw field scalars, not ring elements — they originate from +/// multilinear Lagrange basis evaluations and are always constant (scalar) +/// ring elements when embedded into the ring. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RingOpeningPoint { + /// Evaluation vector of length `2^m` (field scalars). + pub a: Vec, + /// Block-select vector of length `2^r` (field scalars). + pub b: Vec, +} diff --git a/src/protocol/proof.rs b/src/protocol/proof.rs new file mode 100644 index 00000000..c020f7b7 --- /dev/null +++ b/src/protocol/proof.rs @@ -0,0 +1,170 @@ +//! Proof structures for the Hachi protocol. + +use crate::algebra::CyclotomicRing; +use crate::primitives::serialization::{Compress, SerializationError}; +use crate::primitives::serialization::{Valid, Validate}; +use crate::protocol::commitment::RingCommitment; +use crate::protocol::sumcheck::SumcheckProof; +use crate::{FieldCore, HachiDeserialize, HachiSerialize}; +use std::io::{Read, Write}; + +/// Prover-side hint produced at commitment time. +/// +/// Stores the ring-level coefficients and the decomposed inner-Ajtai outputs +/// `t̂_i`. The basis-decomposed inputs `s_i` are NOT stored; they are +/// recomputed from `ring_coeffs` during proving to avoid multi-GB memory +/// usage at production parameters. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiCommitmentHint { + /// Decomposed `t̂_i` blocks from the commitment phase. + pub t_hat: Vec>>, + /// Ring coefficients from the §3.1 reduction (evaluation table). + pub ring_coeffs: Vec>, +} + +/// Temporary auxiliary data the verifier needs for sumcheck output verification. +/// +/// Will be removed once recursive PCS evaluation proofs replace the direct +/// oracle check at the end of each sumcheck instance. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SumcheckAux { + /// `w` coefficients (z and r coefficients, concatenated). The verifier + /// reshapes this into sumcheck evaluation form to compute the expected + /// output claims for F_0 and F_alpha. + pub w: Vec, +} + +/// Hachi Proof for One Iteration. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiProof { + /// `y_ring` from the §3.1 reduction. + pub y_ring: CyclotomicRing, + /// `v = D · ŵ`. + pub v: Vec>, + /// Batched sumcheck proof (F_0 norm + F_α relation, §4.3). + pub sumcheck_proof: SumcheckProof, + /// Temporary verifier auxiliary (will be removed with recursive PCS). + pub sumcheck_aux: SumcheckAux, + /// Commitment to the sumcheck witness `w`. + pub w_commitment: RingCommitment, +} + +impl HachiProof { + /// Returns the proof size in bytes (uncompressed). + pub fn size(&self) -> usize { + self.v.serialized_size(Compress::No) + + self.y_ring.serialized_size(Compress::No) + + self.sumcheck_aux.w.serialized_size(Compress::No) + + self.sumcheck_proof.serialized_size(Compress::No) + + self.w_commitment.serialized_size(Compress::No) + } +} + +impl HachiSerialize for SumcheckAux { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.w.serialize_with_mode(writer, compress) + } + fn serialized_size(&self, compress: Compress) -> usize { + self.w.serialized_size(compress) + } +} + +impl Valid for SumcheckAux { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl HachiDeserialize for SumcheckAux { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + Ok(Self { + w: Vec::::deserialize_with_mode(reader, compress, validate)?, + }) + } +} + +impl HachiSerialize for HachiCommitmentHint { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.t_hat.serialize_with_mode(&mut writer, compress)?; + self.ring_coeffs.serialize_with_mode(&mut writer, compress) + } + fn serialized_size(&self, compress: Compress) -> usize { + self.t_hat.serialized_size(compress) + self.ring_coeffs.serialized_size(compress) + } +} + +impl Valid for HachiCommitmentHint { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl HachiDeserialize for HachiCommitmentHint { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + Ok(Self { + t_hat: Vec::deserialize_with_mode(&mut reader, compress, validate)?, + ring_coeffs: Vec::deserialize_with_mode(&mut reader, compress, validate)?, + }) + } +} + +impl HachiSerialize for HachiProof { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.y_ring.serialize_with_mode(&mut writer, compress)?; + self.v.serialize_with_mode(&mut writer, compress)?; + self.sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + self.sumcheck_aux + .serialize_with_mode(&mut writer, compress)?; + self.w_commitment.serialize_with_mode(&mut writer, compress) + } + fn serialized_size(&self, compress: Compress) -> usize { + self.y_ring.serialized_size(compress) + + self.v.serialized_size(compress) + + self.sumcheck_proof.serialized_size(compress) + + self.sumcheck_aux.serialized_size(compress) + + self.w_commitment.serialized_size(compress) + } +} + +impl Valid for HachiProof { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl HachiDeserialize for HachiProof { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + Ok(Self { + y_ring: CyclotomicRing::deserialize_with_mode(&mut reader, compress, validate)?, + v: Vec::deserialize_with_mode(&mut reader, compress, validate)?, + sumcheck_proof: SumcheckProof::deserialize_with_mode(&mut reader, compress, validate)?, + sumcheck_aux: SumcheckAux::deserialize_with_mode(&mut reader, compress, validate)?, + w_commitment: RingCommitment::deserialize_with_mode(&mut reader, compress, validate)?, + }) + } +} diff --git a/src/protocol/quadratic_equation.rs b/src/protocol/quadratic_equation.rs new file mode 100644 index 00000000..092a017e --- /dev/null +++ b/src/protocol/quadratic_equation.rs @@ -0,0 +1,881 @@ +//! Quadratic equation builder for the Hachi PCS (§4.2). +//! +//! This module encapsulates the stage-1 prover logic and the generation of +//! the quadratic equation components M, y, z, and v. + +use crate::algebra::{CyclotomicRing, SparseChallenge, SparseChallengeConfig}; +use crate::cfg_iter; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::challenges::sparse::sample_sparse_challenges; +use crate::protocol::commitment::utils::crt_ntt::NttMatrixCache; +use crate::protocol::commitment::utils::linear::{ + decompose_block, mat_vec_mul_ntt_cached, MatrixSlot, +}; +use crate::protocol::commitment::utils::norm::{detect_field_modulus, vec_inf_norm}; +use crate::protocol::commitment::{ + CommitmentConfig, HachiCommitmentLayout, HachiExpandedSetup, HachiProverSetup, + HachiVerifierSetup, RingCommitment, +}; +use crate::protocol::opening_point::RingOpeningPoint; +use crate::protocol::proof::HachiCommitmentHint; +use crate::protocol::ring_switch::eval_ring_at; +use crate::protocol::transcript::labels::{ABSORB_PROVER_V, CHALLENGE_STAGE1_FOLD}; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; + +/// **Steps 1–3.** Compute `w_i = a^T G_{2^m} s_i` and decompose: `ŵ_i = G_1^{-1}(w_i)`. +/// +/// Recomputes each block's `s_i` from `ring_coeffs` on the fly to avoid +/// storing all `s_i` simultaneously (which can be tens of GB at production +/// parameters). +fn compute_w_hat( + opening_point: &RingOpeningPoint, + ring_coeffs: &[CyclotomicRing], + layout: HachiCommitmentLayout, +) -> Vec>> +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, +{ + let a = &opening_point.a; + let block_len = layout.block_len; + let delta = Cfg::DELTA; + let log_basis = Cfg::LOG_BASIS; + + debug_assert_eq!(a.len(), block_len); + + let blocks: Vec<&[CyclotomicRing]> = (0..layout.num_blocks) + .map(|i| { + let start = i * block_len; + let end = (start + block_len).min(ring_coeffs.len()); + if start < ring_coeffs.len() { + &ring_coeffs[start..end] + } else { + &[] as &[CyclotomicRing] + } + }) + .collect(); + + cfg_iter!(blocks) + .map(|block| { + let s_i = decompose_block(block, delta, log_basis); + let mut w_i = CyclotomicRing::::zero(); + for (j, a_j) in a.iter().enumerate().take(block_len) { + let start = j * delta; + let end = start + delta; + let recomp_j = CyclotomicRing::gadget_recompose_pow2(&s_i[start..end], log_basis); + w_i += recomp_j.scale(a_j); + } + w_i.balanced_decompose_pow2(delta, log_basis) + }) + .collect() +} + +/// **Step 4.** Compute `v = D · ŵ` (first prover message). +fn compute_v( + cache: &NttMatrixCache, + w_hat: &[Vec>], +) -> Result>, HachiError> { + let w_hat_flat: Vec> = + w_hat.iter().flat_map(|v| v.iter().copied()).collect(); + mat_vec_mul_ntt_cached(cache, MatrixSlot::D, &w_hat_flat) +} + +/// **Steps 7–9.** Fold `z_pre = Σ c_i · s_i` and check `‖z_pre‖_∞ ≤ β`. +/// +/// Returns the pre-decomposition `z` vector (before gadget decomposition into +/// `ẑ = J^{-1}(z_pre)`). Callers that need `z_hat` can apply +/// `balanced_decompose_pow2(TAU, LOG_BASIS)` themselves. +fn compute_z_pre( + ring_coeffs: &[CyclotomicRing], + challenges: &[SparseChallenge], + layout: HachiCommitmentLayout, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, +{ + let block_len = layout.block_len; + let delta = Cfg::DELTA; + let log_basis = Cfg::LOG_BASIS; + let inner_width = block_len * delta; + + debug_assert_eq!(challenges.len(), layout.num_blocks); + + let mut z = vec![CyclotomicRing::::zero(); inner_width]; + + for (i, c_i) in challenges.iter().enumerate() { + let start = i * block_len; + let end = (start + block_len).min(ring_coeffs.len()); + let block = if start < ring_coeffs.len() { + &ring_coeffs[start..end] + } else { + &[] as &[CyclotomicRing] + }; + let s_i = decompose_block(block, delta, log_basis); + for (j, z_j) in z.iter_mut().enumerate() { + *z_j += s_i[j].mul_by_sparse(c_i); + } + } + + let modulus = detect_field_modulus::(); + let norm = vec_inf_norm(&z, modulus); + let beta = Cfg::beta_bound(layout)?; + if norm > beta { + return Err(HachiError::InvalidInput(format!( + "prover abort: ||z||_inf = {norm} > beta = {beta}" + ))); + } + + Ok(z) +} + +/// Stage-1 quadratic equation state for the Hachi protocol. +/// +/// Encapsulates the relation $M(x) \cdot z = y(x) + (X^D + 1) \cdot r(x)$ +/// along with intermediate prover witness data (`w_hat`, `z_pre`, `hint`). +/// +/// M and z are never materialized — split-eq factoring computes their +/// products on-the-fly via `compute_r_split_eq` and `compute_m_a_streaming`. +pub struct QuadraticEquation { + /// Stage-1 proof vector `v = D · ŵ`. + pub v: Vec>, + /// Stage-1 folding challenges (sparse representation). + pub challenges: Vec, + /// Vector `y`. + y: Vec>, + /// Opening point (a, b) Lagrange weights. + opening_point: RingOpeningPoint, + /// Pre-decomposition folded witness `z_pre = Σ c_i · s_i` (prover only). + /// Replaces both `z_hat` and `z`: `z_hat = J^{-1}(z_pre)`. + z_pre: Option>>, + /// Decomposed `ŵ_i = G_1^{-1}(w_i)` (prover only). + w_hat: Option>>>, + /// Commitment hint (prover only). + hint: Option>, + + _marker: std::marker::PhantomData, +} + +impl QuadraticEquation +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, +{ + /// Prover constructor: runs §4.2 stage 1 and builds all equation components. + /// + /// # Errors + /// + /// Returns an error if the norm check, challenge sampling, or matrix + /// generation fails. + pub fn new_prover>( + setup: &HachiProverSetup, + ring_opening_point: &RingOpeningPoint, + hint: &HachiCommitmentHint, + transcript: &mut T, + commitment: &RingCommitment, + y_ring: &CyclotomicRing, + ) -> Result { + let layout = setup.layout(); + let w_hat = compute_w_hat::(ring_opening_point, &hint.ring_coeffs, layout); + let v = compute_v(setup.ntt_cache()?, &w_hat)?; + + // Step 5: append v to transcript + transcript.append_serde(ABSORB_PROVER_V, &v); + + // Step 6: sample sparse folding challenges + let challenge_cfg = SparseChallengeConfig { + weight: Cfg::CHALLENGE_WEIGHT, + nonzero_coeffs: vec![-1, 1], + }; + let challenges = sample_sparse_challenges::( + transcript, + CHALLENGE_STAGE1_FOLD, + layout.num_blocks, + &challenge_cfg, + )?; + + let z_pre = compute_z_pre::(&hint.ring_coeffs, &challenges, layout)?; + + let y = generate_y::(&v, &commitment.u, y_ring)?; + + Ok(Self { + v, + challenges, + y, + opening_point: ring_opening_point.clone(), + z_pre: Some(z_pre), + w_hat: Some(w_hat), + hint: Some(hint.clone()), + _marker: std::marker::PhantomData, + }) + } + + /// Verifier constructor: Derives challenges and computes M and y. + /// + /// # Errors + /// + /// Returns an error if challenge derivation fails. + pub fn new_verifier>( + setup: &HachiVerifierSetup, + ring_opening_point: &RingOpeningPoint, + v: &Vec>, + transcript: &mut T, + commitment: &RingCommitment, + y_ring: &CyclotomicRing, + ) -> Result { + let layout = setup.expanded.seed.layout; + let challenges = + derive_stage1_challenges::(transcript, v, layout.num_blocks)?; + let y = generate_y::(v, &commitment.u, y_ring)?; + + Ok(Self { + v: v.to_vec(), + challenges, + y, + opening_point: ring_opening_point.clone(), + z_pre: None, + w_hat: None, + hint: None, + _marker: std::marker::PhantomData, + }) + } + + /// Get the vector y. + pub fn y(&self) -> &[CyclotomicRing] { + &self.y + } + + /// Get the vector v. + pub fn v(&self) -> &[CyclotomicRing] { + &self.v + } + + /// Get the opening point (a, b) Lagrange weights. + pub fn opening_point(&self) -> &RingOpeningPoint { + &self.opening_point + } + + /// Get the pre-decomposition folded witness `z_pre` (prover only). + pub fn z_pre(&self) -> Option<&[CyclotomicRing]> { + self.z_pre.as_deref() + } + + /// Take ownership of `z_pre`, leaving `None` in its place. + pub fn take_z_pre(&mut self) -> Option>> { + self.z_pre.take() + } + + /// Get the decomposed witness `ŵ` (prover only). + pub fn w_hat(&self) -> Option<&[Vec>]> { + self.w_hat.as_deref() + } + + /// Take ownership of `w_hat`, leaving `None` in its place. + pub fn take_w_hat(&mut self) -> Option>>> { + self.w_hat.take() + } + + /// Get the commitment hint (prover only). + pub fn hint(&self) -> Option<&HachiCommitmentHint> { + self.hint.as_ref() + } + + /// Take ownership of the hint, leaving `None` in its place. + pub fn take_hint(&mut self) -> Option> { + self.hint.take() + } +} + +pub(crate) fn derive_stage1_challenges( + transcript: &mut T, + v: &Vec>, + num_blocks: usize, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + let challenge_cfg = SparseChallengeConfig { + weight: Cfg::CHALLENGE_WEIGHT, + nonzero_coeffs: vec![-1, 1], + }; + transcript.append_serde(ABSORB_PROVER_V, v); + sample_sparse_challenges::( + transcript, + CHALLENGE_STAGE1_FOLD, + num_blocks, + &challenge_cfg, + ) +} + +fn gadget_row_scalars(levels: usize, log_basis: u32) -> Vec { + let base = F::from_canonical_u128_reduced(1u128 << log_basis); + let mut out = Vec::with_capacity(levels); + let mut power = F::one(); + for _ in 0..levels { + out.push(power); + power = power * base; + } + out +} + +/// Accumulate unreduced polynomial product `a * b` into `poly` (length 2D-1). +fn add_unreduced_product( + poly: &mut [F], + a: &CyclotomicRing, + b: &CyclotomicRing, +) { + if a.is_zero() { + return; + } + let ac = a.coefficients(); + let bc = b.coefficients(); + let is_scalar = ac[1..].iter().all(|c| c.is_zero()); + if is_scalar { + let s = ac[0]; + for k in 0..D { + poly[k] = poly[k] + s * bc[k]; + } + } else { + for t in 0..D { + for s in 0..D { + poly[t + s] = poly[t + s] + ac[t] * bc[s]; + } + } + } +} + +/// Accumulate negated unreduced product `-a * b` into `poly`. +fn sub_unreduced_product( + poly: &mut [F], + a: &CyclotomicRing, + b: &CyclotomicRing, +) { + if a.is_zero() { + return; + } + let ac = a.coefficients(); + let bc = b.coefficients(); + let is_scalar = ac[1..].iter().all(|c| c.is_zero()); + if is_scalar { + let s = ac[0]; + for k in 0..D { + poly[k] = poly[k] - s * bc[k]; + } + } else { + for t in 0..D { + for s in 0..D { + poly[t + s] = poly[t + s] - ac[t] * bc[s]; + } + } + } +} + +/// Add scalar * ring_element into the low-D coefficients of `poly`. +/// scalar * ring produces degree D-1, so no high-half contribution. +fn add_scalar_ring_product( + poly: &mut [F], + scalar: &F, + ring: &CyclotomicRing, +) { + for (k, coeff) in ring.coefficients().iter().enumerate() { + poly[k] = poly[k] + *scalar * *coeff; + } +} + +/// Subtract scalar * ring_element from the low-D coefficients of `poly`. +fn sub_scalar_ring_product( + poly: &mut [F], + scalar: &F, + ring: &CyclotomicRing, +) { + for (k, coeff) in ring.coefficients().iter().enumerate() { + poly[k] = poly[k] - *scalar * *coeff; + } +} + +/// Add sparse_challenge * ring_element as unreduced product into `poly`. +fn add_sparse_ring_product( + poly: &mut [F], + challenge: &SparseChallenge, + ring: &CyclotomicRing, +) { + let dense: CyclotomicRing = challenge.to_dense().expect("valid sparse challenge"); + add_unreduced_product(poly, &dense, ring); +} + +/// Split-eq replacement for `generate_m` + `compute_r_via_poly_division`. +/// +/// Computes `r` such that `M·z = y + (X^D+1)·r` without materializing M or z. +/// Uses split-eq factoring: `kron(left, gadget) · decomposed = left · pre_decomp`. +pub(crate) fn compute_r_split_eq( + setup: &HachiExpandedSetup, + opening_point: &RingOpeningPoint, + challenges: &[SparseChallenge], + w_hat: &[Vec>], + t_hat: &[Vec>], + z_pre: &[CyclotomicRing], + y: &[CyclotomicRing], +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, +{ + let poly_len = 2 * D - 1; + let num_rows = Cfg::N_D + Cfg::N_B + 1 + 1 + Cfg::N_A; + + let w_hat_flat: Vec> = + w_hat.iter().flat_map(|v| v.iter().copied()).collect(); + let t_hat_flat: Vec> = + t_hat.iter().flat_map(|v| v.iter().copied()).collect(); + + let mut result = Vec::with_capacity(num_rows); + + for (row_idx, y_i) in y.iter().enumerate().take(num_rows) { + let mut poly = vec![F::zero(); poly_len]; + + if row_idx < Cfg::N_D { + let d_row = &setup.D[row_idx]; + for (m_ij, z_j) in d_row.iter().zip(w_hat_flat.iter()) { + add_unreduced_product(&mut poly, m_ij, z_j); + } + } else if row_idx < Cfg::N_D + Cfg::N_B { + let b_row = &setup.B[row_idx - Cfg::N_D]; + for (m_ij, z_j) in b_row.iter().zip(t_hat_flat.iter()) { + add_unreduced_product(&mut poly, m_ij, z_j); + } + } else if row_idx == Cfg::N_D + Cfg::N_B { + // row3: b . w_recomp (scalar * ring, degree D-1) + for (i, w_hat_i) in w_hat.iter().enumerate() { + let w_recomp = CyclotomicRing::gadget_recompose_pow2(w_hat_i, Cfg::LOG_BASIS); + add_scalar_ring_product(&mut poly, &opening_point.b[i], &w_recomp); + } + } else if row_idx == Cfg::N_D + Cfg::N_B + 1 { + // row4 w-segment: c . w_recomp (sparse*ring unreduced) + for (i, w_hat_i) in w_hat.iter().enumerate() { + let w_recomp = CyclotomicRing::gadget_recompose_pow2(w_hat_i, Cfg::LOG_BASIS); + add_sparse_ring_product(&mut poly, &challenges[i], &w_recomp); + } + // row4 z-segment: -(a . z_pre_recomp), double-factored + // z_pre_recomp[i] = gadget_recompose(z_pre[i*DELTA..(i+1)*DELTA]) + let block_len = opening_point.a.len(); + for i in 0..block_len { + let start = i * Cfg::DELTA; + let end = start + Cfg::DELTA; + if end <= z_pre.len() { + let z_pre_recomp = + CyclotomicRing::gadget_recompose_pow2(&z_pre[start..end], Cfg::LOG_BASIS); + sub_scalar_ring_product(&mut poly, &opening_point.a[i], &z_pre_recomp); + } + } + } else { + // row5 (N_A rows) + let a_idx = row_idx - (Cfg::N_D + Cfg::N_B + 2); + // t-segment: c . t_recomp[a_idx] (sparse*ring unreduced) + for (i, t_hat_i) in t_hat.iter().enumerate() { + let start = a_idx * Cfg::DELTA; + let end = start + Cfg::DELTA; + if end <= t_hat_i.len() { + let t_recomp = + CyclotomicRing::gadget_recompose_pow2(&t_hat_i[start..end], Cfg::LOG_BASIS); + add_sparse_ring_product(&mut poly, &challenges[i], &t_recomp); + } + } + // z-segment: -(A[a_idx] . z_pre) (ring*ring unreduced) + let a_row = &setup.A[a_idx]; + for (m_ij, z_j) in a_row.iter().zip(z_pre.iter()) { + sub_unreduced_product(&mut poly, m_ij, z_j); + } + } + + let y_coeffs = y_i.coefficients(); + for k in 0..D { + poly[k] = poly[k] - y_coeffs[k]; + } + + // Divide by X^D + 1 + let mut quotient = vec![F::zero(); D]; + for k in (D..poly_len).rev() { + let q = poly[k]; + quotient[k - D] = q; + poly[k - D] = poly[k - D] - q; + } + let coeffs: [F; D] = std::array::from_fn(|k| quotient[k]); + result.push(CyclotomicRing::from_coefficients(coeffs)); + } + + Ok(result) +} + +/// Split-eq replacement for `generate_m` + `eval_ring_matrix_at`. +/// +/// Computes the field-element evaluations of each M entry at `alpha`, +/// organized as rows of field elements, without materializing M. +pub(crate) fn compute_m_a_streaming( + setup: &HachiExpandedSetup, + opening_point: &RingOpeningPoint, + challenges: &[SparseChallenge], + alpha: &F, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, +{ + let layout = setup.seed.layout; + let num_blocks = layout.num_blocks; + let block_len = layout.block_len; + let w_len = Cfg::DELTA * num_blocks; + let t_len = Cfg::DELTA * Cfg::N_A * num_blocks; + let z_len = Cfg::TAU * Cfg::DELTA * block_len; + let total_cols = w_len + t_len + z_len; + + let g1 = gadget_row_scalars::(Cfg::DELTA, Cfg::LOG_BASIS); + let j1 = gadget_row_scalars::(Cfg::TAU, Cfg::LOG_BASIS); + + // Pre-evaluate alpha powers for gadget scalars (already field elements) + // g1 and j1 are already field scalars, so eval_ring_at(constant(g), alpha) = g. + + let mut rows = Vec::with_capacity(Cfg::N_D + Cfg::N_B + 1 + 1 + Cfg::N_A); + + // D rows: setup.D[i] evaluated at alpha, zero-padded + for d_row in setup.D.iter() { + let mut full = vec![F::zero(); total_cols]; + for (j, ring) in d_row.iter().enumerate() { + full[j] = eval_ring_at(ring, alpha); + } + rows.push(full); + } + + // B rows: setup.B[i] evaluated at alpha, in t-segment + for b_row in setup.B.iter() { + let mut full = vec![F::zero(); total_cols]; + for (j, ring) in b_row.iter().enumerate() { + full[w_len + j] = eval_ring_at(ring, alpha); + } + rows.push(full); + } + + // row3: kron(b, g1) evaluated at alpha -> b[i] * g1[d] (all scalars) + { + let mut full = vec![F::zero(); total_cols]; + for (i, &b_i) in opening_point.b.iter().enumerate() { + for (d, &g) in g1.iter().enumerate() { + full[i * Cfg::DELTA + d] = b_i * g; + } + } + rows.push(full); + } + + // row4: w-segment = kron(c, g1) evaluated at alpha + // z-segment = -kron(kron(a,g1), j1) (all scalars) + { + let mut full = vec![F::zero(); total_cols]; + // w-segment: c[i] evaluated at alpha, times g1[d] + for (i, c) in challenges.iter().enumerate() { + let c_alpha = eval_ring_at(&c.to_dense::().expect("valid challenge"), alpha); + for (d, &g) in g1.iter().enumerate() { + full[i * Cfg::DELTA + d] = c_alpha * g; + } + } + // z-segment: -kron(kron(a,g1), j1) = -(a[i] * g1[d] * j1[t]) + let z_offset = w_len + t_len; + for (i, &a_i) in opening_point.a.iter().enumerate() { + for (d, &g) in g1.iter().enumerate() { + let ag = a_i * g; + for (t, &j) in j1.iter().enumerate() { + let idx = (i * Cfg::DELTA + d) * Cfg::TAU + t; + full[z_offset + idx] = -(ag * j); + } + } + } + rows.push(full); + } + + // row5 (N_A rows): t-segment = kron(c, g_na[a_idx]) evaluated at alpha + // z-segment = -kron(A[a_idx], j1) evaluated at alpha + for a_idx in 0..Cfg::N_A { + let mut full = vec![F::zero(); total_cols]; + // t-segment: block-diagonal gadget times challenges + for (i, c) in challenges.iter().enumerate() { + let c_alpha = eval_ring_at(&c.to_dense::().expect("valid challenge"), alpha); + for (d, &g) in g1.iter().enumerate() { + let t_idx = i * (Cfg::N_A * Cfg::DELTA) + a_idx * Cfg::DELTA + d; + full[w_len + t_idx] = c_alpha * g; + } + } + // z-segment: -A[a_idx][k] evaluated at alpha, times j1[t] + let z_offset = w_len + t_len; + let a_row = &setup.A[a_idx]; + for (k, ring) in a_row.iter().enumerate() { + let ring_alpha = eval_ring_at(ring, alpha); + for (t, &j) in j1.iter().enumerate() { + full[z_offset + k * Cfg::TAU + t] = -(ring_alpha * j); + } + } + rows.push(full); + } + + Ok(rows) +} + +pub(crate) fn generate_y( + v: &[CyclotomicRing], + u: &[CyclotomicRing], + u_eval: &CyclotomicRing, +) -> Result>, HachiError> +where + F: FieldCore, +{ + if v.len() != Cfg::N_D { + return Err(HachiError::InvalidSize { + expected: Cfg::N_D, + actual: v.len(), + }); + } + if u.len() != Cfg::N_B { + return Err(HachiError::InvalidSize { + expected: Cfg::N_B, + actual: u.len(), + }); + } + let mut out = Vec::with_capacity(Cfg::N_D + Cfg::N_B + 1 + 1 + Cfg::N_A); + out.extend_from_slice(v); + out.extend_from_slice(u); + out.push(*u_eval); + out.push(CyclotomicRing::::zero()); + out.extend(std::iter::repeat_n( + CyclotomicRing::::zero(), + Cfg::N_A, + )); + Ok(out) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::{CyclotomicRing, SparseChallengeConfig}; + use crate::protocol::challenges::sparse::sample_sparse_challenges; + use crate::protocol::commitment::{HachiCommitmentCore, RingCommitmentScheme}; + use crate::protocol::proof::HachiCommitmentHint; + use crate::protocol::transcript::Blake2bTranscript; + use crate::test_utils::*; + use crate::Transcript; + + const TRANSCRIPT_SEED: &[u8] = b"test/prover-relation"; + + fn replay_challenges(v: &Vec>) -> Vec> { + let mut transcript = Blake2bTranscript::::new(TRANSCRIPT_SEED); + transcript.append_serde(ABSORB_PROVER_V, v); + + let challenge_cfg = SparseChallengeConfig { + weight: TinyConfig::CHALLENGE_WEIGHT, + nonzero_coeffs: vec![-1, 1], + }; + let sparse = sample_sparse_challenges::, D>( + &mut transcript, + CHALLENGE_STAGE1_FOLD, + NUM_BLOCKS, + &challenge_cfg, + ) + .unwrap(); + sparse + .iter() + .map(|c| c.to_dense::().unwrap()) + .collect() + } + + struct Fixture { + setup: HachiProverSetup, + commitment_u: Vec>, + point: RingOpeningPoint, + blocks: Vec>>, + quad_eq: QuadraticEquation, + /// Challenges re-derived via transcript replay (cross-check). + challenges: Vec>, + } + + fn build_fixture() -> Fixture { + let (setup, _) = + >::setup(16).unwrap(); + + let blocks = sample_blocks(); + let w = + >::commit_ring_blocks( + &blocks, &setup, + ) + .unwrap(); + + let point = RingOpeningPoint { + a: sample_a(), + b: sample_b(), + }; + + let ring_coeffs: Vec> = + blocks.iter().flat_map(|b| b.iter().copied()).collect(); + let hint = HachiCommitmentHint { + t_hat: w.t_hat, + ring_coeffs, + }; + let mut transcript = Blake2bTranscript::::new(TRANSCRIPT_SEED); + let y_ring = CyclotomicRing::::zero(); + let quad_eq = QuadraticEquation::::new_prover( + &setup, + &point, + &hint, + &mut transcript, + &w.commitment, + &y_ring, + ) + .unwrap(); + + let challenges = replay_challenges(&quad_eq.v); + + Fixture { + setup, + commitment_u: w.commitment.u.clone(), + point, + blocks, + quad_eq, + challenges, + } + } + + /// Row 1: D · ŵ = v + #[test] + fn row1_d_times_w_hat_equals_v() { + let f = build_fixture(); + + let w_hat = f.quad_eq.w_hat().unwrap(); + let w_hat_flat: Vec> = + w_hat.iter().flat_map(|v| v.iter().copied()).collect(); + let lhs = mat_vec_mul(&f.setup.expanded.D, &w_hat_flat); + + assert_eq!(lhs, f.quad_eq.v(), "Row 1 failed: D · ŵ ≠ v"); + } + + /// Row 2: B · t̂ = u (commitment vector) + #[test] + fn row2_b_times_t_hat_equals_u_commitment() { + let f = build_fixture(); + + let hint = f.quad_eq.hint().unwrap(); + let t_hat_flat: Vec> = + hint.t_hat.iter().flat_map(|v| v.iter().copied()).collect(); + let lhs = mat_vec_mul(&f.setup.expanded.B, &t_hat_flat); + + assert_eq!(lhs, f.commitment_u, "Row 2 failed: B · t̂ ≠ u"); + } + + /// Row 3: b^T · G_{2^r} · ŵ = u_eval + #[test] + fn row3_bt_gadget_w_hat_equals_u_eval() { + let f = build_fixture(); + + let w_hat = f.quad_eq.w_hat().unwrap(); + let w_recomposed: Vec> = w_hat + .iter() + .map(|w_hat_i| CyclotomicRing::gadget_recompose_pow2(w_hat_i, LOG_BASIS)) + .collect(); + + let u_eval = w_recomposed + .iter() + .zip(f.point.b.iter()) + .fold(CyclotomicRing::::zero(), |acc, (w_i, b_i)| { + acc + w_i.scale(b_i) + }); + + let u_eval_direct = f.blocks.iter().zip(f.point.b.iter()).fold( + CyclotomicRing::::zero(), + |acc, (block_i, b_i)| { + let inner: CyclotomicRing = block_i + .iter() + .zip(f.point.a.iter()) + .fold(CyclotomicRing::::zero(), |acc2, (f_ij, a_j)| { + acc2 + f_ij.scale(a_j) + }); + acc + inner.scale(b_i) + }, + ); + + assert_eq!( + u_eval, u_eval_direct, + "Row 3 failed: b^T G ŵ ≠ Σ b_i (a^T f_i)" + ); + } + + /// Derive z_hat from z_pre for test assertions. + fn derive_z_hat(z_pre: &[CyclotomicRing]) -> Vec> { + z_pre + .iter() + .flat_map(|z_j| z_j.balanced_decompose_pow2(TAU, LOG_BASIS)) + .collect() + } + + /// Row 4: (c^T ⊗ G_1) · ŵ = a^T · G_{2^m} · J · ẑ + #[test] + fn row4_challenge_fold_w_equals_a_gadget_j_z_hat() { + let f = build_fixture(); + + let w_hat = f.quad_eq.w_hat().unwrap(); + let w: Vec> = w_hat + .iter() + .map(|w_hat_i| CyclotomicRing::gadget_recompose_pow2(w_hat_i, LOG_BASIS)) + .collect(); + + let lhs = f + .challenges + .iter() + .zip(w.iter()) + .fold(CyclotomicRing::::zero(), |acc, (c_i, w_i)| { + acc + (*c_i * *w_i) + }); + + let z_hat = derive_z_hat(f.quad_eq.z_pre().unwrap()); + let z_recovered = recompose_z_hat(&z_hat); + let rhs = a_transpose_gadget_times_vec(&f.point.a, &z_recovered); + + assert_eq!(lhs, rhs, "Row 4 failed: (c^T ⊗ G_1)ŵ ≠ a^T G J ẑ"); + } + + /// Row 5: (c^T ⊗ G_{n_A}) · t̂ = A · J · ẑ + #[test] + fn row5_challenge_fold_t_equals_a_j_z_hat() { + let f = build_fixture(); + + let hint = f.quad_eq.hint().unwrap(); + let mut lhs = vec![CyclotomicRing::::zero(); N_A]; + for (c_i, t_hat_i) in f.challenges.iter().zip(hint.t_hat.iter()) { + let t_i = gadget_recompose_vec(t_hat_i); + assert_eq!(t_i.len(), N_A); + for (lhs_j, t_ij) in lhs.iter_mut().zip(t_i.iter()) { + *lhs_j += *c_i * *t_ij; + } + } + + let z_hat = derive_z_hat(f.quad_eq.z_pre().unwrap()); + let z_recovered = recompose_z_hat(&z_hat); + let rhs = mat_vec_mul(&f.setup.expanded.A, &z_recovered); + + assert_eq!(lhs, rhs, "Row 5 failed: (c^T ⊗ G_nA)t̂ ≠ A · J · ẑ"); + } + + #[test] + fn prove_output_shapes_are_correct() { + let f = build_fixture(); + + assert_eq!(f.quad_eq.v().len(), TinyConfig::N_D); + + let w_hat = f.quad_eq.w_hat().unwrap(); + assert_eq!(w_hat.len(), NUM_BLOCKS); + assert!(w_hat.iter().all(|v| v.len() == DELTA)); + + let hint = f.quad_eq.hint().unwrap(); + assert_eq!(hint.t_hat.len(), NUM_BLOCKS); + assert!(hint.t_hat.iter().all(|v| v.len() == N_A * DELTA)); + + assert_eq!(f.quad_eq.z_pre().unwrap().len(), BLOCK_LEN * DELTA); + } +} diff --git a/src/protocol/ring_switch.rs b/src/protocol/ring_switch.rs new file mode 100644 index 00000000..ef87ed0e --- /dev/null +++ b/src/protocol/ring_switch.rs @@ -0,0 +1,673 @@ +//! Ring switching logic for the Hachi PCS (Section 4.3). +//! +//! Handles the transition from the ring-based quadratic equation to field-based +//! sumcheck instances by expanding the ring elements into their coefficient +//! vectors and setting up the evaluation tables. + +use crate::algebra::CyclotomicRing; +use crate::cfg_into_iter; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::commitment::utils::norm::detect_field_modulus; +use crate::protocol::commitment::{ + CommitmentConfig, HachiCommitmentCore, HachiCommitmentLayout, HachiExpandedSetup, + RingCommitment, RingCommitmentScheme, +}; +use crate::protocol::quadratic_equation::{ + compute_m_a_streaming, compute_r_split_eq, QuadraticEquation, +}; +use crate::protocol::sumcheck::eq_poly::EqPolynomial; +use crate::protocol::transcript::labels::{ + ABSORB_SUMCHECK_W, CHALLENGE_RING_SWITCH, CHALLENGE_TAU0, CHALLENGE_TAU1, +}; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, FieldSampling}; + +/// Output of the ring switch protocol, containing everything needed for sumchecks. +pub struct RingSwitchOutput { + /// The witness vector w (concatenation of z and r coefficients). + pub w: Vec, + /// Commitment to w. + pub w_commitment: RingCommitment, + /// Compact evaluation table of w (all entries in [-8, 7], reordered for sumcheck). + /// Populated by the prover; empty on the verifier side. + pub w_evals: Vec, + /// Evaluation table of M_alpha(x) (tau1-weighted). + pub m_evals_x: Vec, + /// Evaluation table of alpha powers (y dimension). + pub alpha_evals_y: Vec, + /// Number of upper variable bits. + pub num_u: usize, + /// Number of lower variable bits. + pub num_l: usize, + /// Challenge tau0 for F_0 sumcheck. + pub tau0: Vec, + /// Challenge tau1 for F_alpha sumcheck. + pub tau1: Vec, + /// Basis size b = 2^LOG_BASIS. + pub b: usize, + /// Ring-switch challenge alpha. + pub alpha: F, +} + +/// Execute the prover side of the ring switching protocol (Section 4.3). +/// +/// # Errors +/// +/// Returns an error if z_pre/w_hat is missing, commitment fails, or matrix expansion fails. +pub fn ring_switch_prover( + quad_eq: &mut QuadraticEquation, + setup: &HachiExpandedSetup, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, + Cfg: CommitmentConfig, +{ + let w_hat = quad_eq + .w_hat() + .ok_or_else(|| HachiError::InvalidInput("missing w_hat in prover".to_string()))?; + let z_pre = quad_eq + .z_pre() + .ok_or_else(|| HachiError::InvalidInput("missing z_pre in prover".to_string()))?; + let hint = quad_eq + .hint() + .ok_or_else(|| HachiError::InvalidInput("missing hint in prover".to_string()))?; + let t_hat = &hint.t_hat; + + let r = compute_r_split_eq::( + setup, + quad_eq.opening_point(), + &quad_eq.challenges, + w_hat, + t_hat, + z_pre, + quad_eq.y(), + )?; + let w = build_w_coeffs::(w_hat, t_hat, z_pre, &r); + + let w_commitment = commit_w::(&w)?; + transcript.append_serde(ABSORB_SUMCHECK_W, &w_commitment); + + let alpha: F = transcript.challenge_scalar(CHALLENGE_RING_SWITCH); + + let m_a = compute_m_a_streaming::( + setup, + quad_eq.opening_point(), + &quad_eq.challenges, + &alpha, + )?; + let m_a_vec = expand_m_a::(&m_a, alpha)?; + let m_rows = m_row_count::(); + let m_cols = if m_a.is_empty() { + 0 + } else { + m_a_vec.len() / m_a.len() + }; + + let (w_evals, num_u, num_l) = build_w_evals_compact::(&w, D)?; + let alpha_evals_y = build_alpha_evals_y(alpha, D); + + let num_sc_vars = num_u + num_l; + let tau0 = sample_tau::(transcript, CHALLENGE_TAU0, num_sc_vars); + + let num_i = m_rows.next_power_of_two().trailing_zeros() as usize; + let tau1 = sample_tau::(transcript, CHALLENGE_TAU1, num_i); + + let m_evals_x = build_m_evals_x::(&m_a_vec, m_rows, m_cols, &tau1); + + Ok(RingSwitchOutput { + w, + w_commitment, + w_evals, + m_evals_x, + alpha_evals_y, + num_u, + num_l, + tau0, + tau1, + b: 1usize << Cfg::LOG_BASIS, + alpha, + }) +} + +/// Replay the verifier side of ring switching to reconstruct evaluation tables. +/// +/// # Errors +/// +/// Returns an error if matrix expansion fails. +pub fn ring_switch_verifier( + quad_eq: &QuadraticEquation, + setup: &HachiExpandedSetup, + w: &[F], + w_commitment: &RingCommitment, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, + Cfg: CommitmentConfig, +{ + transcript.append_serde(ABSORB_SUMCHECK_W, w_commitment); + + let alpha: F = transcript.challenge_scalar(CHALLENGE_RING_SWITCH); + + let m_a = compute_m_a_streaming::( + setup, + quad_eq.opening_point(), + &quad_eq.challenges, + &alpha, + )?; + let m_a_vec = expand_m_a::(&m_a, alpha)?; + let m_rows = m_row_count::(); + let m_cols = if m_a.is_empty() { + 0 + } else { + m_a_vec.len() / m_a.len() + }; + + let num_ring_elems = w.len() / D; + let num_u = num_ring_elems.next_power_of_two().trailing_zeros() as usize; + let num_l = D.trailing_zeros() as usize; + let alpha_evals_y = build_alpha_evals_y(alpha, D); + + let num_sc_vars = num_u + num_l; + let tau0 = sample_tau::(transcript, CHALLENGE_TAU0, num_sc_vars); + + let num_i = m_rows.next_power_of_two().trailing_zeros() as usize; + let tau1 = sample_tau::(transcript, CHALLENGE_TAU1, num_i); + + let m_evals_x = build_m_evals_x::(&m_a_vec, m_rows, m_cols, &tau1); + + Ok(RingSwitchOutput { + w: w.to_vec(), + w_commitment: w_commitment.clone(), + w_evals: Vec::new(), + m_evals_x, + alpha_evals_y, + num_u, + num_l, + tau0, + tau1, + b: 1usize << Cfg::LOG_BASIS, + alpha, + }) +} + +#[cfg(test)] +pub(crate) fn compute_r_via_poly_division( + m: &[Vec>], + z: &[CyclotomicRing], + y: &[CyclotomicRing], +) -> Result>, HachiError> { + let poly_len = 2 * D - 1; + let out = m + .iter() + .zip(y.iter()) + .map(|(row, y_i)| { + let column_contribution = + |m_ij: &CyclotomicRing, z_j: &CyclotomicRing| -> Vec { + let mut local = vec![F::zero(); poly_len]; + if m_ij.is_zero() { + return local; + } + let a = m_ij.coefficients(); + let b = z_j.coefficients(); + let is_scalar = a[1..].iter().all(|c| c.is_zero()); + if is_scalar { + let scalar = a[0]; + for s in 0..D { + local[s] = scalar * b[s]; + } + } else { + for t in 0..D { + for s in 0..D { + local[t + s] = local[t + s] + a[t] * b[s]; + } + } + } + local + }; + + let pointwise_add = |mut a: Vec, b: Vec| -> Vec { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + *ai = *ai + *bi; + } + a + }; + + #[cfg(feature = "parallel")] + let mut poly = row + .par_iter() + .zip(z.par_iter()) + .fold( + || vec![F::zero(); poly_len], + |acc, (m_ij, z_j)| pointwise_add(acc, column_contribution(m_ij, z_j)), + ) + .reduce(|| vec![F::zero(); poly_len], pointwise_add); + + #[cfg(not(feature = "parallel"))] + let mut poly = row + .iter() + .zip(z.iter()) + .fold(vec![F::zero(); poly_len], |acc, (m_ij, z_j)| { + pointwise_add(acc, column_contribution(m_ij, z_j)) + }); + let y_coeffs = y_i.coefficients(); + for k in 0..D { + poly[k] = poly[k] - y_coeffs[k]; + } + let mut quotient = vec![F::zero(); D]; + for k in (D..poly_len).rev() { + let q = poly[k]; + quotient[k - D] = q; + poly[k - D] = poly[k - D] - q; + } + let coeffs: [F; D] = std::array::from_fn(|k| quotient[k]); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + Ok(out) +} + +#[derive(Clone, Copy, Debug)] +struct WCommitmentConfig { + _cfg: std::marker::PhantomData, +} + +impl CommitmentConfig for WCommitmentConfig { + const D: usize = D; + const N_A: usize = Cfg::N_A; + const N_B: usize = Cfg::N_B; + const N_D: usize = Cfg::N_D; + const LOG_BASIS: u32 = Cfg::LOG_BASIS; + const DELTA: usize = Cfg::DELTA; + const TAU: usize = Cfg::TAU; + const CHALLENGE_WEIGHT: usize = Cfg::CHALLENGE_WEIGHT; + + fn commitment_layout(max_num_vars: usize) -> Result { + let alpha = D.trailing_zeros() as usize; + let m_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + HachiCommitmentLayout::new::(m_vars, 0) + } +} + +fn commit_w(w: &[F]) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, +{ + type WCfg = WCommitmentConfig; + + let ring_elems: Vec> = w + .chunks(D) + .map(|chunk| { + let coeffs: [F; D] = + std::array::from_fn(|i| if i < chunk.len() { chunk[i] } else { F::zero() }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + + let block_len = ring_elems.len().next_power_of_two().max(1); + let mut padded = ring_elems; + padded.resize(block_len, CyclotomicRing::::zero()); + let m_vars = block_len.trailing_zeros() as usize; + let max_num_vars = m_vars + D.trailing_zeros() as usize; + let blocks = vec![padded]; + + let (w_setup, _) = + >>::setup(max_num_vars)?; + + let w = >>::commit_ring_blocks( + &blocks, &w_setup, + )?; + + Ok(w.commitment) +} + +pub(crate) fn eval_ring_at(r: &CyclotomicRing, alpha: &F) -> F { + let mut acc = F::zero(); + let mut power = F::one(); + for coeff in r.coefficients() { + acc = acc + (*coeff * power); + power = power * *alpha; + } + acc +} + +pub(crate) fn r_decomp_levels() -> usize { + let modulus = detect_field_modulus::(); + let bits = 128 - (modulus.saturating_sub(1)).leading_zeros() as usize; + let log_basis = Cfg::LOG_BASIS as usize; + let mut levels = (bits + log_basis.saturating_sub(1)) / log_basis.max(1); + if levels == 0 { + levels = 1; + } + + let b = 1u128 << Cfg::LOG_BASIS; + let half_q = modulus / 2; + let half_b_minus_1 = b / 2 - 1; + let b_minus_1 = b - 1; + let mut b_pow = 1u128; + for _ in 0..levels { + b_pow = b_pow.saturating_mul(b); + } + let max_positive = half_b_minus_1.saturating_mul((b_pow - 1) / b_minus_1); + if max_positive < half_q { + levels += 1; + } + + levels +} + +pub(crate) fn expand_m_a( + m_a: &[Vec], + alpha: F, +) -> Result, HachiError> { + if m_a.is_empty() { + return Ok(Vec::new()); + } + let rows = m_a.len(); + let cols = m_a[0].len(); + if cols == 0 { + return Ok(vec![F::zero(); rows]); + } + for row in m_a.iter() { + if row.len() != cols { + return Err(HachiError::InvalidSize { + expected: cols, + actual: row.len(), + }); + } + } + + let levels = r_decomp_levels::(); + let total_cols = cols + .checked_add( + rows.checked_mul(levels) + .ok_or_else(|| HachiError::InvalidSetup("expanded M width overflow".to_string()))?, + ) + .ok_or_else(|| HachiError::InvalidSetup("expanded M width overflow".to_string()))?; + + let base = F::from_canonical_u128_reduced(1u128 << Cfg::LOG_BASIS); + let mut gadget_row = Vec::with_capacity(levels); + let mut power = F::one(); + for _ in 0..levels { + gadget_row.push(power); + power = power * base; + } + + let mut alpha_pow = F::one(); + for _ in 0..D { + alpha_pow = alpha_pow * alpha; + } + let denom = alpha_pow + F::one(); + + let mut out = vec![F::zero(); rows * total_cols]; + for (i, m_a_row) in m_a.iter().enumerate() { + let row_start = i * total_cols; + out[row_start..row_start + cols].copy_from_slice(m_a_row); + let r_start = row_start + cols + i * levels; + for (j, g) in gadget_row.iter().enumerate() { + out[r_start + j] = -denom * *g; + } + } + Ok(out) +} + +/// # Errors +/// +/// Returns an error if `w.len()` is not a multiple of `d`. +pub(crate) fn build_w_evals( + w: &[F], + d: usize, +) -> Result<(Vec, usize, usize), HachiError> { + if d == 0 || w.len() % d != 0 { + return Err(HachiError::InvalidSize { + expected: d, + actual: w.len(), + }); + } + let num_l = d.trailing_zeros() as usize; + let num_ring_elems = w.len() / d; + let num_u = num_ring_elems.next_power_of_two().trailing_zeros() as usize; + let x_len = 1usize << num_u; + let n = x_len << num_l; + + let evals: Vec = cfg_into_iter!(0..n) + .map(|dst| { + let x = dst & (x_len - 1); + let y = dst >> num_u; + let src = y + (x << num_l); + if src < w.len() { + w[src] + } else { + F::zero() + } + }) + .collect(); + Ok((evals, num_u, num_l)) +} + +/// Compact variant of `build_w_evals` returning `Vec`. +/// +/// All entries in `w` must have canonical values in `[-half_b, half_b - 1]` +/// where `half_b = 2^(LOG_BASIS-1)`. This holds when `w` is produced by +/// `build_w_coeffs` (all components go through `balanced_decompose_pow2`). +pub(crate) fn build_w_evals_compact( + w: &[F], + d: usize, +) -> Result<(Vec, usize, usize), HachiError> { + if d == 0 || w.len() % d != 0 { + return Err(HachiError::InvalidSize { + expected: d, + actual: w.len(), + }); + } + let num_l = d.trailing_zeros() as usize; + let num_ring_elems = w.len() / d; + let num_u = num_ring_elems.next_power_of_two().trailing_zeros() as usize; + let x_len = 1usize << num_u; + let n = x_len << num_l; + + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + + let evals: Vec = (0..n) + .map(|dst| { + let x = dst & (x_len - 1); + let y = dst >> num_u; + let src = y + (x << num_l); + if src < w.len() { + let canonical = w[src].to_canonical_u128(); + if canonical <= half_q { + canonical as i8 + } else { + (canonical as i128 - q as i128) as i8 + } + } else { + 0i8 + } + }) + .collect(); + Ok((evals, num_u, num_l)) +} + +pub(crate) fn m_row_count() -> usize { + Cfg::N_D + Cfg::N_B + 1 + 1 + Cfg::N_A +} + +pub(crate) fn build_m_evals_x( + m_a_flat: &[F], + rows: usize, + cols: usize, + tau1: &[F], +) -> Vec { + let eq_tau1 = EqPolynomial::evals(tau1); + let x_len = cols.next_power_of_two(); + cfg_into_iter!(0..x_len) + .map(|x| { + let mut acc = F::zero(); + for i in 0..eq_tau1.len() { + let row_val = if i < rows && x < cols { + m_a_flat[i * cols + x] + } else { + F::zero() + }; + acc = acc + eq_tau1[i] * row_val; + } + acc + }) + .collect() +} + +pub(crate) fn build_alpha_evals_y(alpha: F, d: usize) -> Vec { + let mut out = vec![F::zero(); d]; + let mut power = F::one(); + for val in out.iter_mut() { + *val = power; + power = power * alpha; + } + out +} + +pub(crate) fn sample_tau>( + transcript: &mut T, + label: &[u8], + n: usize, +) -> Vec { + (0..n).map(|_| transcript.challenge_scalar(label)).collect() +} + +pub(crate) fn build_w_coeffs( + w_hat: &[Vec>], + t_hat: &[Vec>], + z_pre: &[CyclotomicRing], + r: &[CyclotomicRing], +) -> Vec { + let levels = r_decomp_levels::(); + let r_hat: Vec> = r + .iter() + .flat_map(|ri| ri.balanced_decompose_pow2(levels, Cfg::LOG_BASIS)) + .collect(); + + let w_hat_flat = w_hat.iter().flat_map(|v| v.iter()); + let t_hat_flat = t_hat.iter().flat_map(|v| v.iter()); + let z_hat_iter = z_pre + .iter() + .flat_map(|z_j| z_j.balanced_decompose_pow2(Cfg::TAU, Cfg::LOG_BASIS)); + + let z_count = w_hat.iter().map(|v| v.len()).sum::() + + t_hat.iter().map(|v| v.len()).sum::() + + z_pre.len() * Cfg::TAU; + let mut out = Vec::with_capacity((z_count + r_hat.len()) * D); + for elem in w_hat_flat + .chain(t_hat_flat) + .chain(z_hat_iter.collect::>().iter()) + .chain(r_hat.iter()) + { + out.extend_from_slice(elem.coefficients()); + } + out +} + +#[cfg(test)] +mod tests { + use super::compute_r_via_poly_division; + use crate::algebra::{CyclotomicRing, Fp64}; + use crate::{FieldCore, FromSmallInt}; + + fn compute_r_schoolbook( + m: &[Vec>], + z: &[CyclotomicRing], + y: &[CyclotomicRing], + ) -> Vec> { + let poly_len = 2 * D - 1; + m.iter() + .zip(y.iter()) + .map(|(row, y_i)| { + let mut poly = vec![F::zero(); poly_len]; + for (m_ij, z_j) in row.iter().zip(z.iter()) { + if m_ij.is_zero() { + continue; + } + let a = m_ij.coefficients(); + let b = z_j.coefficients(); + let is_scalar = a[1..].iter().all(|c| c.is_zero()); + if is_scalar { + let scalar = a[0]; + for s in 0..D { + poly[s] = poly[s] + scalar * b[s]; + } + } else { + for t in 0..D { + for s in 0..D { + poly[t + s] = poly[t + s] + a[t] * b[s]; + } + } + } + } + let y_coeffs = y_i.coefficients(); + for k in 0..D { + poly[k] = poly[k] - y_coeffs[k]; + } + let mut quotient = vec![F::zero(); D]; + for k in (D..poly_len).rev() { + let q = poly[k]; + quotient[k - D] = q; + poly[k - D] = poly[k - D] - q; + } + let coeffs: [F; D] = std::array::from_fn(|k| quotient[k]); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + } + + #[test] + fn compute_r_matches_schoolbook_reference() { + type F = Fp64<4294967197>; + const D: usize = 64; + + let m: Vec>> = (0..3) + .map(|i| { + (0..4) + .map(|j| { + if (i + j) % 3 == 0 { + let mut coeffs = [F::zero(); D]; + coeffs[0] = F::from_u64((i * 5 + j + 1) as u64); + CyclotomicRing::from_coefficients(coeffs) + } else { + let coeffs = std::array::from_fn(|k| { + F::from_u64((i as u64 * 1000 + j as u64 * 100 + k as u64 + 1) % 97) + }); + CyclotomicRing::from_coefficients(coeffs) + } + }) + .collect() + }) + .collect(); + let z: Vec> = (0..4) + .map(|j| { + let coeffs = + std::array::from_fn(|k| F::from_u64((j as u64 * 37 + k as u64 + 5) % 89)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let y: Vec> = (0..3) + .map(|i| { + let coeffs = + std::array::from_fn(|k| F::from_u64((i as u64 * 29 + k as u64 + 7) % 83)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + + let expected = compute_r_schoolbook(&m, &z, &y); + let got = compute_r_via_poly_division::(&m, &z, &y) + .expect("ring-switch CRT+NTT path should dispatch for D=64"); + assert_eq!(got, expected); + } +} diff --git a/src/protocol/sumcheck/batched_sumcheck.rs b/src/protocol/sumcheck/batched_sumcheck.rs new file mode 100644 index 00000000..5caa9810 --- /dev/null +++ b/src/protocol/sumcheck/batched_sumcheck.rs @@ -0,0 +1,349 @@ +//! Batched sumcheck protocol. +//! +//! Implements the standard technique for batching parallel sumchecks to reduce +//! verifier cost and proof size. +//! +//! For details, refer to Jim Posen's ["Perspectives on Sumcheck Batching"](https://hackmd.io/s/HyxaupAAA). +//! We do what they describe as "front-loaded" batch sumcheck. +//! +//! Adapted from Jolt's `BatchedSumcheck` implementation. + +use super::{SumcheckInstanceProver, SumcheckInstanceVerifier, SumcheckProof, UniPoly}; +use crate::error::HachiError; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::Transcript; +use crate::{FieldCore, FromSmallInt}; + +fn mul_pow_2(x: E, k: usize) -> E { + let mut result = x; + for _ in 0..k { + result = result + result; + } + result +} + +fn linear_combination(polys: &[UniPoly], coeffs: &[E]) -> UniPoly { + let max_len = polys.iter().map(|p| p.coeffs.len()).max().unwrap_or(0); + let mut result = vec![E::zero(); max_len]; + for (poly, coeff) in polys.iter().zip(coeffs.iter()) { + for (i, c) in poly.coeffs.iter().enumerate() { + result[i] = result[i] + *c * *coeff; + } + } + UniPoly::from_coeffs(result) +} + +/// Verifier-side output of the batched sumcheck round replay. +/// +/// This carries all transcript-derived values needed for the final oracle check, +/// which is intentionally split out so callers can compute the expected output +/// claim through an external reduction (e.g. Greyhound) before enforcing +/// equality. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BatchedSumcheckRoundResult { + /// Final claim produced by replaying all sumcheck rounds. + pub output_claim: E, + /// Challenge vector sampled during replay. + pub r_sumcheck: Vec, + /// Front-loaded batching coefficient per verifier instance. + pub batching_coeffs: Vec, + /// Maximum number of rounds among batched instances. + pub max_num_rounds: usize, +} + +/// Produce a batched sumcheck proof for multiple instances sharing the same +/// variable space, driving the Fiat–Shamir transcript. +/// +/// This function: +/// - absorbs each instance's initial claim, +/// - samples batching coefficients (one per instance), +/// - computes a single batched round polynomial per round as a linear +/// combination of the individual round polynomials, +/// - returns a single [`SumcheckProof`] and the derived challenge vector. +/// +/// Instances with fewer rounds than the maximum are padded with constant +/// "dummy" round polynomials (the Jolt "front-loaded" approach). +/// +/// # Panics +/// +/// Panics if `instances` is empty or if 2 is not invertible in the field. +/// +/// # Errors +/// +/// Returns an error if the field inverse of 2 does not exist. +pub fn prove_batched_sumcheck( + mut instances: Vec<&mut dyn SumcheckInstanceProver>, + transcript: &mut T, + mut sample_challenge: S, +) -> Result<(SumcheckProof, Vec), HachiError> +where + F: FieldCore + crate::CanonicalField, + T: Transcript, + E: FieldCore + FromSmallInt, + S: FnMut(&mut T) -> E, +{ + if instances.is_empty() { + return Err(HachiError::InvalidInput( + "no sumcheck instances provided".into(), + )); + } + + let max_num_rounds = instances + .iter() + .map(|inst| inst.num_rounds()) + .max() + .unwrap(); // safe: non-empty checked above + + // Absorb individual input claims. + for inst in instances.iter() { + let claim = inst.input_claim(); + transcript.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &claim); + } + + // Sample one batching coefficient per instance. + let batching_coeffs: Vec = (0..instances.len()) + .map(|_| sample_challenge(transcript)) + .collect(); + + // To see why we may need to scale by a power of two, consider a batch of + // two sumchecks: + // claim_a = \sum_x P(x) where x \in {0, 1}^M + // claim_b = \sum_{x, y} Q(x, y) where x \in {0, 1}^M, y \in {0, 1}^N + // Then the batched sumcheck is: + // \sum_{x, y} A * P(x) + B * Q(x, y) where A and B are batching coefficients + // = A * \sum_y \sum_x P(x) + B * \sum_{x, y} Q(x, y) + // = A * \sum_y claim_a + B * claim_b + // = A * 2^N * claim_a + B * claim_b + let mut individual_claims: Vec = instances + .iter() + .map(|inst| { + let n = inst.num_rounds(); + let claim = inst.input_claim(); + mul_pow_2(claim, max_num_rounds - n) + }) + .collect(); + + let two_inv = E::from_u64(2) + .inv() + .expect("field characteristic 2 not supported"); + + let mut round_polys = Vec::with_capacity(max_num_rounds); + let mut challenges = Vec::with_capacity(max_num_rounds); + + for round in 0..max_num_rounds { + let univariate_polys: Vec> = instances + .iter_mut() + .zip(individual_claims.iter()) + .map(|(inst, previous_claim)| { + let n = inst.num_rounds(); + let offset = max_num_rounds - n; + let active = round >= offset && round < offset + n; + if active { + inst.compute_round_univariate(round - offset, *previous_claim) + } else { + UniPoly::from_coeffs(vec![*previous_claim * two_inv]) + } + }) + .collect(); + + let batched_poly = linear_combination(&univariate_polys, &batching_coeffs); + + #[cfg(debug_assertions)] + { + let g0 = batched_poly.evaluate(&E::zero()); + let g1 = batched_poly.evaluate(&E::one()); + let batched_claim: E = individual_claims + .iter() + .zip(batching_coeffs.iter()) + .map(|(c, b)| *c * *b) + .fold(E::zero(), |a, v| a + v); + debug_assert!( + g0 + g1 == batched_claim, + "round {round}: H(0) + H(1) != batched claim" + ); + } + + let compressed = batched_poly.compress(); + transcript.append_serde(labels::ABSORB_SUMCHECK_ROUND, &compressed); + let r_j = sample_challenge(transcript); + challenges.push(r_j); + + // Update individual claims from each instance's own univariate. + for (claim, poly) in individual_claims.iter_mut().zip(univariate_polys.iter()) { + *claim = poly.evaluate(&r_j); + } + + // Ingest challenge into each active instance. + for inst in instances.iter_mut() { + let n = inst.num_rounds(); + let offset = max_num_rounds - n; + let active = round >= offset && round < offset + n; + if active { + inst.ingest_challenge(round - offset, r_j); + } + } + + round_polys.push(compressed); + } + + for inst in instances.iter_mut() { + inst.finalize(); + } + + Ok((SumcheckProof { round_polys }, challenges)) +} + +/// Verify a batched sumcheck proof. +/// +/// This function: +/// - absorbs each verifier instance's initial claim, +/// - re-derives the batching coefficients, +/// - computes the batched initial claim, +/// - verifies the proof against the batched claim. +/// +/// Returns transcript-derived verifier data for the caller to perform the final +/// expected-output equality check. +/// +/// # Panics +/// +/// Panics if `verifiers` is empty. +/// +/// # Errors +/// +/// Propagates per-round verification errors. +pub fn verify_batched_sumcheck_rounds( + proof: &SumcheckProof, + verifiers: Vec<&dyn SumcheckInstanceVerifier>, + transcript: &mut T, + mut sample_challenge: S, +) -> Result, HachiError> +where + F: FieldCore + crate::CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, +{ + if verifiers.is_empty() { + return Err(HachiError::InvalidInput( + "no sumcheck instances provided".into(), + )); + } + + let max_degree = verifiers.iter().map(|v| v.degree_bound()).max().unwrap(); // safe: non-empty + let max_num_rounds = verifiers.iter().map(|v| v.num_rounds()).max().unwrap(); // safe: non-empty + + // Absorb individual input claims. + for v in verifiers.iter() { + let claim = v.input_claim(); + transcript.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &claim); + } + + // Re-derive batching coefficients. + let batching_coeffs: Vec = (0..verifiers.len()) + .map(|_| sample_challenge(transcript)) + .collect(); + + // Compute the combined initial claim with power-of-two scaling. + let batched_claim: E = verifiers + .iter() + .zip(batching_coeffs.iter()) + .map(|(v, coeff)| { + let n = v.num_rounds(); + let claim = v.input_claim(); + mul_pow_2(claim, max_num_rounds - n) * *coeff + }) + .fold(E::zero(), |a, v| a + v); + + let (output_claim, r_sumcheck) = proof.verify::( + batched_claim, + max_num_rounds, + max_degree, + transcript, + &mut sample_challenge, + )?; + + Ok(BatchedSumcheckRoundResult { + output_claim, + r_sumcheck, + batching_coeffs, + max_num_rounds, + }) +} + +/// Compute the expected batched output claim from verifier instances and +/// transcript-derived batching data. +/// +/// # Errors +/// +/// Propagates errors from verifier `expected_output_claim` calls. +pub fn compute_batched_expected_output_claim( + verifiers: Vec<&dyn SumcheckInstanceVerifier>, + batching_coeffs: &[E], + max_num_rounds: usize, + r_sumcheck: &[E], +) -> Result { + let expected_output_claim: E = verifiers + .iter() + .zip(batching_coeffs.iter()) + .map(|(v, coeff)| { + let offset = max_num_rounds - v.num_rounds(); + let r_slice = &r_sumcheck[offset..offset + v.num_rounds()]; + v.expected_output_claim(r_slice).map(|val| val * *coeff) + }) + .try_fold(E::zero(), |a, v| v.map(|val| a + val))?; + + Ok(expected_output_claim) +} + +/// Enforce final batched output-claim equality. +/// +/// # Errors +/// +/// Returns an error if `output_claim != expected_output_claim`. +pub fn check_batched_output_claim( + output_claim: E, + expected_output_claim: E, +) -> Result<(), HachiError> { + if output_claim != expected_output_claim { + return Err(HachiError::InvalidProof); + } + + Ok(()) +} + +/// Verify a batched sumcheck proof, including final expected-output equality. +/// +/// This convenience wrapper preserves the previous behavior. Callers that need +/// to inject an external reduction should use [`verify_batched_sumcheck_rounds`] +/// and [`check_batched_output_claim`] directly. +/// +/// # Errors +/// +/// Propagates errors from round verification and output-claim equality check. +pub fn verify_batched_sumcheck( + proof: &SumcheckProof, + verifiers: Vec<&dyn SumcheckInstanceVerifier>, + transcript: &mut T, + mut sample_challenge: S, +) -> Result, HachiError> +where + F: FieldCore + crate::CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, +{ + let round_result = verify_batched_sumcheck_rounds::( + proof, + verifiers.clone(), + transcript, + &mut sample_challenge, + )?; + let expected_output_claim = compute_batched_expected_output_claim( + verifiers, + &round_result.batching_coeffs, + round_result.max_num_rounds, + &round_result.r_sumcheck, + )?; + check_batched_output_claim(round_result.output_claim, expected_output_claim)?; + Ok(round_result.r_sumcheck) +} diff --git a/src/protocol/sumcheck/eq_poly.rs b/src/protocol/sumcheck/eq_poly.rs new file mode 100644 index 00000000..353d0830 --- /dev/null +++ b/src/protocol/sumcheck/eq_poly.rs @@ -0,0 +1,229 @@ +//! Utilities for the equality polynomial `eq(x, y) = Πᵢ (xᵢ yᵢ + (1 − xᵢ)(1 − yᵢ))`. +//! +//! The equality polynomial evaluates to 1 when `x = y` (over the boolean hypercube) +//! and 0 otherwise. Its multilinear extension (MLE) is used throughout sumcheck +//! protocols. +//! +//! Adapted from Jolt's `EqPolynomial` implementation. +//! +//! ## Bit / index order: Little-endian +//! +//! The evaluation tables produced by this module use **little-endian** bit order: +//! entry `b` (as an integer index) corresponds to the boolean vector where +//! bit `k` of `b` equals `x[k]`. In other words, `r[0]` corresponds to the +//! **least-significant bit** (bit 0) and `r[n-1]` to the MSB. + +use crate::FieldCore; +use std::marker::PhantomData; + +/// Utilities for the equality polynomial `eq(x, y) = Πᵢ (xᵢ yᵢ + (1 − xᵢ)(1 − yᵢ))`. +pub struct EqPolynomial(PhantomData); + +impl EqPolynomial { + /// Compute the MLE of the equality polynomial at two points: + /// `eq(x, y) = Πᵢ (xᵢ yᵢ + (1 − xᵢ)(1 − yᵢ))`. + /// + /// # Panics + /// + /// Panics if `x.len() != y.len()`. + pub fn mle(x: &[E], y: &[E]) -> E { + assert_eq!(x.len(), y.len()); + x.iter() + .zip(y.iter()) + .map(|(&x_i, &y_i)| x_i * y_i + (E::one() - x_i) * (E::one() - y_i)) + .fold(E::one(), |acc, v| acc * v) + } + + /// Compute the zero selector: `eq(r, 0) = Πᵢ (1 − rᵢ)`. + pub fn zero_selector(r: &[E]) -> E { + r.iter().fold(E::one(), |acc, &r_i| acc * (E::one() - r_i)) + } + + /// Compute the full evaluation table `{ eq(r, x) : x ∈ {0,1}^n }`. + /// + /// Uses **little-endian** bit order: entry `b` has bit `k` of `b` + /// corresponding to `r[k]`. + /// + /// For a scaled table, use [`Self::evals_with_scaling`]. + pub fn evals(r: &[E]) -> Vec { + Self::evals_with_scaling(r, None) + } + + /// Compute the full evaluation table with optional scaling: + /// `scaling_factor · eq(r, x)` for all `x ∈ {0,1}^n`. + /// + /// Uses the same **little-endian** index order as [`Self::evals`]. + /// If `scaling_factor` is `None`, defaults to 1 (no scaling). + pub fn evals_with_scaling(r: &[E], scaling_factor: Option) -> Vec { + #[cfg(feature = "parallel")] + { + const PARALLEL_THRESHOLD: usize = 16; + if r.len() > PARALLEL_THRESHOLD { + return Self::evals_parallel(r, scaling_factor); + } + } + Self::evals_serial(r, scaling_factor) + } + + /// Serial (single-threaded) version of [`Self::evals_with_scaling`]. + /// + /// Uses **little-endian** index order. + pub fn evals_serial(r: &[E], scaling_factor: Option) -> Vec { + let size = 1usize << r.len(); + let mut evals = vec![E::zero(); size]; + evals[0] = scaling_factor.unwrap_or(E::one()); + let mut len = 1usize; + for &t in r.iter().rev() { + let one_minus_t = E::one() - t; + for j in (0..len).rev() { + evals[2 * j + 1] = evals[j] * t; + evals[2 * j] = evals[j] * one_minus_t; + } + len *= 2; + } + evals + } + + /// Compute eq evaluations and cache intermediate tables. + /// + /// Returns `result` where `result[j]` contains evaluations for the prefix + /// `r[..j]`: `result[j][x] = eq(r[..j], x)` for `x ∈ {0,1}^j`. + /// + /// So `result[0] = [1]`, `result[1]` has 2 entries, ..., and `result[n]` + /// equals [`Self::evals(r)`]. + pub fn evals_cached(r: &[E]) -> Vec> { + Self::evals_cached_with_scaling(r, None) + } + + /// Like [`Self::evals_cached`], but with optional scaling. + pub fn evals_cached_with_scaling(r: &[E], scaling_factor: Option) -> Vec> { + let mut result: Vec> = (0..r.len() + 1).map(|i| vec![E::zero(); 1 << i]).collect(); + result[0][0] = scaling_factor.unwrap_or(E::one()); + for j in 0..r.len() { + let idx = r.len() - 1 - j; + let t = r[idx]; + let one_minus_t = E::one() - t; + let prev_len = 1 << j; + for i in (0..prev_len).rev() { + result[j + 1][2 * i + 1] = result[j][i] * t; + result[j + 1][2 * i] = result[j][i] * one_minus_t; + } + } + result + } + + /// Parallel version of [`Self::evals_with_scaling`]. + /// + /// Uses rayon to compute the largest layers of the DP tree in parallel. + /// Uses the same **little-endian** index order as [`Self::evals`]. + #[cfg(feature = "parallel")] + pub fn evals_parallel(r: &[E], scaling_factor: Option) -> Vec { + use rayon::prelude::*; + + let final_size = 1usize << r.len(); + let mut evals = vec![E::zero(); final_size]; + evals[0] = scaling_factor.unwrap_or(E::one()); + let mut size = 1; + + // Forward iteration (r[0] first) produces little-endian ordering. + for &r_i in r.iter() { + let (evals_left, evals_right) = evals.split_at_mut(size); + let (evals_right, _) = evals_right.split_at_mut(size); + + evals_left + .par_iter_mut() + .zip(evals_right.par_iter_mut()) + .for_each(|(x, y)| { + *y = *x * r_i; + *x = *x - *y; + }); + + size *= 2; + } + + evals + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Fp64; + use crate::{FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp64<4294967197>; + + #[test] + fn evals_matches_mle_pointwise() { + let mut rng = StdRng::seed_from_u64(0xEE); + for n in 1..8 { + let r: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let table = EqPolynomial::evals(&r); + assert_eq!(table.len(), 1 << n); + for (idx, &val) in table.iter().enumerate() { + let bits: Vec = (0..n) + .map(|k| { + if (idx >> k) & 1 == 1 { + F::one() + } else { + F::zero() + } + }) + .collect(); + let expected = EqPolynomial::mle(&r, &bits); + assert_eq!(val, expected, "n={n} idx={idx}"); + } + } + } + + #[test] + fn evals_with_scaling_scales_uniformly() { + let mut rng = StdRng::seed_from_u64(0xAB); + let r: Vec = (0..5).map(|_| F::sample(&mut rng)).collect(); + let scale = F::from_u64(7); + let unscaled = EqPolynomial::evals(&r); + let scaled = EqPolynomial::evals_with_scaling(&r, Some(scale)); + for (u, s) in unscaled.iter().zip(scaled.iter()) { + assert_eq!(*s, *u * scale); + } + } + + #[test] + fn evals_cached_last_matches_evals() { + let mut rng = StdRng::seed_from_u64(0xCD); + for n in 1..8 { + let r: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let table = EqPolynomial::evals(&r); + let cached = EqPolynomial::evals_cached(&r); + assert_eq!(cached.len(), n + 1); + assert_eq!(cached[0], vec![F::one()]); + assert_eq!(*cached.last().unwrap(), table); + } + } + + #[test] + fn zero_selector_matches_mle_at_origin() { + let mut rng = StdRng::seed_from_u64(0x00); + for n in 1..8 { + let r: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let zeros = vec![F::zero(); n]; + let expected = EqPolynomial::mle(&r, &zeros); + let actual = EqPolynomial::zero_selector(&r); + assert_eq!(actual, expected, "n={n}"); + } + } + + #[cfg(feature = "parallel")] + #[test] + fn evals_parallel_matches_serial() { + let mut rng = StdRng::seed_from_u64(0xFF); + for n in 1..20 { + let r: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let serial = EqPolynomial::evals_serial(&r, None); + let parallel = EqPolynomial::evals_parallel(&r, None); + assert_eq!(serial, parallel, "n={n}"); + } + } +} diff --git a/src/protocol/sumcheck/hachi_sumcheck.rs b/src/protocol/sumcheck/hachi_sumcheck.rs new file mode 100644 index 00000000..f7e1b12a --- /dev/null +++ b/src/protocol/sumcheck/hachi_sumcheck.rs @@ -0,0 +1,548 @@ +//! Fused norm+relation sumcheck prover/verifier for the Hachi PCS. +//! +//! Eliminates the redundant `w_evals` clone by sharing a single `w_table` +//! across both the norm (F_0) and relation (F_α) sumcheck computations. +//! Supports compact `Vec` storage for round 0 (all entries in [-8, 7]), +//! transitioning to `Vec` at half size after the first fold. + +use super::eq_poly::EqPolynomial; +use super::norm_sumcheck::{ + accumulate_affine_range_coeffs, range_check_eval_precomputed, trim_trailing_zeros, + NormRoundKernel, PointEvalPrecomp, RangeAffinePrecomp, +}; +use super::split_eq::GruenSplitEq; +use super::{fold_evals_in_place, multilinear_eval, range_check_eval}; +use super::{SumcheckInstanceProver, SumcheckInstanceVerifier, UniPoly}; +use crate::algebra::CyclotomicRing; +use crate::cfg_into_iter; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::ring_switch::eval_ring_at; +use crate::{FieldCore, FromSmallInt}; + +enum WTable { + Compact(Vec), + Full(Vec), +} + +/// Fused norm+relation sumcheck prover. +/// +/// Holds a single `w_table` shared by both sumcheck instances, weighted +/// by `batching_coeff`. The round polynomial is +/// `batching_coeff * norm_round(t) + relation_round(t)`. +pub struct HachiSumcheckProver { + w_table: WTable, + batching_coeff: E, + + // Norm state + split_eq: GruenSplitEq, + round_kernel: NormRoundKernel, + point_precomp: Option>, + range_precomp: Option>, + b: usize, + + // Relation state + alpha_table: Vec, + m_table: Vec, + + num_vars: usize, + relation_claim: E, +} + +impl HachiSumcheckProver { + /// Create a fused norm+relation sumcheck prover. + /// + /// # Panics + /// + /// Panics if table sizes are inconsistent with `num_u` and `num_l`. + #[allow(clippy::too_many_arguments)] + pub fn new( + batching_coeff: E, + w_evals_compact: Vec, + tau0: &[E], + b: usize, + alpha_evals_y: &[E], + m_evals_x: &[E], + num_u: usize, + num_l: usize, + ) -> Self { + let num_vars = num_u + num_l; + let n = 1usize << num_vars; + assert_eq!(w_evals_compact.len(), n); + assert_eq!(tau0.len(), num_vars); + assert_eq!(alpha_evals_y.len(), 1 << num_l); + assert_eq!(m_evals_x.len(), 1 << num_u); + + let x_mask = (1usize << num_u) - 1; + let alpha_table: Vec = cfg_into_iter!(0..n) + .map(|idx| alpha_evals_y[idx >> num_u]) + .collect(); + let m_table: Vec = cfg_into_iter!(0..n) + .map(|idx| m_evals_x[idx & x_mask]) + .collect(); + + let relation_claim = + Self::compute_relation_claim_compact(&w_evals_compact, &alpha_table, &m_table); + + let round_kernel = if b <= 8 { + NormRoundKernel::PointEvalInterpolation + } else { + NormRoundKernel::AffineCoeffComposition + }; + let point_precomp = match round_kernel { + NormRoundKernel::PointEvalInterpolation => Some(PointEvalPrecomp::new(b)), + NormRoundKernel::AffineCoeffComposition => None, + }; + let range_precomp = match round_kernel { + NormRoundKernel::PointEvalInterpolation => None, + NormRoundKernel::AffineCoeffComposition => Some(RangeAffinePrecomp::new(b)), + }; + + Self { + w_table: WTable::Compact(w_evals_compact), + batching_coeff, + split_eq: GruenSplitEq::new(tau0), + round_kernel, + point_precomp, + range_precomp, + b, + alpha_table, + m_table, + num_vars, + relation_claim, + } + } + + fn compute_relation_claim_compact(w_compact: &[i8], alpha_table: &[E], m_table: &[E]) -> E { + w_compact + .iter() + .zip(alpha_table.iter()) + .zip(m_table.iter()) + .fold(E::zero(), |acc, ((&w, &a), &m)| { + acc + E::from_i64(w as i64) * a * m + }) + } + + fn lift_i8(v: i8) -> E { + E::from_i64(v as i64) + } + + fn compute_round_norm_compact(&self, w_compact: &[i8]) -> UniPoly { + let half = w_compact.len() / 2; + + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + + match self.round_kernel { + NormRoundKernel::PointEvalInterpolation => { + let degree_q = 2 * self.b - 1; + let num_points_q = degree_q + 1; + let range_offsets = &self.point_precomp.as_ref().unwrap().range_offsets; + + let mut q_evals = vec![E::zero(); num_points_q]; + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = Self::lift_i8(w_compact[2 * j]); + let w_1 = Self::lift_i8(w_compact[2 * j + 1]); + let delta = w_1 - w_0; + let mut w_t = w_0; + for eval in q_evals.iter_mut() { + *eval = *eval + eq_rem * range_check_eval_precomputed(w_t, range_offsets); + w_t = w_t + delta; + } + } + let q_poly = UniPoly::from_evals(&q_evals); + self.split_eq.gruen_mul(&q_poly) + } + NormRoundKernel::AffineCoeffComposition => { + let range_precomp = self.range_precomp.as_ref().unwrap(); + let num_coeffs_q = range_precomp.degree_q + 1; + let coeff_mix = &range_precomp.coeff_mix; + + let mut q_coeffs = vec![E::zero(); num_coeffs_q]; + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = Self::lift_i8(w_compact[2 * j]); + let w_1 = Self::lift_i8(w_compact[2 * j + 1]); + let a = w_1 - w_0; + accumulate_affine_range_coeffs(&mut q_coeffs, coeff_mix, w_0, a, eq_rem); + } + trim_trailing_zeros(&mut q_coeffs); + let q_poly = UniPoly::from_coeffs(q_coeffs); + self.split_eq.gruen_mul(&q_poly) + } + } + } + + fn compute_round_relation_compact(&self, w_compact: &[i8]) -> UniPoly { + let half = w_compact.len() / 2; + let num_points = 3; + + let mut evals = vec![E::zero(); num_points]; + for j in 0..half { + let w_0 = Self::lift_i8(w_compact[2 * j]); + let w_1 = Self::lift_i8(w_compact[2 * j + 1]); + let a_0 = self.alpha_table[2 * j]; + let a_1 = self.alpha_table[2 * j + 1]; + let m_0 = self.m_table[2 * j]; + let m_1 = self.m_table[2 * j + 1]; + for (t, eval) in evals.iter_mut().enumerate() { + let t_e = E::from_u64(t as u64); + let w_t = w_0 + t_e * (w_1 - w_0); + let a_t = a_0 + t_e * (a_1 - a_0); + let m_t = m_0 + t_e * (m_1 - m_0); + *eval = *eval + w_t * a_t * m_t; + } + } + UniPoly::from_evals(&evals) + } + + fn compute_round_norm_full(&self, w_full: &[E]) -> UniPoly { + let half = w_full.len() / 2; + + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + + match self.round_kernel { + NormRoundKernel::PointEvalInterpolation => { + let degree_q = 2 * self.b - 1; + let num_points_q = degree_q + 1; + let range_offsets = &self.point_precomp.as_ref().unwrap().range_offsets; + + #[cfg(feature = "parallel")] + let q_evals = { + (0..half) + .into_par_iter() + .fold( + || vec![E::zero(); num_points_q], + |mut evals, j| { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = w_full[2 * j]; + let w_1 = w_full[2 * j + 1]; + let delta = w_1 - w_0; + let mut w_t = w_0; + for eval in evals.iter_mut() { + *eval = *eval + + eq_rem * range_check_eval_precomputed(w_t, range_offsets); + w_t = w_t + delta; + } + evals + }, + ) + .reduce( + || vec![E::zero(); num_points_q], + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai = *ai + *bi; + } + a + }, + ) + }; + #[cfg(not(feature = "parallel"))] + let q_evals = { + let mut evals = vec![E::zero(); num_points_q]; + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = w_full[2 * j]; + let w_1 = w_full[2 * j + 1]; + let delta = w_1 - w_0; + let mut w_t = w_0; + for eval in evals.iter_mut() { + *eval = + *eval + eq_rem * range_check_eval_precomputed(w_t, range_offsets); + w_t = w_t + delta; + } + } + evals + }; + + let q_poly = UniPoly::from_evals(&q_evals); + self.split_eq.gruen_mul(&q_poly) + } + NormRoundKernel::AffineCoeffComposition => { + let range_precomp = self.range_precomp.as_ref().unwrap(); + let num_coeffs_q = range_precomp.degree_q + 1; + let coeff_mix = &range_precomp.coeff_mix; + + #[cfg(feature = "parallel")] + let mut q_coeffs = { + (0..half) + .into_par_iter() + .fold( + || vec![E::zero(); num_coeffs_q], + |mut coeffs, j| { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = w_full[2 * j]; + let w_1 = w_full[2 * j + 1]; + let a = w_1 - w_0; + accumulate_affine_range_coeffs( + &mut coeffs, + coeff_mix, + w_0, + a, + eq_rem, + ); + coeffs + }, + ) + .reduce( + || vec![E::zero(); num_coeffs_q], + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai = *ai + *bi; + } + a + }, + ) + }; + #[cfg(not(feature = "parallel"))] + let mut q_coeffs = { + let mut coeffs = vec![E::zero(); num_coeffs_q]; + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = w_full[2 * j]; + let w_1 = w_full[2 * j + 1]; + let a = w_1 - w_0; + accumulate_affine_range_coeffs(&mut coeffs, coeff_mix, w_0, a, eq_rem); + } + coeffs + }; + + trim_trailing_zeros(&mut q_coeffs); + let q_poly = UniPoly::from_coeffs(q_coeffs); + self.split_eq.gruen_mul(&q_poly) + } + } + } + + fn compute_round_relation_full(&self, w_full: &[E]) -> UniPoly { + let half = w_full.len() / 2; + let num_points = 3; + + #[cfg(feature = "parallel")] + let round_evals = { + (0..half) + .into_par_iter() + .fold( + || vec![E::zero(); num_points], + |mut evals, j| { + let w_0 = w_full[2 * j]; + let w_1 = w_full[2 * j + 1]; + let a_0 = self.alpha_table[2 * j]; + let a_1 = self.alpha_table[2 * j + 1]; + let m_0 = self.m_table[2 * j]; + let m_1 = self.m_table[2 * j + 1]; + for (t, eval) in evals.iter_mut().enumerate() { + let t_e = E::from_u64(t as u64); + let w_t = w_0 + t_e * (w_1 - w_0); + let a_t = a_0 + t_e * (a_1 - a_0); + let m_t = m_0 + t_e * (m_1 - m_0); + *eval = *eval + w_t * a_t * m_t; + } + evals + }, + ) + .reduce( + || vec![E::zero(); num_points], + |mut a, b| { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + *ai = *ai + *bi; + } + a + }, + ) + }; + #[cfg(not(feature = "parallel"))] + let round_evals = { + let mut evals = vec![E::zero(); num_points]; + for j in 0..half { + let w_0 = w_full[2 * j]; + let w_1 = w_full[2 * j + 1]; + let a_0 = self.alpha_table[2 * j]; + let a_1 = self.alpha_table[2 * j + 1]; + let m_0 = self.m_table[2 * j]; + let m_1 = self.m_table[2 * j + 1]; + for (t, eval) in evals.iter_mut().enumerate() { + let t_e = E::from_u64(t as u64); + let w_t = w_0 + t_e * (w_1 - w_0); + let a_t = a_0 + t_e * (a_1 - a_0); + let m_t = m_0 + t_e * (m_1 - m_0); + *eval = *eval + w_t * a_t * m_t; + } + } + evals + }; + + UniPoly::from_evals(&round_evals) + } + + fn fold_compact_to_full(w_compact: &[i8], r: E) -> Vec { + let half = w_compact.len() / 2; + let mut out = Vec::with_capacity(half); + for j in 0..half { + let w_0 = Self::lift_i8(w_compact[2 * j]); + let w_1 = Self::lift_i8(w_compact[2 * j + 1]); + out.push(w_0 + r * (w_1 - w_0)); + } + out + } +} + +impl SumcheckInstanceProver for HachiSumcheckProver { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 2 * self.b + } + + fn input_claim(&self) -> E { + self.relation_claim + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + let (norm_poly, relation_poly) = match &self.w_table { + WTable::Compact(w_compact) => { + let norm = self.compute_round_norm_compact(w_compact); + let relation = self.compute_round_relation_compact(w_compact); + (norm, relation) + } + WTable::Full(w_full) => { + let norm = self.compute_round_norm_full(w_full); + let relation = self.compute_round_relation_full(w_full); + (norm, relation) + } + }; + + let max_len = norm_poly.coeffs.len().max(relation_poly.coeffs.len()); + let mut combined = vec![E::zero(); max_len]; + for (i, c) in norm_poly.coeffs.iter().enumerate() { + combined[i] = combined[i] + self.batching_coeff * *c; + } + for (i, c) in relation_poly.coeffs.iter().enumerate() { + combined[i] = combined[i] + *c; + } + UniPoly::from_coeffs(combined) + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + self.split_eq.bind(r); + + self.w_table = match std::mem::replace(&mut self.w_table, WTable::Full(Vec::new())) { + WTable::Compact(w_compact) => WTable::Full(Self::fold_compact_to_full(&w_compact, r)), + WTable::Full(mut w_full) => { + fold_evals_in_place(&mut w_full, r); + WTable::Full(w_full) + } + }; + + fold_evals_in_place(&mut self.alpha_table, r); + fold_evals_in_place(&mut self.m_table, r); + } +} + +/// Fused norm+relation sumcheck verifier. +pub struct HachiSumcheckVerifier { + batching_coeff: F, + w_evals: Vec, + tau0: Vec, + b: usize, + alpha_evals_y: Vec, + m_evals_x: Vec, + num_u: usize, + num_l: usize, + relation_claim: F, + _marker: std::marker::PhantomData<[F; D]>, +} + +impl HachiSumcheckVerifier { + /// Create a fused verifier for the norm + relation sumcheck. + #[allow(clippy::too_many_arguments)] + pub fn new( + batching_coeff: F, + w_evals: Vec, + tau0: Vec, + b: usize, + alpha_evals_y: Vec, + m_evals_x: Vec, + tau1: Vec, + v: Vec>, + u: Vec>, + y_ring: CyclotomicRing, + alpha: F, + num_u: usize, + num_l: usize, + ) -> Self { + let y_a: Vec = v + .iter() + .chain(u.iter()) + .chain(std::iter::once(&y_ring)) + .map(|r| eval_ring_at(r, &alpha)) + .collect(); + let eq_tau1 = EqPolynomial::evals(&tau1); + let mut relation_claim = F::zero(); + for (i, eq_i) in eq_tau1.iter().enumerate() { + let y_i = if i < y_a.len() { y_a[i] } else { F::zero() }; + relation_claim = relation_claim + *eq_i * y_i; + } + + Self { + batching_coeff, + w_evals, + tau0, + b, + alpha_evals_y, + m_evals_x, + num_u, + num_l, + relation_claim, + _marker: std::marker::PhantomData, + } + } +} + +impl SumcheckInstanceVerifier + for HachiSumcheckVerifier +{ + fn num_rounds(&self) -> usize { + self.num_u + self.num_l + } + + fn degree_bound(&self) -> usize { + 2 * self.b + } + + fn input_claim(&self) -> F { + self.relation_claim + } + + fn expected_output_claim(&self, challenges: &[F]) -> Result { + let eq_val = EqPolynomial::mle(&self.tau0, challenges); + let w_val = multilinear_eval(&self.w_evals, challenges)?; + let norm_oracle = eq_val * range_check_eval(w_val, self.b); + + let (x_challenges, y_challenges) = challenges.split_at(self.num_u); + let alpha_val = multilinear_eval(&self.alpha_evals_y, y_challenges)?; + let m_val = multilinear_eval(&self.m_evals_x, x_challenges)?; + let relation_oracle = w_val * alpha_val * m_val; + + Ok(self.batching_coeff * norm_oracle + relation_oracle) + } +} diff --git a/src/protocol/sumcheck/mod.rs b/src/protocol/sumcheck/mod.rs new file mode 100644 index 00000000..dc824bba --- /dev/null +++ b/src/protocol/sumcheck/mod.rs @@ -0,0 +1,196 @@ +//! Sumcheck protocol: traits, proof driver, and concrete instances. +//! +//! Types (`UniPoly`, `CompressedUniPoly`, `SumcheckProof`) live in the +//! [`types`] submodule. Polynomial evaluation utilities (`multilinear_eval`, +//! `fold_evals_in_place`, `range_check_eval`) live in [`crate::algebra::poly`]. +//! +//! ## Temporary duplication notice (Jolt integration) +//! +//! Jolt already has a mature, streaming-aware sumcheck implementation. Long-term, we +//! expect to extract the common sumcheck machinery into a dedicated crate and depend +//! on it from both Hachi and Jolt. Until that exists, this module intentionally +//! duplicates the essential sumcheck data types and transcript-driving logic as a +//! pragmatic workaround. + +pub mod batched_sumcheck; +pub mod eq_poly; +pub mod hachi_sumcheck; +pub mod norm_sumcheck; +pub mod relation_sumcheck; +pub mod split_eq; +pub mod types; + +use crate::error::HachiError; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::Transcript; +use crate::FieldCore; + +pub use crate::algebra::poly::{fold_evals_in_place, multilinear_eval, range_check_eval}; +pub use types::{CompressedUniPoly, SumcheckProof, UniPoly}; + +/// Prover-side sumcheck instance interface. +/// +/// This trait encapsulates the protocol-specific logic required to compute each +/// per-round univariate polynomial `g_j(X)` and to update (fold) internal state +/// after receiving the verifier challenge `r_j`. +/// +/// Hachi §4.3 will implement concrete instances for `H_0` and `H_α`. +pub trait SumcheckInstanceProver: Send + Sync { + /// Number of rounds (i.e. number of variables bound by sumcheck). + fn num_rounds(&self) -> usize; + + /// Maximum allowed degree for any round univariate polynomial. + fn degree_bound(&self) -> usize; + + /// The initial claimed sum that this sumcheck instance is proving. + fn input_claim(&self) -> E; + + /// Compute the prover message `g_round(X)` given the previous running claim. + /// + /// In standard sumcheck, `previous_claim` is the expected value of the + /// remaining sum after binding previous challenges, and must satisfy: + /// + /// `g_round(0) + g_round(1) == previous_claim`. + fn compute_round_univariate(&mut self, round: usize, previous_claim: E) -> UniPoly; + + /// Ingest the verifier challenge `r_round` to fold/bind the current variable. + fn ingest_challenge(&mut self, round: usize, r_round: E); + + /// Optional end-of-protocol hook after the last challenge has been ingested. + fn finalize(&mut self) {} +} + +/// Verifier-side sumcheck instance interface. +/// +/// Implementations provide the initial claim and the oracle evaluation at the +/// challenge point, enabling the verifier to perform the final consistency check. +pub trait SumcheckInstanceVerifier: Send + Sync { + /// Number of rounds (i.e. number of variables bound by sumcheck). + fn num_rounds(&self) -> usize; + + /// Maximum allowed degree for any round univariate polynomial. + fn degree_bound(&self) -> usize; + + /// The initial claimed sum that this sumcheck instance is proving. + fn input_claim(&self) -> E; + + /// Compute the expected final evaluation `f(r_0, ..., r_{n-1})` at the + /// challenge point derived during the protocol. + /// + /// # Errors + /// + /// May return an error if internal evaluations fail (e.g., malformed + /// evaluation tables from untrusted proof data). + fn expected_output_claim(&self, challenges: &[E]) -> Result; +} + +/// Produce a sumcheck proof for a single instance, driving the Fiat-Shamir transcript. +/// +/// This method: +/// - does **not** absorb the initial claim into the transcript (callers should do so), +/// - appends each round message under `labels::ABSORB_SUMCHECK_ROUND`, +/// - samples one challenge per round via `sample_challenge`, +/// - updates the running claim using the per-round hint (`g(0)+g(1)`). +/// +/// It returns the proof, the derived point `r`, and the final claimed value at `r`. +/// +/// # Errors +/// +/// Returns an error if any per-round polynomial exceeds the instance's degree bound. +pub fn prove_sumcheck( + instance: &mut Inst, + transcript: &mut T, + mut sample_challenge: S, +) -> Result<(SumcheckProof, Vec, E), HachiError> +where + F: crate::FieldCore + crate::CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, + Inst: SumcheckInstanceProver, +{ + let mut claim = instance.input_claim(); + transcript.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &claim); + + let num_rounds = instance.num_rounds(); + let degree_bound = instance.degree_bound(); + + let mut round_polys = Vec::with_capacity(num_rounds); + let mut r = Vec::with_capacity(num_rounds); + + for round in 0..num_rounds { + let g = instance.compute_round_univariate(round, claim); + debug_assert!( + g.evaluate(&E::zero()) + g.evaluate(&E::one()) == claim, + "sumcheck round univariate does not match previous claim hint" + ); + + let compressed = g.compress(); + if compressed.degree() > degree_bound { + return Err(HachiError::InvalidInput(format!( + "sumcheck round poly degree {} exceeds bound {}", + compressed.degree(), + degree_bound + ))); + } + + transcript.append_serde(labels::ABSORB_SUMCHECK_ROUND, &compressed); + let r_i = sample_challenge(transcript); + r.push(r_i); + + claim = compressed.eval_from_hint(&claim, &r_i); + + instance.ingest_challenge(round, r_i); + round_polys.push(compressed); + } + + instance.finalize(); + + let proof = SumcheckProof { round_polys }; + Ok((proof, r, claim)) +} + +/// Verify a single-instance sumcheck proof. +/// +/// This function: +/// - absorbs the initial claim into the transcript, +/// - delegates round-by-round verification to [`SumcheckProof::verify`], +/// - performs the final oracle check: `final_claim == verifier.expected_output_claim(r)`. +/// +/// Returns the challenge point `r` on success. +/// +/// # Errors +/// +/// Returns [`HachiError::InvalidProof`] if the final sumcheck claim does not +/// match the oracle evaluation, or propagates any error from the per-round +/// verification (e.g. degree-bound violation, round-count mismatch). +pub fn verify_sumcheck( + proof: &SumcheckProof, + verifier: &V, + transcript: &mut T, + sample_challenge: S, +) -> Result, HachiError> +where + F: crate::FieldCore + crate::CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, + V: SumcheckInstanceVerifier, +{ + let claim = verifier.input_claim(); + transcript.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &claim); + let (final_claim, challenges) = proof.verify::( + claim, + verifier.num_rounds(), + verifier.degree_bound(), + transcript, + sample_challenge, + )?; + + let expected = verifier.expected_output_claim(&challenges)?; + if final_claim != expected { + return Err(HachiError::InvalidProof); + } + + Ok(challenges) +} diff --git a/src/protocol/sumcheck/norm_sumcheck.rs b/src/protocol/sumcheck/norm_sumcheck.rs new file mode 100644 index 00000000..db4625cf --- /dev/null +++ b/src/protocol/sumcheck/norm_sumcheck.rs @@ -0,0 +1,750 @@ +//! Norm (range-check) sumcheck instance (F_0). +//! +//! **F_{0,τ₀}(x, y)** = ẽq(τ₀,(x,y)) · w̃(x,y) · (w̃−1)(w̃+1)···(w̃−b+1)(w̃+b−1) +//! +//! Proves that all entries of w̃ lie in {−(b−1), …, b−1}; the sum over the +//! boolean hypercube should equal zero when the range constraint holds. + +use super::eq_poly::EqPolynomial; +use super::split_eq::GruenSplitEq; +use super::{fold_evals_in_place, multilinear_eval, range_check_eval}; +use super::{SumcheckInstanceProver, SumcheckInstanceVerifier, UniPoly}; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::{FieldCore, FromSmallInt}; + +const SMALL_B_POINT_EVAL_MAX: usize = 8; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum NormRoundKernel { + PointEvalInterpolation, + AffineCoeffComposition, +} + +fn choose_round_kernel(b: usize) -> NormRoundKernel { + if b <= SMALL_B_POINT_EVAL_MAX { + NormRoundKernel::PointEvalInterpolation + } else { + NormRoundKernel::AffineCoeffComposition + } +} + +#[derive(Clone)] +pub(crate) struct RangeAffinePrecomp { + /// `coeff_mix[i][k] = c_{i+k} * binom(i+k, i)`, where + /// `R(w) = sum_m c_m * w^m` is the range-check polynomial. + pub(crate) coeff_mix: Vec>, + pub(crate) degree_q: usize, +} + +impl RangeAffinePrecomp { + pub(crate) fn new(b: usize) -> Self { + assert!(b >= 1, "b must be at least 1"); + let range_coeffs = range_check_coeffs::(b); + let degree_q = range_coeffs.len() - 1; + let small_scalars: Vec = (0..=degree_q + 1).map(|x| E::from_u64(x as u64)).collect(); + let inv_small_scalars: Vec = (0..=degree_q + 1) + .map(|x| { + if x == 0 { + E::zero() + } else { + small_scalars[x] + .inv() + .expect("field characteristic too small for range-check precomputation") + } + }) + .collect(); + let mut coeff_mix = Vec::with_capacity(degree_q + 1); + + for i in 0..=degree_q { + let row_len = degree_q - i + 1; + let mut row = Vec::with_capacity(row_len); + let mut binom_m_i = E::one(); // binom(i, i) + for k in 0..row_len { + let m = i + k; + row.push(range_coeffs[m] * binom_m_i); + if k + 1 < row_len { + let numer = small_scalars[m + 1]; + let denom_inv = inv_small_scalars[k + 1]; + binom_m_i = binom_m_i * numer * denom_inv; + } + } + coeff_mix.push(row); + } + + Self { + coeff_mix, + degree_q, + } + } +} + +#[derive(Clone)] +pub(crate) struct PointEvalPrecomp { + pub(crate) range_offsets: Vec, +} + +impl PointEvalPrecomp { + pub(crate) fn new(b: usize) -> Self { + let range_offsets = (1..b).map(|k| E::from_u64(k as u64)).collect(); + Self { range_offsets } + } +} + +/// Coefficients of `R(w) = w * Π_{k=1}^{b-1}(w-k)(w+k)` in increasing degree order. +fn range_check_coeffs(b: usize) -> Vec { + assert!(b >= 1, "b must be at least 1"); + let mut coeffs = vec![E::zero(), E::one()]; // R(w)=w when b=1 + for k in 1..b { + let k_e = E::from_u64(k as u64); + let k_sq = k_e * k_e; + // Multiply by (w^2 - k^2). + let mut next = vec![E::zero(); coeffs.len() + 2]; + for (idx, c) in coeffs.iter().enumerate() { + next[idx] = next[idx] - *c * k_sq; + next[idx + 2] = next[idx + 2] + *c; + } + coeffs = next; + } + coeffs +} + +pub(crate) fn range_check_eval_precomputed(w: E, range_offsets: &[E]) -> E { + let mut acc = w; + for &k in range_offsets { + acc = acc * (w - k) * (w + k); + } + acc +} + +pub(crate) fn accumulate_affine_range_coeffs( + out_coeffs: &mut [E], + coeff_mix: &[Vec], + w_0: E, + a: E, + scale: E, +) { + let mut a_pow = E::one(); + for (i, row) in coeff_mix.iter().enumerate() { + let mut h_i_w0 = E::zero(); + for coeff in row.iter().rev() { + h_i_w0 = h_i_w0 * w_0 + *coeff; + } + out_coeffs[i] = out_coeffs[i] + scale * a_pow * h_i_w0; + a_pow = a_pow * a; + } +} + +pub(crate) fn trim_trailing_zeros(coeffs: &mut Vec) { + while coeffs.len() > 1 && coeffs.last().is_some_and(|c| c.is_zero()) { + coeffs.pop(); + } +} + +/// Prover for `F_{0,τ₀}(x,y) = ẽq(τ₀,(x,y)) · w̃(x,y) · range_check(w̃(x,y), b)`. +/// +/// Uses the Gruen/Dao-Thaler optimization: the eq polynomial is factored into +/// a running scalar and split tables instead of being stored as a full table +/// and folded each round. The round polynomial is computed as `l(X) · q(X)` +/// where `l(X)` is the linear eq factor and `q(X)` is the inner sum without +/// the current-variable eq contribution. +pub struct NormSumcheckProver { + split_eq: GruenSplitEq, + w_table: Vec, + round_kernel: NormRoundKernel, + point_precomp: Option>, + range_precomp: Option>, + num_vars: usize, + b: usize, +} + +impl NormSumcheckProver { + /// Create a new norm (range-check) sumcheck prover. + /// + /// # Panics + /// + /// Panics if `w_evals.len() != 2^tau.len()`. + pub fn new(tau: &[E], w_evals: Vec, b: usize) -> Self { + Self::new_with_kernel(tau, w_evals, b, choose_round_kernel(b)) + } + + fn new_with_kernel( + tau: &[E], + w_evals: Vec, + b: usize, + round_kernel: NormRoundKernel, + ) -> Self { + let num_vars = tau.len(); + assert_eq!(w_evals.len(), 1 << num_vars); + let point_precomp = match round_kernel { + NormRoundKernel::PointEvalInterpolation => Some(PointEvalPrecomp::new(b)), + NormRoundKernel::AffineCoeffComposition => None, + }; + let range_precomp = match round_kernel { + NormRoundKernel::PointEvalInterpolation => None, + NormRoundKernel::AffineCoeffComposition => Some(RangeAffinePrecomp::new(b)), + }; + Self { + split_eq: GruenSplitEq::new(tau), + w_table: w_evals, + round_kernel, + point_precomp, + range_precomp, + num_vars, + b, + } + } + + fn compute_round_univariate_point_eval(&self) -> UniPoly { + let half = self.w_table.len() / 2; + let degree_q = 2 * self.b - 1; + let num_points_q = degree_q + 1; + let point_precomp = self + .point_precomp + .as_ref() + .expect("point-eval precomputation must exist"); + let range_offsets = &point_precomp.range_offsets; + + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + + #[cfg(feature = "parallel")] + let q_evals = { + (0..half) + .into_par_iter() + .fold( + || vec![E::zero(); num_points_q], + |mut evals, j| { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + let delta = w_1 - w_0; + let mut w_t = w_0; + for eval in evals.iter_mut() { + *eval = + *eval + eq_rem * range_check_eval_precomputed(w_t, range_offsets); + w_t = w_t + delta; + } + evals + }, + ) + .reduce( + || vec![E::zero(); num_points_q], + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai = *ai + *bi; + } + a + }, + ) + }; + #[cfg(not(feature = "parallel"))] + let q_evals = { + let mut evals = vec![E::zero(); num_points_q]; + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + let delta = w_1 - w_0; + let mut w_t = w_0; + for eval in evals.iter_mut() { + *eval = *eval + eq_rem * range_check_eval_precomputed(w_t, range_offsets); + w_t = w_t + delta; + } + } + evals + }; + + let q_poly = UniPoly::from_evals(&q_evals); + self.split_eq.gruen_mul(&q_poly) + } + + fn compute_round_univariate_affine_coeff(&self) -> UniPoly { + let half = self.w_table.len() / 2; + let range_precomp = self + .range_precomp + .as_ref() + .expect("affine-coeff precomputation must exist"); + let num_coeffs_q = range_precomp.degree_q + 1; + + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + let coeff_mix = &range_precomp.coeff_mix; + + #[cfg(feature = "parallel")] + let q_coeffs = { + (0..half) + .into_par_iter() + .fold( + || vec![E::zero(); num_coeffs_q], + |mut coeffs, j| { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + let a = w_1 - w_0; + accumulate_affine_range_coeffs(&mut coeffs, coeff_mix, w_0, a, eq_rem); + coeffs + }, + ) + .reduce( + || vec![E::zero(); num_coeffs_q], + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai = *ai + *bi; + } + a + }, + ) + }; + #[cfg(not(feature = "parallel"))] + let q_coeffs = { + let mut coeffs = vec![E::zero(); num_coeffs_q]; + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + let a = w_1 - w_0; + accumulate_affine_range_coeffs(&mut coeffs, coeff_mix, w_0, a, eq_rem); + } + coeffs + }; + + let mut q_coeffs = q_coeffs; + trim_trailing_zeros(&mut q_coeffs); + let q_poly = UniPoly::from_coeffs(q_coeffs); + self.split_eq.gruen_mul(&q_poly) + } +} + +impl SumcheckInstanceProver for NormSumcheckProver { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 2 * self.b + } + + fn input_claim(&self) -> E { + E::zero() + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + match self.round_kernel { + NormRoundKernel::PointEvalInterpolation => self.compute_round_univariate_point_eval(), + NormRoundKernel::AffineCoeffComposition => self.compute_round_univariate_affine_coeff(), + } + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + self.split_eq.bind(r); + fold_evals_in_place(&mut self.w_table, r); + } +} + +/// Verifier for the norm (range-check) sumcheck `F_{0,τ₀}`. +pub struct NormSumcheckVerifier { + tau: Vec, + w_evals: Vec, + num_vars: usize, + b: usize, +} + +impl NormSumcheckVerifier { + /// Create a new norm (range-check) sumcheck verifier. + /// + /// # Panics + /// + /// Panics if `w_evals.len() != 2^tau.len()`. + pub fn new(tau: Vec, w_evals: Vec, b: usize) -> Self { + let num_vars = tau.len(); + assert_eq!(w_evals.len(), 1 << num_vars); + Self { + tau, + w_evals, + num_vars, + b, + } + } +} + +impl SumcheckInstanceVerifier for NormSumcheckVerifier { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 2 * self.b + } + + fn input_claim(&self) -> E { + E::zero() + } + + fn expected_output_claim(&self, challenges: &[E]) -> Result { + let eq_val = EqPolynomial::mle(&self.tau, challenges); + let w_val = multilinear_eval(&self.w_evals, challenges)?; + Ok(eq_val * range_check_eval(w_val, self.b)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::ring::CyclotomicRing; + use crate::algebra::Fp64; + use crate::primitives::multilinear_evals::DenseMultilinearEvals; + use crate::protocol::ring_switch::r_decomp_levels; + use crate::protocol::sumcheck::eq_poly::EqPolynomial; + use crate::protocol::sumcheck::multilinear_eval; + use crate::protocol::transcript::labels; + use crate::protocol::{ + prove_sumcheck, verify_sumcheck, Blake2bTranscript, CommitmentConfig, CommitmentScheme, + HachiCommitmentScheme, SmallTestCommitmentConfig, Transcript, + }; + use crate::{FieldCore, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp64<4294967197>; + const D: usize = 8; + type Cfg = SmallTestCommitmentConfig; + type Scheme = HachiCommitmentScheme<{ Cfg::D }, Cfg>; + + struct PointEvalReferenceNormSumcheckProver { + split_eq: GruenSplitEq, + w_table: Vec, + num_vars: usize, + b: usize, + } + + impl PointEvalReferenceNormSumcheckProver { + fn new(tau: &[E], w_evals: Vec, b: usize) -> Self { + let num_vars = tau.len(); + assert_eq!(w_evals.len(), 1 << num_vars); + Self { + split_eq: GruenSplitEq::new(tau), + w_table: w_evals, + num_vars, + b, + } + } + } + + impl SumcheckInstanceProver + for PointEvalReferenceNormSumcheckProver + { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 2 * self.b + } + + fn input_claim(&self) -> E { + E::zero() + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + let half = self.w_table.len() / 2; + let degree_q = 2 * self.b - 1; + let num_points_q = degree_q + 1; + + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + let b = self.b; + + let mut q_evals = vec![E::zero(); num_points_q]; + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + for (t, eval) in q_evals.iter_mut().enumerate() { + let t_e = E::from_u64(t as u64); + let w_t = w_0 + t_e * (w_1 - w_0); + *eval = *eval + eq_rem * range_check_eval(w_t, b); + } + } + + let q_poly = UniPoly::from_evals(&q_evals); + self.split_eq.gruen_mul(&q_poly) + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + self.split_eq.bind(r); + fold_evals_in_place(&mut self.w_table, r); + } + } + + fn ring_with_small_coeff(value: u64) -> CyclotomicRing { + let coeffs = std::array::from_fn(|_| F::from_u64(value)); + CyclotomicRing::from_coefficients(coeffs) + } + + #[test] + fn norm_sumcheck_runtime_dispatch_matches_reference_kernels() { + let mut rng = StdRng::seed_from_u64(0xC0FFEE); + for (case_idx, b) in [4usize, 8, 16].into_iter().enumerate() { + let case_idx = case_idx as u64; + let num_vars = 6; + let n = 1usize << num_vars; + let w_evals: Vec = (0..n) + .map(|i| F::from_u64((i as u64 * 31 + case_idx * 17) % b as u64)) + .collect(); + let tau: Vec = (0..num_vars) + .map(|_| F::from_u64(rand::Rng::gen_range(&mut rng, 1u64..=257))) + .collect(); + + let mut dispatched = NormSumcheckProver::new(&tau, w_evals.clone(), b); + let mut point_eval = NormSumcheckProver::new_with_kernel( + &tau, + w_evals.clone(), + b, + NormRoundKernel::PointEvalInterpolation, + ); + let mut affine_coeff = NormSumcheckProver::new_with_kernel( + &tau, + w_evals.clone(), + b, + NormRoundKernel::AffineCoeffComposition, + ); + let mut reference = PointEvalReferenceNormSumcheckProver::new(&tau, w_evals, b); + + let mut claim_dispatched = F::zero(); + let mut claim_point = F::zero(); + let mut claim_affine = F::zero(); + let mut claim_reference = F::zero(); + for round in 0..num_vars { + let g_dispatch = dispatched.compute_round_univariate(round, claim_dispatched); + let g_point = point_eval.compute_round_univariate(round, claim_point); + let g_affine = affine_coeff.compute_round_univariate(round, claim_affine); + let g_ref = reference.compute_round_univariate(round, claim_reference); + + assert_eq!( + g_point, g_ref, + "point-eval mismatch for case {case_idx} round {round}" + ); + assert_eq!( + g_affine, g_ref, + "affine-coeff mismatch for case {case_idx} round {round}" + ); + match choose_round_kernel(b) { + NormRoundKernel::PointEvalInterpolation => { + assert_eq!( + g_dispatch, g_point, + "dispatch mismatch for case {case_idx} round {round}" + ); + } + NormRoundKernel::AffineCoeffComposition => { + assert_eq!( + g_dispatch, g_affine, + "dispatch mismatch for case {case_idx} round {round}" + ); + } + } + + assert_eq!( + g_dispatch.evaluate(&F::zero()) + g_dispatch.evaluate(&F::one()), + claim_dispatched, + "dispatched hint mismatch for case {case_idx} round {round}" + ); + assert_eq!( + g_ref.evaluate(&F::zero()) + g_ref.evaluate(&F::one()), + claim_reference, + "reference hint mismatch for case {case_idx} round {round}" + ); + + let r = F::from_u64(rand::Rng::gen_range(&mut rng, 1u64..=257)); + claim_dispatched = g_dispatch.evaluate(&r); + claim_point = g_point.evaluate(&r); + claim_affine = g_affine.evaluate(&r); + claim_reference = g_ref.evaluate(&r); + dispatched.ingest_challenge(round, r); + point_eval.ingest_challenge(round, r); + affine_coeff.ingest_challenge(round, r); + reference.ingest_challenge(round, r); + } + assert_eq!( + claim_dispatched, claim_reference, + "final dispatched claim mismatch for case {case_idx}" + ); + assert_eq!( + claim_point, claim_reference, + "final point claim mismatch for case {case_idx}" + ); + assert_eq!( + claim_affine, claim_reference, + "final affine claim mismatch for case {case_idx}" + ); + } + } + + #[test] + fn norm_sumcheck_uses_commitment_w_evals() { + let z = [ + ring_with_small_coeff(1), + ring_with_small_coeff(2), + ring_with_small_coeff(3), + ]; + let r = [ring_with_small_coeff(0), ring_with_small_coeff(1)]; + let levels = r_decomp_levels::(); + let r_hat: Vec> = r + .iter() + .flat_map(|ri| ri.balanced_decompose_pow2(levels, SmallTestCommitmentConfig::LOG_BASIS)) + .collect(); + let mut w_evals: Vec = z + .iter() + .chain(r_hat.iter()) + .flat_map(|elem| elem.coefficients().iter().copied()) + .collect(); + + let target_len = w_evals.len().next_power_of_two(); + w_evals.resize(target_len, F::zero()); + let num_vars = target_len.trailing_zeros() as usize; + let tau: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + let b = 1usize << SmallTestCommitmentConfig::LOG_BASIS; + + let eq_table = EqPolynomial::evals(&tau); + let _claim: F = (0..w_evals.len()) + .map(|i| eq_table[i] * range_check_eval(w_evals[i], b)) + .fold(F::zero(), |a, v| a + v); + + let mut prover = NormSumcheckProver::new(&tau, w_evals.clone(), b); + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (proof, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut pt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let oracle = EqPolynomial::mle(&tau, &prover_challenges) + * range_check_eval(multilinear_eval(&w_evals, &prover_challenges).unwrap(), b); + assert_eq!(final_claim, oracle, "prover final claim != oracle eval"); + + let verifier = NormSumcheckVerifier::new(tau, w_evals, b); + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let verifier_challenges = + verify_sumcheck::(&proof, &verifier, &mut vt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + assert_eq!(prover_challenges, verifier_challenges); + } + + #[test] + fn norm_sumcheck_uses_prove_w_evals() { + let alpha = SmallTestCommitmentConfig::D.trailing_zeros() as usize; + let layout = SmallTestCommitmentConfig::commitment_layout(8).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha; + let len = 1usize << num_vars; + let evals: Vec = (0..len).map(|i| F::from_u64(i as u64)).collect(); + let poly = DenseMultilinearEvals::new_padded(evals); + + let setup = Scheme::setup_prover(num_vars); + let (commitment, hint) = Scheme::commit(&poly, &setup).unwrap(); + + let opening_point: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + let mut prover_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let proof = Scheme::prove( + &setup, + &poly, + &opening_point, + Some(hint), + &mut prover_transcript, + &commitment, + ) + .unwrap(); + + let mut w_evals = proof.sumcheck_aux.w.clone(); + let target_len = w_evals.len().next_power_of_two(); + w_evals.resize(target_len, F::zero()); + let num_sumcheck_vars = target_len.trailing_zeros() as usize; + let tau: Vec = (0..num_sumcheck_vars) + .map(|i| F::from_u64((i + 3) as u64)) + .collect(); + let b = 1usize << SmallTestCommitmentConfig::LOG_BASIS; + + let eq_table = EqPolynomial::evals(&tau); + let _claim: F = (0..w_evals.len()) + .map(|i| eq_table[i] * range_check_eval(w_evals[i], b)) + .fold(F::zero(), |a, v| a + v); + + let mut prover = NormSumcheckProver::new(&tau, w_evals.clone(), b); + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (proof_sc, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut pt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let oracle = EqPolynomial::mle(&tau, &prover_challenges) + * range_check_eval(multilinear_eval(&w_evals, &prover_challenges).unwrap(), b); + assert_eq!(final_claim, oracle, "prover final claim != oracle eval"); + + let verifier = NormSumcheckVerifier::new(tau, w_evals, b); + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let verifier_challenges = + verify_sumcheck::(&proof_sc, &verifier, &mut vt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + assert_eq!(prover_challenges, verifier_challenges); + } + + #[test] + fn norm_sumcheck_over_ext2() { + use crate::algebra::fields::ext::Ext2; + use crate::algebra::fields::lift::LiftBase; + + type E2 = Ext2; + + let num_vars = 3; + let n = 1usize << num_vars; + let b = 2; + let w_evals_f: Vec = (0..n).map(|i| F::from_u64(i as u64 % b as u64)).collect(); + let tau_f: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + + let w_evals_e: Vec = w_evals_f.iter().map(|&f| E2::lift_base(f)).collect(); + let tau_e: Vec = tau_f.iter().map(|&f| E2::lift_base(f)).collect(); + + let mut prover = NormSumcheckProver::new(&tau_e, w_evals_e.clone(), b); + + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (proof, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut pt, |tr| { + E2::lift_base(tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND)) + }) + .unwrap(); + + let oracle = EqPolynomial::mle(&tau_e, &prover_challenges) + * range_check_eval(multilinear_eval(&w_evals_e, &prover_challenges).unwrap(), b); + assert_eq!(final_claim, oracle, "E2 prover final claim != oracle eval"); + + let verifier = NormSumcheckVerifier::new(tau_e, w_evals_e, b); + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let verifier_challenges = + verify_sumcheck::(&proof, &verifier, &mut vt, |tr| { + E2::lift_base(tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND)) + }) + .unwrap(); + + assert_eq!(prover_challenges, verifier_challenges); + } +} diff --git a/src/protocol/sumcheck/relation_sumcheck.rs b/src/protocol/sumcheck/relation_sumcheck.rs new file mode 100644 index 00000000..1f03a09a --- /dev/null +++ b/src/protocol/sumcheck/relation_sumcheck.rs @@ -0,0 +1,408 @@ +//! Evaluation-relation sumcheck instance (F_α). +//! +//! **F_{α,τ₁}(x, y)** = w̃(x,y) · α̃(y) · m(x) +//! where m(x) = Σ_i ẽq(τ₁,i) · M̃_α(i,x). +//! +//! Proves the evaluation relation; sum equals `a = Σ_i ẽq(τ₁,i) · y_i(α)`. + +use super::eq_poly::EqPolynomial; +use super::{fold_evals_in_place, multilinear_eval}; +use super::{SumcheckInstanceProver, SumcheckInstanceVerifier, UniPoly}; +use crate::algebra::ring::CyclotomicRing; +use crate::cfg_into_iter; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::ring_switch::eval_ring_at; +use crate::{FieldCore, FromSmallInt}; + +/// Prover for `F_{α,τ₁}(x,y) = w̃(x,y) · α̃(y) · m(x)`. +/// +/// All three constituent evaluation tables are stored at full domain size +/// (`2^{num_u + num_l}`). `α̃` is replicated along x dimensions and `m` along +/// y dimensions so that a uniform fold-by-pairs works in every round. +/// +/// Round polynomial degree is 2 (product of at most two multilinear factors +/// depending on any single variable). +pub struct RelationSumcheckProver { + w_table: Vec, + alpha_table: Vec, + m_table: Vec, + num_vars: usize, +} + +impl RelationSumcheckProver { + /// Construct from the three constituent evaluation tables. + /// + /// - `w_evals`: evaluations of `w̃` over `{0,1}^{num_u + num_l}` (full domain). + /// - `alpha_evals_y`: evaluations of `α̃` over `{0,1}^{num_l}` (compact). + /// - `m_evals_x`: evaluations of `m` over `{0,1}^{num_u}` (compact). + /// + /// The constructor extends the compact tables to the full domain by replication. + /// + /// # Panics + /// + /// Panics if table sizes don't match `2^num_u`, `2^num_l`, or `2^(num_u+num_l)`. + pub fn new( + w_evals: Vec, + alpha_evals_y: &[E], + m_evals_x: &[E], + num_u: usize, + num_l: usize, + ) -> Self { + let num_vars = num_u + num_l; + let n = 1usize << num_vars; + assert_eq!(w_evals.len(), n); + assert_eq!(alpha_evals_y.len(), 1 << num_l); + assert_eq!(m_evals_x.len(), 1 << num_u); + + let x_mask = (1usize << num_u) - 1; + let alpha_table: Vec = cfg_into_iter!(0..n) + .map(|idx| alpha_evals_y[idx >> num_u]) + .collect(); + let m_table: Vec = cfg_into_iter!(0..n) + .map(|idx| m_evals_x[idx & x_mask]) + .collect(); + + Self { + w_table: w_evals, + alpha_table, + m_table, + num_vars, + } + } +} + +impl SumcheckInstanceProver for RelationSumcheckProver { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 2 + } + + fn input_claim(&self) -> E { + #[cfg(feature = "parallel")] + { + self.w_table + .par_iter() + .zip(self.alpha_table.par_iter()) + .zip(self.m_table.par_iter()) + .fold(|| E::zero(), |acc, ((&w, &a), &m)| acc + w * a * m) + .reduce(|| E::zero(), |a, b| a + b) + } + #[cfg(not(feature = "parallel"))] + { + self.w_table + .iter() + .zip(self.alpha_table.iter()) + .zip(self.m_table.iter()) + .fold(E::zero(), |acc, ((&w, &a), &m)| acc + w * a * m) + } + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + let half = self.w_table.len() / 2; + let num_points = 3; // degree 2 → 3 evaluation points + + #[cfg(feature = "parallel")] + let round_evals = { + (0..half) + .into_par_iter() + .fold( + || vec![E::zero(); num_points], + |mut evals, j| { + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + let a_0 = self.alpha_table[2 * j]; + let a_1 = self.alpha_table[2 * j + 1]; + let m_0 = self.m_table[2 * j]; + let m_1 = self.m_table[2 * j + 1]; + for (t, eval) in evals.iter_mut().enumerate() { + let t_e = E::from_u64(t as u64); + let w_t = w_0 + t_e * (w_1 - w_0); + let a_t = a_0 + t_e * (a_1 - a_0); + let m_t = m_0 + t_e * (m_1 - m_0); + *eval = *eval + w_t * a_t * m_t; + } + evals + }, + ) + .reduce( + || vec![E::zero(); num_points], + |mut a, b| { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + *ai = *ai + *bi; + } + a + }, + ) + }; + #[cfg(not(feature = "parallel"))] + let round_evals = { + let mut evals = vec![E::zero(); num_points]; + for j in 0..half { + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + let a_0 = self.alpha_table[2 * j]; + let a_1 = self.alpha_table[2 * j + 1]; + let m_0 = self.m_table[2 * j]; + let m_1 = self.m_table[2 * j + 1]; + for (t, eval) in evals.iter_mut().enumerate() { + let t_e = E::from_u64(t as u64); + let w_t = w_0 + t_e * (w_1 - w_0); + let a_t = a_0 + t_e * (a_1 - a_0); + let m_t = m_0 + t_e * (m_1 - m_0); + *eval = *eval + w_t * a_t * m_t; + } + } + evals + }; + + UniPoly::from_evals(&round_evals) + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + fold_evals_in_place(&mut self.w_table, r); + fold_evals_in_place(&mut self.alpha_table, r); + fold_evals_in_place(&mut self.m_table, r); + } +} + +/// Verifier for the evaluation-relation sumcheck `F_{α,τ₁}`. +pub struct RelationSumcheckVerifier { + w_evals: Vec, + alpha_evals_y: Vec, + m_evals_x: Vec, + tau: Vec, + v: Vec>, + u: Vec>, + y_ring: CyclotomicRing, + alpha: F, + num_u: usize, + num_l: usize, +} + +impl RelationSumcheckVerifier { + /// Create a new evaluation-relation sumcheck verifier. + /// + /// # Panics + /// + /// Panics if table sizes don't match `2^num_u`, `2^num_l`, or `2^(num_u+num_l)`. + #[allow(clippy::too_many_arguments)] + pub fn new( + w_evals: Vec, + alpha_evals_y: Vec, + m_evals_x: Vec, + tau: Vec, + v: Vec>, + u: Vec>, + y_ring: CyclotomicRing, + alpha: F, + num_u: usize, + num_l: usize, + ) -> Self { + assert_eq!(w_evals.len(), 1 << (num_u + num_l)); + assert_eq!(alpha_evals_y.len(), 1 << num_l); + assert_eq!(m_evals_x.len(), 1 << num_u); + Self { + w_evals, + alpha_evals_y, + m_evals_x, + tau, + v, + u, + y_ring, + alpha, + num_u, + num_l, + } + } +} + +impl SumcheckInstanceVerifier for RelationSumcheckVerifier { + fn num_rounds(&self) -> usize { + self.num_u + self.num_l + } + + fn degree_bound(&self) -> usize { + 2 + } + + fn input_claim(&self) -> F { + let y_a: Vec = self + .v + .iter() + .chain(self.u.iter()) + .chain(std::iter::once(&self.y_ring)) + .map(|r| eval_ring_at(r, &self.alpha)) + .collect(); + + let eq_tau = EqPolynomial::evals(&self.tau); + let mut acc = F::zero(); + for (i, eq_i) in eq_tau.iter().enumerate() { + let y_i = if i < y_a.len() { y_a[i] } else { F::zero() }; + acc = acc + *eq_i * y_i; + } + acc + } + + fn expected_output_claim(&self, challenges: &[F]) -> Result { + let (x_challenges, y_challenges) = challenges.split_at(self.num_u); + let w_val = multilinear_eval(&self.w_evals, challenges)?; + let alpha_val = multilinear_eval(&self.alpha_evals_y, y_challenges)?; + let m_val = multilinear_eval(&self.m_evals_x, x_challenges)?; + Ok(w_val * alpha_val * m_val) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Fp64; + use crate::primitives::multilinear_evals::DenseMultilinearEvals; + use crate::protocol::commitment_scheme::rederive_alpha_and_m_a; + use crate::protocol::sumcheck::eq_poly::EqPolynomial; + use crate::protocol::transcript::labels; + use crate::protocol::{ + prove_sumcheck, verify_sumcheck, Blake2bTranscript, CommitmentConfig, CommitmentScheme, + HachiCommitmentScheme, SmallTestCommitmentConfig, Transcript, + }; + use crate::{FieldCore, FromSmallInt}; + + type F = Fp64<4294967197>; + type Cfg = SmallTestCommitmentConfig; + type Scheme = HachiCommitmentScheme<{ Cfg::D }, Cfg>; + + #[test] + fn relation_sumcheck_uses_prove_w_evals() { + let alpha_bits = SmallTestCommitmentConfig::D.trailing_zeros() as usize; + let layout = SmallTestCommitmentConfig::commitment_layout(8).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha_bits; + let len = 1usize << num_vars; + let evals: Vec = (0..len).map(|i| F::from_u64(i as u64)).collect(); + let poly = DenseMultilinearEvals::new_padded(evals); + + let setup = Scheme::setup_prover(num_vars); + let (commitment, hint) = Scheme::commit(&poly, &setup).unwrap(); + + let opening_point: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + let mut prover_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let proof = Scheme::prove( + &setup, + &poly, + &opening_point, + Some(hint), + &mut prover_transcript, + &commitment, + ) + .unwrap(); + + let (alpha, m_a_vec) = rederive_alpha_and_m_a::( + &proof, + &Scheme::setup_verifier(&setup), + &opening_point, + &commitment, + ) + .unwrap(); + + let d = SmallTestCommitmentConfig::D; + assert_eq!(proof.sumcheck_aux.w.len() % d, 0); + let w_u = proof.sumcheck_aux.w.len() / d; + let rows = SmallTestCommitmentConfig::N_D + + SmallTestCommitmentConfig::N_B + + 1 + + 1 + + SmallTestCommitmentConfig::N_A; + assert!(rows > 0); + assert_eq!(m_a_vec.len() % rows, 0); + let cols = m_a_vec.len() / rows; + assert_eq!(w_u, cols); + + let num_u = cols.next_power_of_two().trailing_zeros() as usize; + let num_l = alpha_bits; + let n = 1usize << (num_u + num_l); + + let mut w_evals = vec![F::zero(); n]; + let y_len = 1usize << num_l; + let x_len = 1usize << num_u; + for x in 0..x_len { + for y in 0..y_len { + let src = y + (x << num_l); + if src < proof.sumcheck_aux.w.len() { + let dst = x + (y << num_u); + w_evals[dst] = proof.sumcheck_aux.w[src]; + } + } + } + + let num_i = rows.next_power_of_two().trailing_zeros() as usize; + let tau1: Vec = (0..num_i).map(|i| F::from_u64((i + 5) as u64)).collect(); + let eq_tau1 = EqPolynomial::evals(&tau1); + + let mut m_evals_x = vec![F::zero(); x_len]; + for x in 0..x_len { + let mut acc = F::zero(); + for i in 0..(1usize << num_i) { + let row_val = if i < rows && x < cols { + m_a_vec[i * cols + x] + } else { + F::zero() + }; + acc = acc + eq_tau1[i] * row_val; + } + m_evals_x[x] = acc; + } + + let mut alpha_evals_y = vec![F::zero(); y_len]; + let mut power = F::one(); + for val in alpha_evals_y.iter_mut() { + *val = power; + power = power * alpha; + } + + let x_mask = x_len - 1; + let alpha_full: Vec = (0..n).map(|idx| alpha_evals_y[idx >> num_u]).collect(); + let m_full: Vec = (0..n).map(|idx| m_evals_x[idx & x_mask]).collect(); + let _claim: F = (0..n) + .map(|i| w_evals[i] * alpha_full[i] * m_full[i]) + .fold(F::zero(), |a, v| a + v); + + let mut prover = + RelationSumcheckProver::new(w_evals.clone(), &alpha_evals_y, &m_evals_x, num_u, num_l); + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (proof_sc, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut pt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let (x_ch, y_ch) = prover_challenges.split_at(num_u); + let oracle = multilinear_eval(&w_evals, &prover_challenges).unwrap() + * multilinear_eval(&alpha_evals_y, y_ch).unwrap() + * multilinear_eval(&m_evals_x, x_ch).unwrap(); + assert_eq!(final_claim, oracle, "prover final claim != oracle eval"); + + let verifier = RelationSumcheckVerifier::new( + w_evals, + alpha_evals_y, + m_evals_x, + tau1, + proof.v.clone(), + commitment.u.clone(), + proof.y_ring, + alpha, + num_u, + num_l, + ); + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let verifier_challenges = + verify_sumcheck::(&proof_sc, &verifier, &mut vt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + assert_eq!(prover_challenges, verifier_challenges); + } +} diff --git a/src/protocol/sumcheck/split_eq.rs b/src/protocol/sumcheck/split_eq.rs new file mode 100644 index 00000000..9ecf7269 --- /dev/null +++ b/src/protocol/sumcheck/split_eq.rs @@ -0,0 +1,213 @@ +//! Gruen/Dao-Thaler split equality polynomial for efficient sumcheck. +//! +//! Factors `eq(τ, x)` into a running scalar, a linear factor for the current +//! variable, and precomputed split tables for the remaining variables. This +//! avoids maintaining and folding a full-size eq table during sumcheck. +//! +//! For details, see . +//! +//! Adapted from Jolt's `GruenSplitEqPolynomial`. +//! +//! ## Variable Layout (forward binding, little-endian) +//! +//! ```text +//! τ = [τ_current, τ_first_half, τ_second_half] +//! 1 var m vars (n-1-m) vars +//! ``` +//! +//! where `m = (n-1) / 2` and `n = τ.len()`. +//! +//! After binding `τ_current`, the next variable comes from `τ_first_half`, +//! then from `τ_second_half`. Suffix-cached eq tables for each half enable +//! O(1) pops per round instead of an O(2^n) fold. + +use super::eq_poly::EqPolynomial; +use super::UniPoly; +use crate::FieldCore; + +/// Split equality polynomial with Gruen scalar accumulation. +/// +/// Instead of storing and folding a full eq table each round, this struct +/// maintains: +/// - `current_scalar`: accumulated `eq(τ_bound, r_bound)` from already-bound +/// variables +/// - `E_first` / `E_second`: suffix-cached eq tables for two halves of the +/// remaining (unbound, non-current) variables +/// +/// The eq contribution for a pair index `j` in the inner sum is: +/// ```text +/// eq_remaining(j) = E_first[j_low] · E_second[j_high] +/// ``` +/// and the full round polynomial is `l(X) · q(X)` where `l(X)` is the linear +/// eq factor for the current variable. +#[allow(non_snake_case)] +pub struct GruenSplitEq { + tau: Vec, + current_round: usize, + current_scalar: E, + /// Suffix-cached eq tables for the first half of remaining variables. + /// `E_first[k]` = `eq(τ[split-k..split], ·)` with `2^k` entries. + /// Invariant: never empty; `E_first[0] = [1]`. + E_first: Vec>, + /// Suffix-cached eq tables for the second half of remaining variables. + /// `E_second[k]` = `eq(τ[n-k..n], ·)` with `2^k` entries. + /// Invariant: never empty; `E_second[0] = [1]`. + E_second: Vec>, +} + +#[allow(non_snake_case)] +impl GruenSplitEq { + /// Create a new split-eq from the full challenge vector `τ`. + /// + /// Precomputes suffix-cached eq tables for two halves of `τ[1..n]`. + /// + /// # Panics + /// + /// Panics if `tau` is empty. + pub fn new(tau: &[E]) -> Self { + let n = tau.len(); + assert!(n >= 1); + let m = (n - 1) / 2; + let split = 1 + m; + let first_half = &tau[1..split]; + let second_half = &tau[split..n]; + let E_first = EqPolynomial::evals_cached(first_half); + let E_second = EqPolynomial::evals_cached(second_half); + Self { + tau: tau.to_vec(), + current_round: 0, + current_scalar: E::one(), + E_first, + E_second, + } + } + + /// The accumulated scalar `Π_{k < current_round} eq(τ[k], r[k])`. + pub fn current_scalar(&self) -> E { + self.current_scalar + } + + /// The τ value for the variable about to be bound. + pub fn current_tau(&self) -> E { + self.tau[self.current_round] + } + + /// Return the current top-level split-eq tables `(E_first, E_second)`. + /// + /// For a pair index `j` in the inner sum, the eq factor for the + /// remaining (non-current) variables is: + /// ```text + /// eq_remaining(j) = E_first[j & (E_first.len()-1)] + /// · E_second[j >> E_first.len().trailing_zeros()] + /// ``` + /// + /// # Panics + /// + /// Panics if either `E_first` or `E_second` is empty (invariant violation). + pub fn remaining_eq_tables(&self) -> (&[E], &[E]) { + ( + self.E_first.last().expect("E_first is never empty"), + self.E_second.last().expect("E_second is never empty"), + ) + } + + /// Bind the current variable to challenge `r`, advancing to the next round. + /// + /// Updates `current_scalar` with `eq(τ[current_round], r)` and pops the + /// appropriate split table level. + pub fn bind(&mut self, r: E) { + let tau_k = self.tau[self.current_round]; + self.current_scalar = + self.current_scalar * (tau_k * r + (E::one() - tau_k) * (E::one() - r)); + self.current_round += 1; + if self.E_first.len() > 1 { + self.E_first.pop(); + } else if self.E_second.len() > 1 { + self.E_second.pop(); + } + } + + /// Compute the round polynomial `s(X) = l(X) · q(X)` from the inner + /// polynomial `q` (given as evaluations at integer points `0, 1, ..., d`). + /// + /// `l(X) = current_scalar · eq(τ_current, X)` is the linear eq factor + /// for the current variable. The result has degree `d + 1`. + pub fn gruen_mul(&self, q_poly: &UniPoly) -> UniPoly { + let tau_k = self.current_tau(); + let scalar = self.current_scalar(); + let l_0 = scalar * (E::one() - tau_k); + let l_1 = scalar * tau_k; + let slope = l_1 - l_0; + let mut coeffs = vec![E::zero(); q_poly.coeffs.len() + 1]; + for (i, &c) in q_poly.coeffs.iter().enumerate() { + coeffs[i] = coeffs[i] + c * l_0; + coeffs[i + 1] = coeffs[i + 1] + c * slope; + } + UniPoly::from_coeffs(coeffs) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Fp64; + use crate::protocol::sumcheck::fold_evals_in_place; + use crate::{FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp64<4294967197>; + + #[test] + fn gruen_eq_matches_full_eq_table() { + let mut rng = StdRng::seed_from_u64(0xBB); + for n in 1..10 { + let tau: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let mut full_eq = EqPolynomial::evals(&tau); + let mut split_eq = GruenSplitEq::new(&tau); + + for _round in 0..n { + let half = full_eq.len() / 2; + let (e_first, e_second) = split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> num_first.trailing_zeros(); + let eq_rem = e_first[j_low] * e_second[j_high]; + + let tau_k = split_eq.current_tau(); + let scalar = split_eq.current_scalar(); + let eq_0 = scalar * (F::one() - tau_k) * eq_rem; + let eq_1 = scalar * tau_k * eq_rem; + + assert_eq!(eq_0, full_eq[2 * j], "n={n} round={_round} j={j} eq_0"); + assert_eq!(eq_1, full_eq[2 * j + 1], "n={n} round={_round} j={j} eq_1"); + } + + let r = F::sample(&mut rng); + fold_evals_in_place(&mut full_eq, r); + split_eq.bind(r); + } + } + } + + #[test] + fn gruen_mul_matches_direct_product() { + let mut rng = StdRng::seed_from_u64(0xCC); + let tau: Vec = (0..5).map(|_| F::sample(&mut rng)).collect(); + let split_eq = GruenSplitEq::new(&tau); + + let q = UniPoly::from_coeffs(vec![F::from_u64(3), F::from_u64(7), F::from_u64(2)]); + let s = split_eq.gruen_mul(&q); + + let tau_k = split_eq.current_tau(); + let scalar = split_eq.current_scalar(); + for t in 0..10u64 { + let x = F::from_u64(t); + let l_x = scalar * (tau_k * x + (F::one() - tau_k) * (F::one() - x)); + let q_x = q.evaluate(&x); + assert_eq!(s.evaluate(&x), l_x * q_x, "t={t}"); + } + } +} diff --git a/src/protocol/sumcheck/types.rs b/src/protocol/sumcheck/types.rs new file mode 100644 index 00000000..7fbf4d3b --- /dev/null +++ b/src/protocol/sumcheck/types.rs @@ -0,0 +1,378 @@ +//! Sumcheck data types: univariate polynomials, compressed representation, and proof container. + +use crate::error::HachiError; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::Transcript; +use crate::FieldCore; +use crate::FromSmallInt; +use std::io::{Read, Write}; + +/// Univariate polynomial in coefficient form: `p(X) = Σ_{i=0}^d coeffs[i] * X^i`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UniPoly { + /// Coefficients from low degree to high degree. + pub coeffs: Vec, +} + +impl UniPoly { + /// Construct from coefficients in increasing-degree order. + pub fn from_coeffs(coeffs: Vec) -> Self { + Self { coeffs } + } + + /// Degree of the polynomial (0 for empty or constant). + pub fn degree(&self) -> usize { + self.coeffs.len().saturating_sub(1) + } + + /// Evaluate at `x` via Horner's method. + pub fn evaluate(&self, x: &E) -> E { + let mut acc = E::zero(); + for c in self.coeffs.iter().rev() { + acc = acc * *x + *c; + } + acc + } + + /// Compress this polynomial by omitting the linear coefficient. + /// + /// The verifier can reconstruct/evaluate the missing linear coefficient using + /// the per-round hint `g(0)+g(1)` from the sumcheck protocol. + /// + /// This matches the technique used by Jolt's sumcheck (`CompressedUniPoly`). + pub fn compress(&self) -> CompressedUniPoly { + let coeffs = &self.coeffs; + if coeffs.is_empty() { + return CompressedUniPoly { + coeffs_except_linear_term: Vec::new(), + }; + } + if coeffs.len() == 1 { + return CompressedUniPoly { + coeffs_except_linear_term: vec![coeffs[0]], + }; + } + let mut out = Vec::with_capacity(coeffs.len().saturating_sub(1)); + out.push(coeffs[0]); + out.extend_from_slice(&coeffs[2..]); + CompressedUniPoly { + coeffs_except_linear_term: out, + } + } +} + +impl UniPoly { + /// Interpolate from evaluations at equispaced integer points `x = 0, 1, ..., d`. + /// + /// Uses Newton forward-difference interpolation: compute divided differences, + /// then expand via Horner on the nested Newton form. + /// + /// # Panics + /// + /// Panics if any required factorial inverse does not exist (field characteristic + /// must exceed the number of evaluation points). This is a prover-only + /// function and the condition always holds for Hachi's fields. + pub fn from_evals(evals: &[E]) -> Self { + let n = evals.len(); + if n == 0 { + return Self::from_coeffs(vec![]); + } + if n == 1 { + return Self::from_coeffs(vec![evals[0]]); + } + + let mut table = evals.to_vec(); + let mut deltas = vec![table[0]]; + for _ in 1..n { + for j in 0..table.len() - 1 { + table[j] = table[j + 1] - table[j]; + } + table.pop(); + deltas.push(table[0]); + } + + let mut factorial = E::one(); + let mut divided_diffs = vec![deltas[0]]; + for (k, delta_k) in deltas.iter().enumerate().skip(1) { + factorial = factorial * E::from_u64(k as u64); + divided_diffs.push( + *delta_k + * factorial + .inv() + .expect("field characteristic too small for interpolation"), + ); + } + + let mut coeffs = vec![divided_diffs[n - 1]]; + + for k in (0..n - 1).rev() { + let shift = E::from_u64(k as u64); + let old_len = coeffs.len(); + let mut new_coeffs = vec![E::zero(); old_len + 1]; + + new_coeffs[0] = divided_diffs[k]; + for i in 0..old_len { + new_coeffs[i + 1] = new_coeffs[i + 1] + coeffs[i]; + new_coeffs[i] = new_coeffs[i] - shift * coeffs[i]; + } + + coeffs = new_coeffs; + } + + while coeffs.len() > 1 && coeffs.last().is_some_and(|c| c.is_zero()) { + coeffs.pop(); + } + + Self::from_coeffs(coeffs) + } +} + +impl Valid for UniPoly { + fn check(&self) -> Result<(), SerializationError> { + self.coeffs.check() + } +} + +impl HachiSerialize for UniPoly { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.coeffs.serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.coeffs.serialized_size(compress) + } +} + +impl HachiDeserialize for UniPoly { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let coeffs = Vec::::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { coeffs }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Compressed univariate polynomial representation omitting the linear term. +/// +/// We store `[c0, c2, c3, ..., cd]`. Given the sumcheck hint `hint = g(0)+g(1)`, +/// the missing linear coefficient is: +/// +/// `c1 = hint - 2*c0 - Σ_{i=2..d} ci`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CompressedUniPoly { + /// Coefficients excluding the linear term: `[c0, c2, c3, ..., cd]`. + pub coeffs_except_linear_term: Vec, +} + +impl CompressedUniPoly { + /// Degree of the underlying uncompressed polynomial. + /// + /// `compress()` stores `[c0, c2, ..., cd]` — exactly `d` entries for + /// degree `d >= 2`. For `len <= 1` (degree 0 or 1, which are ambiguous + /// in compressed form) we report 0; this is conservative for the + /// verifier's degree-bound check since `degree_bound >= 2` in practice. + pub fn degree(&self) -> usize { + let len = self.coeffs_except_linear_term.len(); + if len <= 1 { + 0 + } else { + len + } + } + + fn recover_linear_term(&self, hint: &E) -> E { + if self.coeffs_except_linear_term.is_empty() { + return E::zero(); + } + + let c0 = self.coeffs_except_linear_term[0]; + let mut linear = *hint - c0 - c0; + for c in self.coeffs_except_linear_term.iter().skip(1) { + linear = linear - *c; + } + linear + } + + /// Decompress using `hint = g(0)+g(1)`. + pub fn decompress(&self, hint: &E) -> UniPoly { + if self.coeffs_except_linear_term.is_empty() { + return UniPoly::from_coeffs(Vec::new()); + } + let linear = self.recover_linear_term(hint); + let mut coeffs = Vec::with_capacity(self.coeffs_except_linear_term.len() + 1); + coeffs.push(self.coeffs_except_linear_term[0]); + coeffs.push(linear); + coeffs.extend_from_slice(&self.coeffs_except_linear_term[1..]); + UniPoly::from_coeffs(coeffs) + } + + /// Evaluate the uncompressed polynomial at `x`, using `hint = g(0)+g(1)`. + /// + /// This avoids materializing the full coefficient list. + pub fn eval_from_hint(&self, hint: &E, x: &E) -> E { + if self.coeffs_except_linear_term.is_empty() { + return E::zero(); + } + + let linear = self.recover_linear_term(hint); + let mut acc = self.coeffs_except_linear_term[0] + (*x * linear); + + let mut pow = *x * *x; + for c in self.coeffs_except_linear_term.iter().skip(1) { + acc = acc + (*c * pow); + pow = pow * *x; + } + acc + } +} + +impl Valid for CompressedUniPoly { + fn check(&self) -> Result<(), SerializationError> { + self.coeffs_except_linear_term.check() + } +} + +impl HachiSerialize for CompressedUniPoly { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.coeffs_except_linear_term + .serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.coeffs_except_linear_term.serialized_size(compress) + } +} + +impl HachiDeserialize for CompressedUniPoly { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let coeffs_except_linear_term = + Vec::::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { + coeffs_except_linear_term, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Sumcheck proof containing one compressed univariate polynomial per round. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SumcheckProof { + /// One compressed univariate polynomial per sumcheck round. + pub round_polys: Vec>, +} + +impl Valid for SumcheckProof { + fn check(&self) -> Result<(), SerializationError> { + self.round_polys.check() + } +} + +impl HachiSerialize for SumcheckProof { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.round_polys.serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.round_polys.serialized_size(compress) + } +} + +impl HachiDeserialize for SumcheckProof { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let round_polys = + Vec::>::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { round_polys }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl SumcheckProof { + /// Verifier-side sumcheck transcript driver. + /// + /// This method: + /// - absorbs the per-round prover message (compressed univariate), + /// - samples one challenge per round via `sample_challenge`, + /// - updates the running claim using `eval_from_hint`. + /// + /// It does **not** perform the final oracle check `final_claim == f(r*)`. + /// Callers (e.g. ring-switching) must compute `f(r*)` themselves and compare. + /// + /// # Errors + /// + /// Returns an error if the proof length does not match `num_rounds` or if any + /// per-round polynomial exceeds `degree_bound`. + pub fn verify( + &self, + mut claim: E, + num_rounds: usize, + degree_bound: usize, + transcript: &mut T, + mut sample_challenge: S, + ) -> Result<(E, Vec), HachiError> + where + F: crate::FieldCore + crate::CanonicalField, + T: Transcript, + S: FnMut(&mut T) -> E, + { + if self.round_polys.len() != num_rounds { + return Err(HachiError::InvalidSize { + expected: num_rounds, + actual: self.round_polys.len(), + }); + } + + let mut r = Vec::with_capacity(num_rounds); + for poly in &self.round_polys { + if poly.degree() > degree_bound { + return Err(HachiError::InvalidInput(format!( + "sumcheck round poly degree {} exceeds bound {}", + poly.degree(), + degree_bound + ))); + } + + transcript.append_serde(labels::ABSORB_SUMCHECK_ROUND, poly); + let r_i = sample_challenge(transcript); + r.push(r_i); + + claim = poly.eval_from_hint(&claim, &r_i); + } + + Ok((claim, r)) + } +} diff --git a/src/protocol/transcript/hash.rs b/src/protocol/transcript/hash.rs new file mode 100644 index 00000000..cfae7b30 --- /dev/null +++ b/src/protocol/transcript/hash.rs @@ -0,0 +1,103 @@ +//! Generic hash-based transcript for protocol-layer Fiat-Shamir. +//! +//! Parameterised over any `Digest + Clone` hasher, eliminating the +//! near-identical Blake2b and Keccak implementations. + +use super::Transcript; +use crate::primitives::serialization::HachiSerialize; +use crate::{CanonicalField, FieldCore}; +use blake2::{Blake2b512, Digest}; +use sha3::Keccak256; +use std::marker::PhantomData; + +/// Hash-based transcript with labeled framing. +/// +/// Works with any cryptographic hash that implements `Digest + Clone`. +#[derive(Clone)] +pub struct HashTranscript +where + F: FieldCore + CanonicalField + 'static, +{ + hasher: D, + _field: PhantomData, +} + +impl HashTranscript +where + F: FieldCore + CanonicalField + 'static, +{ + #[inline] + fn append_bytes_impl(&mut self, label: &[u8], bytes: &[u8]) { + self.hasher.update(label); + self.hasher.update((bytes.len() as u64).to_le_bytes()); + self.hasher.update(bytes); + } + + #[inline] + fn challenge_and_chain(&mut self, label: &[u8]) -> Vec { + self.hasher.update(label); + let digest = self.hasher.clone().finalize(); + let out = digest.to_vec(); + self.hasher.update(&out); + out + } +} + +impl Transcript for HashTranscript +where + F: FieldCore + CanonicalField + 'static, +{ + fn new(domain_label: &[u8]) -> Self { + let mut hasher = D::new(); + hasher.update(domain_label); + Self { + hasher, + _field: PhantomData, + } + } + + fn append_bytes(&mut self, label: &[u8], bytes: &[u8]) { + self.append_bytes_impl(label, bytes); + } + + fn append_field(&mut self, label: &[u8], x: &F) { + self.append_bytes_impl(label, &x.to_canonical_u128().to_le_bytes()); + } + + fn append_serde(&mut self, label: &[u8], s: &S) { + let mut bytes = Vec::new(); + s.serialize_compressed(&mut bytes) + .expect("HachiSerialize should not fail"); + self.append_bytes_impl(label, &bytes); + } + + fn challenge_scalar(&mut self, label: &[u8]) -> F { + let bytes = self.challenge_and_chain(label); + let mut lo = [0u8; 16]; + lo.copy_from_slice(&bytes[..16]); + let sampled = u128::from_le_bytes(lo); + F::from_canonical_u128_reduced(sampled) + } +} + +impl HashTranscript +where + F: FieldCore + CanonicalField + 'static, +{ + /// Reset transcript state under a new domain label. + /// + /// This is an inherent method (not part of the `Transcript` trait) to + /// discourage use in production protocol code where resetting the + /// Fiat-Shamir chain would be unsound. + pub fn reset(&mut self, domain_label: &[u8]) { + let mut hasher = D::new(); + hasher.update(domain_label); + self.hasher = hasher; + } +} + +/// Blake2b512 transcript with labeled framing. +pub type Blake2bTranscript = HashTranscript; + +/// Keccak256 transcript with labeled framing. +pub type KeccakTranscript = HashTranscript; diff --git a/src/protocol/transcript/labels.rs b/src/protocol/transcript/labels.rs new file mode 100644 index 00000000..0cfa6d4a --- /dev/null +++ b/src/protocol/transcript/labels.rs @@ -0,0 +1,72 @@ +//! Hachi-native transcript labels. +//! +//! These constants are the single source of truth for protocol transcript +//! labels in Hachi core. External integrations should translate at adapter +//! boundaries instead of introducing foreign label names here. + +/// Top-level protocol domain label. +pub const DOMAIN_HACHI_PROTOCOL: &[u8] = b"hachi/protocol"; + +/// Absorb commitment object(s) (paper §4.1). +pub const ABSORB_COMMITMENT: &[u8] = b"hachi/absorb/commitment"; +/// Absorb claimed openings/evaluations before relation reduction (paper §4.2). +pub const ABSORB_EVALUATION_CLAIMS: &[u8] = b"hachi/absorb/evaluation-claims"; +/// Challenge for the evaluation-to-linear-relation reduction (paper §4.2). +pub const CHALLENGE_LINEAR_RELATION: &[u8] = b"hachi/challenge/linear-relation"; +/// Absorb ring-switch relation messages (paper §4.3). +pub const ABSORB_RING_SWITCH_MESSAGE: &[u8] = b"hachi/absorb/ring-switch-message"; +/// Challenge used by ring-switching relation checks (paper §4.3). +pub const CHALLENGE_RING_SWITCH: &[u8] = b"hachi/challenge/ring-switch"; +/// Absorb sparse-challenge sampling context (e.g. for short/sparse ring `c`). +pub const ABSORB_SPARSE_CHALLENGE: &[u8] = b"hachi/absorb/sparse-challenge"; +/// Challenge bytes used to sample sparse challenges (e.g. ring `c` with weight ω). +pub const CHALLENGE_SPARSE_CHALLENGE: &[u8] = b"hachi/challenge/sparse-challenge"; +/// Absorb the initial sumcheck claim before round messages begin. +pub const ABSORB_SUMCHECK_CLAIM: &[u8] = b"hachi/absorb/sumcheck-claim"; +/// Absorb per-round sumcheck messages (paper §4.3). +pub const ABSORB_SUMCHECK_ROUND: &[u8] = b"hachi/absorb/sumcheck-round"; +/// Challenge sampled per sumcheck round (paper §4.3). +pub const CHALLENGE_SUMCHECK_ROUND: &[u8] = b"hachi/challenge/sumcheck-round"; +/// Challenge for batched sumcheck coefficient sampling. +pub const CHALLENGE_SUMCHECK_BATCH: &[u8] = b"hachi/challenge/sumcheck-batch"; +/// Absorb recursion/stop-condition message payloads (paper §4.5). +pub const ABSORB_STOP_CONDITION: &[u8] = b"hachi/absorb/stop-condition"; +/// Challenge sampled for recursion stop-condition checks (paper §4.5). +pub const CHALLENGE_STOP_CONDITION: &[u8] = b"hachi/challenge/stop-condition"; + +/// Absorb the prover's stage-1 message `v = D · ŵ` (paper §4.2, Figure 3). +pub const ABSORB_PROVER_V: &[u8] = b"hachi/absorb/prover-stage1-v"; +/// Challenge label for stage-1 fold (sampling sparse `c_i`). +pub const CHALLENGE_STAGE1_FOLD: &[u8] = b"hachi/challenge/stage1-fold"; + +/// Absorb the `w` coefficient vector before sumcheck (paper §4.3). +pub const ABSORB_SUMCHECK_W: &[u8] = b"hachi/absorb/sumcheck-w"; +/// Challenge for sampling `τ₀` (F_0 range-check batching point, paper §4.3). +pub const CHALLENGE_TAU0: &[u8] = b"hachi/challenge/tau0"; +/// Challenge for sampling `τ₁` (F_α evaluation-relation batching point, paper §4.3). +pub const CHALLENGE_TAU1: &[u8] = b"hachi/challenge/tau1"; + +/// Return all Hachi-core transcript labels. +pub fn all_labels() -> &'static [&'static [u8]] { + &[ + DOMAIN_HACHI_PROTOCOL, + ABSORB_COMMITMENT, + ABSORB_EVALUATION_CLAIMS, + CHALLENGE_LINEAR_RELATION, + ABSORB_RING_SWITCH_MESSAGE, + CHALLENGE_RING_SWITCH, + ABSORB_SPARSE_CHALLENGE, + CHALLENGE_SPARSE_CHALLENGE, + ABSORB_SUMCHECK_CLAIM, + ABSORB_SUMCHECK_ROUND, + CHALLENGE_SUMCHECK_ROUND, + CHALLENGE_SUMCHECK_BATCH, + ABSORB_STOP_CONDITION, + CHALLENGE_STOP_CONDITION, + ABSORB_PROVER_V, + CHALLENGE_STAGE1_FOLD, + ABSORB_SUMCHECK_W, + CHALLENGE_TAU0, + CHALLENGE_TAU1, + ] +} diff --git a/src/protocol/transcript/mod.rs b/src/protocol/transcript/mod.rs new file mode 100644 index 00000000..05434ddb --- /dev/null +++ b/src/protocol/transcript/mod.rs @@ -0,0 +1,50 @@ +//! Protocol transcript contracts and implementations. + +mod hash; +pub mod labels; + +use crate::algebra::fields::lift::ExtField; +use crate::{CanonicalField, FieldCore, HachiSerialize}; + +pub use hash::{Blake2bTranscript, HashTranscript, KeccakTranscript}; + +/// Transcript interface for protocol Fiat-Shamir transforms. +/// +/// The protocol layer is label-aware and uses deterministic byte encoding for +/// all absorbed values. +pub trait Transcript: Clone + Send + Sync + 'static +where + F: FieldCore + CanonicalField, +{ + /// Construct a new transcript under a domain label. + fn new(domain_label: &[u8]) -> Self; + + /// Append labeled raw bytes. + fn append_bytes(&mut self, label: &[u8], bytes: &[u8]); + + /// Append a field element with deterministic encoding. + fn append_field(&mut self, label: &[u8], x: &F); + + /// Append a serializable protocol value. + fn append_serde(&mut self, label: &[u8], s: &S); + + /// Derive a challenge scalar under the provided label. + fn challenge_scalar(&mut self, label: &[u8]) -> F; +} + +/// Sample an extension field challenge by drawing `EXT_DEGREE` base-field +/// challenges and assembling them via `from_base_slice`. +/// +/// When `E = F` (degree 1), this compiles to a single `challenge_scalar` call. +pub fn sample_ext_challenge(tr: &mut T, label: &[u8]) -> E +where + F: FieldCore + CanonicalField, + T: Transcript, + E: ExtField, +{ + E::from_base_slice( + &(0..E::EXT_DEGREE) + .map(|_| tr.challenge_scalar(label)) + .collect::>(), + ) +} diff --git a/src/test_utils.rs b/src/test_utils.rs new file mode 100644 index 00000000..5de0450f --- /dev/null +++ b/src/test_utils.rs @@ -0,0 +1,126 @@ +//! Shared test configuration and helpers. +//! +//! This module is only compiled under `#[cfg(test)]` and provides common +//! building blocks for both unit tests (inside `src/`) and integration +//! tests (inside `tests/`). + +use crate::algebra::{CyclotomicRing, Fp64}; +use crate::protocol::commitment::CommitmentConfig; +use crate::{FieldCore, FromSmallInt}; + +pub type F = Fp64<4294967197>; +pub const D: usize = 64; + +#[derive(Clone)] +pub struct TinyConfig; + +impl CommitmentConfig for TinyConfig { + const D: usize = 64; + const N_A: usize = 2; + const N_B: usize = 2; + const N_D: usize = 2; + const LOG_BASIS: u32 = 4; + const DELTA: usize = 9; + const TAU: usize = 4; + const CHALLENGE_WEIGHT: usize = 3; + + fn commitment_layout( + _max_num_vars: usize, + ) -> Result { + crate::protocol::commitment::HachiCommitmentLayout::new::(1, 1) + } +} + +pub const BLOCK_LEN: usize = 2; +pub const NUM_BLOCKS: usize = 2; +pub const DELTA: usize = TinyConfig::DELTA; +pub const LOG_BASIS: u32 = TinyConfig::LOG_BASIS; +pub const N_A: usize = TinyConfig::N_A; +pub const TAU: usize = TinyConfig::TAU; + +pub fn mat_vec_mul( + mat: &[Vec>], + vec: &[CyclotomicRing], +) -> Vec> { + mat.iter() + .map(|row| { + assert_eq!(row.len(), vec.len()); + row.iter() + .zip(vec.iter()) + .fold(CyclotomicRing::::zero(), |acc, (a, x)| { + acc + (*a * *x) + }) + }) + .collect() +} + +pub fn sample_blocks() -> Vec>> { + (0..NUM_BLOCKS) + .map(|bi| { + (0..BLOCK_LEN) + .map(|bj| { + let coeffs = + std::array::from_fn(|k| F::from_u64((bi * 1_000 + bj * 100 + k) as u64)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect() +} + +pub fn sample_a() -> Vec { + (0..BLOCK_LEN) + .map(|j| F::from_u64((j * 10 + 1) as u64)) + .collect() +} + +pub fn sample_b() -> Vec { + (0..NUM_BLOCKS) + .map(|i| F::from_u64((i * 7 + 3) as u64)) + .collect() +} + +pub fn field_gadget_recompose( + parts: &[CyclotomicRing], + log_basis: u32, +) -> CyclotomicRing { + let b = F::from_u64(1u64 << log_basis); + let mut result = CyclotomicRing::::zero(); + let mut b_power = F::one(); + for part in parts { + result += part.scale(&b_power); + b_power = b_power * b; + } + result +} + +pub fn recompose_z_hat(z_hat: &[CyclotomicRing]) -> Vec> { + z_hat + .chunks(TAU) + .map(|chunk| field_gadget_recompose(chunk, LOG_BASIS)) + .collect() +} + +pub fn gadget_recompose_vec(x_hat: &[CyclotomicRing]) -> Vec> { + x_hat + .chunks(DELTA) + .map(|chunk| field_gadget_recompose(chunk, LOG_BASIS)) + .collect() +} + +pub fn field_gadget_recompose_vec(v: &[CyclotomicRing]) -> Vec> { + v.chunks(DELTA) + .map(|chunk| field_gadget_recompose(chunk, LOG_BASIS)) + .collect() +} + +pub fn a_transpose_gadget_times_vec(a: &[F], z: &[CyclotomicRing]) -> CyclotomicRing { + let recomposed = field_gadget_recompose_vec(z); + assert_eq!(recomposed.len(), a.len()); + recomposed + .iter() + .zip(a.iter()) + .fold(CyclotomicRing::::zero(), |acc, (z_j, a_j)| { + acc + z_j.scale(a_j) + }) +} diff --git a/tests/algebra.rs b/tests/algebra.rs new file mode 100644 index 00000000..668cacd0 --- /dev/null +++ b/tests/algebra.rs @@ -0,0 +1,1469 @@ +#![allow(missing_docs)] + +#[cfg(test)] +mod tests { + use num_bigint::BigUint; + use rand::{rngs::StdRng, SeedableRng}; + + use hachi_pcs::algebra::backend::{CrtReconstruct, NttPrimeOps}; + use hachi_pcs::algebra::ntt::butterfly::{forward_ntt, inverse_ntt, NttTwiddles}; + use hachi_pcs::algebra::poly::Poly; + use hachi_pcs::algebra::tables::{ + q128_garner, q128_primes, q32_garner, q64_garner, q64_primes, Q128_MODULUS, + Q128_NUM_PRIMES, Q32_MODULUS, Q32_NUM_PRIMES, Q32_PRIMES, Q64_MODULUS, Q64_NUM_PRIMES, + }; + use hachi_pcs::algebra::{ + pseudo_mersenne_modulus, Pow2Offset128Field, Pow2OffsetPrimeSpec, POW2_OFFSET_MAX, + POW2_OFFSET_PRIMES, POW2_OFFSET_TABLE, + }; + use hachi_pcs::algebra::{ + CyclotomicCrtNtt, CyclotomicRing, Fp128, Fp2, Fp2Config, Fp32, Fp4, Fp4Config, Fp64, LimbQ, + MontCoeff, Prime128M13M4P0, Prime128M37P3P0, Prime128M52M3P0, Prime128M54P4P0, + Prime128M8M4M1M0, ScalarBackend, VectorModule, + }; + use hachi_pcs::primitives::serialization::SerializationError; + use hachi_pcs::{ + CanonicalField, FieldCore, FieldSampling, FromSmallInt, HachiDeserialize, HachiSerialize, + Invertible, Module, PseudoMersenneField, + }; + + const P_159: u128 = 340282366920938463463374607431768211297u128; + + #[test] + fn fp32_basic_arith() { + type F = Fp32<251>; + let a = F::from_u64(17); + let b = F::from_u64(99); + assert_eq!((a + b).to_canonical_u32(), 116); + assert_eq!((a * b).to_canonical_u32(), (17u32 * 99) % 251); + + let inv = a.inv().unwrap(); + assert_eq!(a * inv, F::one()); + } + + #[test] + fn fp64_hachi_q_inv() { + type F = Fp64<4294967197>; + let two = F::from_u64(2); + let inv2 = two.inv().unwrap(); + assert_eq!(two * inv2, F::one()); + } + + #[test] + fn fp128_basic_arith() { + type F = Fp128; + + let a = F::from_u64(123); + let b = F::from_u64(456); + let c = a * b + a - b; + let inv = c.inv().unwrap(); + assert_eq!(c * inv, F::one()); + } + + fn rand_u128(rng: &mut R) -> u128 { + let lo = rng.next_u64() as u128; + let hi = rng.next_u64() as u128; + lo | (hi << 64) + } + + fn biguint_to_u128(x: &num_bigint::BigUint) -> u128 { + let mut bytes = x.to_bytes_le(); + bytes.resize(16, 0); + let mut arr = [0u8; 16]; + arr.copy_from_slice(&bytes[..16]); + u128::from_le_bytes(arr) + } + + fn big_mul_mod_u128(a: u128, b: u128, p: u128) -> u128 { + let n = BigUint::from(a) * BigUint::from(b); + let r = n % BigUint::from(p); + biguint_to_u128(&r) + } + + fn check_solinas_prime< + S: CanonicalField + FieldCore + Invertible + PseudoMersenneField + std::fmt::Debug, + >( + p: u128, + iters: usize, + seed: u64, + ) { + assert_eq!(::MODULUS_BITS, 128); + assert_eq!( + ::MODULUS_OFFSET, + 0u128.wrapping_sub(p) + ); + assert_eq!(std::mem::size_of::(), 16); + + let mut rng = StdRng::seed_from_u64(seed); + + for _ in 0..iters { + let a_raw = rand_u128(&mut rng); + let b_raw = rand_u128(&mut rng); + + let a = S::from_canonical_u128_reduced(a_raw); + let b = S::from_canonical_u128_reduced(b_raw); + + // Canonical range invariant. + assert!(a.to_canonical_u128() < p); + assert!(b.to_canonical_u128() < p); + + // Add/sub/neg identities. + assert_eq!(a + S::zero(), a); + assert_eq!(a - S::zero(), a); + assert_eq!(a + (-a), S::zero()); + + // Multiplicative identity. + assert_eq!(a * S::one(), a); + + // BigUint oracle for mul and sqr (exercises reduction). + let aa = a.to_canonical_u128(); + let bb = b.to_canonical_u128(); + let got_mul = (a * b).to_canonical_u128(); + let exp_mul = big_mul_mod_u128(aa, bb, p); + assert_eq!(got_mul, exp_mul); + + let got_sqr = (a * a).to_canonical_u128(); + let exp_sqr = big_mul_mod_u128(aa, aa, p); + assert_eq!(got_sqr, exp_sqr); + + // Inversion checks (skip explicit inv on zero). + let inv = a.inv_or_zero(); + if a.is_zero() { + assert_eq!(inv, S::zero()); + } else { + assert_eq!(a * inv, S::one()); + assert_eq!(a.inv().unwrap(), inv); + } + } + } + + #[test] + fn fp128_sparse_primes_match_biguint_oracle() { + // These are the five sparse `2^128 - c` primes we care about. + const P13: u128 = 0xffffffffffffffffffffffffffffdff1u128; + const P37: u128 = 0xffffffffffffffffffffffe000000009u128; + const P52: u128 = 0xffffffffffffffffffeffffffffffff9u128; + const P54: u128 = 0xffffffffffffffffffc0000000000011u128; + const P275: u128 = 0xfffffffffffffffffffffffffffffeedu128; + + check_solinas_prime::(P13, 2_000, 13); + check_solinas_prime::(P37, 2_000, 37); + check_solinas_prime::(P52, 2_000, 52); + check_solinas_prime::(P54, 2_000, 54); + check_solinas_prime::(P275, 2_000, 275); + } + + struct NR; + impl Fp2Config> for NR { + fn non_residue() -> Fp32<251> { + -Fp32::<251>::one() + } + } + + struct NR4; + impl Fp4Config, NR> for NR4 { + fn non_residue() -> Fp2, NR> { + Fp2::new(Fp32::<251>::zero(), Fp32::<251>::one()) + } + } + + #[test] + fn fp2_fp4_inversion_smoke() { + type F = Fp32<251>; + type F2 = Fp2; + type F4 = Fp4; + + let x = F2::new(F::from_u64(3), F::from_u64(7)); + let inv = x.inv().unwrap(); + assert!((x * inv) == F2::one()); + + let y = F4::new( + F2::new(F::from_u64(5), F::from_u64(1)), + F2::new(F::from_u64(2), F::from_u64(9)), + ); + let invy = y.inv().unwrap(); + assert!((y * invy) == F4::one()); + } + + #[test] + fn vector_module_ops() { + type F = Fp32<251>; + + let a = VectorModule::([F::from_u64(1), F::from_u64(2), F::from_u64(3)]); + let b = VectorModule::([F::from_u64(3), F::from_u64(4), F::from_u64(5)]); + + let c = a + b; + assert_eq!(c.0[0], F::from_u64(4)); + + let d = a.scale(&F::from_u64(7)); + assert_eq!(d.0[1], F::from_u64(14)); + } + + #[test] + fn inv_zero_returns_none() { + assert!(Fp32::<251>::zero().inv().is_none()); + assert!(Fp64::<4294967197>::zero().inv().is_none()); + assert!(Fp128::::zero().inv().is_none()); + } + + #[test] + fn inv_or_zero_behavior_for_prime_fields() { + type F32 = Fp32<251>; + assert_eq!(F32::zero().inv_or_zero(), F32::zero()); + let x32 = F32::from_u64(17); + let inv32 = x32.inv_or_zero(); + assert_eq!(x32 * inv32, F32::one()); + + type F64 = Fp64<4294967197>; + assert_eq!(F64::zero().inv_or_zero(), F64::zero()); + let x64 = F64::from_u64(2); + let inv64 = x64.inv_or_zero(); + assert_eq!(x64 * inv64, F64::one()); + + type F128 = Fp128; + assert_eq!(F128::zero().inv_or_zero(), F128::zero()); + let x128 = F128::from_u64(12345); + let inv128 = x128.inv_or_zero(); + assert_eq!(x128 * inv128, F128::one()); + } + + #[test] + fn field_identities_fp32() { + type F = Fp32<251>; + let a = F::from_u64(42); + let b = F::from_u64(73); + let c = F::from_u64(11); + + // Additive identity + assert_eq!(a + F::zero(), a); + // Multiplicative identity + assert_eq!(a * F::one(), a); + // Additive inverse + assert_eq!(a + (-a), F::zero()); + // Distributivity + assert_eq!(a * (b + c), a * b + a * c); + // Commutativity + assert_eq!(a * b, b * a); + assert_eq!(a + b, b + a); + } + + #[test] + fn field_identities_fp64() { + type F = Fp64<4294967197>; + let a = F::from_u64(123456); + let b = F::from_u64(789012); + let c = F::from_u64(345678); + + assert_eq!(a + F::zero(), a); + assert_eq!(a * F::one(), a); + assert_eq!(a + (-a), F::zero()); + assert_eq!(a * (b + c), a * b + a * c); + } + + #[test] + fn field_identities_fp128() { + type F = Fp128; + let a = F::from_u64(999999); + let b = F::from_u64(888888); + let c = F::from_u64(777777); + + assert_eq!(a + F::zero(), a); + assert_eq!(a * F::one(), a); + assert_eq!(a + (-a), F::zero()); + assert_eq!(a * (b + c), a * b + a * c); + } + + #[test] + fn serialization_round_trip_fp32() { + type F = Fp32<251>; + let val = F::from_u64(42); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn serialization_round_trip_fp64() { + type F = Fp64<4294967197>; + let val = F::from_u64(123456789); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn serialization_round_trip_fp128() { + type F = Fp128; + let val = F::from_u64(999999999); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn serialization_round_trip_ext() { + type F = Fp32<251>; + type F2 = Fp2; + let val = F2::new(F::from_u64(3), F::from_u64(7)); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F2::deserialize_compressed(&buf[..]).unwrap(); + assert!(val == restored); + } + + #[test] + fn serialization_round_trip_fp4() { + type F = Fp32<251>; + type F2 = Fp2; + type F4 = Fp4; + + let val = F4::new( + F2::new(F::from_u64(5), F::from_u64(1)), + F2::new(F::from_u64(2), F::from_u64(9)), + ); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F4::deserialize_compressed(&buf[..]).unwrap(); + assert!(val == restored); + } + + #[test] + fn serialization_round_trip_vector_module() { + type F = Fp32<251>; + let val = VectorModule::([F::from_u64(1), F::from_u64(2), F::from_u64(3)]); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = VectorModule::::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn serialization_round_trip_poly() { + type F = Fp32<251>; + + let val = Poly::([ + F::from_u64(7), + F::from_u64(11), + F::from_u64(13), + F::from_u64(29), + ]); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = Poly::::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn deserialize_checked_rejects_non_canonical_field_elements() { + type F32 = Fp32<251>; + let bad32 = 251u32.to_le_bytes(); + let err32 = F32::deserialize_compressed(&bad32[..]).unwrap_err(); + assert!(matches!(err32, SerializationError::InvalidData(_))); + let unchecked32 = F32::deserialize_compressed_unchecked(&bad32[..]).unwrap(); + assert_eq!(unchecked32, F32::zero()); + + type F64 = Fp64<4294967197>; + let bad64 = 4294967197u64.to_le_bytes(); + let err64 = F64::deserialize_compressed(&bad64[..]).unwrap_err(); + assert!(matches!(err64, SerializationError::InvalidData(_))); + let unchecked64 = F64::deserialize_compressed_unchecked(&bad64[..]).unwrap(); + assert_eq!(unchecked64, F64::zero()); + + type F128 = Fp128; + let bad128 = P_159.to_le_bytes(); + let err128 = F128::deserialize_compressed(&bad128[..]).unwrap_err(); + assert!(matches!(err128, SerializationError::InvalidData(_))); + let unchecked128 = F128::deserialize_compressed_unchecked(&bad128[..]).unwrap(); + assert_eq!(unchecked128, F128::zero()); + + // Sparse 128-bit prime: same checked/unchecked behavior. + type S13 = Prime128M13M4P0; + const P13: u128 = 0xffffffffffffffffffffffffffffdff1u128; + let bad13 = P13.to_le_bytes(); + let err13 = S13::deserialize_compressed(&bad13[..]).unwrap_err(); + assert!(matches!(err13, SerializationError::InvalidData(_))); + let unchecked13 = S13::deserialize_compressed_unchecked(&bad13[..]).unwrap(); + assert_eq!(unchecked13, S13::zero()); + } + + #[test] + fn fp2_conjugate_and_norm() { + type F = Fp32<251>; + type F2 = Fp2; + let x = F2::new(F::from_u64(3), F::from_u64(7)); + let conj = x.conjugate(); + assert!(conj == F2::new(F::from_u64(3), -F::from_u64(7))); + // For Fp2 with u^2 = -1: norm = c0^2 + c1^2 = 9 + 49 = 58 + assert_eq!(x.norm(), F::from_u64(58)); + // x * conjugate(x) should embed the norm into Fp2 + let prod = x * conj; + assert!(prod == F2::new(F::from_u64(58), F::zero())); + } + + #[test] + fn fp2_distributivity() { + type F = Fp32<251>; + type F2 = Fp2; + let a = F2::new(F::from_u64(3), F::from_u64(7)); + let b = F2::new(F::from_u64(11), F::from_u64(5)); + let c = F2::new(F::from_u64(2), F::from_u64(9)); + assert!(a * (b + c) == a * b + a * c); + } + + #[test] + fn limbq_from_to_u128_round_trip() { + for &val in &[0u128, 1, 12345, 123456789, (1u128 << 28) - 1] { + let limb: LimbQ<3> = LimbQ::from(val); + assert_eq!( + u128::try_from(limb).unwrap(), + val, + "round-trip failed for {val}" + ); + } + } + + #[test] + fn limbq_add_sub_inverse() { + let a: LimbQ<3> = LimbQ::from(12345u128); + let b: LimbQ<3> = LimbQ::from(6789u128); + let sum = a + b; + let diff = sum - b; + assert_eq!(diff, a); + } + + #[test] + fn limbq_ordering() { + let a: LimbQ<3> = LimbQ::from(100u128); + let b: LimbQ<3> = LimbQ::from(200u128); + assert!(a < b); + assert!(b > a); + assert_eq!(a, a); + } + + #[test] + fn ntt_normalize_in_range() { + for prime in &Q32_PRIMES { + for &a in &[0i16, 1, -1, 100, -100, prime.p - 1, -(prime.p - 1)] { + let n = prime.normalize(MontCoeff::from_raw(a)); + assert!( + n.raw() >= 0 && n.raw() < prime.p, + "normalize({a}) = {} for p={}", + n.raw(), + prime.p + ); + } + } + } + + #[test] + fn csubp_widened_handles_large_negative_i16() { + for &prime in &Q32_PRIMES { + let p = prime.p; + // Values in (-2p, -(2^15 - p)) that previously overflowed in narrow i16 + for &raw in &[-20000i16, -(p + p / 2), -(p + 1000)] { + if raw <= -2 * p || raw >= 0 { + continue; + } + let a = MontCoeff::from_raw(raw); + let reduced = prime.reduce_range(a); + let r = reduced.raw(); + assert!( + r > -p && r < p, + "reduce_range({raw}) = {r} not in (-{p}, {p}) for p={p}" + ); + + let norm = prime.normalize(reduced); + let n = norm.raw(); + assert!( + n >= 0 && n < p, + "normalize(reduce_range({raw})) = {n} not in [0, {p}) for p={p}" + ); + } + } + } + + #[test] + fn ntt_mul_commutative() { + let prime = Q32_PRIMES[0]; + let a = MontCoeff::from_raw(1234); + let b = MontCoeff::from_raw(5678); + assert_eq!(prime.mul(a, b), prime.mul(b, a)); + } + + #[test] + fn mont_coeff_round_trip() { + for prime in &Q32_PRIMES { + for &val in &[0i16, 1, 2, 100, prime.p - 1] { + let mont = prime.from_canonical(val); + let back = prime.to_canonical(mont); + assert_eq!(back, val, "round-trip failed for val={val}, p={}", prime.p); + } + } + } + + #[test] + fn poly_add_sub_neg() { + type F = Fp32<251>; + let a = Poly::([F::from_u64(1), F::from_u64(2), F::from_u64(3)]); + let b = Poly::([F::from_u64(10), F::from_u64(20), F::from_u64(30)]); + + let sum = a + b; + assert_eq!(sum.0[0], F::from_u64(11)); + assert_eq!(sum.0[1], F::from_u64(22)); + assert_eq!(sum.0[2], F::from_u64(33)); + + let diff = b - a; + assert_eq!(diff.0[0], F::from_u64(9)); + + let neg_a = -a; + assert_eq!(a + neg_a, Poly::zero()); + } + + #[test] + fn cyclotomic_ring_negacyclic_property() { + type F = Fp32<251>; + type R = CyclotomicRing; + + // X in the ring: [0, 1, 0, 0] + let x = R::x(); + + // X^2 + let x2 = x * x; + let expected_x2 = R::from_coefficients([F::zero(), F::zero(), F::one(), F::zero()]); + assert_eq!(x2, expected_x2); + + // X^4 should equal -1 (because X^4 + 1 = 0 in the ring) + let x4 = x2 * x2; + assert_eq!(x4, -R::one(), "X^D should equal -1 in Z_q[X]/(X^D + 1)"); + } + + #[test] + fn cyclotomic_ring_mul_identity() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + assert_eq!(a * R::one(), a); + assert_eq!(R::one() * a, a); + } + + #[test] + fn cyclotomic_ring_mul_zero() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + assert_eq!(a * R::zero(), R::zero()); + } + + #[test] + fn cyclotomic_ring_commutativity() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + let b = R::from_coefficients([ + F::from_u64(5), + F::from_u64(13), + F::from_u64(99), + F::from_u64(1), + ]); + assert_eq!(a * b, b * a); + } + + #[test] + fn cyclotomic_ring_distributivity() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + let b = R::from_coefficients([ + F::from_u64(5), + F::from_u64(13), + F::from_u64(99), + F::from_u64(1), + ]); + let c = R::from_coefficients([ + F::from_u64(2), + F::from_u64(9), + F::from_u64(50), + F::from_u64(77), + ]); + assert_eq!(a * (b + c), a * b + a * c); + } + + #[test] + fn cyclotomic_ring_associativity() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + let b = R::from_coefficients([ + F::from_u64(5), + F::from_u64(13), + F::from_u64(99), + F::from_u64(1), + ]); + let c = R::from_coefficients([ + F::from_u64(2), + F::from_u64(9), + F::from_u64(50), + F::from_u64(77), + ]); + assert_eq!((a * b) * c, a * (b * c)); + } + + #[test] + fn cyclotomic_ring_additive_inverse() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + assert_eq!(a + (-a), R::zero()); + } + + #[test] + fn cyclotomic_ring_serialization_round_trip() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + let mut buf = Vec::new(); + a.serialize_compressed(&mut buf).unwrap(); + let restored = R::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(a, restored); + } + + #[test] + fn cyclotomic_ring_degree_64() { + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + // X^64 = -1 in Z_q[X]/(X^64 + 1) + let x = R::x(); + let mut power = R::one(); + for _ in 0..64 { + power *= x; + } + assert_eq!(power, -R::one(), "X^64 should equal -1"); + } + + #[test] + fn ntt_forward_inverse_round_trip() { + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + + let original: [MontCoeff; 64] = + std::array::from_fn(|i| prime.from_canonical((i as i16) % prime.p)); + + let mut a = original; + forward_ntt(&mut a, prime, &tw); + inverse_ntt(&mut a, prime, &tw); + + for (i, (got, expected)) in a.iter().zip(original.iter()).enumerate() { + let got_canon = prime.to_canonical(prime.normalize(*got)); + let exp_canon = prime.to_canonical(prime.normalize(*expected)); + assert_eq!( + got_canon, exp_canon, + "NTT round-trip mismatch at index {i}: got {got_canon}, expected {exp_canon}" + ); + } + } + + #[test] + fn ntt_forward_inverse_all_primes() { + for (pi, prime) in Q32_PRIMES.iter().enumerate() { + let tw = NttTwiddles::::compute(*prime); + + let original: [_; 64] = + std::array::from_fn(|i| prime.from_canonical(((i * (pi + 1)) as i16) % prime.p)); + + let mut a = original; + forward_ntt(&mut a, *prime, &tw); + inverse_ntt(&mut a, *prime, &tw); + + for (i, (got, expected)) in a.iter().zip(original.iter()).enumerate() { + let g = prime.to_canonical(prime.normalize(*got)); + let e = prime.to_canonical(prime.normalize(*expected)); + assert_eq!( + g, e, + "prime[{pi}] p={}: round-trip mismatch at index {i}", + prime.p + ); + } + } + } + + #[test] + fn negacyclic_ntt_mul_matches_schoolbook_single_prime_d8() { + const D: usize = 8; + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + + let a_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 7) + 3) % prime.p); + let b_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 5) + 11) % prime.p); + + // Schoolbook negacyclic convolution mod p: X^D = -1. + let mut school = [0i16; D]; + for (i, &ai) in a_canon.iter().enumerate() { + for (j, &bj) in b_canon.iter().enumerate() { + let prod = (ai as i64 * bj as i64) % (prime.p as i64); + let idx = i + j; + if idx < D { + school[idx] = ((school[idx] as i64 + prod) % (prime.p as i64)) as i16; + } else { + let k = idx - D; + school[k] = ((school[k] as i64 - prod) % (prime.p as i64)) as i16; + } + } + } + for x in &mut school { + if *x < 0 { + *x = (*x as i64 + prime.p as i64) as i16; + } + } + + let mut a = std::array::from_fn(|i| prime.from_canonical(a_canon[i])); + let mut b = std::array::from_fn(|i| prime.from_canonical(b_canon[i])); + forward_ntt(&mut a, prime, &tw); + forward_ntt(&mut b, prime, &tw); + + let mut c: [_; D] = std::array::from_fn(|i| prime.mul(a[i], b[i])); + inverse_ntt(&mut c, prime, &tw); + + let got: [i16; D] = std::array::from_fn(|i| prime.to_canonical(prime.normalize(c[i]))); + assert_eq!(got, school); + } + + #[test] + fn negacyclic_ntt_forward_matches_manual_evals_d8() { + const D: usize = 8; + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + let p = prime.p as i64; + + fn pow_mod(mut base: i64, mut exp: i64, modulus: i64) -> i64 { + let mut acc = 1i64; + base %= modulus; + while exp > 0 { + if exp & 1 == 1 { + acc = (acc * base) % modulus; + } + base = (base * base) % modulus; + exp >>= 1; + } + acc + } + + // Compute canonical psi (primitive 2D-th root) directly. + let half = (p - 1) / 2; + let exp = (p - 1) / (2 * D as i64); + let mut psi = None; + for a in 2..p { + if pow_mod(a, half, p) == p - 1 { + let cand = pow_mod(a, exp, p); + if pow_mod(cand, D as i64, p) == p - 1 { + psi = Some(cand); + break; + } + } + } + let psi = psi.expect("psi should exist"); + let a_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 7) + 3) % prime.p); + + let mut expected = Vec::with_capacity(D); + for k in 0..D { + let alpha = pow_mod(psi, (2 * k + 1) as i64, p); + let mut acc = 0i64; + let mut power = 1i64; + for &ai in &a_canon { + acc = (acc + (ai as i64) * power) % p; + power = (power * alpha) % p; + } + expected.push(acc as i16); + } + expected.sort_unstable(); + + let mut a = std::array::from_fn(|i| prime.from_canonical(a_canon[i])); + forward_ntt(&mut a, prime, &tw); + let mut got: Vec = a + .iter() + .map(|x| prime.to_canonical(prime.normalize(*x))) + .collect(); + got.sort_unstable(); + + assert_eq!(got, expected); + } + + #[test] + fn negacyclic_ntt_mul_matches_schoolbook_single_prime_d64() { + const D: usize = 64; + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + let p = prime.p as i64; + + let a_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 7) + 3) % prime.p); + let b_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 5) + 11) % prime.p); + + let mut school = [0i16; D]; + for (i, &ai) in a_canon.iter().enumerate() { + for (j, &bj) in b_canon.iter().enumerate() { + let prod = (ai as i64 * bj as i64) % p; + let idx = i + j; + if idx < D { + school[idx] = ((school[idx] as i64 + prod) % p) as i16; + } else { + let k = idx - D; + school[k] = ((school[k] as i64 - prod) % p) as i16; + } + } + } + for x in &mut school { + if *x < 0 { + *x = (*x as i64 + p) as i16; + } + } + + let mut a = std::array::from_fn(|i| prime.from_canonical(a_canon[i])); + let mut b = std::array::from_fn(|i| prime.from_canonical(b_canon[i])); + forward_ntt(&mut a, prime, &tw); + forward_ntt(&mut b, prime, &tw); + + let mut c: [_; D] = std::array::from_fn(|i| prime.reduce_range(prime.mul(a[i], b[i]))); + inverse_ntt(&mut c, prime, &tw); + + let got: [i16; D] = std::array::from_fn(|i| prime.to_canonical(prime.normalize(c[i]))); + assert_eq!(got, school); + } + + #[test] + fn negacyclic_ntt_mul_matches_schoolbook_all_q32_primes_d64() { + const D: usize = 64; + let a_canon: [i16; D] = std::array::from_fn(|i| (i as i16 * 7 + 3)); + let b_canon: [i16; D] = std::array::from_fn(|i| (i as i16 * 5 + 11)); + + for (pi, &prime) in Q32_PRIMES.iter().enumerate() { + let tw = NttTwiddles::::compute(prime); + let p = prime.p as i64; + + let a_mod: [i16; D] = + std::array::from_fn(|i| ((a_canon[i] as i64).rem_euclid(p)) as i16); + let b_mod: [i16; D] = + std::array::from_fn(|i| ((b_canon[i] as i64).rem_euclid(p)) as i16); + + let mut school = [0i16; D]; + for (i, &ai) in a_mod.iter().enumerate() { + for (j, &bj) in b_mod.iter().enumerate() { + let prod = (ai as i64 * bj as i64) % p; + let idx = i + j; + if idx < D { + school[idx] = ((school[idx] as i64 + prod) % p) as i16; + } else { + let k = idx - D; + school[k] = ((school[k] as i64 - prod) % p) as i16; + } + } + } + for x in &mut school { + if *x < 0 { + *x = (*x as i64 + p) as i16; + } + } + + let mut a = std::array::from_fn(|i| prime.from_canonical(a_mod[i])); + let mut b = std::array::from_fn(|i| prime.from_canonical(b_mod[i])); + forward_ntt(&mut a, prime, &tw); + forward_ntt(&mut b, prime, &tw); + + let mut c = [MontCoeff::from_raw(0i16); D]; + for i in 0..D { + c[i] = prime.reduce_range(prime.mul(a[i], b[i])); + } + inverse_ntt(&mut c, prime, &tw); + + let got: [i16; D] = std::array::from_fn(|i| prime.to_canonical(prime.normalize(c[i]))); + assert_eq!(got, school, "prime[{pi}] p={} mismatch", prime.p); + } + } + + #[test] + fn cyclotomic_ntt_crt_round_trip_q32() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + + let coeffs: [F; 64] = + std::array::from_fn(|i| F::from_u64(((i as u64 * 17) + 5) % Q32_MODULUS)); + let ring = R::from_coefficients(coeffs); + let ntt = N::from_ring(&ring, &Q32_PRIMES, &twiddles); + let garner = q32_garner(); + let round_trip = ntt.to_ring(&Q32_PRIMES, &twiddles, &garner); + + assert_eq!(ring, round_trip); + } + + #[test] + fn cyclotomic_ntt_reduced_ops_are_stable() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + + let a = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 3) + 1) % Q32_MODULUS) + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 11) + 7) % Q32_MODULUS) + })); + + let ntt_a = N::from_ring(&a, &Q32_PRIMES, &twiddles); + let ntt_b = N::from_ring(&b, &Q32_PRIMES, &twiddles); + + let sum = ntt_a.add_reduced(&ntt_b, &Q32_PRIMES); + let back = sum.sub_reduced(&ntt_b, &Q32_PRIMES); + assert_eq!(back, ntt_a); + + let garner = q32_garner(); + let zero_ntt = ntt_a.add_reduced(&ntt_a.neg_reduced(&Q32_PRIMES), &Q32_PRIMES); + let zero_ring = zero_ntt.to_ring(&Q32_PRIMES, &twiddles, &garner); + assert_eq!(zero_ring, R::zero()); + } + + #[test] + fn backend_path_matches_default_scalar_path() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + let ring = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 13) + 9) % Q32_MODULUS) + })); + + let default_ntt = N::from_ring(&ring, &Q32_PRIMES, &twiddles); + let backend_ntt = + N::from_ring_with_backend::(&ring, &Q32_PRIMES, &twiddles); + assert_eq!(default_ntt, backend_ntt); + + let garner = q32_garner(); + let default_back = default_ntt.to_ring(&Q32_PRIMES, &twiddles, &garner); + let backend_back = + backend_ntt.to_ring_with_backend::(&Q32_PRIMES, &twiddles, &garner); + assert_eq!(default_back, backend_back); + } + + #[test] + fn crt_ntt_mul_matches_schoolbook_q32() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + let garner = q32_garner(); + + let a = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 7) + 3) % Q32_MODULUS) + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 5) + 11) % Q32_MODULUS) + })); + + let schoolbook = a * b; + + let ntt_a = N::from_ring(&a, &Q32_PRIMES, &twiddles); + let ntt_b = N::from_ring(&b, &Q32_PRIMES, &twiddles); + let ntt_prod = ntt_a.pointwise_mul(&ntt_b, &Q32_PRIMES); + let ntt_result: R = ntt_prod.to_ring(&Q32_PRIMES, &twiddles, &garner); + + assert_eq!(schoolbook, ntt_result); + } + + #[test] + fn q128_garner_reconstruct_matches_coeffs_no_ntt() { + type F = Fp128<{ Q128_MODULUS }>; + + let primes = q128_primes(); + let garner = q128_garner(); + + let coeffs: [F; 64] = std::array::from_fn(|i| { + if i < 8 { + F::from_u64((i as u64 * 31) + 7) + } else { + F::zero() + } + }); + + let mut canonical = [[0i32; 64]; Q128_NUM_PRIMES]; + for (k, prime) in primes.iter().enumerate() { + let p = prime.p as u32 as u128; + for (i, c) in coeffs.iter().enumerate() { + canonical[k][i] = (c.to_canonical_u128() % p) as i32; + } + } + + let reconstructed: [F; 64] = + >::reconstruct( + &primes, &canonical, &garner, + ); + + assert_eq!(reconstructed, coeffs); + } + + #[test] + fn q128_prime_ntt_round_trip_per_prime() { + let primes = q128_primes(); + let twiddles: [NttTwiddles; Q128_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + + // Use the same sparse coefficient pattern as q128_ntt_round_trip, but test + // the per-prime NTT+Montgomery machinery in isolation (no Garner/Fp128). + let residues: [u32; 64] = + std::array::from_fn(|i| if i < 8 { (i as u32 * 31) + 7 } else { 0 }); + + for k in 0..Q128_NUM_PRIMES { + let prime = primes[k]; + let mut limb = [MontCoeff::from_raw(0i32); 64]; + for (i, r) in residues.iter().enumerate() { + let reduced = (*r as i64 % (prime.p as i64)) as i32; + limb[i] = >::from_canonical(prime, reduced); + } + + forward_ntt(&mut limb, prime, &twiddles[k]); + inverse_ntt(&mut limb, prime, &twiddles[k]); + + for (i, r) in residues.iter().enumerate() { + let expected = (*r as i64 % (prime.p as i64)) as i32; + let got = >::to_canonical(prime, limb[i]); + assert_eq!(got, expected, "prime idx={k} coeff idx={i}"); + } + } + } + + #[test] + fn q128_ntt_round_trip() { + type F = Fp128<{ Q128_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let primes = q128_primes(); + let twiddles: [NttTwiddles; Q128_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = q128_garner(); + + let coeffs: [F; 64] = std::array::from_fn(|i| { + if i < 8 { + F::from_u64((i as u64 * 31) + 7) + } else { + F::zero() + } + }); + let ring = R::from_coefficients(coeffs); + let ntt = N::from_ring(&ring, &primes, &twiddles); + let round_trip: R = ntt.to_ring(&primes, &twiddles, &garner); + + assert_eq!(ring, round_trip); + } + + #[test] + fn crt_ntt_mul_matches_schoolbook_q128() { + type F = Fp128<{ Q128_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let primes = q128_primes(); + let twiddles: [NttTwiddles; Q128_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = q128_garner(); + + let a = R::from_coefficients(std::array::from_fn(|i| { + if i < 8 { + F::from_u64((i as u64 * 7) + 3) + } else { + F::zero() + } + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + if i < 8 { + F::from_u64((i as u64 * 9) + 11) + } else { + F::zero() + } + })); + + let schoolbook = a * b; + + let ntt_a = N::from_ring(&a, &primes, &twiddles); + let ntt_b = N::from_ring(&b, &primes, &twiddles); + let ntt_prod = ntt_a.pointwise_mul(&ntt_b, &primes); + let ntt_result: R = ntt_prod.to_ring(&primes, &twiddles, &garner); + + assert_eq!(schoolbook, ntt_result); + } + + #[test] + fn q64_ntt_round_trip() { + type F = Fp64<{ Q64_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let primes = q64_primes(); + let twiddles: [NttTwiddles; Q64_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = q64_garner(); + + let coeffs: [F; 64] = + std::array::from_fn(|i| F::from_u64(((i as u64 * 19) + 3) % Q64_MODULUS)); + let ring = R::from_coefficients(coeffs); + let ntt = N::from_ring(&ring, &primes, &twiddles); + let round_trip: R = ntt.to_ring(&primes, &twiddles, &garner); + + assert_eq!(ring, round_trip); + } + + #[test] + fn crt_ntt_mul_matches_schoolbook_q64() { + type F = Fp64<{ Q64_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let primes = q64_primes(); + let twiddles: [NttTwiddles; Q64_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = q64_garner(); + + let a = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 5) + 9) % Q64_MODULUS) + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 17) + 13) % Q64_MODULUS) + })); + + let schoolbook = a * b; + + let ntt_a = N::from_ring(&a, &primes, &twiddles); + let ntt_b = N::from_ring(&b, &primes, &twiddles); + let ntt_prod = ntt_a.pointwise_mul(&ntt_b, &primes); + let ntt_result: R = ntt_prod.to_ring(&primes, &twiddles, &garner); + + assert_eq!(schoolbook, ntt_result); + } + + #[test] + fn field_sampling_respects_modulus() { + type F = Fp32<251>; + let mut rng = StdRng::seed_from_u64(42); + for _ in 0..1024 { + let x = F::sample(&mut rng); + assert!(x.to_canonical_u32() < 251); + } + } + + #[test] + fn pow2_offset_registry_is_consistent() { + fn assert_is_pseudo_mersenne() {} + assert_is_pseudo_mersenne::(); + + for Pow2OffsetPrimeSpec { + bits, + offset, + modulus, + .. + } in POW2_OFFSET_PRIMES + { + assert!((offset as u128) <= POW2_OFFSET_MAX); + assert_eq!(POW2_OFFSET_TABLE[bits as usize], offset as i16); + assert_eq!( + Some(modulus), + pseudo_mersenne_modulus(bits, offset as u128), + "2^k-offset modulus mismatch for k={bits}, offset={offset}" + ); + assert_eq!(modulus % 8, 5); + } + + let x = Pow2Offset128Field::from_u64(1234567); + let inv = x.inv().unwrap(); + assert_eq!(x * inv, Pow2Offset128Field::one()); + } + + #[test] + fn cyclotomic_sigma_is_ring_automorphism() { + type F = Fp32<251>; + type R = CyclotomicRing; + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((3 * i + 1) as u64))); + let b = R::from_coefficients(std::array::from_fn(|i| F::from_u64((5 * i + 2) as u64))); + + let k1 = 3usize; + let k2 = 5usize; + let two_d = 16usize; + + assert_eq!(a.sigma(1), a); + assert_eq!(a.sigma_m1().sigma_m1(), a); + assert_eq!(a.sigma(k1).sigma(k2), a.sigma((k1 * k2) % two_d)); + assert_eq!((a * b).sigma(k1), a.sigma(k1) * b.sigma(k1)); + } + + #[test] + fn cyclotomic_balanced_pow2_decompose_recompose_round_trip() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + + let ring = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 73) + 17) % Q32_MODULUS) + })); + + // Q32 balanced base-16: 9 levels absorb the carry-out near q/2. + let digits = ring.balanced_decompose_pow2(9, 4); + let round_trip = R::gadget_recompose_pow2(&digits, 4); + assert_eq!(round_trip, ring); + } + + #[test] + fn sparse_pm1_challenge_has_expected_weight() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let mut rng = StdRng::seed_from_u64(123); + let challenge = R::sample_sparse_pm1(&mut rng, 11); + assert_eq!(challenge.hamming_weight(), 11); + + for c in challenge.coefficients() { + let x = c.to_canonical_u32(); + if x != 0 { + assert!(x == 1 || x == 250, "nonzero coefficient must be +/-1"); + } + } + } + + #[test] + fn negacyclic_shift_equals_mul_by_monomial() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((3 * i + 1) as u64))); + + for k in 0..8 { + let mut monomial_coeffs = [F::zero(); 8]; + monomial_coeffs[k] = F::one(); + let monomial = R::from_coefficients(monomial_coeffs); + assert_eq!( + a.negacyclic_shift(k), + a * monomial, + "negacyclic_shift({k}) != mul by X^{k}" + ); + } + + assert_eq!(a.negacyclic_shift(0), a); + assert_eq!( + a.negacyclic_shift(8), + a, + "shift by D should be identity mod D" + ); + } + + #[test] + fn negacyclic_shift_degree_64() { + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((7 * i + 3) as u64))); + let x = R::x(); + let mut x_pow = R::one(); + for k in 0..64 { + assert_eq!( + a.negacyclic_shift(k), + a * x_pow, + "negacyclic_shift({k}) mismatch at D=64" + ); + x_pow *= x; + } + } + + #[test] + fn mul_by_monomial_sum_matches_ring_mul() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((5 * i + 2) as u64))); + + // Sum of X^1 + X^3 + X^5 + let positions = [1, 3, 5]; + let mut sparse = [F::zero(); 8]; + for &p in &positions { + sparse[p] = F::one(); + } + let sparse_ring = R::from_coefficients(sparse); + + assert_eq!( + a.mul_by_monomial_sum(&positions), + a * sparse_ring, + "mul_by_monomial_sum should equal ring mul by sparse element" + ); + } + + #[test] + fn mul_by_monomial_sum_single_position_equals_shift() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((i + 1) as u64))); + for k in 0..8 { + assert_eq!( + a.mul_by_monomial_sum(&[k]), + a.negacyclic_shift(k), + "single-position monomial_sum should equal negacyclic_shift" + ); + } + } + + #[test] + fn mul_by_monomial_sum_empty_is_zero() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((i + 1) as u64))); + assert_eq!(a.mul_by_monomial_sum(&[]), R::zero()); + } + + #[test] + fn mul_by_sparse_matches_schoolbook() { + use hachi_pcs::algebra::SparseChallenge; + + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((3 * i + 7) as u64))); + + let challenge = SparseChallenge { + positions: vec![2, 17, 41], + coeffs: vec![1, -1, 1], + }; + let dense: R = challenge.to_dense().unwrap(); + + let via_sparse = a.mul_by_sparse(&challenge); + let via_schoolbook = a * dense; + + assert_eq!( + via_sparse, via_schoolbook, + "mul_by_sparse must equal schoolbook multiplication" + ); + } + + #[test] + fn mul_by_sparse_with_all_negative_coeffs() { + use hachi_pcs::algebra::SparseChallenge; + + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((i + 1) as u64))); + + let challenge = SparseChallenge { + positions: vec![0, 5, 63], + coeffs: vec![-1, -1, -1], + }; + let dense: R = challenge.to_dense().unwrap(); + + assert_eq!(a.mul_by_sparse(&challenge), a * dense); + } + + #[test] + fn is_zero_detects_zero_and_nonzero() { + type F = Fp32<251>; + type R = CyclotomicRing; + + assert!(R::zero().is_zero()); + assert!(!R::one().is_zero()); + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64(i as u64))); + assert!(!a.is_zero()); + } + + #[test] + fn kron_scalars_matches_kron_row_constant_rings() { + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + let scalars_a: Vec = (0..4).map(|i| F::from_u64(i * 3 + 1)).collect(); + let scalars_b: Vec = (0..3).map(|i| F::from_u64(i * 7 + 2)).collect(); + + let rings_a: Vec = scalars_a + .iter() + .map(|&s| { + let mut c = [F::zero(); 16]; + c[0] = s; + R::from_coefficients(c) + }) + .collect(); + let rings_b: Vec = scalars_b + .iter() + .map(|&s| { + let mut c = [F::zero(); 16]; + c[0] = s; + R::from_coefficients(c) + }) + .collect(); + + let via_ring: Vec = rings_a + .iter() + .flat_map(|l| rings_b.iter().map(move |r| *l * *r)) + .collect(); + + let via_scalar: Vec = scalars_a + .iter() + .flat_map(|&l| { + scalars_b.iter().map(move |&r| { + let mut c = [F::zero(); 16]; + c[0] = l * r; + R::from_coefficients(c) + }) + }) + .collect(); + + assert_eq!(via_ring, via_scalar); + } +} diff --git a/tests/commitment_contract.rs b/tests/commitment_contract.rs new file mode 100644 index 00000000..4aaf3822 --- /dev/null +++ b/tests/commitment_contract.rs @@ -0,0 +1,236 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::commitment::{DummyProof, HachiCommitment}; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{ + AppendToTranscript, Blake2bTranscript, CommitmentScheme, StreamingCommitmentScheme, Transcript, +}; +use hachi_pcs::{CanonicalField, FieldCore, FromSmallInt, HachiError, Polynomial}; + +type F = Fp64<4294967197>; + +#[derive(Clone)] +struct SimplePoly { + coeffs: Vec, +} + +impl Polynomial for SimplePoly { + fn num_vars(&self) -> usize { + self.coeffs.len().saturating_sub(1) + } + + fn evaluate(&self, point: &[F]) -> F { + assert_eq!(point.len(), self.num_vars()); + let mut acc = self.coeffs[0]; + for (i, r_i) in point.iter().enumerate() { + acc = acc + self.coeffs[i + 1] * *r_i; + } + acc + } + + fn coeffs(&self) -> Vec { + self.coeffs.clone() + } +} + +#[derive(Clone)] +struct DummySetup { + _max_num_vars: usize, +} + +#[derive(Clone)] +struct DummyScheme; + +impl CommitmentScheme for DummyScheme { + type ProverSetup = DummySetup; + type VerifierSetup = DummySetup; + type Commitment = HachiCommitment; + type Proof = DummyProof; + type OpeningProofHint = HachiCommitment; + + fn setup_prover(max_num_vars: usize) -> Self::ProverSetup { + DummySetup { + _max_num_vars: max_num_vars, + } + } + + fn setup_verifier(setup: &Self::ProverSetup) -> Self::VerifierSetup { + setup.clone() + } + + fn commit>( + poly: &P, + _setup: &Self::ProverSetup, + ) -> Result<(Self::Commitment, Self::OpeningProofHint), HachiError> { + let zero = vec![F::zero(); poly.num_vars()]; + let c = HachiCommitment(poly.evaluate(&zero).to_canonical_u128()); + Ok((c, c)) + } + + fn prove, P: Polynomial>( + setup: &Self::ProverSetup, + poly: &P, + opening_point: &[F], + hint: Option, + transcript: &mut T, + _commitment: &Self::Commitment, + ) -> Result { + if opening_point.len() != poly.num_vars() { + return Err(HachiError::InvalidPointDimension { + expected: poly.num_vars(), + actual: opening_point.len(), + }); + } + + let absorb_commitment = if let Some(h) = hint { + h + } else { + Self::commit(poly, setup)?.0 + }; + absorb_commitment.append_to_transcript(labels::ABSORB_COMMITMENT, transcript); + + let q = transcript.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + let opening = poly.evaluate(opening_point); + Ok(DummyProof( + opening.to_canonical_u128() ^ q.to_canonical_u128(), + )) + } + + fn verify>( + proof: &Self::Proof, + _setup: &Self::VerifierSetup, + transcript: &mut T, + _opening_point: &[F], + opening: &F, + commitment: &Self::Commitment, + ) -> Result<(), HachiError> { + commitment.append_to_transcript(labels::ABSORB_COMMITMENT, transcript); + let q = transcript.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + let expected = opening.to_canonical_u128() ^ q.to_canonical_u128(); + if proof.0 == expected { + Ok(()) + } else { + Err(HachiError::InvalidProof) + } + } + + fn combine_commitments(commitments: &[Self::Commitment], coeffs: &[F]) -> Self::Commitment { + let acc = commitments + .iter() + .zip(coeffs.iter()) + .fold(0u128, |sum, (c, coeff)| { + sum.wrapping_add(c.0.wrapping_mul(coeff.to_canonical_u128())) + }); + HachiCommitment(acc) + } + + fn combine_hints(hints: Vec, coeffs: &[F]) -> Self::OpeningProofHint { + let acc = hints + .iter() + .zip(coeffs.iter()) + .fold(0u128, |sum, (h, coeff)| { + sum.wrapping_add(h.0.wrapping_mul(coeff.to_canonical_u128())) + }); + HachiCommitment(acc) + } + + fn protocol_name() -> &'static [u8] { + b"HachiDummy" + } +} + +impl StreamingCommitmentScheme for DummyScheme { + type ChunkState = HachiCommitment; + + fn process_chunk(_setup: &Self::ProverSetup, chunk: &[F]) -> Self::ChunkState { + let sum = chunk + .iter() + .fold(0u128, |acc, x| acc.wrapping_add(x.to_canonical_u128())); + HachiCommitment(sum) + } + + fn process_chunk_onehot( + _setup: &Self::ProverSetup, + onehot_k: usize, + chunk: &[Option], + ) -> Self::ChunkState { + let sum = chunk.iter().fold(0u128, |acc, x| { + let v = x.unwrap_or(0) as u128; + acc.wrapping_add(v) + }); + HachiCommitment(sum.wrapping_add(onehot_k as u128)) + } + + fn aggregate_chunks( + _setup: &Self::ProverSetup, + _onehot_k: Option, + chunks: &[Self::ChunkState], + ) -> (Self::Commitment, Self::OpeningProofHint) { + let sum = chunks.iter().fold(0u128, |acc, c| acc.wrapping_add(c.0)); + let c = HachiCommitment(sum); + (c, c) + } +} + +#[test] +fn commitment_scheme_round_trip() { + let poly = SimplePoly { + coeffs: vec![F::from_u64(3), F::from_u64(5), F::from_u64(7)], + }; + let opening_point = [F::from_u64(11), F::from_u64(13)]; + let opening = poly.evaluate(&opening_point); + + let psetup = DummyScheme::setup_prover(poly.num_vars()); + let vsetup = DummyScheme::setup_verifier(&psetup); + + let (commitment, hint) = DummyScheme::commit(&poly, &psetup).unwrap(); + + let mut prover_t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let proof = DummyScheme::prove( + &psetup, + &poly, + &opening_point, + Some(hint), + &mut prover_t, + &commitment, + ) + .unwrap(); + + let mut verifier_t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + DummyScheme::verify( + &proof, + &vsetup, + &mut verifier_t, + &opening_point, + &opening, + &commitment, + ) + .unwrap(); +} + +#[test] +fn combine_commitments_and_hints_are_consistent() { + let c1 = HachiCommitment(10); + let c2 = HachiCommitment(20); + let coeffs = [F::from_u64(3), F::from_u64(7)]; + + let combined_c = DummyScheme::combine_commitments(&[c1, c2], &coeffs); + let combined_h = DummyScheme::combine_hints(vec![c1, c2], &coeffs); + + let expected = 10u128 + .wrapping_mul(coeffs[0].to_canonical_u128()) + .wrapping_add(20u128.wrapping_mul(coeffs[1].to_canonical_u128())); + assert_eq!(combined_c.0, expected); + assert_eq!(combined_h.0, expected); +} + +#[test] +fn streaming_chunk_path_aggregates() { + let setup = DummyScheme::setup_prover(4); + let c1 = DummyScheme::process_chunk(&setup, &[F::from_u64(1), F::from_u64(2)]); + let c2 = DummyScheme::process_chunk_onehot(&setup, 8, &[Some(3), None, Some(5)]); + + let (commitment, hint) = DummyScheme::aggregate_chunks(&setup, Some(8), &[c1, c2]); + assert_eq!(commitment, hint); +} diff --git a/tests/hachi_sumcheck.rs b/tests/hachi_sumcheck.rs new file mode 100644 index 00000000..ce33b91b --- /dev/null +++ b/tests/hachi_sumcheck.rs @@ -0,0 +1,225 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::ring::CyclotomicRing; +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::sumcheck::eq_poly::EqPolynomial; +use hachi_pcs::protocol::sumcheck::norm_sumcheck::{NormSumcheckProver, NormSumcheckVerifier}; +use hachi_pcs::protocol::sumcheck::relation_sumcheck::{ + RelationSumcheckProver, RelationSumcheckVerifier, +}; +use hachi_pcs::protocol::sumcheck::{multilinear_eval, range_check_eval}; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{prove_sumcheck, verify_sumcheck, Blake2bTranscript, Transcript}; +use hachi_pcs::{FieldCore, FieldSampling, FromSmallInt}; +use rand::rngs::StdRng; +use rand::SeedableRng; +use std::time::Instant; + +type F = Fp64<4294967197>; + +fn run_f0_e2e(num_u: usize, num_l: usize, b: usize) { + let num_vars = num_u + num_l; + let n = 1usize << num_vars; + let mut rng = StdRng::seed_from_u64(0xF0); + + let w_evals: Vec = (0..n).map(|i| F::from_u64((i % b) as u64)).collect(); + let tau0: Vec = (0..num_vars).map(|_| F::sample(&mut rng)).collect(); + + let t0 = Instant::now(); + let mut prover = NormSumcheckProver::new(&tau0, w_evals.clone(), b); + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let (proof, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut pt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + let prove_time = t0.elapsed(); + + // Sanity: prover's final claim matches oracle evaluation. + let oracle = EqPolynomial::mle(&tau0, &prover_challenges) + * range_check_eval(multilinear_eval(&w_evals, &prover_challenges).unwrap(), b); + assert_eq!(final_claim, oracle, "prover final claim != oracle eval"); + + let t1 = Instant::now(); + let verifier = NormSumcheckVerifier::new(tau0, w_evals, b); + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let verifier_challenges = verify_sumcheck::(&proof, &verifier, &mut vt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + let verify_time = t1.elapsed(); + + assert_eq!(prover_challenges, verifier_challenges); + + eprintln!( + "[F0 e2e] num_u={num_u} num_l={num_l} b={b} n=2^{num_vars}={n} \ + prove={prove_time:.2?} verify={verify_time:.2?} \ + rounds={} degree={}", + proof.round_polys.len(), + 2 * b, + ); +} + +#[test] +fn f0_sumcheck_e2e_small() { + run_f0_e2e(3, 2, 2); +} + +#[test] +fn f0_sumcheck_e2e() { + run_f0_e2e(4, 3, 2); +} + +#[test] +fn f0_sumcheck_e2e_larger_b() { + run_f0_e2e(3, 3, 3); +} + +fn run_f_alpha_e2e(num_u: usize, num_i: usize) { + let num_l = D.trailing_zeros() as usize; + let num_vars = num_u + num_l; + let n = 1usize << num_vars; + let mut rng = StdRng::seed_from_u64(0xFA); + + let w_evals: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let alpha_evals_y: Vec = (0..D).map(|_| F::sample(&mut rng)).collect(); + let m_alpha_evals: Vec = (0..(1usize << (num_i + num_u))) + .map(|_| F::sample(&mut rng)) + .collect(); + let tau1: Vec = (0..num_i).map(|_| F::sample(&mut rng)).collect(); + + // Compute m(x) = Σ_i ẽq(τ₁, i) · M̃_α(i, x) + let eq_tau1 = EqPolynomial::evals(&tau1); + let num_x = 1usize << num_u; + let m_evals_x: Vec = (0..num_x) + .map(|x_idx| { + (0..(1usize << num_i)) + .map(|i_idx| eq_tau1[i_idx] * m_alpha_evals[i_idx * num_x + x_idx]) + .fold(F::zero(), |a, v| a + v) + }) + .collect(); + + // Compute y_a[i] = Σ_x M̃_α(i,x) · w_α(x), where w_α(x) = Σ_y w(x,y) · α̃(y) + let num_y = D; + let num_rows = 1usize << num_i; + let w_alpha: Vec = (0..num_x) + .map(|x| { + (0..num_y) + .map(|y| w_evals[x + y * num_x] * alpha_evals_y[y]) + .fold(F::zero(), |a, v| a + v) + }) + .collect(); + let y_a: Vec = (0..num_rows) + .map(|i| { + (0..num_x) + .map(|x| m_alpha_evals[i * num_x + x] * w_alpha[x]) + .fold(F::zero(), |a, v| a + v) + }) + .collect(); + + // Embed y_a values as constant ring elements for the verifier. + let v_rings: Vec> = y_a + .iter() + .map(|&val| { + let mut coeffs = [F::zero(); D]; + coeffs[0] = val; + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let u_rings: Vec> = vec![]; + let u_eval_ring = CyclotomicRing::::zero(); + let ring_alpha = F::one(); + + let t0 = Instant::now(); + let mut prover = + RelationSumcheckProver::new(w_evals.clone(), &alpha_evals_y, &m_evals_x, num_u, num_l); + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let (proof, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut pt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + let prove_time = t0.elapsed(); + + // Sanity: prover's final claim matches oracle evaluation. + let (x_ch, y_ch) = prover_challenges.split_at(num_u); + let oracle = multilinear_eval(&w_evals, &prover_challenges).unwrap() + * multilinear_eval(&alpha_evals_y, y_ch).unwrap() + * multilinear_eval(&m_evals_x, x_ch).unwrap(); + assert_eq!(final_claim, oracle, "prover final claim != oracle eval"); + + let t1 = Instant::now(); + let verifier = RelationSumcheckVerifier::::new( + w_evals, + alpha_evals_y, + m_evals_x, + tau1, + v_rings, + u_rings, + u_eval_ring, + ring_alpha, + num_u, + num_l, + ); + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let verifier_challenges = verify_sumcheck::(&proof, &verifier, &mut vt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + let verify_time = t1.elapsed(); + + assert_eq!(prover_challenges, verifier_challenges); + + eprintln!( + "[Fα e2e] num_u={num_u} num_l={num_l} num_i={num_i} n=2^{num_vars}={n} \ + prove={prove_time:.2?} verify={verify_time:.2?} \ + rounds={} degree=2", + proof.round_polys.len(), + ); +} + +#[test] +fn f_alpha_sumcheck_e2e_small() { + run_f_alpha_e2e::<4>(3, 2); +} + +#[test] +fn f_alpha_sumcheck_e2e() { + run_f_alpha_e2e::<8>(4, 3); +} + +#[test] +fn f_alpha_sumcheck_e2e_asymmetric() { + run_f_alpha_e2e::<4>(5, 4); +} + +#[test] +fn from_evals_matches_direct_polynomial() { + use hachi_pcs::protocol::UniPoly; + + // Verify that interpolation at integer points reproduces the polynomial. + let mut rng = StdRng::seed_from_u64(0xEE); + + for degree in 0..6usize { + let coeffs: Vec = (0..=degree).map(|_| F::sample(&mut rng)).collect(); + let poly = UniPoly::from_coeffs(coeffs); + + let evals: Vec = (0..=degree) + .map(|t| poly.evaluate(&F::from_u64(t as u64))) + .collect(); + let reconstructed = UniPoly::from_evals(&evals); + + for x_u64 in [0u64, 1, 2, 3, 7, 13] { + let x = F::from_u64(x_u64); + assert_eq!( + poly.evaluate(&x), + reconstructed.evaluate(&x), + "degree {degree}, x={x_u64}" + ); + } + } +} diff --git a/tests/label_schedule.rs b/tests/label_schedule.rs new file mode 100644 index 00000000..c943b83b --- /dev/null +++ b/tests/label_schedule.rs @@ -0,0 +1,63 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{Blake2bTranscript, Transcript}; + +type F = Fp64<4294967197>; + +#[test] +fn label_namespace_does_not_include_dory_literals() { + let banned = ["vmv_", "beta", "alpha", "gamma", "final_e", "dory"]; + for label in labels::all_labels() { + let text = std::str::from_utf8(label).expect("labels must be valid utf8 literals"); + for needle in &banned { + assert!( + !text.contains(needle), + "label `{text}` must not contain banned token `{needle}`" + ); + } + } +} + +fn run_hachi_schedule>(transcript: &mut T) -> (F, F, F) { + transcript.append_bytes(labels::ABSORB_COMMITMENT, b"C"); + transcript.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"O"); + let c_linear_relation = transcript.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + + transcript.append_bytes(labels::ABSORB_RING_SWITCH_MESSAGE, b"RS"); + let c_ring_switch = transcript.challenge_scalar(labels::CHALLENGE_RING_SWITCH); + + transcript.append_bytes(labels::ABSORB_SUMCHECK_ROUND, b"SC1"); + let c_sumcheck = transcript.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + transcript.append_bytes(labels::ABSORB_STOP_CONDITION, b"STOP"); + let _ = transcript.challenge_scalar(labels::CHALLENGE_STOP_CONDITION); + + (c_linear_relation, c_ring_switch, c_sumcheck) +} + +#[test] +fn schedule_is_replayable_with_hachi_labels() { + let mut prover = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut verifier = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + assert_eq!( + run_hachi_schedule(&mut prover), + run_hachi_schedule(&mut verifier) + ); +} + +#[test] +fn schedule_detects_reordered_round_messages() { + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"C"); + t1.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"O"); + let a = t1.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"O"); + t2.append_bytes(labels::ABSORB_COMMITMENT, b"C"); + let b = t2.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + + assert_ne!(a, b); +} diff --git a/tests/onehot_commitment.rs b/tests/onehot_commitment.rs new file mode 100644 index 00000000..e3f6c109 --- /dev/null +++ b/tests/onehot_commitment.rs @@ -0,0 +1,158 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::CyclotomicRing; +use hachi_pcs::protocol::commitment::{HachiCommitmentCore, RingCommitmentScheme}; +use hachi_pcs::test_utils::*; +use hachi_pcs::{FieldCore, FromSmallInt}; + +type Core = HachiCommitmentCore; + +fn psetup() -> >::ProverSetup { + >::setup(16) + .unwrap() + .0 +} + +/// Compare the optimized one-hot path against the default dense path. +/// +/// The default implementation materializes the full vector and calls +/// `commit_coeffs`. The optimized impl uses sparse inner Ajtai. +/// Both must produce identical (commitment, s_all, t_hat_all). +fn assert_onehot_matches_dense(onehot_k: usize, indices: &[usize]) { + let opt_indices: Vec> = indices.iter().map(|&i| Some(i)).collect(); + let setup = psetup(); + + // Optimized sparse path. + let w_sparse = >::commit_onehot( + onehot_k, + &opt_indices, + &setup, + ) + .unwrap(); + + // Reference: materialize the full one-hot vector, pack into ring elements, + // and commit via the dense path. + let total_field = indices.len() * onehot_k; + let total_ring = total_field / D; + let mut field_elems = vec![F::zero(); total_field]; + for (c, &idx) in indices.iter().enumerate() { + field_elems[c * onehot_k + idx] = F::from_u64(1); + } + let ring_coeffs: Vec> = (0..total_ring) + .map(|r| { + let coeffs: [F; D] = std::array::from_fn(|i| field_elems[r * D + i]); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let w_dense = + >::commit_coeffs(&ring_coeffs, &setup) + .unwrap(); + + assert_eq!( + w_sparse.commitment, w_dense.commitment, + "commitments must match" + ); + assert_eq!( + w_sparse.t_hat, w_dense.t_hat, + "t_hat_all (decomposed inner output) must match" + ); +} + +#[test] +fn onehot_k_gt_d_basic() { + // K=128, D=64 => K/D=2, T=2 => T*K=256 => 4 ring elements + assert_onehot_matches_dense(128, &[0, 64]); +} + +#[test] +fn onehot_k_gt_d_various_positions() { + assert_onehot_matches_dense(128, &[127, 0]); + assert_onehot_matches_dense(128, &[63, 65]); + assert_onehot_matches_dense(128, &[32, 96]); +} + +#[test] +fn onehot_k_much_gt_d() { + // K=256, D=64 => K/D=4, T=1 => T*K=256 => 4 ring elements + assert_onehot_matches_dense(256, &[0]); + assert_onehot_matches_dense(256, &[63]); + assert_onehot_matches_dense(256, &[64]); + assert_onehot_matches_dense(256, &[255]); + assert_onehot_matches_dense(256, &[100]); +} + +#[test] +fn onehot_k_eq_d_basic() { + // K=64=D, T=4 => 4 ring elements, each is a monomial X^{idx}. + assert_onehot_matches_dense(64, &[0, 0, 0, 0]); +} + +#[test] +fn onehot_k_eq_d_varied() { + assert_onehot_matches_dense(64, &[0, 31, 32, 63]); + assert_onehot_matches_dense(64, &[1, 2, 3, 4]); + assert_onehot_matches_dense(64, &[63, 63, 63, 63]); +} + +#[test] +fn onehot_k_lt_d_basic() { + // K=16, D=64 => D/K=4, T=16 => T*K=256 => 4 ring elements. + // Each ring element spans 4 chunks, so has 4 nonzero coefficients. + let indices: Vec = (0..16).map(|i| i % 16).collect(); + assert_onehot_matches_dense(16, &indices); +} + +#[test] +fn onehot_k_lt_d_all_zeros() { + let indices = vec![0; 16]; + assert_onehot_matches_dense(16, &indices); +} + +#[test] +fn onehot_k_lt_d_all_max() { + let indices = vec![15; 16]; + assert_onehot_matches_dense(16, &indices); +} + +#[test] +fn onehot_k_lt_d_mixed() { + let indices = vec![0, 15, 7, 3, 12, 1, 8, 14, 5, 10, 2, 9, 6, 11, 4, 13]; + assert_onehot_matches_dense(16, &indices); +} + +#[test] +fn onehot_k_lt_d_ratio_2() { + // K=32, D=64 => D/K=2, T=8 => T*K=256 => 4 ring elements. + let indices = vec![0, 31, 16, 8, 24, 4, 12, 20]; + assert_onehot_matches_dense(32, &indices); +} + +#[test] +fn onehot_rejects_non_divisible_k_and_d() { + let setup = psetup(); + let result = + >::commit_onehot(17, &[Some(0); 4], &setup); + assert!(result.is_err()); +} + +#[test] +fn onehot_rejects_out_of_range_index() { + let setup = psetup(); + let result = >::commit_onehot( + 64, + &[Some(0), Some(64), Some(0), Some(0)], + &setup, + ); + assert!(result.is_err()); +} + +#[test] +fn onehot_rejects_wrong_total_size() { + let setup = psetup(); + let result = >::commit_onehot( + 64, + &[Some(0), Some(0), Some(0)], + &setup, + ); + assert!(result.is_err()); +} diff --git a/tests/primality.rs b/tests/primality.rs new file mode 100644 index 00000000..0d7a3a8f --- /dev/null +++ b/tests/primality.rs @@ -0,0 +1,124 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::{pseudo_mersenne_modulus, Pow2OffsetPrimeSpec, POW2_OFFSET_PRIMES}; + +// Strong probable-prime test using multiple fixed bases. +// This is not a formal primality certificate, but is sufficient as a +// practical regression guard for the current Pow2Offset profiles. +fn is_probable_prime_miller_rabin(n: u128) -> bool { + if n < 2 { + return false; + } + if n % 2 == 0 { + return n == 2; + } + + const SMALL_PRIMES: [u128; 11] = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]; + for p in SMALL_PRIMES { + if n == p { + return true; + } + if n % p == 0 { + return false; + } + } + + let (d, s) = decompose_pow2(n - 1); + const BASES: [u128; 24] = [ + 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, + ]; + + 'outer: for a in BASES { + if a >= n { + continue; + } + let mut x = pow_mod(a, d, n); + if x == 1 || x == n - 1 { + continue; + } + for _ in 1..s { + x = mul_mod(x, x, n); + if x == n - 1 { + continue 'outer; + } + } + return false; + } + + true +} + +fn decompose_pow2(mut d: u128) -> (u128, u32) { + let mut s = 0u32; + while d % 2 == 0 { + d >>= 1; + s += 1; + } + (d, s) +} + +fn pow_mod(mut base: u128, mut exp: u128, modulus: u128) -> u128 { + let mut result = 1u128; + base %= modulus; + while exp > 0 { + if (exp & 1) == 1 { + result = mul_mod(result, base, modulus); + } + base = mul_mod(base, base, modulus); + exp >>= 1; + } + result +} + +fn mul_mod(mut a: u128, mut b: u128, modulus: u128) -> u128 { + let mut result = 0u128; + a %= modulus; + b %= modulus; + while b > 0 { + if (b & 1) == 1 { + result = add_mod(result, a, modulus); + } + a = add_mod(a, a, modulus); + b >>= 1; + } + result +} + +fn add_mod(a: u128, b: u128, modulus: u128) -> u128 { + if a >= modulus - b { + a - (modulus - b) + } else { + a + b + } +} + +#[test] +fn pow2_offset_profiles_are_probable_primes() { + for Pow2OffsetPrimeSpec { + bits, + offset, + modulus, + } in POW2_OFFSET_PRIMES + { + assert_eq!( + Some(modulus), + pseudo_mersenne_modulus(bits, offset as u128), + "profile formula mismatch for bits={bits}, offset={offset}" + ); + assert!( + is_probable_prime_miller_rabin(modulus), + "Miller-Rabin rejected bits={bits}, offset={offset}, q={modulus}" + ); + } +} + +#[test] +fn miller_rabin_rejects_known_composites() { + let composites: [u128; 9] = [4, 9, 15, 21, 341, 561, 645, 1105, 1729]; + for n in composites { + assert!( + !is_probable_prime_miller_rabin(n), + "composite unexpectedly accepted: {n}" + ); + } +} diff --git a/tests/ring_commitment_core.rs b/tests/ring_commitment_core.rs new file mode 100644 index 00000000..4b09501e --- /dev/null +++ b/tests/ring_commitment_core.rs @@ -0,0 +1,187 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::CyclotomicRing; +use hachi_pcs::error::HachiError; +use hachi_pcs::protocol::commitment::{ + CommitmentConfig, HachiCommitmentCore, HachiCommitmentLayout, RingCommitmentScheme, + SmallTestCommitmentConfig, +}; +use hachi_pcs::test_utils::*; + +#[derive(Clone)] +struct BadDegreeConfig; + +impl CommitmentConfig for BadDegreeConfig { + const D: usize = 32; + const N_A: usize = 8; + const N_B: usize = 4; + const N_D: usize = 4; + const LOG_BASIS: u32 = 4; + const DELTA: usize = 8; + const TAU: usize = 4; + const CHALLENGE_WEIGHT: usize = 3; + + fn commitment_layout(_max_num_vars: usize) -> Result { + HachiCommitmentLayout::new::(4, 2) + } +} + +#[derive(Clone)] +struct BadDigitBudgetConfig; + +impl CommitmentConfig for BadDigitBudgetConfig { + const D: usize = 64; + const N_A: usize = 8; + const N_B: usize = 4; + const N_D: usize = 4; + const LOG_BASIS: u32 = 32; + const DELTA: usize = 5; // 160 > 128 + const TAU: usize = 4; + const CHALLENGE_WEIGHT: usize = 3; + + fn commitment_layout(_max_num_vars: usize) -> Result { + HachiCommitmentLayout::new::(4, 2) + } +} + +#[test] +fn setup_shape_is_consistent() { + let (p1, v1) = + >::setup(16).unwrap(); + let (p2, v2) = + >::setup(16).unwrap(); + + assert_eq!(p1.expanded.seed.max_num_vars, 16); + assert_eq!(v1.expanded.seed.max_num_vars, 16); + assert_eq!(p2.expanded.seed.max_num_vars, 16); + assert_eq!(v2.expanded.seed.max_num_vars, 16); + assert_eq!(p1.expanded.A.len(), TinyConfig::N_A); + assert_eq!( + p1.expanded.A[0].len(), + hachi_pcs::test_utils::BLOCK_LEN * TinyConfig::DELTA + ); + assert_eq!(p1.expanded.B.len(), TinyConfig::N_B); + assert_eq!( + p1.expanded.B[0].len(), + TinyConfig::N_A * TinyConfig::DELTA * hachi_pcs::test_utils::NUM_BLOCKS + ); +} + +#[test] +fn commit_is_deterministic_and_shape_consistent() { + let (psetup, _) = + >::setup(16).unwrap(); + let blocks = sample_blocks(); + + let w1 = >::commit_ring_blocks( + &blocks, &psetup, + ) + .unwrap(); + let w2 = >::commit_ring_blocks( + &blocks, &psetup, + ) + .unwrap(); + + assert_eq!(w1.commitment, w2.commitment); + assert_eq!(w1.t_hat, w2.t_hat); + + let num_blocks = hachi_pcs::test_utils::NUM_BLOCKS; + assert_eq!(w1.commitment.u.len(), TinyConfig::N_B); + assert_eq!(w1.t_hat.len(), num_blocks); + assert!(w1 + .t_hat + .iter() + .all(|t| t.len() == TinyConfig::N_A * TinyConfig::DELTA)); +} + +#[test] +fn commit_ring_coeffs_matches_block_commitment() { + let (psetup, _) = + >::setup(16).unwrap(); + let blocks = sample_blocks(); + + let wb = >::commit_ring_blocks( + &blocks, &psetup, + ) + .unwrap(); + + // Sequential layout: block 0 elements, then block 1 elements, etc. + let f_coeffs: Vec<_> = blocks + .iter() + .flat_map(|block| block.iter().copied()) + .collect(); + + let wc = >::commit_coeffs( + &f_coeffs, &psetup, + ) + .unwrap(); + + assert_eq!(wb.commitment, wc.commitment); + assert_eq!(wb.t_hat, wc.t_hat); +} + +#[test] +fn opening_satisfies_inner_and_outer_equations() { + let (psetup, _) = + >::setup(16).unwrap(); + let blocks = sample_blocks(); + let w = >::commit_ring_blocks( + &blocks, &psetup, + ) + .unwrap(); + + for (i, block) in blocks.iter().enumerate() { + let s_i = hachi_pcs::protocol::commitment::utils::linear::decompose_block( + block, + TinyConfig::DELTA, + TinyConfig::LOG_BASIS, + ); + let lhs = mat_vec_mul(&psetup.expanded.A, &s_i); + let rhs: Vec> = (0..TinyConfig::N_A) + .map(|j| { + let start = j * TinyConfig::DELTA; + let end = start + TinyConfig::DELTA; + CyclotomicRing::gadget_recompose_pow2( + &w.t_hat[i][start..end], + TinyConfig::LOG_BASIS, + ) + }) + .collect(); + assert_eq!(lhs, rhs); + } + + let t_hat_flat: Vec> = + w.t_hat.iter().flat_map(|x| x.iter().copied()).collect(); + let outer = mat_vec_mul(&psetup.expanded.B, &t_hat_flat); + assert_eq!(outer, w.commitment.u); +} + +#[test] +fn small_test_config_has_expected_shape() { + assert_eq!(SmallTestCommitmentConfig::D, 16); + let layout = SmallTestCommitmentConfig::commitment_layout(8).unwrap(); + assert_eq!(layout.block_len, 16); + assert_eq!(layout.num_blocks, 4); + let delta = SmallTestCommitmentConfig::DELTA; + assert!(delta > 0); +} + +#[test] +fn setup_rejects_mismatched_degree() { + let err = >::setup(16) + .unwrap_err(); + match err { + HachiError::InvalidSetup(msg) => assert!(msg.contains("mismatches")), + other => panic!("unexpected error: {other:?}"), + } +} + +#[test] +fn setup_rejects_invalid_digit_budget() { + let err = >::setup(16) + .unwrap_err(); + match err { + HachiError::InvalidSetup(msg) => assert!(msg.contains("DELTA * LOG_BASIS")), + other => panic!("unexpected error: {other:?}"), + } +} diff --git a/tests/sparse_challenge.rs b/tests/sparse_challenge.rs new file mode 100644 index 00000000..baba7bd3 --- /dev/null +++ b/tests/sparse_challenge.rs @@ -0,0 +1,98 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::fields::LiftBase; +use hachi_pcs::algebra::ring::{CyclotomicRing, SparseChallenge, SparseChallengeConfig}; +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::challenges::sparse::sparse_challenge_from_transcript; +use hachi_pcs::protocol::transcript::labels::DOMAIN_HACHI_PROTOCOL; +use hachi_pcs::protocol::transcript::{Blake2bTranscript, Transcript}; +use hachi_pcs::{FieldCore, FromSmallInt}; + +type F = Fp64<4294967197>; + +const D: usize = 16; + +fn dense_eval>(alpha: E, x: &CyclotomicRing) -> E { + let mut acc = E::zero(); + let mut pow = E::one(); + for c in x.coefficients().iter().copied() { + acc = acc + E::lift_base(c) * pow; + pow = pow * alpha; + } + acc +} + +#[test] +fn sparse_challenge_validate_and_to_dense() { + let cfg = SparseChallengeConfig { + weight: 3, + nonzero_coeffs: vec![-1, 1], + }; + cfg.validate::().unwrap(); + + let s = SparseChallenge { + positions: vec![0, 7, 12], + coeffs: vec![1, -1, 1], + }; + s.validate::().unwrap(); + assert_eq!(s.hamming_weight(), 3); + assert_eq!(s.l1_norm(), 3); + + let dense = s.to_dense::().unwrap(); + assert_eq!(dense.hamming_weight(), 3); + assert_eq!(dense.coefficients()[0], F::one()); + assert_eq!(dense.coefficients()[7], -F::one()); + assert_eq!(dense.coefficients()[12], F::one()); +} + +#[test] +fn sparse_eval_at_alpha_matches_dense_eval() { + let alpha = F::from_u64(5); + let alpha_pows = { + let mut out = Vec::with_capacity(D); + let mut acc = F::one(); + for _ in 0..D { + out.push(acc); + acc = acc * alpha; + } + out + }; + + let s = SparseChallenge { + positions: vec![1, 3, 9], + coeffs: vec![2, -1, 1], + }; + let dense = s.to_dense::().unwrap(); + + let sparse_eval = s.eval_at_alpha::(&alpha_pows).unwrap(); + let dense_eval = dense_eval::(alpha, &dense); + assert_eq!(sparse_eval, dense_eval); +} + +#[test] +fn sparse_challenge_sampling_is_deterministic_and_exact_weight() { + let cfg = SparseChallengeConfig { + weight: 8, + nonzero_coeffs: vec![-1, 1], + }; + + let mut t1 = Blake2bTranscript::::new(DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(DOMAIN_HACHI_PROTOCOL); + + // Make transcript state non-empty to avoid degenerate behavior. + t1.append_field(b"seed", &F::from_u64(123)); + t2.append_field(b"seed", &F::from_u64(123)); + + let c1 = sparse_challenge_from_transcript::(&mut t1, b"c", 0, &cfg).unwrap(); + let c2 = sparse_challenge_from_transcript::(&mut t2, b"c", 0, &cfg).unwrap(); + assert_eq!(c1, c2); + c1.validate::().unwrap(); + assert_eq!(c1.hamming_weight(), cfg.weight); + assert_eq!(c1.l1_norm(), cfg.weight as u64); + + // Different instance_idx should change the sample. + let mut t3 = Blake2bTranscript::::new(DOMAIN_HACHI_PROTOCOL); + t3.append_field(b"seed", &F::from_u64(123)); + let c3 = sparse_challenge_from_transcript::(&mut t3, b"c", 1, &cfg).unwrap(); + assert_ne!(c1, c3); +} diff --git a/tests/sumcheck_core.rs b/tests/sumcheck_core.rs new file mode 100644 index 00000000..dcd6486f --- /dev/null +++ b/tests/sumcheck_core.rs @@ -0,0 +1,290 @@ +#![allow(missing_docs)] + +use std::time::Instant; + +use hachi_pcs::algebra::poly::multilinear_eval; +use hachi_pcs::algebra::Fp64; +use hachi_pcs::error::HachiError; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{ + prove_sumcheck, verify_sumcheck, Blake2bTranscript, CompressedUniPoly, SumcheckInstanceProver, + SumcheckInstanceVerifier, SumcheckProof, Transcript, UniPoly, +}; +use hachi_pcs::{FieldCore, FieldSampling, FromSmallInt}; +use rand::rngs::StdRng; +use rand::RngCore; +use rand::SeedableRng; + +type F = Fp64<4294967197>; + +#[test] +fn compressed_unipoly_round_trip_and_eval() { + let mut rng = StdRng::seed_from_u64(123); + + for degree in 0..8usize { + let coeffs: Vec = (0..=degree).map(|_| F::sample(&mut rng)).collect(); + let poly = UniPoly::from_coeffs(coeffs); + + // Hint is g(0) + g(1). + let hint = poly.evaluate(&F::zero()) + poly.evaluate(&F::one()); + + let compressed = poly.compress(); + let decompressed = compressed.decompress(&hint); + + // Decompression should be functionally equivalent (it may materialize + // a trailing zero linear term for constant polynomials). + for x_u64 in [0u64, 1, 2, 3, 17] { + let x = F::from_u64(x_u64); + let direct = poly.evaluate(&x); + let decompressed_direct = decompressed.evaluate(&x); + let via_hint = compressed.eval_from_hint(&hint, &x); + assert_eq!(direct, decompressed_direct); + assert_eq!(direct, via_hint); + } + } +} + +#[test] +fn sumcheck_proof_verifier_driver_is_transcript_deterministic() { + // This test checks that the verifier driver absorbs messages and samples challenges + // consistently, and that the returned (final_claim, r_vec) matches a manual replay. + let mut rng = StdRng::seed_from_u64(999); + + let num_rounds = 5usize; + let degree_bound = 7usize; + + // Build random per-round univariates (degree <= degree_bound), compress them. + let round_polys: Vec> = (0..num_rounds) + .map(|_| { + let deg = (rng.next_u32() as usize) % (degree_bound + 1); + let coeffs: Vec = (0..=deg).map(|_| F::sample(&mut rng)).collect(); + UniPoly::from_coeffs(coeffs).compress() + }) + .collect(); + + let proof = SumcheckProof { round_polys }; + let claim0 = F::sample(&mut rng); + + // Verifier run. + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (final_claim_1, r_1) = proof + .verify::(claim0, num_rounds, degree_bound, &mut t1, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + // Manual replay with a fresh transcript (must match). + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut claim = claim0; + let mut r_manual = Vec::with_capacity(num_rounds); + for poly in &proof.round_polys { + t2.append_serde(labels::ABSORB_SUMCHECK_ROUND, poly); + let r_i = t2.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + r_manual.push(r_i); + claim = poly.eval_from_hint(&claim, &r_i); + } + + assert_eq!(r_1, r_manual); + assert_eq!(final_claim_1, claim); +} + +struct DenseSumcheckProver { + evals: Vec, + num_vars: usize, +} + +impl SumcheckInstanceProver for DenseSumcheckProver { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 1 + } + + fn input_claim(&self) -> E { + self.evals.iter().copied().fold(E::zero(), |a, b| a + b) + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + let half = self.evals.len() / 2; + let mut eval_0 = E::zero(); + let mut eval_1 = E::zero(); + for i in 0..half { + eval_0 = eval_0 + self.evals[2 * i]; + eval_1 = eval_1 + self.evals[2 * i + 1]; + } + UniPoly::from_coeffs(vec![eval_0, eval_1 - eval_0]) + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + let half = self.evals.len() / 2; + let mut new_evals = Vec::with_capacity(half); + for i in 0..half { + new_evals.push(self.evals[2 * i] + r * (self.evals[2 * i + 1] - self.evals[2 * i])); + } + self.evals = new_evals; + } +} + +struct DenseSumcheckVerifier { + evals: Vec, + num_vars: usize, + claim: E, +} + +impl SumcheckInstanceVerifier for DenseSumcheckVerifier { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 1 + } + + fn input_claim(&self) -> E { + self.claim + } + + fn expected_output_claim(&self, challenges: &[E]) -> Result { + multilinear_eval(&self.evals, challenges) + } +} + +#[test] +fn prove_and_verify_single_sumcheck() { + let num_vars = 4; + let n = 1 << num_vars; + + let evals: Vec = (1..=n).map(|i| F::from_u64(i as u64)).collect(); + let claim: F = evals.iter().copied().fold(F::zero(), |a, b| a + b); + + let mut prover = DenseSumcheckProver { + evals: evals.clone(), + num_vars, + }; + + let mut prover_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let (proof, prover_challenges, _final_claim) = + prove_sumcheck::(&mut prover, &mut prover_transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let verifier = DenseSumcheckVerifier { + evals, + num_vars, + claim, + }; + + let mut verifier_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let verifier_challenges = + verify_sumcheck::(&proof, &verifier, &mut verifier_transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + assert_eq!(prover_challenges, verifier_challenges); +} + +#[test] +fn verify_rejects_wrong_claim() { + let num_vars = 3; + let n = 1 << num_vars; + + let evals: Vec = (1..=n).map(|i| F::from_u64(i as u64)).collect(); + let correct_claim: F = evals.iter().copied().fold(F::zero(), |a, b| a + b); + let wrong_claim = correct_claim + F::one(); + + // Prove with correct claim. + let mut prover = DenseSumcheckProver { + evals: evals.clone(), + num_vars, + }; + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let (proof, _, _) = prove_sumcheck::(&mut prover, &mut pt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + // Verify with *wrong* claim — should fail. + let verifier = DenseSumcheckVerifier { + evals, + num_vars, + claim: wrong_claim, + }; + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let result = verify_sumcheck::(&proof, &verifier, &mut vt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }); + + assert!(result.is_err()); +} + +/// End-to-end sumcheck over 2^20 random field elements. +/// +/// The prover holds a multilinear polynomial f with 2^20 evaluations and +/// proves that Σ_{b ∈ {0,1}^20} f(b) = claimed_sum. The verifier checks the +/// proof using only the proof transcript and the oracle evaluation f(r). +#[test] +fn e2e_sumcheck_2_pow_20() { + let num_vars = 20; + let n: usize = 1 << num_vars; // 1,048,576 + + let mut rng = StdRng::seed_from_u64(42); + let evals: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let claim: F = evals.iter().copied().fold(F::zero(), |a, b| a + b); + + let t0 = Instant::now(); + + let mut prover = DenseSumcheckProver { + evals: evals.clone(), + num_vars, + }; + let mut prover_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let (proof, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut prover_transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let prove_time = t0.elapsed(); + + // Proof is just 20 compressed univariate polynomials (degree 1 each). + assert_eq!(proof.round_polys.len(), num_vars); + + // Sanity: final claim must equal f evaluated at the challenge point. + let oracle_eval = multilinear_eval(&evals, &prover_challenges).unwrap(); + assert_eq!(final_claim, oracle_eval); + + let t1 = Instant::now(); + + let verifier = DenseSumcheckVerifier { + evals, + num_vars, + claim, + }; + let mut verifier_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let verifier_challenges = + verify_sumcheck::(&proof, &verifier, &mut verifier_transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let verify_time = t1.elapsed(); + + assert_eq!(prover_challenges, verifier_challenges); + + eprintln!( + "[e2e_sumcheck_2_pow_20] n=2^{num_vars}={n} \ + prove={prove_time:.2?} verify={verify_time:.2?} \ + rounds={} degree=1", + proof.round_polys.len() + ); +} diff --git a/tests/sumcheck_prover_driver.rs b/tests/sumcheck_prover_driver.rs new file mode 100644 index 00000000..7ef3ebae --- /dev/null +++ b/tests/sumcheck_prover_driver.rs @@ -0,0 +1,97 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{ + prove_sumcheck, Blake2bTranscript, SumcheckInstanceProver, Transcript, UniPoly, +}; +use hachi_pcs::{FieldCore, FieldSampling}; +use rand::rngs::StdRng; +use rand::SeedableRng; + +type F = Fp64<4294967197>; + +/// A tiny prover-side sumcheck instance for a multilinear function in evaluation-table form. +/// +/// Variable order convention: the current round binds the least-significant index bit first, +/// i.e. pairs are `(i<<1)|0` and `(i<<1)|1` (matches the common LSB-first sumcheck table fold). +struct DenseTableSumcheck { + table: Vec, +} + +impl DenseTableSumcheck { + fn new(table: Vec) -> Self { + assert!(table.len().is_power_of_two()); + Self { table } + } +} + +impl SumcheckInstanceProver for DenseTableSumcheck { + fn num_rounds(&self) -> usize { + self.table.len().trailing_zeros() as usize + } + + fn degree_bound(&self) -> usize { + 1 + } + + fn input_claim(&self) -> F { + self.table.iter().copied().fold(F::zero(), |a, b| a + b) + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: F) -> UniPoly { + let half = self.table.len() / 2; + let mut s0 = F::zero(); + let mut s1 = F::zero(); + for i in 0..half { + s0 = s0 + self.table[i << 1]; + s1 = s1 + self.table[(i << 1) | 1]; + } + UniPoly::from_coeffs(vec![s0, s1 - s0]) + } + + fn ingest_challenge(&mut self, _round: usize, r_round: F) { + let half = self.table.len() / 2; + let mut next = Vec::with_capacity(half); + let one_minus = F::one() - r_round; + for i in 0..half { + let v0 = self.table[i << 1]; + let v1 = self.table[(i << 1) | 1]; + next.push(one_minus * v0 + r_round * v1); + } + self.table = next; + } +} + +#[test] +fn prover_driver_produces_proof_that_verifier_replays() { + let mut rng = StdRng::seed_from_u64(2026); + let num_rounds = 8usize; + let n = 1usize << num_rounds; + + let table: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let mut prover_inst = DenseTableSumcheck::new(table.clone()); + let mut prover_t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (proof, r_vec, final_claim) = + prove_sumcheck::(&mut prover_inst, &mut prover_t, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + // After folding all variables, the table should be a single value equal to f(r*). + assert_eq!(prover_inst.table.len(), 1); + assert_eq!(final_claim, prover_inst.table[0]); + + // Verifier replay must derive the same (final_claim, r_vec). + let initial_claim = table.iter().copied().fold(F::zero(), |acc, x| acc + x); + let mut verifier_t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + verifier_t.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &initial_claim); + let (final_claim_v, r_vec_v) = proof + .verify::(initial_claim, num_rounds, 1, &mut verifier_t, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + assert_eq!(r_vec_v, r_vec); + assert_eq!(final_claim_v, final_claim); +} diff --git a/tests/transcript.rs b/tests/transcript.rs new file mode 100644 index 00000000..50335bec --- /dev/null +++ b/tests/transcript.rs @@ -0,0 +1,136 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{Blake2bTranscript, KeccakTranscript, Transcript}; + +type F = Fp64<4294967197>; + +fn sample_schedule>(transcript: &mut T) -> F { + transcript.append_bytes(labels::ABSORB_COMMITMENT, b"commitment-a"); + transcript.append_bytes(labels::ABSORB_COMMITMENT, b"commitment-b"); + transcript.append_serde(labels::ABSORB_EVALUATION_CLAIMS, &42u64); + let rho = transcript.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + + transcript.append_bytes(labels::ABSORB_RING_SWITCH_MESSAGE, b"ring-switch"); + let zeta = transcript.challenge_scalar(labels::CHALLENGE_RING_SWITCH); + + transcript.append_field(labels::ABSORB_SUMCHECK_ROUND, &(rho + zeta)); + let r = transcript.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + + transcript.append_field(labels::ABSORB_STOP_CONDITION, &r); + transcript.challenge_scalar(labels::CHALLENGE_STOP_CONDITION) +} + +#[test] +fn transcript_is_deterministic_for_identical_schedule() { + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let c1 = sample_schedule(&mut t1); + let c2 = sample_schedule(&mut t2); + assert_eq!(c1, c2); +} + +#[test] +fn transcript_differs_when_label_changes() { + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"same-bytes"); + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"same-bytes"); + let c1 = t1.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + let c2 = t2.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + assert_ne!(c1, c2); +} + +#[test] +fn transcript_differs_when_absorb_order_changes() { + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"A"); + t1.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"B"); + + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"B"); + t2.append_bytes(labels::ABSORB_COMMITMENT, b"A"); + + let c1 = t1.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + let c2 = t2.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + assert_ne!(c1, c2); +} + +#[test] +fn transcript_reset_restores_domain_state() { + let mut t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + t.append_bytes(labels::ABSORB_COMMITMENT, b"before-reset"); + let _ = t.challenge_scalar(labels::CHALLENGE_STOP_CONDITION); + + t.reset(labels::DOMAIN_HACHI_PROTOCOL); + let after_reset = sample_schedule(&mut t); + + let mut fresh = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let fresh_challenge = sample_schedule(&mut fresh); + assert_eq!(after_reset, fresh_challenge); +} + +#[test] +fn keccak_transcript_is_deterministic_for_identical_schedule() { + let mut t1 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let c1 = sample_schedule(&mut t1); + let c2 = sample_schedule(&mut t2); + assert_eq!(c1, c2); +} + +#[test] +fn keccak_transcript_differs_when_label_changes() { + let mut t1 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"same-bytes"); + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"same-bytes"); + let c1 = t1.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + let c2 = t2.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + assert_ne!(c1, c2); +} + +#[test] +fn keccak_transcript_differs_when_absorb_order_changes() { + let mut t1 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"A"); + t1.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"B"); + + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"B"); + t2.append_bytes(labels::ABSORB_COMMITMENT, b"A"); + + let c1 = t1.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + let c2 = t2.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + assert_ne!(c1, c2); +} + +#[test] +fn keccak_transcript_reset_restores_domain_state() { + let mut t = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + t.append_bytes(labels::ABSORB_COMMITMENT, b"before-reset"); + let _ = t.challenge_scalar(labels::CHALLENGE_STOP_CONDITION); + + t.reset(labels::DOMAIN_HACHI_PROTOCOL); + let after_reset = sample_schedule(&mut t); + + let mut fresh = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let fresh_challenge = sample_schedule(&mut fresh); + assert_eq!(after_reset, fresh_challenge); +} + +#[test] +fn blake2b_and_keccak_diverge_on_same_schedule() { + let mut blake = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut keccak = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let b = sample_schedule(&mut blake); + let k = sample_schedule(&mut keccak); + assert_ne!(b, k); +}