diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7aa166de..8d38457b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,6 +6,9 @@ on: pull_request: branches: ["**", main] +permissions: + contents: read + env: RUSTFLAGS: -D warnings CARGO_TERM_COLOR: always @@ -19,8 +22,8 @@ jobs: name: Format runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: actions-rust-lang/setup-rust-toolchain@a0b538fa0b742a6aa35d6e2c169b4bd06d225a98 # v1 with: components: rustfmt - name: Check formatting @@ -30,21 +33,21 @@ jobs: name: Clippy runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: actions-rust-lang/setup-rust-toolchain@a0b538fa0b742a6aa35d6e2c169b4bd06d225a98 # v1 with: components: clippy - name: Clippy (all features) - run: cargo clippy -q --message-format=short --all-features --all-targets -- -D warnings + run: cargo clippy --all --all-targets --all-features -- -D warnings - name: Clippy (no default features) - run: cargo clippy -q --message-format=short --no-default-features --lib -- -D warnings + run: cargo clippy --all --all-targets --no-default-features -- -D warnings doc: name: Documentation runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: actions-rust-lang/setup-rust-toolchain@a0b538fa0b742a6aa35d6e2c169b4bd06d225a98 # v1 - name: Build documentation run: cargo doc -q --no-deps --all-features env: @@ -54,9 +57,9 @@ jobs: name: Test runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: actions-rust-lang/setup-rust-toolchain@a0b538fa0b742a6aa35d6e2c169b4bd06d225a98 # v1 - name: Install cargo-nextest - uses: taiki-e/install-action@nextest + uses: taiki-e/install-action@f092c064826410a38929a5791d2c0225b94432fe # nextest - name: Run tests - run: cargo nextest run -q --all-features + run: cargo nextest run --all-features diff --git a/.gitignore b/.gitignore index bd04f281..98e812bf 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,7 @@ .urs PUBLISH_CHECKLIST.md + +profile_traces/ + +.cursor/ diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..68f667fa --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,35 @@ +# AGENTS.md + +**Compatibility notice (explicit): This repo makes NO backward-compatibility guarantees. Breaking changes are allowed and expected.** + +## Project Overview + +Hachi is a lattice-based polynomial commitment scheme (PCS) with transparent setup and post-quantum security. Built in Rust. Intended to replace Dory in Jolt. + +## Essential Commands + +```bash +cargo clippy --all --message-format=short -q -- -D warnings +cargo fmt -q +cargo test # no nextest yet +``` + +## Crate Structure + +Two workspace members: `hachi-pcs` (root) and `derive` (proc macros). + +- `src/primitives/` — Core traits: `FieldCore`, `Module`, `MultilinearLagrange`, `Transcript`, serialization +- `src/algebra/` — Concrete backends: prime fields, extension fields, cyclotomic rings, NTT, domains +- `src/protocol/` — Protocol layer: commitment, prover, verifier, opening (ring-switch), challenges, transcript +- `src/error.rs` — Error types + +## Key Abstractions + +- `CommitmentScheme` / `StreamingCommitmentScheme` — top-level PCS traits +- `FieldCore` + `PseudoMersenneField` + `Module` — arithmetic over lattice-friendly fields and rings +- `MultilinearLagrange` — multilinear polynomial in Lagrange basis +- `Transcript` — Fiat-Shamir + +## Feature Flags + +- `parallel` — Rayon parallelization diff --git a/CLAUDE.md b/CLAUDE.md new file mode 120000 index 00000000..47dc3e3d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/CONSTANT_TIME_NOTES.md b/CONSTANT_TIME_NOTES.md new file mode 100644 index 00000000..dac24600 --- /dev/null +++ b/CONSTANT_TIME_NOTES.md @@ -0,0 +1,42 @@ +# Constant-Time Review Notes (Phase 0/1 Algebra) + +This note tracks timing-sensitive implementation decisions for the current +algebra and ring stack. + +## Reviewed Components + +- `src/algebra/fields/fp32.rs` +- `src/algebra/fields/fp64.rs` +- `src/algebra/fields/fp128.rs` +- `src/algebra/ntt/prime.rs` +- `src/algebra/ntt/butterfly.rs` +- `src/algebra/ring/cyclotomic.rs` +- `src/algebra/ring/crt_ntt_repr.rs` + +## Current State + +- Branchless primitives are in place for: + - `Fp32/Fp64/Fp128` add/sub/neg raw helpers. + - `Fp128` multiplication reduction (`reduce_u256`) with branchless conditional subtract. + - `Fp32/Fp64` multiplication reduction (division-free fixed-iteration paths). + - NTT helper operations `csubp`, `caddp`, and `center`. +- NTT butterfly arithmetic runs in fixed loop structure independent of data. +- Ring multiplication (`CyclotomicRing`) is fixed-structure schoolbook over `D`. +- CRT reconstruction inner accumulation now uses fixed-trip, branchless + modular add/mul-by-small-factor helpers. +- Prime fields now expose `Invertible::inv_or_zero()` for secret-bearing + inversion use-cases without input-dependent branching on zero. +- CRT reconstruction final projection now uses a division-free fixed-iteration + reducer (`reduce_u128_divfree`) instead of `% q`. + +## Known Timing Risks / Follow-ups + +- `FieldCore::inv()` still returns `Option` and therefore branches on zero; + treat that API as public-value oriented. Use `Invertible::inv_or_zero()` + in secret-dependent paths. + +## Action Items Before Production-Critical Use + +1. Wire secret-bearing call sites to `Invertible::inv_or_zero()` as + protocol code matures. +2. Add dedicated CT review tests/checklists for any arithmetic subsystem changes. diff --git a/Cargo.lock b/Cargo.lock index a505ce5c..0a1f995c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,29 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -11,6 +34,33 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocative" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fac2ce611db8b8cee9b2aa886ca03c924e9da5e5295d0dbd0526e5d0b0710f7" +dependencies = [ + "allocative_derive", + "ctor", +] + +[[package]] +name = "allocative_derive" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe233a377643e0fc1a56421d7c90acdec45c291b30345eb9f08e8d0ddce5a4ab" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "anes" version = "0.1.6" @@ -23,12 +73,154 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +[[package]] +name = "ark-bn254" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-serialize", + "ark-std", +] + +[[package]] +name = "ark-ec" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "ahash", + "ark-ff", + "ark-poly", + "ark-serialize", + "ark-std", + "educe", + "fnv", + "hashbrown", + "itertools 0.13.0", + "num-bigint", + "num-integer", + "num-traits", + "zeroize", +] + +[[package]] +name = "ark-ff" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "allocative", + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std", + "arrayvec", + "digest", + "educe", + "itertools 0.13.0", + "num-bigint", + "num-traits", + "paste", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "quote", + "syn 2.0.110", +] + +[[package]] +name = "ark-ff-macros" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "ark-poly" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "ahash", + "ark-ff", + "ark-serialize", + "ark-std", + "educe", + "fnv", + "hashbrown", +] + +[[package]] +name = "ark-serialize" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "ark-serialize-derive", + "ark-std", + "arrayvec", + "digest", + "num-bigint", +] + +[[package]] +name = "ark-serialize-derive" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "ark-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -74,6 +266,16 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "4.5.51" @@ -99,6 +301,15 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "criterion" version = "0.5.1" @@ -111,7 +322,7 @@ dependencies = [ "clap", "criterion-plot", "is-terminal", - "itertools", + "itertools 0.10.5", "num-traits", "once_cell", "oorandom", @@ -132,7 +343,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" dependencies = [ "cast", - "itertools", + "itertools 0.10.5", ] [[package]] @@ -166,12 +377,100 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "ctor" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d2301688392eb071b0bf1a37be05c469d3cc4dbbd95df672fe28ab021e6a096" +dependencies = [ + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.110", +] + [[package]] name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "enum-ordinalize" +version = "4.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a1091a7bb1f8f2c4b28f1fe2cef4980ca2d410a3d727d67ecc3178c9b0800f0" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ca9601fb2d62598ee17836250842873a413586e5d7ed88b356e38ddbb0ec631" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -189,20 +488,29 @@ version = "0.1.0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", ] [[package]] name = "hachi-pcs" version = "0.1.0" dependencies = [ + "aes", + "ark-bn254", + "ark-ff", + "blake2", "criterion", + "ctr", "hachi-derive", + "num-bigint", "rand", "rand_core", "rayon", + "sha3", "thiserror", "tracing", + "tracing-chrome", + "tracing-subscriber", ] [[package]] @@ -216,12 +524,30 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "allocator-api2", +] + [[package]] name = "hermit-abi" version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "generic-array", +] + [[package]] name = "is-terminal" version = "0.4.17" @@ -242,6 +568,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.15" @@ -258,18 +593,76 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "keccak" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" +dependencies = [ + "cpufeatures", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "libc" version = "0.2.177" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + [[package]] name = "memchr" version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -291,6 +684,12 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -479,7 +878,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", ] [[package]] @@ -495,6 +894,48 @@ dependencies = [ "serde_core", ] +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.110" @@ -523,7 +964,16 @@ checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", ] [[package]] @@ -538,9 +988,9 @@ dependencies = [ [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -549,30 +999,89 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", +] + +[[package]] +name = "tracing-chrome" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf0a738ed5d6450a9fb96e86a23ad808de2b727fd1394585da5cdd6788ffe724" +dependencies = [ + "serde_json", + "tracing-core", + "tracing-subscriber", ] [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", + "valuable", ] +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicode-ident" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "walkdir" version = "2.5.0" @@ -621,7 +1130,7 @@ dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn", + "syn 2.0.110", "wasm-bindgen-shared", ] @@ -685,5 +1194,25 @@ checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", ] diff --git a/Cargo.toml b/Cargo.toml index 47abb4d8..1068b531 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,9 +6,11 @@ resolver = "2" name = "hachi-pcs" version = "0.1.0" edition = "2021" -rust-version = "1.75" +rust-version = "1.88" authors = [ "Markos Georghiades ", + "Quang Dao ", + "Omid Bodaghi ", ] license = "Apache-2.0 OR MIT" description = "A high performance and modular implementation of the Hachi polynomial commitment scheme." @@ -32,19 +34,52 @@ include = [ all-features = true [features] -default = [] +default = ["parallel"] parallel = ["dep:rayon"] +disk-persistence = [] [dependencies] thiserror = "2.0" -rand_core = "0.6" +rand_core = { version = "0.6", features = ["getrandom"] } hachi-derive = { version = "0.1.0", path = "derive" } tracing = "0.1" rayon = { version = "1.10", optional = true } +blake2 = "0.10.6" +sha3 = "0.10.8" +aes = "0.8.4" +ctr = "0.9.2" [dev-dependencies] rand = "0.8" criterion = { version = "0.5", features = ["html_reports"] } +num-bigint = "0.4.6" +ark-bn254 = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout", features = ["scalar_field"] } +ark-ff = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" } +tracing-chrome = "0.7" +tracing-subscriber = { version = "0.3", features = ["env-filter", "registry"] } + +[[example]] +name = "profile" + +[[bench]] +name = "ring_ntt" +harness = false + +[[bench]] +name = "field_arith" +harness = false + +[[bench]] +name = "fp64_reduce_probe" +harness = false + +[[bench]] +name = "hachi_e2e" +harness = false + +[[bench]] +name = "norm_sumcheck" +harness = false [lints.rust] missing_docs = "warn" diff --git a/HACHI_PROGRESS.md b/HACHI_PROGRESS.md new file mode 100644 index 00000000..afca63e6 --- /dev/null +++ b/HACHI_PROGRESS.md @@ -0,0 +1,200 @@ +## Hachi PCS implementation progress + +This file is the **single source of truth** for implementation status and near-term priorities. + +### Goals (project-level) + +- **Production-ready implementation**: correctness, security, maintainability, and performance are first-class goals. +- **Standalone codebase**: implementation and comments should stand on their own; external acknowledgements live in `README.md`. +- **Constant-time cryptographic core**: arithmetic and protocol-critical paths must be constant-time with respect to secret data. +- **No shortcuts / no fallback design**: avoid temporary or degraded code paths in the core implementation. + +### Non-negotiable requirements + +- **Constant-time discipline** + - No secret-dependent branches or memory access patterns in cryptographic hot paths. + - No secret-indexed table lookups; table access patterns must be independent of secret data. + - Keep data representations and reductions explicit and auditable for timing behavior. + - Add targeted tests/reviews for constant-time-sensitive code as features land. +- **Code quality bar** + - Clear naming, explicit invariants, small cohesive modules, and API docs for public interfaces. + - No placeholder crypto logic in mainline code (no "temporary" arithmetic shortcuts). + - Tests are required for correctness-critical arithmetic before dependent protocol code is built. + - No section-banner comments (e.g., `// ---- Section ----`, `// === ... ===`). Let the code and doc-comments speak for themselves. +- **Standalone implementation policy** + - Do not mention external inspirations/ports in core code comments. + - Keep terminology and structure internally coherent and project-native. + - Keep external attribution limited to dedicated docs (for now: `README.md` acknowledgements). +- **Git discipline** + - Do not commit or push without explicit user approval. + +### Implementation workflow (cautious + approval-driven) + +- Before each major subsystem, present implementation options with trade-offs. +- Seek explicit approval before proceeding with a selected option. +- Pause at milestone boundaries for review and feedback before continuing. +- Prefer slow, verifiable progress over rapid, high-risk changes. +- Ask for user input frequently when requirements are ambiguous or involve design trade-offs. + +### Definition of Done (all crypto-critical work) + +- **Security / constant-time** + - Secret-independent control flow and memory access in cryptographic paths. + - Constant-time review notes included for non-trivial arithmetic/ring changes. +- **Correctness** + - Unit tests for edge cases and algebraic identities. + - Cross-check vectors/reference checks added where practical. +- **Code quality** + - Clear naming, explicit invariants, and no placeholder logic in core paths. + - Public interfaces documented sufficiently for safe usage. +- **Performance** + - Hot-path performance impact evaluated (benchmark or measured rationale). +- **Tooling + CI** + - `cargo fmt --all --check` passes. + - `cargo clippy --all --all-targets --all-features` passes. + - `cargo test` (or targeted suite for touched modules) passes. +- **Process** + - Implementation options reviewed with user before major subsystem changes. + - Milestone update recorded in this file. + +### Scope (current) + +- **Implemented so far (Phase 0 + Phase 1 functional core)**: prime fields (32/64/128-bit representations), extension fields, cyclotomic `R_q = Z_q[X]/(X^d + 1)`, CRT+NTT representation, backend/domain layering, ring automorphisms, and functional gadget decomposition. +- **Phase 2+ protocol status**: interface scaffold plus ring-native §4.1 commitment core are present (`Transcript`, Blake2b/Keccak backends, phase-grounded labels, `RingCommitmentScheme`, config layer, and setup/commit implementation). Sumcheck core building blocks (univariate messages + transcript-driving prover/verifier driver) are now implemented, with tests. Open-check prover/verifier paths remain stubbed. +- **Deferred future phase**: integration into Jolt (replacement of Dory with Hachi) is intentionally out of current execution scope; cross-repo analysis is design input only. + +### Critical review snapshot (2026-02-13) + +- **Phase 1 functional milestone appears complete** + - Ring/gadget components listed in Phase 1 are implemented and currently checked off. + - Conversion and arithmetic paths in coefficient and CRT+NTT domains are exercised by passing tests. +- **Not yet "production-ready" despite functional completion** + - Constant-time hardening follow-ups narrowed: secret-bearing call-sites still need to migrate from `FieldCore::inv()` to `Invertible::inv_or_zero()` as protocol code lands (see `CONSTANT_TIME_NOTES.md`). + - Current ring multiplication in coefficient form remains `O(D^2)` schoolbook (`src/algebra/ring/cyclotomic.rs`), with CRT+NTT available as the faster domain path. +- **Tooling/quality gate status (current branch snapshot)** + - `cargo test` passes, including protocol transcript/label/commitment contract tests and new ring-commitment core/config/stub tests. + - `cargo fmt --all --check` passes. + - `cargo clippy --all --all-targets --all-features` passes. +- **Phase 2 scaffold + commitment core landed; proof-system work still pending** + - `src/protocol/*` now provides transcript + commitment abstraction boundaries with `Transcript` naming. + - Two transcript backends are wired (`Blake2bTranscript`, `KeccakTranscript`) with deterministic replay/order/reset tests. + - Hachi-native labels are now calibrated to paper-stage phases (§4.1, §4.2, §4.3, §4.5). + - Commitment absorption is label-directed at call sites (`AppendToTranscript` no longer hardcodes commitment labels). + - Ring-native commitment setup/commit flow for §4.1 is implemented in `src/protocol/commitment/commit.rs` behind `RingCommitmentScheme`. + - Sumcheck core module landed (`src/protocol/sumcheck.rs`) with unit/integration tests (`tests/sumcheck_core.rs`, `tests/sumcheck_prover_driver.rs`). + - Prover/verifier split folders are wired with explicit stubs (`src/protocol/prover/stub.rs`, `src/protocol/verifier/stub.rs`) for future open-check implementation. +- **Conclusion** + - Treat **Phase 1 as functionally complete**. + - Treat **Phase 2 as active/in-progress** (commitment core implemented; prove/verify and later reductions still open). + - Remaining strict CT follow-ups stay tracked in `CONSTANT_TIME_NOTES.md`. + +### Status board + +#### Phase 0 — Algebra + +- [x] Prime field `Fp32` (u32 storage; u64 mul) implementing `FieldCore + CanonicalField` (`src/algebra/fields/fp32.rs`) +- [x] Prime field `Fp64` (u64 storage; u128 mul) implementing `FieldCore + CanonicalField` (`src/algebra/fields/fp64.rs`) +- [x] Prime field `Fp128` (u128 storage; 256-bit intermediate) implementing `FieldCore + CanonicalField` (`src/algebra/fields/fp128.rs`, `src/algebra/fields/u256.rs`) +- [x] Branchless constant-time `add_raw`, `sub_raw`, `neg` for all field types +- [x] Constant-time inversion helper for prime fields: `Invertible::inv_or_zero()` (`src/primitives/arithmetic.rs`, `src/algebra/fields/fp*.rs`) +- [x] Division-free fixed-iteration reduction for `Fp32/Fp64` multiplication paths +- [x] Division-free fixed-iteration CRT final projection (replaced `% q` in scalar reconstruction path) +- [x] Rejection-sampled `FieldSampling::sample()` for all field types (no modular bias) +- [x] Pow2Offset pseudo-Mersenne registry + aliases (`q = 2^k - offset`, bounded `k <= 128`, `q % 8 == 5`) (`src/algebra/fields/pseudo_mersenne.rs`) +- [x] Constant-time review notes for current algebra/ring paths (`CONSTANT_TIME_NOTES.md`) +- [x] Deterministic parameter presets + - [x] `q = 2^32 - 99` constants scaffold (`src/algebra/ntt/tables.rs`) + - [x] `Pow2Offset` presets selected for 64/128-bit path: + - `q = 2^64 - 59` (`POW2_OFFSET_MODULUS_64`) + - `q = 2^128 - 275` (`POW2_OFFSET_MODULUS_128`) + - source: `src/algebra/fields/pseudo_mersenne.rs` +- [x] `Module` implementations: + - [x] `VectorModule` (fixed-length vectors; `Module` via scalar*vector mul) (`src/algebra/module.rs`) + - [x] `PolyModule` removed from current scope (not needed for near-term Hachi milestones) +- [ ] Extension fields: + - [x] `Fp2` quadratic extension (`src/algebra/fields/ext.rs`) + - [x] `Fp4` tower extension (`src/algebra/fields/ext.rs`) +- [x] Serialization for algebra types (`HachiSerialize` / `HachiDeserialize`) (+ `u128/i128` primitives in `src/primitives/serialization.rs`) +- [x] NTT small-prime arithmetic: Montgomery-like `fpmul`, Barrett-like `fpred`, branchless `csubq`/`caddq`/`center` (`src/algebra/ntt/prime.rs`) +- [x] CRT limb arithmetic: `LimbQ`, `QData` (`src/algebra/ntt/crt.rs`) +- [x] Tests (49 total in `tests/algebra.rs`): + - [x] field arithmetic, identities, distributivity (Fp32/Fp64/Fp128) + - [x] zero inversion returns None + - [x] serialization round-trips (all field types, extensions, Poly, VectorModule) + - [x] Fp2 conjugate, norm, distributivity + - [x] U256 wide multiply and bit access + - [x] LimbQ round-trip, add/sub inverse, QData consistency + - [x] NTT normalize range, fpmul commutativity + - [x] Poly add/sub/neg + - [x] Cyclotomic ring identities and serialization (D=4, D=64) + - [x] NTT forward/inverse round-trips (single prime and all Q32 primes) + - [x] Cyclotomic CRT+NTT full round-trip (`from_ring` -> `to_ring`) + - [x] Scalar backend path equivalence (`*_with_backend` vs default path) + - [x] Pow2Offset profile invariants (`q = 2^k - offset`, `q % 8 == 5`) + - [x] `FieldSampling::sample()` output bound checks + - [x] Checked deserialization rejects non-canonical field encodings + - [x] Galois automorphism checks (`sigma` composition + multiplicativity) + - [x] Functional gadget decompose/recompose round-trip checks + - [x] Sparse `+/-1` challenge support checks (`hamming_weight = omega`) +- [x] Dedicated Pow2Offset primality regression tests (`tests/primality.rs`) + - [x] Miller-Rabin probable-prime checks for all registered Pow2Offset moduli + - [x] Composite sanity rejection checks + +#### Phase 1 — Ring + gadgets (functional core) + +- [x] Cyclotomic ring `Rq` with `X^D = -1` (`src/algebra/ring/cyclotomic.rs`) +- [x] CRT+NTT-domain ring representation + CRT conversion (`src/algebra/ring/crt_ntt_repr.rs`) +- [x] Backend/domain layering for ring execution (`src/algebra/backend/*`, `src/algebra/domains/*`) +- [x] Galois automorphisms `sigma_i: X ↦ X^i` (odd `i`) +- [x] Functional gadget decomposition/recomposition (`G^{-1}` / `G` behavior) for base-`2^d` digits, without materializing dense gadget matrices +- [x] sparse short challenges (paper: `||c||_1 ≤ ω`, sparse ±1) + +#### Phase 2+ — Protocol (later) + +- [x] Protocol module scaffold (`src/protocol/*`) and top-level re-exports +- [x] Transcript interface (`Transcript`) plus Blake2b/Keccak implementations +- [x] Hachi-native transcript label schedule aligned to paper phases (§4.1/§4.2/§4.3/§4.5) +- [x] Commitment trait surface + streaming trait surface + contract tests +- [x] Label-directed transcript absorption for commitments (`AppendToTranscript` takes label at call site) +- [x] ring-native commitment core (`RingCommitmentScheme`, `commit.rs`, config wiring) for §4.1 setup/commit +- [x] protocol prover/verifier folder split with explicit stubs (`prover/stub.rs`, `verifier/stub.rs`) +- [x] ring-commitment tests (`ring_commitment_core`, `ring_commitment_config`, `prover_verifier_stub_contract`) +- [x] sumcheck core building blocks (univariate messages + transcript-driving prover/verifier driver) (`src/protocol/sumcheck.rs`) +- [x] sumcheck core tests (`tests/sumcheck_core.rs`, `tests/sumcheck_prover_driver.rs`) +- [ ] commitment open-check prove/verify implementation (currently stubs) +- [ ] evaluation → linear relation (paper §4.2) +- [ ] ring-switching constraints as sumcheck instances (paper §4.3, Fig. 4–7) +- [ ] recursion / “stop condition” + optional Greyhound composition (§4.5) + +#### Phase 3 — Integration into Jolt (deferred; not active now) + +- [ ] Define compatibility boundary document (what must match Jolt/Dory behavior vs what can remain Hachi-native) +- [ ] Provide Jolt-facing transcript adapter design (`Jolt` transcript pattern ↔ Hachi transcript object) +- [ ] Provide Jolt-facing PCS shim design (`CommitmentScheme`/`StreamingCommitmentScheme` mapping) +- [ ] Add transcript/commitment compatibility tests for integration-readiness (without wiring into Jolt yet) + +### Conventions + +- **Correctness first**: lock arithmetic with tests before touching protocol code. +- **Security first**: enforce constant-time behavior for secret-dependent operations. +- **Lean deps**: avoid heavyweight crypto crates until there is a clear need. +- **Explicit parameter sets**: each field/ring preset lives in code with a clear name and rationale. + +### Module layout + +``` +src/algebra/ +├── backend/ Backend execution traits + scalar backend +├── domains/ Domain-level aliases (coefficient / CRT+NTT) +├── fields/ Prime fields, pseudo-mersenne registry, u256, and extensions +├── ntt/ NTT kernels (butterfly), prime kernels (prime), CRT helpers (crt), presets (tables) +├── module.rs VectorModule +├── poly.rs Poly container +└── ring/ Cyclotomic ring and CRT+NTT representation +``` + +### References + +- Hachi paper: `paper/hachi.pdf` +- Core traits: `src/primitives/arithmetic.rs`, `src/primitives/serialization.rs` + diff --git a/NTT_PRIME_ANALYSIS.md b/NTT_PRIME_ANALYSIS.md new file mode 100644 index 00000000..e6b9c85c --- /dev/null +++ b/NTT_PRIME_ANALYSIS.md @@ -0,0 +1,146 @@ +# NTT Prime Analysis (Pow2Offset / Solinas Context) + +This note records the current analysis for small NTT primes and CRT coverage targets. + +## References + +- NIST ML-KEM: `paper/standards/NIST.FIPS.203.pdf` +- NIST ML-DSA: `paper/standards/NIST.FIPS.204.pdf` +- Current small-prime table: `src/algebra/ntt/tables.rs` +- Labrador generator heuristic: `../labrador/data.py` + +## Why does `2D` divide `p - 1`? + +For negacyclic NTT on `Z_p[X]/(X^D + 1)`, we need a primitive `2D`-th root `psi` such that: + +- `psi^D = -1 (mod p)` +- `psi^(2D) = 1 (mod p)` + +Over prime fields, `F_p^*` is cyclic of size `p - 1`, so an element of order `2D` exists iff: + +- `2D | (p - 1)` + +So yes, the `128 | (p - 1)` condition is directly tied to `D = 64`. + +## What if `D = 1024`? + +Then requirement becomes: + +- `2D = 2048`, so `2048 | (p - 1)`. + +Under the current "small prime" cap (`p < 2^14`), this is extremely restrictive. + +## Why `< 2^14` in current code? + +This is a backend implementation constraint, not a hard NTT math requirement: + +- Current small-prime NTT backend stores modulus/coefficients in signed 16-bit lanes (`i16`). +- It relies on centered signed arithmetic and butterfly add/sub before full normalization. +- Keeping `p < 2^14` leaves practical headroom in those 16-bit operations. +- Current CRT limb code is also radix-`2^14`, matching this design style. + +So the `2^14` cap is about the present `i16` scalar kernel design. If we introduce an `i32` backend, this cap can be raised substantially. + +## Exhaustive counts (for `p < 2^14`) + +We classify exact Solinas as: + +- `p = 2^x - 2^y + 1`. + +Results: + +- `D = 64` (`128 | p-1`) + - all small NTT primes: **31** + - exact Solinas NTT primes: **6** + - all-prime set: + - `257, 641, 769, 1153, 1409, 2689, 3329, 3457, 4481, 4993, 6529, 7297, 7681, 7937, 9473, 9601, 9857, 10369, 10753, 11393, 11777, 12161, 12289, 13313, 13441, 13697, 14081, 14593, 15233, 15361, 16001` + - Solinas set: `257, 769, 7681, 7937, 12289, 15361` +- `D = 256` (`512 | p-1`) + - all small NTT primes: **6** + - exact Solinas NTT primes: **3** + - Solinas set: `7681, 12289, 15361` +- `D = 1024` (`2048 | p-1`) + - all small NTT primes: **1** + - exact Solinas NTT primes: **1** + - Solinas set: `12289` + +Conclusion: for higher `D`, the small-prime pool shrinks rapidly. + +## 30-bit exploration (`p < 2^30`) with NTT constraints + +To assess a larger-prime backend direction, we scanned for primes under `2^30` with: + +- `p ≡ 1 (mod 2D)`. + +Below are the **full outputs of the bounded search run** (top 30 largest primes found by descending scan): + +### `D = 64` (`2D = 128`) + +- Top-30 list: + - `1073741441, 1073739649, 1073738753, 1073736449, 1073735297, 1073734913, 1073732993, 1073732609, 1073731201, 1073731073, 1073730817, 1073728897, 1073727617, 1073726977, 1073722753, 1073719681, 1073717377, 1073716993, 1073713409, 1073712769, 1073712257, 1073710721, 1073708929, 1073707009, 1073703809, 1073702657, 1073702401, 1073698817, 1073696257, 1073693441` +- Coverage for `q = 2^128 - 275`: + - `P > q`: **5** limbs + - `P > 128*q^2`: **9** limbs + +### `D = 1024` (`2D = 2048`) + +- Top-30 list: + - `1073707009, 1073698817, 1073692673, 1073682433, 1073668097, 1073655809, 1073651713, 1073643521, 1073620993, 1073600513, 1073569793, 1073563649, 1073551361, 1073539073, 1073522689, 1073510401, 1073508353, 1073479681, 1073453057, 1073442817, 1073440769, 1073430529, 1073412097, 1073391617, 1073385473, 1073354753, 1073350657, 1073330177, 1073299457, 1073268737` +- Coverage for `q = 2^128 - 275`: + - `P > q`: **5** limbs + - `P > 128*q^2`: **9** limbs + +### Bit-estimate sanity check + +- `ceil(128 / 30) = 5` +- `ceil(263 / 30) = 9` (for `128*q^2 ~ 2^263`) + +This matches the concrete product counts above. + +## CRT size targets for `q = 2^128 - 275` + +Two common thresholds: + +1. Minimal uniqueness target: + - `P = prod(p_i) > q` +2. Labrador conservative heuristic: + - `P > 128 * q^2` (from `data.py`, with `FIXME` comment) + +### Limb counts at `D = 64` with current small-prime pool + +- Using all small NTT primes (`31` available): + - `P > q` achievable with **10** limbs + - `P > 128*q^2` achievable with **20** limbs +- Using only exact Solinas NTT primes (`6` available): + - `P > q`: **not achievable** + - `P > 128*q^2`: **not achievable** + - total product is only about `2^70` + +### Limb counts at `D = 1024` with current small-prime pool + +- Only one qualifying prime (`12289`) under `p < 2^14`, so neither threshold is achievable. + +## What is Labrador's safety margin doing? + +In Labrador code, prime selection stops at: + +- `P > 128 * q^2` + +Interpretation: + +- `q^2` tracks product-scale growth, +- extra factor `128` gives additional headroom (for `N=64`, this is `2N`), +- but their own `FIXME` comment indicates this is a conservative engineering bound, not a tight proof. + +So treat this as a robust heuristic rather than a formal minimum. + +## Practical implication for Hachi + +- If we stay with `D=64` and small i16-ish primes, we need non-Solinas primes in the CRT set. +- If we push to `D=1024`, we must either: + - lift prime size beyond `<2^14`, or + - change CRT strategy (fewer larger limbs / different backend), or + - avoid strict small-prime CRT-NTT at that degree. +- A mixed backend model is sensible: + - keep the current `i16` backend for small-prime kernels, + - add an `i32`/wider backend for larger-prime kernels (e.g., up to ~30-bit). diff --git a/README.md b/README.md index 492b4d12..8036344f 100644 --- a/README.md +++ b/README.md @@ -3,3 +3,7 @@ A high performance and modular implementation of the Hachi polynomial commitment scheme. Hachi is a lattice-based polynomial commitment scheme with transparent setup and post-quantum security. + +## Acknowledgements + +The CRT/NTT and small-prime arithmetic design in this repository is informed by the Labrador/Greyhound C implementation family. In particular, the current pseudo-Mersenne profile uses moduli of the form `q = 2^k - offset` (smallest prime below `2^k` with `q % 8 == 5`). Hachi provides a Rust-native architecture and APIs, while drawing algorithmic inspiration from those implementations. diff --git a/benches/field_arith.rs b/benches/field_arith.rs new file mode 100644 index 00000000..e5f07f25 --- /dev/null +++ b/benches/field_arith.rs @@ -0,0 +1,1448 @@ +#![allow(missing_docs)] + +use ark_bn254::Fr as BN254Fr; +use ark_ff::{AdditiveGroup, Field}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use hachi_pcs::algebra::fields::fp128::{Prime128M18M0, Prime128M54P0}; +use hachi_pcs::algebra::fields::fp32::Fp32; +use hachi_pcs::algebra::{HasPacking, PackedField, PackedValue, Prime128M13M4P0, Prime128M8M4M1M0}; +use hachi_pcs::algebra::{ + Pow2Offset24Field, Pow2Offset30Field, Pow2Offset31Field, Pow2Offset32Field, Pow2Offset40Field, + Pow2Offset48Field, Pow2Offset56Field, Pow2Offset64Field, +}; +use hachi_pcs::{CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible}; +use rand::{rngs::StdRng, RngCore, SeedableRng}; +use std::env; +#[cfg(feature = "parallel")] +use std::thread; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; +#[cfg(feature = "parallel")] +use rayon::ThreadPoolBuilder; + +fn rand_u128(rng: &mut R) -> u128 { + let lo = rng.next_u64() as u128; + let hi = rng.next_u64() as u128; + lo | (hi << 64) +} + +fn env_usize(name: &str, default: usize) -> usize { + env::var(name) + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(default) +} + +fn bench_mul(c: &mut Criterion) { + type F13 = Prime128M13M4P0; + type F275 = Prime128M8M4M1M0; + type F2p18p1 = Prime128M18M0; + type F2p54m1 = Prime128M54P0; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let inputs_u128: Vec = (0..2048).map(|_| rand_u128(&mut rng)).collect(); + + let inputs_f13: Vec = inputs_u128 + .iter() + .copied() + .map(F13::from_canonical_u128_reduced) + .collect(); + + let inputs_f275: Vec = inputs_u128 + .iter() + .copied() + .map(F275::from_canonical_u128_reduced) + .collect(); + let inputs_f2p18p1: Vec = inputs_u128 + .iter() + .copied() + .map(F2p18p1::from_canonical_u128_reduced) + .collect(); + let inputs_f2p54m1: Vec = inputs_u128 + .iter() + .copied() + .map(F2p54m1::from_canonical_u128_reduced) + .collect(); + + let mut group = c.benchmark_group("field_mul"); + + group.bench_function("fp128_prime128m13m4p0", |b| { + b.iter(|| { + let mut acc = F13::one(); + for x in inputs_f13.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + + group.bench_function("fp128_prime128m8m4m1m0", |b| { + b.iter(|| { + let mut acc = F275::one(); + for x in inputs_f275.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + + group.bench_function("fp128_prime128m18m0_shift_special", |b| { + b.iter(|| { + let mut acc = F2p18p1::one(); + for x in inputs_f2p18p1.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + + group.bench_function("fp128_prime128m54p0_shift_special", |b| { + b.iter(|| { + let mut acc = F2p54m1::one(); + for x in inputs_f2p54m1.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + + group.finish(); +} + +fn bench_mul_only(c: &mut Criterion) { + type F13 = Prime128M13M4P0; + type F2p18p1 = Prime128M18M0; + type F2p54m1 = Prime128M54P0; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let inputs_f13: Vec = (0..2048) + .map(|_| F13::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let inputs_f2p18p1: Vec = (0..2048) + .map(|_| F2p18p1::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let inputs_f2p54m1: Vec = (0..2048) + .map(|_| F2p54m1::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let mut group = c.benchmark_group("field_mul_only"); + + group.bench_function("mul_chain_2048", |b| { + b.iter(|| { + let mut acc = F13::one(); + for x in inputs_f13.iter() { + acc *= *x; + } + black_box(acc) + }) + }); + + group.bench_function("mul_chain_16384", |b| { + b.iter(|| { + let mut acc = F13::one(); + for _ in 0..8 { + for x in inputs_f13.iter() { + acc *= *x; + } + } + black_box(acc) + }) + }); + + group.bench_function("mul_parallel_1024", |b| { + b.iter(|| { + let mut sum = F13::zero(); + for pair in inputs_f13.chunks_exact(2) { + sum += pair[0] * pair[1]; + } + black_box(sum) + }) + }); + + group.bench_function("mul_chain_2048_special_m18m0", |b| { + b.iter(|| { + let mut acc = F2p18p1::one(); + for x in inputs_f2p18p1.iter() { + acc *= *x; + } + black_box(acc) + }) + }); + + group.bench_function("mul_chain_2048_special_m54p0", |b| { + b.iter(|| { + let mut acc = F2p54m1::one(); + for x in inputs_f2p54m1.iter() { + acc *= *x; + } + black_box(acc) + }) + }); + + group.finish(); +} + +fn bench_mul_isolated(c: &mut Criterion) { + use ark_ff::UniformRand; + + type F13 = Prime128M13M4P0; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let a_fp128 = F13::from_canonical_u128_reduced(rand_u128(&mut rng)); + let b_fp128 = F13::from_canonical_u128_reduced(rand_u128(&mut rng)); + let a_bn254 = BN254Fr::rand(&mut rng); + let b_bn254 = BN254Fr::rand(&mut rng); + + let mut group = c.benchmark_group("field_mul_isolated"); + + group.bench_function("fp128_black_box_only", |b| b.iter(|| black_box(a_fp128))); + + group.bench_function("bn254_black_box_only", |b| b.iter(|| black_box(a_bn254))); + + group.bench_function("fp128_pair_passthrough", |b| { + b.iter(|| { + let x = black_box(a_fp128); + let y = black_box(b_fp128); + black_box((x, y)) + }) + }); + + group.bench_function("bn254_pair_passthrough", |b| { + b.iter(|| { + let x = black_box(a_bn254); + let y = black_box(b_bn254); + black_box((x, y)) + }) + }); + + group.bench_function("fp128_mul_single", |b| { + b.iter(|| { + let x = black_box(a_fp128); + let y = black_box(b_fp128); + black_box(x * y) + }) + }); + + group.bench_function("bn254_mul_single", |b| { + b.iter(|| { + let x = black_box(a_bn254); + let y = black_box(b_bn254); + black_box(x * y) + }) + }); + + let lanes_fp128: [(F13, F13); 8] = std::array::from_fn(|_| { + ( + F13::from_canonical_u128_reduced(rand_u128(&mut rng)), + F13::from_canonical_u128_reduced(rand_u128(&mut rng)), + ) + }); + let lanes_bn254: [(BN254Fr, BN254Fr); 8] = + std::array::from_fn(|_| (BN254Fr::rand(&mut rng), BN254Fr::rand(&mut rng))); + + group.bench_function("fp128_mul_8way_independent", |b| { + b.iter(|| { + let lanes = black_box(&lanes_fp128); + let p0 = lanes[0].0 * lanes[0].1; + let p1 = lanes[1].0 * lanes[1].1; + let p2 = lanes[2].0 * lanes[2].1; + let p3 = lanes[3].0 * lanes[3].1; + let p4 = lanes[4].0 * lanes[4].1; + let p5 = lanes[5].0 * lanes[5].1; + let p6 = lanes[6].0 * lanes[6].1; + let p7 = lanes[7].0 * lanes[7].1; + black_box([p0, p1, p2, p3, p4, p5, p6, p7]) + }) + }); + + group.bench_function("fp128_8way_passthrough", |b| { + b.iter(|| { + let lanes = black_box(&lanes_fp128); + let p0 = lanes[0].0; + let p1 = lanes[1].0; + let p2 = lanes[2].0; + let p3 = lanes[3].0; + let p4 = lanes[4].0; + let p5 = lanes[5].0; + let p6 = lanes[6].0; + let p7 = lanes[7].0; + black_box([p0, p1, p2, p3, p4, p5, p6, p7]) + }) + }); + + group.bench_function("bn254_mul_8way_independent", |b| { + b.iter(|| { + let lanes = black_box(&lanes_bn254); + let p0 = lanes[0].0 * lanes[0].1; + let p1 = lanes[1].0 * lanes[1].1; + let p2 = lanes[2].0 * lanes[2].1; + let p3 = lanes[3].0 * lanes[3].1; + let p4 = lanes[4].0 * lanes[4].1; + let p5 = lanes[5].0 * lanes[5].1; + let p6 = lanes[6].0 * lanes[6].1; + let p7 = lanes[7].0 * lanes[7].1; + black_box([p0, p1, p2, p3, p4, p5, p6, p7]) + }) + }); + + group.bench_function("bn254_8way_passthrough", |b| { + b.iter(|| { + let lanes = black_box(&lanes_bn254); + let p0 = lanes[0].0; + let p1 = lanes[1].0; + let p2 = lanes[2].0; + let p3 = lanes[3].0; + let p4 = lanes[4].0; + let p5 = lanes[5].0; + let p6 = lanes[6].0; + let p7 = lanes[7].0; + black_box([p0, p1, p2, p3, p4, p5, p6, p7]) + }) + }); + + group.finish(); +} + +fn bench_sqr(c: &mut Criterion) { + type F13 = Prime128M13M4P0; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let start = F13::from_canonical_u128_reduced(rand_u128(&mut rng)); + + let mut group = c.benchmark_group("field_sqr"); + + group.bench_function("sqr_chain_2048", |b| { + b.iter(|| { + let mut acc = start; + for _ in 0..2048 { + acc = acc.square(); + } + black_box(acc) + }) + }); + + group.bench_function("mul_self_chain_2048", |b| { + b.iter(|| { + let mut acc = start; + for _ in 0..2048 { + acc = acc * acc; + } + black_box(acc) + }) + }); + + group.finish(); +} + +fn bench_inv(c: &mut Criterion) { + type F13 = Prime128M13M4P0; + + let mut rng = StdRng::seed_from_u64(0x1a2b_3c4d_5e6f_7788); + let inputs: Vec = (0..256) + .map(|_| F13::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + c.bench_function("fp128_inv_or_zero_prime128m13m4p0", |b| { + b.iter(|| { + let mut acc = F13::one(); + for x in inputs.iter() { + acc *= x.inv_or_zero(); + } + black_box(acc) + }) + }); +} + +fn bench_bn254(c: &mut Criterion) { + use ark_ff::UniformRand; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let inputs: Vec = (0..2048).map(|_| BN254Fr::rand(&mut rng)).collect(); + + let mut group = c.benchmark_group("bn254_fr"); + + group.bench_function("mul_add_chain_2048", |b| { + b.iter(|| { + let mut acc = BN254Fr::ONE; + for x in inputs.iter() { + acc = acc * x + acc; + } + black_box(acc) + }) + }); + + group.bench_function("mul_chain_2048", |b| { + b.iter(|| { + let mut acc = BN254Fr::ONE; + for x in inputs.iter() { + acc *= x; + } + black_box(acc) + }) + }); + + group.bench_function("mul_chain_16384", |b| { + b.iter(|| { + let mut acc = BN254Fr::ONE; + for _ in 0..8 { + for x in inputs.iter() { + acc *= x; + } + } + black_box(acc) + }) + }); + + group.bench_function("mul_parallel_1024", |b| { + b.iter(|| { + let mut sum = BN254Fr::ZERO; + for pair in inputs.chunks_exact(2) { + sum += pair[0] * pair[1]; + } + black_box(sum) + }) + }); + + group.bench_function("sqr_chain_2048", |b| { + b.iter(|| { + let mut acc = inputs[0]; + for _ in 0..2048 { + acc.square_in_place(); + } + black_box(acc) + }) + }); + + group.bench_function("inv_256", |b| { + b.iter(|| { + let mut acc = BN254Fr::ONE; + for x in inputs[..256].iter() { + acc *= x.inverse().unwrap_or(BN254Fr::ZERO); + } + black_box(acc) + }) + }); + + group.finish(); +} + +fn bench_packed_fp128_backend(c: &mut Criterion) { + type F = Prime128M13M4P0; + type PF = ::Packing; + let packed_streams = env_usize("HACHI_BENCH_PACKED_STREAMS", 8); + let latency_iters = env_usize("HACHI_BENCH_LATENCY_ITERS", 4096); + let throughput_iters = env_usize("HACHI_BENCH_THROUGHPUT_ITERS", 256); + let stream_iters = env_usize("HACHI_BENCH_STREAM_ITERS", 2048); + let mix_iters = env_usize("HACHI_BENCH_MIX_ITERS", 256); + let mix_muls = env_usize("HACHI_BENCH_MIX_MULS", 3); + let mix_adds = env_usize("HACHI_BENCH_MIX_ADDS", 1); + let mix_subs = env_usize("HACHI_BENCH_MIX_SUBS", 1); + + assert!(packed_streams > 0, "HACHI_BENCH_PACKED_STREAMS must be > 0"); + assert!(latency_iters > 0, "HACHI_BENCH_LATENCY_ITERS must be > 0"); + assert!( + throughput_iters > 0, + "HACHI_BENCH_THROUGHPUT_ITERS must be > 0" + ); + assert!(stream_iters > 0, "HACHI_BENCH_STREAM_ITERS must be > 0"); + assert!(mix_iters > 0, "HACHI_BENCH_MIX_ITERS must be > 0"); + + let muls_per_stream = throughput_iters + 1; + let mix_ops = mix_muls + mix_adds + mix_subs; + assert!(mix_ops > 0, "at least one mix operation must be enabled"); + + let backend = if cfg!(all(target_arch = "aarch64", target_feature = "neon")) { + "aarch64_neon" + } else { + "scalar_fallback" + }; + let mut group = c.benchmark_group(format!("field_packed_backend/{backend}/w{}", PF::WIDTH)); + + let mut rng = StdRng::seed_from_u64(0xd00d_f00d_1122_3344); + let scalar_stream_len = PF::WIDTH * stream_iters; + let lhs: Vec = (0..scalar_stream_len) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let rhs: Vec = (0..scalar_stream_len) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let packed_lhs: Vec = PF::pack_slice(&lhs); + let packed_rhs: Vec = PF::pack_slice(&rhs); + let scalar_latency_inputs: Vec = (0..latency_iters) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let packed_latency_inputs: Vec = (0..latency_iters) + .map(|_| PF::from_fn(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng)))) + .collect(); + + let scalar_streams = packed_streams * PF::WIDTH; + let scalar_lanes: Vec<(F, F)> = (0..scalar_streams) + .map(|_| { + ( + F::from_canonical_u128_reduced(rand_u128(&mut rng)), + F::from_canonical_u128_reduced(rand_u128(&mut rng)), + ) + }) + .collect(); + let packed_lanes: Vec<(PF, PF)> = (0..packed_streams) + .map(|_| { + ( + PF::from_fn(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))), + PF::from_fn(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))), + ) + }) + .collect(); + + group.throughput(Throughput::Elements(scalar_stream_len as u64)); + group.bench_function("scalar_add_stream", |b| { + let mut out = lhs.clone(); + b.iter(|| { + for (dst, src) in out.iter_mut().zip(rhs.iter()) { + *dst += *src; + } + black_box(out[0]) + }) + }); + + group.throughput(Throughput::Elements(scalar_stream_len as u64)); + group.bench_function("packed_add_stream", |b| { + let mut out = packed_lhs.clone(); + b.iter(|| { + for (dst, src) in out.iter_mut().zip(packed_rhs.iter()) { + *dst += *src; + } + black_box(out[0].extract(0)) + }) + }); + + group.throughput(Throughput::Elements(latency_iters as u64)); + group.bench_function("scalar_mul_latency_chain", |b| { + b.iter(|| { + let mut acc = F::one(); + for x in scalar_latency_inputs.iter() { + acc *= *x; + } + black_box(acc) + }) + }); + + group.throughput(Throughput::Elements((latency_iters * PF::WIDTH) as u64)); + group.bench_function("packed_mul_latency_chain", |b| { + b.iter(|| { + let mut acc = PF::broadcast(F::one()); + for x in packed_latency_inputs.iter() { + acc *= *x; + } + black_box(acc.extract(0)) + }) + }); + + group.throughput(Throughput::Elements( + (scalar_streams * muls_per_stream) as u64, + )); + group.bench_function("scalar_mul_throughput_8way", |b| { + b.iter(|| { + let lanes = black_box(&scalar_lanes); + let mut acc: Vec = lanes.iter().map(|(a, b)| *a * *b).collect(); + for _ in 0..throughput_iters { + for (acc_i, lane) in acc.iter_mut().zip(lanes.iter()) { + *acc_i *= lane.0; + } + } + black_box(acc[0]) + }) + }); + + group.throughput(Throughput::Elements( + (packed_streams * muls_per_stream * PF::WIDTH) as u64, + )); + group.bench_function("packed_mul_throughput_8way", |b| { + b.iter(|| { + let lanes = black_box(&packed_lanes); + let mut acc: Vec = lanes.iter().map(|(a, b)| *a * *b).collect(); + for _ in 0..throughput_iters { + for (acc_i, lane) in acc.iter_mut().zip(lanes.iter()) { + *acc_i *= lane.0; + } + } + black_box(acc[0].extract(0)) + }) + }); + + group.throughput(Throughput::Elements( + (scalar_streams * mix_iters * mix_ops) as u64, + )); + group.bench_function("scalar_mix_sumcheck_like", |b| { + b.iter(|| { + let lanes = black_box(&scalar_lanes); + let mut acc: Vec = lanes.iter().map(|(a, b)| *a + *b).collect(); + for _ in 0..mix_iters { + for (acc_i, lane) in acc.iter_mut().zip(lanes.iter()) { + let (x, y) = *lane; + for _ in 0..mix_muls { + *acc_i *= x; + } + for _ in 0..mix_adds { + *acc_i += y; + } + for _ in 0..mix_subs { + *acc_i -= x; + } + } + } + black_box(acc[0]) + }) + }); + + group.throughput(Throughput::Elements( + (packed_streams * PF::WIDTH * mix_iters * mix_ops) as u64, + )); + group.bench_function("packed_mix_sumcheck_like", |b| { + b.iter(|| { + let lanes = black_box(&packed_lanes); + let mut acc: Vec = lanes.iter().map(|(a, b)| *a + *b).collect(); + for _ in 0..mix_iters { + for (acc_i, lane) in acc.iter_mut().zip(lanes.iter()) { + let (x, y) = *lane; + for _ in 0..mix_muls { + *acc_i *= x; + } + for _ in 0..mix_adds { + *acc_i += y; + } + for _ in 0..mix_subs { + *acc_i -= x; + } + } + } + black_box(acc[0].extract(0)) + }) + }); + + group.finish(); +} + +fn bench_fp32_fp64_mul(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(0x3264_3264); + let n = 2048; + + let inputs_24: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_30: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_31: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_32: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_40: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_64: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let mut group = c.benchmark_group("fp32_fp64_mul"); + + macro_rules! chain_bench { + ($name:expr, $ty:ty, $inputs:expr) => { + group.bench_function(concat!($name, "_mul_chain_2048"), |b| { + b.iter(|| { + let mut acc = <$ty>::one(); + for x in $inputs.iter() { + acc *= *x; + } + black_box(acc) + }) + }); + group.bench_function(concat!($name, "_mul_add_chain_2048"), |b| { + b.iter(|| { + let mut acc = <$ty>::one(); + for x in $inputs.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + }; + } + + chain_bench!("fp32_2pow24m3", Pow2Offset24Field, inputs_24); + chain_bench!("fp32_2pow30m35", Pow2Offset30Field, inputs_30); + chain_bench!("fp32_2pow31m19", Pow2Offset31Field, inputs_31); + chain_bench!("fp32_2pow32m99", Pow2Offset32Field, inputs_32); + chain_bench!("fp64_2pow40m195", Pow2Offset40Field, inputs_40); + chain_bench!("fp64_2pow64m59", Pow2Offset64Field, inputs_64); + + group.finish(); +} + +fn bench_widening_ops(c: &mut Criterion) { + type F = Prime128M8M4M1M0; + + let mut rng = StdRng::seed_from_u64(0x01de_be0c_0001); + let a = F::from_canonical_u128_reduced(rand_u128(&mut rng)); + let b = F::from_canonical_u128_reduced(rand_u128(&mut rng)); + let b_u64 = rng.next_u64(); + + let mut group = c.benchmark_group("widening_ops"); + + group.bench_function("mul_wide_u64_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_u64(black_box(b_u64)))) + }); + + group.bench_function("mul_wide_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide(black_box(b)))) + }); + + let limbs3 = [rng.next_u64(), rng.next_u64(), rng.next_u64()]; + let limbs4 = [ + rng.next_u64(), + rng.next_u64(), + rng.next_u64(), + rng.next_u64(), + ]; + + group.bench_function("mul_wide_limbs_3_to_5_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_limbs::<3, 5>(black_box(limbs3)))) + }); + group.bench_function("mul_wide_limbs_3_to_4_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_limbs::<3, 4>(black_box(limbs3)))) + }); + group.bench_function("mul_wide_limbs_4_to_5_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_limbs::<4, 5>(black_box(limbs4)))) + }); + group.bench_function("mul_wide_limbs_4_to_4_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_limbs::<4, 4>(black_box(limbs4)))) + }); + + group.bench_function("full_mul_u64_reduce", |bench| { + bench.iter(|| black_box(black_box(a) * F::from_u64(black_box(b_u64)))) + }); + + group.bench_function("full_mul_reduce", |bench| { + bench.iter(|| black_box(black_box(a) * black_box(b))) + }); + + let wide3 = a.mul_wide_u64(b_u64); + let wide4 = a.mul_wide(b); + let wide5 = { + let mut l = [0u64; 5]; + l[..3].copy_from_slice(&wide3); + l[4] = rng.next_u64() & 0xFF; + l + }; + + group.bench_function("solinas_reduce_3_limbs", |bench| { + bench.iter(|| black_box(F::solinas_reduce(black_box(&wide3)))) + }); + + group.bench_function("solinas_reduce_4_limbs", |bench| { + bench.iter(|| black_box(F::solinas_reduce(black_box(&wide4)))) + }); + + group.bench_function("solinas_reduce_5_limbs", |bench| { + bench.iter(|| black_box(F::solinas_reduce(black_box(&wide5)))) + }); + + group.bench_function("mul_wide_u64_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let y = black_box(b_u64); + black_box(F::solinas_reduce(&x.mul_wide_u64(y))) + }) + }); + + group.bench_function("mul_wide_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let y = black_box(b); + black_box(F::solinas_reduce(&x.mul_wide(y))) + }) + }); + + group.bench_function("mul_wide_limbs_3_to_5_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let m = black_box(limbs3); + black_box(F::solinas_reduce(&x.mul_wide_limbs::<3, 5>(m))) + }) + }); + group.bench_function("mul_wide_limbs_3_to_4_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let m = black_box(limbs3); + black_box(F::solinas_reduce(&x.mul_wide_limbs::<3, 4>(m))) + }) + }); + group.bench_function("mul_wide_limbs_4_to_5_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let m = black_box(limbs4); + black_box(F::solinas_reduce(&x.mul_wide_limbs::<4, 5>(m))) + }) + }); + group.bench_function("mul_wide_limbs_4_to_4_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let m = black_box(limbs4); + black_box(F::solinas_reduce(&x.mul_wide_limbs::<4, 4>(m))) + }) + }); + + group.finish(); +} + +fn bench_accumulator_pattern(c: &mut Criterion) { + type F = Prime128M8M4M1M0; + + let mut rng = StdRng::seed_from_u64(0xacc0_1a70_0002); + let inputs_a: Vec = (0..256) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let inputs_b_u64: Vec = (0..256).map(|_| rng.next_u64()).collect(); + let inputs_b_f: Vec = (0..256) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let mut group = c.benchmark_group("accumulator_pattern"); + + for &n in &[16, 64, 256] { + group.bench_function(format!("eager_mul_u64_{n}"), |bench| { + bench.iter(|| { + let a_s = black_box(&inputs_a[..n]); + let b_s = black_box(&inputs_b_u64[..n]); + let mut acc = F::zero(); + for i in 0..n { + acc += a_s[i] * F::from_u64(b_s[i]); + } + black_box(acc) + }) + }); + + group.bench_function(format!("widening_accum_u64_{n}"), |bench| { + bench.iter(|| { + let a_s = black_box(&inputs_a[..n]); + let b_s = black_box(&inputs_b_u64[..n]); + let mut acc = [0u64; 5]; + for i in 0..n { + let wide = a_s[i].mul_wide_u64(b_s[i]); + let mut carry: u64 = 0; + for j in 0..3 { + let sum = acc[j] as u128 + wide[j] as u128 + carry as u128; + acc[j] = sum as u64; + carry = (sum >> 64) as u64; + } + for item in &mut acc[3..5] { + let sum = *item as u128 + carry as u128; + *item = sum as u64; + carry = (sum >> 64) as u64; + } + } + black_box(F::solinas_reduce(&acc)) + }) + }); + + group.bench_function(format!("eager_mul_full_{n}"), |bench| { + bench.iter(|| { + let a_s = black_box(&inputs_a[..n]); + let b_s = black_box(&inputs_b_f[..n]); + let mut acc = F::zero(); + for i in 0..n { + acc += a_s[i] * b_s[i]; + } + black_box(acc) + }) + }); + + group.bench_function(format!("widening_accum_full_{n}"), |bench| { + bench.iter(|| { + let a_s = black_box(&inputs_a[..n]); + let b_s = black_box(&inputs_b_f[..n]); + let mut acc = [0u64; 6]; + for i in 0..n { + let wide = a_s[i].mul_wide(b_s[i]); + let mut carry: u64 = 0; + for j in 0..4 { + let sum = acc[j] as u128 + wide[j] as u128 + carry as u128; + acc[j] = sum as u64; + carry = (sum >> 64) as u64; + } + for item in &mut acc[4..6] { + let sum = *item as u128 + carry as u128; + *item = sum as u64; + carry = (sum >> 64) as u64; + } + } + black_box(F::solinas_reduce(&acc)) + }) + }); + } + + group.finish(); +} + +fn bench_throughput(c: &mut Criterion) { + let n = 4096u64; + let mut rng = StdRng::seed_from_u64(0xdead_cafe); + + type M31 = Fp32<{ (1u32 << 31) - 1 }>; + + let a24: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b24: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a30: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b30: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let am31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let bm31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a32: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b32: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a40: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b40: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a48: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b48: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a56: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b56: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a64: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b64: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a128: Vec = (0..n) + .map(|_| Prime128M8M4M1M0::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let b128: Vec = (0..n) + .map(|_| Prime128M8M4M1M0::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let mut out24 = vec![Pow2Offset24Field::zero(); n as usize]; + let mut out30 = vec![Pow2Offset30Field::zero(); n as usize]; + let mut out31 = vec![Pow2Offset31Field::zero(); n as usize]; + let mut outm31 = vec![M31::zero(); n as usize]; + let mut out32 = vec![Pow2Offset32Field::zero(); n as usize]; + let mut out40 = vec![Pow2Offset40Field::zero(); n as usize]; + let mut out48 = vec![Pow2Offset48Field::zero(); n as usize]; + let mut out56 = vec![Pow2Offset56Field::zero(); n as usize]; + let mut out64 = vec![Pow2Offset64Field::zero(); n as usize]; + let mut out128 = vec![Prime128M8M4M1M0::zero(); n as usize]; + + let mut group = c.benchmark_group("throughput"); + group.throughput(Throughput::Elements(n)); + + macro_rules! bench_op { + ($name:expr, $a:expr, $b:expr, $out:expr, $op:tt) => { + group.bench_function($name, |bench| { + bench.iter(|| { + let a = black_box(&$a); + let b = black_box(&$b); + let out = &mut $out; + for i in 0..n as usize { + out[i] = a[i] $op b[i]; + } + }) + }); + }; + } + + bench_op!("fp32_24b_mul", a24, b24, out24, *); + bench_op!("fp32_24b_add", a24, b24, out24, +); + bench_op!("fp32_30b_mul", a30, b30, out30, *); + bench_op!("fp32_30b_add", a30, b30, out30, +); + bench_op!("fp32_31b_mul", a31, b31, out31, *); + bench_op!("fp32_31b_add", a31, b31, out31, +); + bench_op!("fp32_m31_mul", am31, bm31, outm31, *); + bench_op!("fp32_m31_add", am31, bm31, outm31, +); + bench_op!("fp32_32b_mul", a32, b32, out32, *); + bench_op!("fp32_32b_add", a32, b32, out32, +); + bench_op!("fp64_40b_mul", a40, b40, out40, *); + bench_op!("fp64_40b_add", a40, b40, out40, +); + bench_op!("fp64_48b_mul", a48, b48, out48, *); + bench_op!("fp64_48b_add", a48, b48, out48, +); + bench_op!("fp64_56b_mul", a56, b56, out56, *); + bench_op!("fp64_56b_add", a56, b56, out56, +); + bench_op!("fp64_64b_mul", a64, b64, out64, *); + bench_op!("fp64_64b_add", a64, b64, out64, +); + bench_op!("fp128_mul", a128, b128, out128, *); + bench_op!("fp128_add", a128, b128, out128, +); + + group.finish(); +} + +fn bench_packed_throughput(c: &mut Criterion) { + use hachi_pcs::algebra::{Fp128Packing, Fp32Packing, Fp64Packing}; + + let n = 4096u64; + let mut rng = StdRng::seed_from_u64(0xbeef_cafe); + + macro_rules! packed_bench { + ($group:expr, $label:expr, $field:ty, $packing:ty, $rng:expr, $n:expr) => {{ + let lhs: Vec<$field> = (0..$n).map(|_| FieldSampling::sample($rng)).collect(); + let rhs: Vec<$field> = (0..$n).map(|_| FieldSampling::sample($rng)).collect(); + let lhs_p = <$packing>::pack_slice(&lhs); + let rhs_p = <$packing>::pack_slice(&rhs); + let mut out_p = vec![<$packing>::broadcast(<$field>::zero()); lhs_p.len()]; + + $group.bench_function(concat!($label, "_packed_mul"), |b| { + b.iter(|| { + let a = black_box(&lhs_p); + let b_v = black_box(&rhs_p); + let out = &mut out_p; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + }) + }); + $group.bench_function(concat!($label, "_packed_add"), |b| { + b.iter(|| { + let a = black_box(&lhs_p); + let b_v = black_box(&rhs_p); + let out = &mut out_p; + for i in 0..out.len() { + out[i] = a[i] + b_v[i]; + } + }) + }); + $group.bench_function(concat!($label, "_packed_sub"), |b| { + b.iter(|| { + let a = black_box(&lhs_p); + let b_v = black_box(&rhs_p); + let out = &mut out_p; + for i in 0..out.len() { + out[i] = a[i] - b_v[i]; + } + }) + }); + }}; + } + + let mut group = c.benchmark_group("packed_throughput"); + group.throughput(Throughput::Elements(n)); + + use hachi_pcs::algebra::fields::pseudo_mersenne::*; + type M31 = Fp32<{ (1u32 << 31) - 1 }>; + + type P24 = Fp32Packing<{ POW2_OFFSET_MODULUS_24 }>; + type P30 = Fp32Packing<{ POW2_OFFSET_MODULUS_30 }>; + type P31 = Fp32Packing<{ POW2_OFFSET_MODULUS_31 }>; + type PM31 = Fp32Packing<{ (1u32 << 31) - 1 }>; + type P32 = Fp32Packing<{ POW2_OFFSET_MODULUS_32 }>; + type P40 = Fp64Packing<{ POW2_OFFSET_MODULUS_40 }>; + type P48 = Fp64Packing<{ POW2_OFFSET_MODULUS_48 }>; + type P56 = Fp64Packing<{ POW2_OFFSET_MODULUS_56 }>; + type P64 = Fp64Packing<{ POW2_OFFSET_MODULUS_64 }>; + type P128 = Fp128Packing<{ POW2_OFFSET_MODULUS_128 }>; + + packed_bench!(group, "fp32_24b", Pow2Offset24Field, P24, &mut rng, n); + packed_bench!(group, "fp32_30b", Pow2Offset30Field, P30, &mut rng, n); + packed_bench!(group, "fp32_31b", Pow2Offset31Field, P31, &mut rng, n); + packed_bench!(group, "fp32_m31", M31, PM31, &mut rng, n); + packed_bench!(group, "fp32_32b", Pow2Offset32Field, P32, &mut rng, n); + packed_bench!(group, "fp64_40b", Pow2Offset40Field, P40, &mut rng, n); + packed_bench!(group, "fp64_48b", Pow2Offset48Field, P48, &mut rng, n); + packed_bench!(group, "fp64_56b", Pow2Offset56Field, P56, &mut rng, n); + packed_bench!(group, "fp64_64b", Pow2Offset64Field, P64, &mut rng, n); + packed_bench!(group, "fp128", Prime128M8M4M1M0, P128, &mut rng, n); + + group.finish(); +} + +#[cfg(feature = "parallel")] +fn bench_parallel_throughput(c: &mut Criterion) { + use hachi_pcs::algebra::{Fp32Packing, Fp64Packing}; + + let profile = env::var("HACHI_BENCH_PAR_PROFILE").unwrap_or_else(|_| "dev".to_string()); + let default_n = match profile.as_str() { + "scale" | "large" => 1 << 20, + "xlarge" => 1 << 22, + _ => 1 << 15, + }; + let n = env_usize("HACHI_BENCH_PAR_N", default_n); + let default_chunk = match profile.as_str() { + "scale" | "large" => 1 << 14, + "xlarge" => 1 << 15, + _ => 1 << 12, + }; + let chunk = env_usize("HACHI_BENCH_PAR_CHUNK", default_chunk); + let threads = env_usize( + "HACHI_BENCH_PAR_THREADS", + thread::available_parallelism() + .map(|v| v.get()) + .unwrap_or(1), + ); + + assert!(threads > 0, "HACHI_BENCH_PAR_THREADS must be > 0"); + assert!(n > 0, "HACHI_BENCH_PAR_N must be > 0"); + assert!(chunk > 0, "HACHI_BENCH_PAR_CHUNK must be > 0"); + assert!(n % 4 == 0, "HACHI_BENCH_PAR_N must be divisible by 4"); + + let pool = ThreadPoolBuilder::new() + .num_threads(threads) + .build() + .expect("failed to build rayon pool"); + + let mut rng = StdRng::seed_from_u64(0xfeed_face); + + let lhs31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let rhs31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let lhs64: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let rhs64: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let lhs128: Vec = (0..n) + .map(|_| Prime128M13M4P0::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let rhs128: Vec = (0..n) + .map(|_| Prime128M13M4P0::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + type P31 = Fp32Packing<{ hachi_pcs::algebra::fields::pseudo_mersenne::POW2_OFFSET_MODULUS_31 }>; + type P64 = Fp64Packing<{ hachi_pcs::algebra::fields::pseudo_mersenne::POW2_OFFSET_MODULUS_64 }>; + type F128 = Prime128M13M4P0; + type P128 = ::Packing; + let chunk31_p = (chunk / P31::WIDTH).max(1); + let chunk64_p = (chunk / P64::WIDTH).max(1); + let chunk128_p = (chunk / P128::WIDTH).max(1); + + let lhs31_p = P31::pack_slice(&lhs31); + let rhs31_p = P31::pack_slice(&rhs31); + let lhs64_p = P64::pack_slice(&lhs64); + let rhs64_p = P64::pack_slice(&rhs64); + let lhs128_p = P128::pack_slice(&lhs128); + let rhs128_p = P128::pack_slice(&rhs128); + + let mut out31 = vec![Pow2Offset31Field::zero(); n]; + let mut out64 = vec![Pow2Offset64Field::zero(); n]; + let mut out128 = vec![F128::zero(); n]; + let mut out31_p = vec![P31::broadcast(Pow2Offset31Field::zero()); lhs31_p.len()]; + let mut out64_p = vec![P64::broadcast(Pow2Offset64Field::zero()); lhs64_p.len()]; + let mut out128_p = vec![P128::broadcast(F128::zero()); lhs128_p.len()]; + + let mut group = c.benchmark_group(format!( + "parallel_throughput/{profile}/t{threads}/n{n}/c{chunk}" + )); + group.throughput(Throughput::Elements(n as u64)); + + group.bench_function("fp32_31b_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs31); + let b_v = black_box(&rhs31); + let out = &mut out31; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0]) + }) + }); + + group.bench_function("fp32_31b_mul_par_zip", |b| { + b.iter(|| { + let a = black_box(&lhs31); + let b_v = black_box(&rhs31); + let out = &mut out31; + pool.install(|| { + out.par_iter_mut() + .zip(a.par_iter()) + .zip(b_v.par_iter()) + .for_each(|((dst, lhs), rhs)| *dst = *lhs * *rhs); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp32_31b_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs31); + let b_v = black_box(&rhs31); + let out = &mut out31; + pool.install(|| { + out.par_chunks_mut(chunk) + .zip(a.par_chunks(chunk)) + .zip(b_v.par_chunks(chunk)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp32_31b_packed_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs31_p); + let b_v = black_box(&rhs31_p); + let out = &mut out31_p; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp32_31b_packed_mul_par_zip", |b| { + b.iter(|| { + let a = black_box(&lhs31_p); + let b_v = black_box(&rhs31_p); + let out = &mut out31_p; + pool.install(|| { + out.par_iter_mut() + .zip(a.par_iter()) + .zip(b_v.par_iter()) + .for_each(|((dst, lhs), rhs)| *dst = *lhs * *rhs); + }); + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp32_31b_packed_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs31_p); + let b_v = black_box(&rhs31_p); + let out = &mut out31_p; + pool.install(|| { + out.par_chunks_mut(chunk31_p) + .zip(a.par_chunks(chunk31_p)) + .zip(b_v.par_chunks(chunk31_p)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp64_64b_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs64); + let b_v = black_box(&rhs64); + let out = &mut out64; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0]) + }) + }); + + group.bench_function("fp64_64b_mul_par_zip", |b| { + b.iter(|| { + let a = black_box(&lhs64); + let b_v = black_box(&rhs64); + let out = &mut out64; + pool.install(|| { + out.par_iter_mut() + .zip(a.par_iter()) + .zip(b_v.par_iter()) + .for_each(|((dst, lhs), rhs)| *dst = *lhs * *rhs); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp64_64b_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs64); + let b_v = black_box(&rhs64); + let out = &mut out64; + pool.install(|| { + out.par_chunks_mut(chunk) + .zip(a.par_chunks(chunk)) + .zip(b_v.par_chunks(chunk)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp64_64b_packed_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs64_p); + let b_v = black_box(&rhs64_p); + let out = &mut out64_p; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp64_64b_packed_mul_par_zip", |b| { + b.iter(|| { + let a = black_box(&lhs64_p); + let b_v = black_box(&rhs64_p); + let out = &mut out64_p; + pool.install(|| { + out.par_iter_mut() + .zip(a.par_iter()) + .zip(b_v.par_iter()) + .for_each(|((dst, lhs), rhs)| *dst = *lhs * *rhs); + }); + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp64_64b_packed_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs64_p); + let b_v = black_box(&rhs64_p); + let out = &mut out64_p; + pool.install(|| { + out.par_chunks_mut(chunk64_p) + .zip(a.par_chunks(chunk64_p)) + .zip(b_v.par_chunks(chunk64_p)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp128_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs128); + let b_v = black_box(&rhs128); + let out = &mut out128; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0]) + }) + }); + + group.bench_function("fp128_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs128); + let b_v = black_box(&rhs128); + let out = &mut out128; + pool.install(|| { + out.par_chunks_mut(chunk) + .zip(a.par_chunks(chunk)) + .zip(b_v.par_chunks(chunk)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp128_packed_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs128_p); + let b_v = black_box(&rhs128_p); + let out = &mut out128_p; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp128_packed_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs128_p); + let b_v = black_box(&rhs128_p); + let out = &mut out128_p; + pool.install(|| { + out.par_chunks_mut(chunk128_p) + .zip(a.par_chunks(chunk128_p)) + .zip(b_v.par_chunks(chunk128_p)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0].extract(0)) + }) + }); + + group.finish(); +} + +#[cfg(not(feature = "parallel"))] +fn bench_parallel_throughput(_: &mut Criterion) {} + +fn bench_packed_sumcheck_mix(c: &mut Criterion) { + use hachi_pcs::algebra::{Fp128Packing, Fp32Packing, Fp64Packing}; + + let n = 4096u64; + let mut rng = StdRng::seed_from_u64(0xface_bead); + + macro_rules! sumcheck_bench { + ($group:expr, $label:expr, $field:ty, $packing:ty, $rng:expr, $n:expr) => {{ + let eq: Vec<$field> = (0..$n).map(|_| FieldSampling::sample($rng)).collect(); + let poly: Vec<$field> = (0..$n).map(|_| FieldSampling::sample($rng)).collect(); + let eq_p = <$packing>::pack_slice(&eq); + let poly_p = <$packing>::pack_slice(&poly); + let mut acc = <$packing>::broadcast(<$field>::zero()); + + $group.bench_function(concat!($label, "_packed_macc"), |b| { + b.iter(|| { + let e = black_box(&eq_p); + let p_v = black_box(&poly_p); + acc = <$packing>::broadcast(<$field>::zero()); + for i in 0..e.len() { + acc += e[i] * p_v[i]; + } + black_box(acc) + }) + }); + }}; + } + + let mut group = c.benchmark_group("packed_sumcheck_mix"); + group.throughput(Throughput::Elements(n)); + + use hachi_pcs::algebra::fields::pseudo_mersenne::*; + type M31 = Fp32<{ (1u32 << 31) - 1 }>; + + type P24 = Fp32Packing<{ POW2_OFFSET_MODULUS_24 }>; + type P30 = Fp32Packing<{ POW2_OFFSET_MODULUS_30 }>; + type P31 = Fp32Packing<{ POW2_OFFSET_MODULUS_31 }>; + type PM31 = Fp32Packing<{ (1u32 << 31) - 1 }>; + type P32 = Fp32Packing<{ POW2_OFFSET_MODULUS_32 }>; + type P40 = Fp64Packing<{ POW2_OFFSET_MODULUS_40 }>; + type P48 = Fp64Packing<{ POW2_OFFSET_MODULUS_48 }>; + type P56 = Fp64Packing<{ POW2_OFFSET_MODULUS_56 }>; + type P64 = Fp64Packing<{ POW2_OFFSET_MODULUS_64 }>; + type P128 = Fp128Packing<{ POW2_OFFSET_MODULUS_128 }>; + + sumcheck_bench!(group, "fp32_24b", Pow2Offset24Field, P24, &mut rng, n); + sumcheck_bench!(group, "fp32_30b", Pow2Offset30Field, P30, &mut rng, n); + sumcheck_bench!(group, "fp32_31b", Pow2Offset31Field, P31, &mut rng, n); + sumcheck_bench!(group, "fp32_m31", M31, PM31, &mut rng, n); + sumcheck_bench!(group, "fp32_32b", Pow2Offset32Field, P32, &mut rng, n); + sumcheck_bench!(group, "fp64_40b", Pow2Offset40Field, P40, &mut rng, n); + sumcheck_bench!(group, "fp64_48b", Pow2Offset48Field, P48, &mut rng, n); + sumcheck_bench!(group, "fp64_56b", Pow2Offset56Field, P56, &mut rng, n); + sumcheck_bench!(group, "fp64_64b", Pow2Offset64Field, P64, &mut rng, n); + sumcheck_bench!(group, "fp128", Prime128M8M4M1M0, P128, &mut rng, n); + + group.finish(); +} + +criterion_group!( + field_arith, + bench_mul, + bench_mul_only, + bench_mul_isolated, + bench_sqr, + bench_inv, + bench_packed_fp128_backend, + bench_bn254, + bench_fp32_fp64_mul, + bench_widening_ops, + bench_accumulator_pattern, + bench_throughput, + bench_packed_throughput, + bench_packed_sumcheck_mix, + bench_parallel_throughput +); +criterion_main!(field_arith); diff --git a/benches/fp64_reduce_probe.rs b/benches/fp64_reduce_probe.rs new file mode 100644 index 00000000..073544a9 --- /dev/null +++ b/benches/fp64_reduce_probe.rs @@ -0,0 +1,118 @@ +#![allow(missing_docs)] + +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; + +const P40: u64 = hachi_pcs::algebra::fields::pseudo_mersenne::POW2_OFFSET_MODULUS_40; +const P64: u64 = hachi_pcs::algebra::fields::pseudo_mersenne::POW2_OFFSET_MODULUS_64; +const C40: u64 = (1u64 << 40) - P40; // 195 +const C64: u64 = 0u64.wrapping_sub(P64); // 59 +const MASK40: u64 = (1u64 << 40) - 1; +const MASK64_U128: u128 = u64::MAX as u128; + +#[inline(always)] +fn mul_c40_split(x: u64) -> u64 { + let c = C40 as u32; + let x_lo = x as u32; + let x_hi = (x >> 32) as u32; + (c as u64 * x_lo as u64).wrapping_add((c as u64 * x_hi as u64) << 32) +} + +#[inline(always)] +fn reduce40_split(lo: u64, hi: u64) -> u64 { + let high = (lo >> 40) | (hi << 24); + let f1 = (lo & MASK40).wrapping_add(mul_c40_split(high)); + let f2 = (f1 & MASK40).wrapping_add(mul_c40_split(f1 >> 40)); + let reduced = f2.wrapping_sub(P40); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P40) +} + +#[inline(always)] +fn reduce40_direct(lo: u64, hi: u64) -> u64 { + let high = (lo >> 40) | (hi << 24); + let f1 = (lo & MASK40).wrapping_add(C40.wrapping_mul(high)); + let f2 = (f1 & MASK40).wrapping_add(C40.wrapping_mul(f1 >> 40)); + let reduced = f2.wrapping_sub(P40); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P40) +} + +#[inline(always)] +fn reduce64(lo: u64, hi: u64) -> u64 { + let f1 = (lo as u128) + (C64 as u128) * (hi as u128); + let f2 = (f1 & MASK64_U128) + (C64 as u128) * ((f1 >> 64) as u64 as u128); + let reduced = f2.wrapping_sub(P64 as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P64 as u128)) as u64 +} + +#[inline(always)] +fn next_u64(state: &mut u64) -> u64 { + let mut x = *state; + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + *state = x; + x +} + +fn bench_fp64_reduce_probe(c: &mut Criterion) { + let n = 1 << 13; + let mut seed = 0x9e37_79b9_7f4a_7c15u64; + + let mut pairs40 = Vec::with_capacity(n); + let mut pairs64 = Vec::with_capacity(n); + for _ in 0..n { + let a40 = next_u64(&mut seed) % P40; + let b40 = next_u64(&mut seed) % P40; + let x40 = (a40 as u128) * (b40 as u128); + pairs40.push((x40 as u64, (x40 >> 64) as u64)); + + let a64 = next_u64(&mut seed); + let b64 = next_u64(&mut seed); + let x64 = (a64 as u128) * (b64 as u128); + pairs64.push((x64 as u64, (x64 >> 64) as u64)); + } + + for &(lo, hi) in &pairs40 { + assert_eq!(reduce40_split(lo, hi), reduce40_direct(lo, hi)); + } + + let mut group = c.benchmark_group("fp64_reduce_probe"); + group.throughput(Throughput::Elements(n as u64)); + + group.bench_function("reduce40_split", |b| { + b.iter(|| { + let mut acc = 0u64; + for &(lo, hi) in black_box(&pairs40) { + acc ^= reduce40_split(lo, hi); + } + black_box(acc) + }) + }); + + group.bench_function("reduce40_direct", |b| { + b.iter(|| { + let mut acc = 0u64; + for &(lo, hi) in black_box(&pairs40) { + acc ^= reduce40_direct(lo, hi); + } + black_box(acc) + }) + }); + + group.bench_function("reduce64", |b| { + b.iter(|| { + let mut acc = 0u64; + for &(lo, hi) in black_box(&pairs64) { + acc ^= reduce64(lo, hi); + } + black_box(acc) + }) + }); + + group.finish(); +} + +criterion_group!(fp64_reduce_probe, bench_fp64_reduce_probe); +criterion_main!(fp64_reduce_probe); diff --git a/benches/hachi_e2e.rs b/benches/hachi_e2e.rs new file mode 100644 index 00000000..cb95c607 --- /dev/null +++ b/benches/hachi_e2e.rs @@ -0,0 +1,420 @@ +#![allow(missing_docs)] + +use criterion::measurement::WallTime; +use criterion::{black_box, criterion_group, BatchSize, BenchmarkGroup, Criterion}; +use hachi_pcs::algebra::poly::multilinear_eval; +use hachi_pcs::algebra::Fp128; +use hachi_pcs::protocol::commitment::{ + Fp128FullCommitmentConfig, Fp128LogBasisCommitmentConfig, Fp128OneHotCommitmentConfig, +}; +use hachi_pcs::protocol::commitment_scheme::HachiCommitmentScheme; +use hachi_pcs::protocol::hachi_poly_ops::{DensePoly, OneHotPoly}; +use hachi_pcs::protocol::transcript::Blake2bTranscript; +use hachi_pcs::protocol::CommitmentConfig; +use hachi_pcs::{BasisMode, CanonicalField, CommitmentScheme, FromSmallInt, Transcript}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::time::Duration; + +type F = Fp128<0xfffffffffffffffffffffffffffffeed>; + +fn make_dense_evals(nv: usize) -> Vec { + let mut rng = StdRng::seed_from_u64(0xdead_beef); + let len = 1usize << nv; + let decomp = Cfg::decomposition(); + if decomp.log_commit_bound >= 128 { + (0..len) + .map(|_| F::from_canonical_u128_reduced(rng.gen::())) + .collect() + } else { + let half_bound = 1i64 << (decomp.log_commit_bound.min(62) - 1); + (0..len) + .map(|_| F::from_i64(rng.gen_range(-half_bound..half_bound))) + .collect() + } +} + +fn random_point(nv: usize) -> Vec { + let mut rng = StdRng::seed_from_u64(0xcafe_babe); + (0..nv) + .map(|_| F::from_canonical_u128_reduced(rng.gen::())) + .collect() +} + +fn configure_group(group: &mut BenchmarkGroup<'_, WallTime>, nv: usize) { + if nv >= 20 { + group.sample_size(10); + group.measurement_time(Duration::from_secs(30)); + } +} + +fn bench_dense_phases( + c: &mut Criterion, + label: &str, + nv: usize, +) { + let layout = Cfg::commitment_layout(nv).expect("benchmark layout"); + let evals = make_dense_evals::(nv); + let poly = DensePoly::::from_field_evals(nv, &evals).unwrap(); + let pt = random_point(nv); + let opening = multilinear_eval(&evals, &pt).unwrap(); + + let mut group = c.benchmark_group(format!("hachi/{label}/nv{nv}")); + configure_group(&mut group, nv); + + group.bench_function("setup", |b| { + b.iter(|| { + black_box( + as CommitmentScheme>::setup_prover(black_box( + nv, + )), + ) + }) + }); + + let setup = as CommitmentScheme>::setup_prover(nv); + + group.bench_function("commit", |b| { + b.iter(|| { + black_box( + as CommitmentScheme>::commit( + black_box(&poly), + black_box(&setup), + black_box(&layout), + ) + .unwrap(), + ) + }) + }); + + let (commitment, hint) = + as CommitmentScheme>::commit(&poly, &setup, &layout) + .unwrap(); + + group.bench_function("prove", |b| { + b.iter_batched( + || hint.clone(), + |h| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + black_box( + as CommitmentScheme>::prove( + &setup, + &poly, + &pt, + h, + &mut transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(), + ) + }, + BatchSize::LargeInput, + ) + }); + + let verifier_setup = + as CommitmentScheme>::setup_verifier(&setup); + let mut prover_transcript = Blake2bTranscript::::new(b"bench"); + let proof = as CommitmentScheme>::prove( + &setup, + &poly, + &pt, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + group.bench_function("verify", |b| { + b.iter(|| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + black_box(&proof), + black_box(&verifier_setup), + &mut transcript, + black_box(&pt), + black_box(&opening), + black_box(&commitment), + BasisMode::Lagrange, + black_box(&layout), + ) + .unwrap(); + }) + }); + + group.bench_function("e2e", |b| { + b.iter(|| { + let (cm, h) = as CommitmentScheme>::commit( + &poly, &setup, &layout, + ) + .unwrap(); + let mut pt_tr = Blake2bTranscript::::new(b"bench"); + let pf = as CommitmentScheme>::prove( + &setup, + &poly, + &pt, + h, + &mut pt_tr, + &cm, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + let mut vt_tr = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + &pf, + &verifier_setup, + &mut vt_tr, + &pt, + &opening, + &cm, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + black_box(()) + }) + }); + + group.finish(); +} + +fn bench_onehot_phases( + c: &mut Criterion, + label: &str, + nv: usize, +) { + let layout = Cfg::commitment_layout(nv).expect("benchmark layout"); + let total_ring = layout.num_blocks * layout.block_len; + let onehot_k = D; + + let mut rng = StdRng::seed_from_u64(0xbeef_cafe); + let indices: Vec> = (0..total_ring) + .map(|_| Some(rng.gen_range(0..onehot_k))) + .collect(); + + let onehot_poly = + OneHotPoly::::new(onehot_k, indices.clone(), layout.r_vars, layout.m_vars).unwrap(); + + let dense_evals: Vec = { + let mut evals = vec![F::from_u64(0); total_ring * onehot_k]; + for (ci, opt_idx) in indices.iter().enumerate() { + if let Some(idx) = opt_idx { + evals[ci * onehot_k + idx] = F::from_u64(1); + } + } + evals + }; + let dense_poly = DensePoly::::from_field_evals(nv, &dense_evals).unwrap(); + let pt = random_point(nv); + let opening = multilinear_eval(&dense_evals, &pt).unwrap(); + + let setup = as CommitmentScheme>::setup_prover(nv); + + let mut group = c.benchmark_group(format!("hachi/{label}/nv{nv}")); + configure_group(&mut group, nv); + + group.bench_function("commit_onehot", |b| { + b.iter(|| { + black_box( + as CommitmentScheme>::commit( + black_box(&onehot_poly), + black_box(&setup), + black_box(&layout), + ) + .unwrap(), + ) + }) + }); + + let (commitment, hint) = as CommitmentScheme>::commit( + &onehot_poly, + &setup, + &layout, + ) + .unwrap(); + + group.bench_function("prove", |b| { + b.iter_batched( + || hint.clone(), + |h| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + black_box( + as CommitmentScheme>::prove( + &setup, + &dense_poly, + &pt, + h, + &mut transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(), + ) + }, + BatchSize::LargeInput, + ) + }); + + let verifier_setup = + as CommitmentScheme>::setup_verifier(&setup); + let mut prover_transcript = Blake2bTranscript::::new(b"bench"); + let proof = as CommitmentScheme>::prove( + &setup, + &dense_poly, + &pt, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + group.bench_function("verify", |b| { + b.iter(|| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + black_box(&proof), + black_box(&verifier_setup), + &mut transcript, + black_box(&pt), + black_box(&opening), + black_box(&commitment), + BasisMode::Lagrange, + black_box(&layout), + ) + .unwrap(); + }) + }); + + group.bench_function("e2e", |b| { + b.iter(|| { + let (cm, h) = as CommitmentScheme>::commit( + &onehot_poly, + &setup, + &layout, + ) + .unwrap(); + let mut pt_tr = Blake2bTranscript::::new(b"bench"); + let pf = as CommitmentScheme>::prove( + &setup, + &dense_poly, + &pt, + h, + &mut pt_tr, + &cm, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + let mut vt_tr = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + &pf, + &verifier_setup, + &mut vt_tr, + &pt, + &opening, + &cm, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + black_box(()) + }) + }); + + group.finish(); +} + +fn bench_full_nv15(c: &mut Criterion) { + bench_dense_phases::<{ Fp128FullCommitmentConfig::D }, Fp128FullCommitmentConfig>( + c, "full", 15, + ); +} +fn bench_full_nv20(c: &mut Criterion) { + bench_dense_phases::<{ Fp128FullCommitmentConfig::D }, Fp128FullCommitmentConfig>( + c, "full", 20, + ); +} +fn bench_full_nv25(c: &mut Criterion) { + bench_dense_phases::<{ Fp128FullCommitmentConfig::D }, Fp128FullCommitmentConfig>( + c, "full", 25, + ); +} + +fn bench_onehot_nv15(c: &mut Criterion) { + bench_onehot_phases::<{ Fp128OneHotCommitmentConfig::D }, Fp128OneHotCommitmentConfig>( + c, "onehot", 15, + ); +} +fn bench_onehot_nv20(c: &mut Criterion) { + bench_onehot_phases::<{ Fp128OneHotCommitmentConfig::D }, Fp128OneHotCommitmentConfig>( + c, "onehot", 20, + ); +} +fn bench_onehot_nv25(c: &mut Criterion) { + bench_onehot_phases::<{ Fp128OneHotCommitmentConfig::D }, Fp128OneHotCommitmentConfig>( + c, "onehot", 25, + ); +} + +fn bench_logbasis_nv15(c: &mut Criterion) { + bench_dense_phases::<{ Fp128LogBasisCommitmentConfig::D }, Fp128LogBasisCommitmentConfig>( + c, "logbasis", 15, + ); +} +fn bench_logbasis_nv20(c: &mut Criterion) { + bench_dense_phases::<{ Fp128LogBasisCommitmentConfig::D }, Fp128LogBasisCommitmentConfig>( + c, "logbasis", 20, + ); +} +fn bench_logbasis_nv25(c: &mut Criterion) { + bench_dense_phases::<{ Fp128LogBasisCommitmentConfig::D }, Fp128LogBasisCommitmentConfig>( + c, "logbasis", 25, + ); +} + +criterion_group!( + hachi_benches, + bench_full_nv15, + bench_full_nv20, + bench_full_nv25, + bench_onehot_nv15, + bench_onehot_nv20, + bench_onehot_nv25, + bench_logbasis_nv15, + bench_logbasis_nv20, + bench_logbasis_nv25, +); + +/// Set `HACHI_PARALLEL=0` to run benchmarks single-threaded. +fn main() { + #[cfg(feature = "parallel")] + { + let num_threads = if std::env::var("HACHI_PARALLEL") + .map(|v| v == "0") + .unwrap_or(false) + { + eprintln!("HACHI_PARALLEL=0: running single-threaded"); + 1 + } else { + 0 + }; + rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .stack_size(64 * 1024 * 1024) + .build_global() + .ok(); + } + + hachi_benches(); + criterion::Criterion::default() + .configure_from_args() + .final_summary(); +} diff --git a/benches/norm_sumcheck.rs b/benches/norm_sumcheck.rs new file mode 100644 index 00000000..266bc880 --- /dev/null +++ b/benches/norm_sumcheck.rs @@ -0,0 +1,203 @@ +#![allow(missing_docs)] + +use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use hachi_pcs::algebra::Fp128; +use hachi_pcs::protocol::sumcheck::norm_sumcheck::NormSumcheckProver; +use hachi_pcs::protocol::sumcheck::split_eq::GruenSplitEq; +use hachi_pcs::protocol::sumcheck::{ + fold_evals_in_place, prove_sumcheck, range_check_eval, SumcheckInstanceProver, UniPoly, +}; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::Blake2bTranscript; +use hachi_pcs::{FieldCore, FromSmallInt, Transcript}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; +use std::time::Duration; + +type F = Fp128<0xfffffffffffffffffffffffffffffeed>; + +/// Baseline prover keeps the pre-dispatch point-eval kernel for apples-to-apples benchmarks. +/// It is intentionally local to this bench and should not be used in production code. +struct BaselineNormSumcheckProver { + split_eq: GruenSplitEq, + w_table: Vec, + num_vars: usize, + b: usize, +} + +impl BaselineNormSumcheckProver { + fn new(tau: &[E], w_evals: Vec, b: usize) -> Self { + let num_vars = tau.len(); + assert_eq!(w_evals.len(), 1 << num_vars); + Self { + split_eq: GruenSplitEq::new(tau), + w_table: w_evals, + num_vars, + b, + } + } +} + +impl SumcheckInstanceProver for BaselineNormSumcheckProver { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 2 * self.b + } + + fn input_claim(&self) -> E { + E::zero() + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + let half = self.w_table.len() / 2; + let degree_q = 2 * self.b - 1; + let num_points_q = degree_q + 1; + + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + let b = self.b; + + #[cfg(feature = "parallel")] + let q_evals = { + (0..half) + .into_par_iter() + .fold( + || vec![E::zero(); num_points_q], + |mut evals, j| { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + for (t, eval) in evals.iter_mut().enumerate() { + let t_e = E::from_u64(t as u64); + let w_t = w_0 + t_e * (w_1 - w_0); + *eval += eq_rem * range_check_eval(w_t, b); + } + evals + }, + ) + .reduce( + || vec![E::zero(); num_points_q], + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai += *bi; + } + a + }, + ) + }; + #[cfg(not(feature = "parallel"))] + let q_evals = { + let mut q_evals = vec![E::zero(); num_points_q]; + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + for (t, eval) in q_evals.iter_mut().enumerate() { + let t_e = E::from_u64(t as u64); + let w_t = w_0 + t_e * (w_1 - w_0); + *eval += eq_rem * range_check_eval(w_t, b); + } + } + q_evals + }; + + let q_poly = UniPoly::from_evals(&q_evals); + self.split_eq.gruen_mul(&q_poly) + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + self.split_eq.bind(r); + fold_evals_in_place(&mut self.w_table, r); + } +} + +#[derive(Clone)] +struct NormCase { + num_vars: usize, + b: usize, + tau: Vec, + w_evals: Vec, +} + +fn build_case(num_vars: usize, b: usize, seed: u64) -> NormCase { + let mut rng = StdRng::seed_from_u64(seed); + let n = 1usize << num_vars; + let tau: Vec = (0..num_vars) + .map(|_| F::from_u64(rng.gen_range(0u64..(1u64 << 24)))) + .collect(); + let w_evals: Vec = (0..n) + .map(|_| F::from_u64(rng.gen_range(0u64..(1u64 << 24)))) + .collect(); + NormCase { + num_vars, + b, + tau, + w_evals, + } +} + +fn bench_norm_sumcheck(c: &mut Criterion) { + let cases = [ + build_case(10, 4, 0xA11CE001), + build_case(10, 8, 0xA11CE002), + build_case(14, 4, 0xA11CE003), + build_case(14, 8, 0xA11CE004), + build_case(14, 16, 0xA11CE005), + build_case(18, 8, 0xA11CE006), + ]; + + let mut group = c.benchmark_group("norm_sumcheck"); + group.warm_up_time(Duration::from_secs(8)); + group.measurement_time(Duration::from_secs(24)); + group.sample_size(35); + + for case in &cases { + let case_tag = format!("nv{}_b{}", case.num_vars, case.b); + group.bench_function(BenchmarkId::new("baseline", &case_tag), |bencher| { + bencher.iter_batched( + || BaselineNormSumcheckProver::new(&case.tau, case.w_evals.clone(), case.b), + |mut prover| { + let mut transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + black_box( + prove_sumcheck::(&mut prover, &mut transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(), + ) + }, + BatchSize::SmallInput, + ); + }); + + group.bench_function(BenchmarkId::new("dispatched", &case_tag), |bencher| { + bencher.iter_batched( + || NormSumcheckProver::new(&case.tau, case.w_evals.clone(), case.b), + |mut prover| { + let mut transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + black_box( + prove_sumcheck::(&mut prover, &mut transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(), + ) + }, + BatchSize::SmallInput, + ); + }); + } + + group.finish(); +} + +criterion_group!(benches, bench_norm_sumcheck); +criterion_main!(benches); diff --git a/benches/ring_ntt.rs b/benches/ring_ntt.rs new file mode 100644 index 00000000..aa2ac9ff --- /dev/null +++ b/benches/ring_ntt.rs @@ -0,0 +1,68 @@ +#![allow(missing_docs)] + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use hachi_pcs::algebra::ntt::butterfly::{forward_ntt, inverse_ntt, NttTwiddles}; +use hachi_pcs::algebra::tables::{q32_garner, Q32_MODULUS, Q32_NUM_PRIMES, Q32_PRIMES}; +use hachi_pcs::algebra::{CyclotomicCrtNtt, CyclotomicRing, Fp64, MontCoeff}; +use hachi_pcs::FromSmallInt; + +type F = Fp64<{ Q32_MODULUS }>; +type R = CyclotomicRing; +type N = CyclotomicCrtNtt; + +fn sample_ring(seed: u64) -> R { + let coeffs = std::array::from_fn(|i| { + let x = seed + .wrapping_mul(31) + .wrapping_add((i as u64).wrapping_mul(17)); + F::from_u64(x % Q32_MODULUS) + }); + R::from_coefficients(coeffs) +} + +fn bench_ring_schoolbook_mul(c: &mut Criterion) { + let lhs = sample_ring(3); + let rhs = sample_ring(11); + c.bench_function("ring_schoolbook_mul_d64", |b| { + b.iter(|| black_box(lhs) * black_box(rhs)) + }); +} + +fn bench_ntt_single_prime_round_trip(c: &mut Criterion) { + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + let base: [MontCoeff; 64] = + std::array::from_fn(|i| prime.from_canonical(((i * 5 + 7) as i16) % prime.p)); + + c.bench_function("ntt_single_prime_forward_inverse_d64", |b| { + b.iter(|| { + let mut a = base; + forward_ntt(&mut a, prime, &tw); + inverse_ntt(&mut a, prime, &tw); + black_box(a) + }) + }); +} + +fn bench_crt_round_trip(c: &mut Criterion) { + let ring = sample_ring(19); + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + let garner = q32_garner(); + + c.bench_function("ring_ntt_crt_round_trip_d64_k6", |b| { + b.iter(|| { + let ntt = N::from_ring(black_box(&ring), &Q32_PRIMES, &twiddles); + let back: R = ntt.to_ring(&Q32_PRIMES, &twiddles, &garner); + black_box(back) + }) + }); +} + +criterion_group!( + ring_ntt, + bench_ring_schoolbook_mul, + bench_ntt_single_prime_round_trip, + bench_crt_round_trip +); +criterion_main!(ring_ntt); diff --git a/examples/codegen_probe_special.rs b/examples/codegen_probe_special.rs new file mode 100644 index 00000000..bc6d5b50 --- /dev/null +++ b/examples/codegen_probe_special.rs @@ -0,0 +1,143 @@ +#![allow(missing_docs)] + +//! Codegen probe for packed/scalar Fp64 multiply kernels. +//! +//! Build with: +//! `cargo rustc --example codegen_probe_special --release -- --emit=asm` + +use hachi_pcs::algebra::fields::pseudo_mersenne::{POW2_OFFSET_MODULUS_40, POW2_OFFSET_MODULUS_64}; +use hachi_pcs::algebra::{Fp64, Fp64Packing, PackedValue}; +use hachi_pcs::CanonicalField; + +const MASK40: u64 = (1u64 << 40) - 1; +const P40: u64 = POW2_OFFSET_MODULUS_40; +const C40: u64 = (1u64 << 40) - P40; // 195 +const P64: u64 = POW2_OFFSET_MODULUS_64; +const C64: u64 = 0u64.wrapping_sub(P64); // 59 + +#[inline(always)] +fn mul_c40_split(x: u64) -> u64 { + let c = C40 as u32; + let x_lo = x as u32; + let x_hi = (x >> 32) as u32; + (c as u64 * x_lo as u64).wrapping_add((c as u64 * x_hi as u64) << 32) +} + +#[inline(always)] +fn mul_c40_shiftadd(x: u64) -> u64 { + // 195x = (128 + 64 + 2 + 1) * x + (x << 7) + .wrapping_add(x << 6) + .wrapping_add(x << 1) + .wrapping_add(x) +} + +#[inline(always)] +fn reduce40_with_mulc(lo: u64, hi: u64, mulc: fn(u64) -> u64) -> u64 { + let high = (lo >> 40) | (hi << 24); + let f1 = (lo & MASK40).wrapping_add(mulc(high)); + let f2 = (f1 & MASK40).wrapping_add(mulc(f1 >> 40)); + let reduced = f2.wrapping_sub(P40); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P40) +} + +#[inline(always)] +fn reduce64(lo: u64, hi: u64) -> u64 { + let f1 = (lo as u128) + (C64 as u128) * (hi as u128); + let f2 = (f1 as u64 as u128) + (C64 as u128) * ((f1 >> 64) as u64 as u128); + let reduced = f2.wrapping_sub(P64 as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P64 as u128)) as u64 +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_reduce40_split(lo: u64, hi: u64) -> u64 { + reduce40_with_mulc(lo, hi, mul_c40_split) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_reduce40_shiftadd(lo: u64, hi: u64) -> u64 { + reduce40_with_mulc(lo, hi, mul_c40_shiftadd) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_reduce64(lo: u64, hi: u64) -> u64 { + reduce64(lo, hi) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_packed_fp64_40_mul(a0: u64, a1: u64, b0: u64, b1: u64) -> u64 { + type F = Fp64<{ POW2_OFFSET_MODULUS_40 }>; + type PF = Fp64Packing<{ POW2_OFFSET_MODULUS_40 }>; + + let a = PF::from_fn(|i| { + if i == 0 { + F::from_canonical_u64(a0) + } else { + F::from_canonical_u64(a1) + } + }); + let b = PF::from_fn(|i| { + if i == 0 { + F::from_canonical_u64(b0) + } else { + F::from_canonical_u64(b1) + } + }); + let c = a * b; + (c.extract(0).to_canonical_u128() as u64) ^ (c.extract(1).to_canonical_u128() as u64) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_packed_fp64_64_mul(a0: u64, a1: u64, b0: u64, b1: u64) -> u64 { + type F = Fp64<{ POW2_OFFSET_MODULUS_64 }>; + type PF = Fp64Packing<{ POW2_OFFSET_MODULUS_64 }>; + + let a = PF::from_fn(|i| { + if i == 0 { + F::from_canonical_u64(a0) + } else { + F::from_canonical_u64(a1) + } + }); + let b = PF::from_fn(|i| { + if i == 0 { + F::from_canonical_u64(b0) + } else { + F::from_canonical_u64(b1) + } + }); + let c = a * b; + (c.extract(0).to_canonical_u128() as u64) ^ (c.extract(1).to_canonical_u128() as u64) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_scalar_fp64_40_mul(a: u64, b: u64) -> u64 { + type F = Fp64<{ POW2_OFFSET_MODULUS_40 }>; + (F::from_canonical_u64(a) * F::from_canonical_u64(b)).to_canonical_u128() as u64 +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_scalar_fp64_64_mul(a: u64, b: u64) -> u64 { + type F = Fp64<{ POW2_OFFSET_MODULUS_64 }>; + (F::from_canonical_u64(a) * F::from_canonical_u64(b)).to_canonical_u128() as u64 +} + +fn main() { + let x = probe_packed_fp64_40_mul(1, 2, 3, 4) + ^ probe_packed_fp64_64_mul(5, 6, 7, 8) + ^ probe_scalar_fp64_40_mul(9, 10) + ^ probe_scalar_fp64_64_mul(11, 12) + ^ probe_reduce40_split(13, 14) + ^ probe_reduce40_shiftadd(15, 16) + ^ probe_reduce64(17, 18); + std::hint::black_box(x); +} diff --git a/examples/profile.rs b/examples/profile.rs new file mode 100644 index 00000000..6fef67f4 --- /dev/null +++ b/examples/profile.rs @@ -0,0 +1,272 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::poly::multilinear_eval; +use hachi_pcs::algebra::Fp128; +use hachi_pcs::primitives::serialization::Compress; +use hachi_pcs::protocol::commitment::{ + Fp128FullCommitmentConfig, Fp128LogBasisCommitmentConfig, Fp128OneHotCommitmentConfig, + HachiCommitmentLayout, +}; +use hachi_pcs::protocol::commitment_scheme::HachiCommitmentScheme; +use hachi_pcs::protocol::hachi_poly_ops::{DensePoly, OneHotPoly}; +use hachi_pcs::protocol::proof::HachiProof; +use hachi_pcs::protocol::transcript::Blake2bTranscript; +use hachi_pcs::protocol::CommitmentConfig; +use hachi_pcs::{ + BasisMode, CanonicalField, CommitmentScheme, FromSmallInt, HachiPolyOps, HachiSerialize, + Transcript, +}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::env; +use std::fs; +use std::time::{Instant, SystemTime, UNIX_EPOCH}; +use tracing_chrome::ChromeLayerBuilder; +use tracing_subscriber::prelude::*; + +type F = Fp128<0xfffffffffffffffffffffffffffffeed>; + +fn run_prove>( + label: &str, + setup: & as CommitmentScheme>::ProverSetup, + poly: &P, + pt: &[F], + opening: F, + layout: &HachiCommitmentLayout, +) { + type Scheme = HachiCommitmentScheme; + + let t0 = Instant::now(); + let (commitment, hint) = + as CommitmentScheme>::commit(poly, setup, layout).unwrap(); + eprintln!("[{label}] commit: {:.3}s", t0.elapsed().as_secs_f64()); + + let t0 = Instant::now(); + let mut prover_transcript = Blake2bTranscript::::new(b"profile"); + let proof = as CommitmentScheme>::prove( + setup, + poly, + pt, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + layout, + ) + .unwrap(); + eprintln!("[{label}] prove: {:.3}s", t0.elapsed().as_secs_f64()); + print_proof_summary(label, &proof); + + let t0 = Instant::now(); + let verifier_setup = as CommitmentScheme>::setup_verifier(setup); + let mut verifier_transcript = Blake2bTranscript::::new(b"profile"); + match as CommitmentScheme>::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + pt, + &opening, + &commitment, + BasisMode::Lagrange, + layout, + ) { + Ok(()) => eprintln!("[{label}] verify: {:.3}s OK", t0.elapsed().as_secs_f64()), + Err(e) => eprintln!( + "[{label}] verify: {:.3}s FAILED ({e})", + t0.elapsed().as_secs_f64() + ), + } +} + +fn print_proof_summary(label: &str, proof: &HachiProof) { + eprintln!( + "[{label}] levels: {}, proof size: {} bytes", + proof.levels.len(), + proof.size() + ); + for (i, lp) in proof.levels.iter().enumerate() { + let w_comm_size = lp.w_commitment.serialized_size(Compress::No); + let sc_size = lp.sumcheck_proof.serialized_size(Compress::No); + eprintln!( + "[{label}] L{i}: w_commitment={} ring elems (D={}, {} bytes), sumcheck={} bytes", + lp.w_commitment.count(), + lp.w_commit_d(), + w_comm_size, + sc_size, + ); + } + eprintln!( + "[{label}] final_w: {} elems, {} bits/elem, packed {} bytes", + proof.final_w.num_elems, + proof.final_w.bits_per_elem, + proof.final_w.serialized_size(Compress::No), + ); +} + +fn print_layout(layout: &HachiCommitmentLayout) { + eprintln!( + " layout: m_vars={}, r_vars={}, num_blocks={}, block_len={}, \ + delta_commit={}, delta_open={}, delta_fold={}, log_basis={}", + layout.m_vars, + layout.r_vars, + layout.num_blocks, + layout.block_len, + layout.num_digits_commit, + layout.num_digits_open, + layout.num_digits_fold, + layout.log_basis, + ); +} + +fn run_dense(nv: usize, layout: &HachiCommitmentLayout) { + let mut rng = StdRng::seed_from_u64(0xbeef_cafe); + let len = 1usize << nv; + let decomp = Cfg::decomposition(); + let half_bound = 1i64 << (decomp.log_commit_bound.min(62) - 1); + let evals: Vec = if decomp.log_commit_bound >= 128 { + (0..len) + .map(|_| F::from_canonical_u128_reduced(rng.gen::())) + .collect() + } else { + (0..len) + .map(|_| F::from_i64(rng.gen_range(-half_bound..half_bound))) + .collect() + }; + let poly = DensePoly::::from_field_evals(nv, &evals).unwrap(); + let pt: Vec = (0..nv) + .map(|_| F::from_canonical_u128_reduced(rng.gen::())) + .collect(); + let opening = multilinear_eval(&evals, &pt).unwrap(); + + let t0 = Instant::now(); + let setup = as CommitmentScheme>::setup_prover(nv); + eprintln!(" setup: {:.3}s", t0.elapsed().as_secs_f64()); + + run_prove::("dense", &setup, &poly, &pt, opening, layout); +} + +fn run_onehot(nv: usize, layout: &HachiCommitmentLayout) { + let mut rng = StdRng::seed_from_u64(0xbeef_cafe); + let total_ring = layout.num_blocks * layout.block_len; + let onehot_k = D; + + let indices: Vec> = (0..total_ring) + .map(|_| Some(rng.gen_range(0..onehot_k))) + .collect(); + let onehot_poly = + OneHotPoly::::new(onehot_k, indices.clone(), layout.r_vars, layout.m_vars).unwrap(); + + let onehot_evals: Vec = { + let mut evals = vec![F::from_u64(0); total_ring * onehot_k]; + for (ci, opt_idx) in indices.iter().enumerate() { + if let Some(idx) = opt_idx { + evals[ci * onehot_k + idx] = F::from_u64(1); + } + } + evals + }; + let pt: Vec = (0..nv) + .map(|_| F::from_canonical_u128_reduced(rng.gen::())) + .collect(); + let opening = multilinear_eval(&onehot_evals, &pt).unwrap(); + + let t0 = Instant::now(); + let setup = as CommitmentScheme>::setup_prover(nv); + eprintln!(" setup: {:.3}s", t0.elapsed().as_secs_f64()); + + run_prove::("onehot", &setup, &onehot_poly, &pt, opening, layout); +} + +fn main() { + #[cfg(feature = "parallel")] + rayon::ThreadPoolBuilder::new() + .stack_size(64 * 1024 * 1024) + .build_global() + .ok(); + + let trace_dir = "profile_traces"; + fs::create_dir_all(trace_dir).ok(); + + let nv: usize = env::var("HACHI_NUM_VARS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(25); + + let mode = env::var("HACHI_MODE").unwrap_or_else(|_| "full".to_string()); + + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let trace_file = format!("{trace_dir}/hachi_nv{nv}_{mode}_{timestamp}.json"); + + let (chrome_layer, _guard) = ChromeLayerBuilder::new() + .include_args(true) + .file(&trace_file) + .build(); + + tracing_subscriber::registry().with(chrome_layer).init(); + + eprintln!("Perfetto trace: {trace_file}"); + eprintln!("num_vars={nv}, mode={mode}"); + eprintln!(); + + match mode.as_str() { + "full" => { + type Cfg = Fp128FullCommitmentConfig; + let layout = resolve_layout::(nv); + eprintln!("=== full (dense, log_commit_bound=128) ==="); + print_layout(&layout); + run_dense::<{ Fp128FullCommitmentConfig::D }, Cfg>(nv, &layout); + } + "onehot" => { + type Cfg = Fp128OneHotCommitmentConfig; + let layout = resolve_layout::(nv); + eprintln!("=== onehot (log_commit_bound=1) ==="); + print_layout(&layout); + run_onehot::<{ Fp128OneHotCommitmentConfig::D }, Cfg>(nv, &layout); + } + "logbasis" => { + type Cfg = Fp128LogBasisCommitmentConfig; + let layout = resolve_layout::(nv); + eprintln!("=== logbasis (dense, log_commit_bound=3) ==="); + print_layout(&layout); + run_dense::<{ Fp128LogBasisCommitmentConfig::D }, Cfg>(nv, &layout); + } + "all" => { + { + type Cfg = Fp128FullCommitmentConfig; + let layout = resolve_layout::(nv); + eprintln!("=== full (dense, log_commit_bound=128) ==="); + print_layout(&layout); + run_dense::<{ Fp128FullCommitmentConfig::D }, Cfg>(nv, &layout); + eprintln!(); + } + { + type Cfg = Fp128OneHotCommitmentConfig; + let layout = resolve_layout::(nv); + eprintln!("=== onehot (log_commit_bound=1) ==="); + print_layout(&layout); + run_onehot::<{ Fp128OneHotCommitmentConfig::D }, Cfg>(nv, &layout); + eprintln!(); + } + { + type Cfg = Fp128LogBasisCommitmentConfig; + let layout = resolve_layout::(nv); + eprintln!("=== logbasis (dense, log_commit_bound=3) ==="); + print_layout(&layout); + run_dense::<{ Fp128LogBasisCommitmentConfig::D }, Cfg>(nv, &layout); + } + } + other => { + eprintln!("Unknown HACHI_MODE={other}. Use: full, onehot, logbasis, all"); + std::process::exit(1); + } + } + + eprintln!("\nDone. Trace saved to {trace_file}"); +} + +fn resolve_layout(nv: usize) -> HachiCommitmentLayout { + Cfg::commitment_layout(nv).expect("layout") +} diff --git a/paper/hachi.pdf b/paper/hachi.pdf deleted file mode 100644 index 33354fc2..00000000 Binary files a/paper/hachi.pdf and /dev/null differ diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 00000000..c5b9f7f3 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,4 @@ +[toolchain] +channel = "1.88" +profile = "minimal" +components = ["cargo", "rustc", "clippy", "rustfmt"] diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 00000000..541c50b7 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,5 @@ +edition = "2021" +tab_spaces = 4 +newline_style = "Unix" +use_try_shorthand = true +use_field_init_shorthand = true diff --git a/src/algebra/backend/mod.rs b/src/algebra/backend/mod.rs new file mode 100644 index 00000000..3b885ce1 --- /dev/null +++ b/src/algebra/backend/mod.rs @@ -0,0 +1,7 @@ +//! Backend contracts and concrete backend implementations. + +pub mod scalar; +pub mod traits; + +pub use scalar::ScalarBackend; +pub use traits::{CrtReconstruct, NttPrimeOps, NttTransform, RingBackend}; diff --git a/src/algebra/backend/scalar.rs b/src/algebra/backend/scalar.rs new file mode 100644 index 00000000..a1386d72 --- /dev/null +++ b/src/algebra/backend/scalar.rs @@ -0,0 +1,94 @@ +//! Default scalar backend: delegates to NTT kernels and uses Garner's +//! algorithm for CRT reconstruction. + +use super::traits::{CrtReconstruct, NttPrimeOps, NttTransform}; +use crate::algebra::ntt::butterfly::{forward_ntt, inverse_ntt, NttTwiddles}; +use crate::algebra::ntt::crt::GarnerData; +use crate::algebra::ntt::prime::{MontCoeff, NttPrime, PrimeWidth}; +use crate::algebra::ring::CrtNttConvertibleField; + +/// Default scalar backend implementation. +#[derive(Debug, Clone, Copy, Default)] +pub struct ScalarBackend; + +impl NttPrimeOps for ScalarBackend { + #[inline] + fn from_canonical(prime: NttPrime, value: W) -> MontCoeff { + prime.from_canonical(value) + } + + #[inline] + fn to_canonical(prime: NttPrime, value: MontCoeff) -> W { + prime.to_canonical(value) + } + + #[inline] + fn reduce_range(prime: NttPrime, value: MontCoeff) -> MontCoeff { + prime.reduce_range(value) + } + + #[inline] + fn pointwise_mul( + prime: NttPrime, + out: &mut [MontCoeff; D], + lhs: &[MontCoeff; D], + rhs: &[MontCoeff; D], + ) { + prime.pointwise_mul(out, lhs, rhs); + } +} + +impl NttTransform for ScalarBackend { + #[inline] + fn forward_ntt(limb: &mut [MontCoeff; D], prime: NttPrime, twiddles: &NttTwiddles) { + forward_ntt(limb, prime, twiddles); + } + + #[inline] + fn inverse_ntt(limb: &mut [MontCoeff; D], prime: NttPrime, twiddles: &NttTwiddles) { + inverse_ntt(limb, prime, twiddles); + } +} + +impl CrtReconstruct for ScalarBackend { + fn reconstruct( + primes: &[NttPrime; K], + canonical: &[[W; D]; K], + garner: &GarnerData, + ) -> [F; D] { + let mut coeffs = [F::zero(); D]; + for (d, coeff) in coeffs.iter_mut().enumerate() { + // Garner mixed-radix decomposition (all arithmetic in i64, mod p_i). + let mut v = [0i64; K]; + v[0] = canonical[0][d].to_i64(); + for i in 1..K { + let pi = primes[i].p.to_i64(); + let mut temp = canonical[i][d].to_i64(); + #[allow(clippy::needless_range_loop)] + for j in 0..i { + temp -= v[j]; + temp = ((temp % pi) + pi) % pi; + temp = (temp * garner.gamma[i][j].to_i64()) % pi; + } + // Center the mixed-radix digit to keep the final reconstruction + // in a small signed range when inputs are centered. + if temp > pi / 2 { + temp -= pi; + } + v[i] = temp; + } + + // Horner accumulation in the target field F. + let mut result = F::from_i64(v[0]); + let mut partial_prod = F::from_i64(primes[0].p.to_i64()); + for i in 1..K { + result += F::from_i64(v[i]) * partial_prod; + if i + 1 < K { + partial_prod = partial_prod * F::from_i64(primes[i].p.to_i64()); + } + } + *coeff = result; + } + coeffs + } +} diff --git a/src/algebra/backend/traits.rs b/src/algebra/backend/traits.rs new file mode 100644 index 00000000..118a5f1f --- /dev/null +++ b/src/algebra/backend/traits.rs @@ -0,0 +1,59 @@ +//! Backend traits for CRT+NTT execution semantics. +//! +//! All traits are generic over `W: PrimeWidth` to support both +//! `i16` (primes < 2^14) and `i32` (primes < 2^30) NTT backends. + +use crate::algebra::ntt::butterfly::NttTwiddles; +use crate::algebra::ntt::crt::GarnerData; +use crate::algebra::ntt::prime::{MontCoeff, NttPrime, PrimeWidth}; +use crate::algebra::ring::CrtNttConvertibleField; + +/// Per-prime arithmetic primitives used by CRT+NTT domains. +pub trait NttPrimeOps { + /// Convert canonical coefficient to backend prime representation. + fn from_canonical(prime: NttPrime, value: W) -> MontCoeff; + + /// Convert backend prime representation back to canonical coefficient. + fn to_canonical(prime: NttPrime, value: MontCoeff) -> W; + + /// Range-reduce one coefficient from `(-2p, 2p)` to `(-p, p)`. + fn reduce_range(prime: NttPrime, value: MontCoeff) -> MontCoeff; + + /// Pointwise multiplication in backend prime representation. + fn pointwise_mul( + prime: NttPrime, + out: &mut [MontCoeff; D], + lhs: &[MontCoeff; D], + rhs: &[MontCoeff; D], + ); +} + +/// Forward/inverse transform kernels for one NTT limb. +pub trait NttTransform { + /// Forward transform from coefficient limb to NTT limb. + fn forward_ntt(limb: &mut [MontCoeff; D], prime: NttPrime, twiddles: &NttTwiddles); + + /// Inverse transform from NTT limb to coefficient limb. + fn inverse_ntt(limb: &mut [MontCoeff; D], prime: NttPrime, twiddles: &NttTwiddles); +} + +/// CRT reconstruction from per-prime canonical coefficients via Garner's algorithm. +pub trait CrtReconstruct { + /// Reconstruct coefficient-domain values from canonical CRT residues. + fn reconstruct( + primes: &[NttPrime; K], + canonical_limbs: &[[W; D]; K], + garner: &GarnerData, + ) -> [F; D]; +} + +/// Convenience composition trait for full ring backend capability. +pub trait RingBackend: + NttPrimeOps + NttTransform + CrtReconstruct +{ +} + +impl RingBackend for T where + T: NttPrimeOps + NttTransform + CrtReconstruct +{ +} diff --git a/src/algebra/fields/ext.rs b/src/algebra/fields/ext.rs new file mode 100644 index 00000000..b46cbbf9 --- /dev/null +++ b/src/algebra/fields/ext.rs @@ -0,0 +1,789 @@ +//! Quadratic and quartic extension fields. + +use super::wide::{AccumPair, HasUnreducedOps}; +use crate::algebra::module::VectorModule; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{AdditiveGroup, FieldCore, FieldSampling, FromSmallInt}; + +/// `Fp2Config` with non-residue = -1. +/// +/// Valid when `p ≡ 3 (mod 4)`, i.e. -1 is a quadratic non-residue. +pub struct NegOneNr; + +impl Fp2Config for NegOneNr { + const IS_NEG_ONE: bool = true; + + fn non_residue() -> F { + -F::one() + } +} + +/// `Fp2Config` with non-residue = 2. +/// +/// Valid when `p ≡ 5 (mod 8)`, i.e. 2 is a quadratic non-residue. +/// All Hachi pseudo-Mersenne primes (`2^k - c` with `c ≡ 3 mod 8`) +/// satisfy this. +pub struct TwoNr; + +impl Fp2Config for TwoNr { + fn non_residue() -> F { + F::from_u64(2) + } +} +use rand_core::RngCore; +use std::io::{Read, Write}; +use std::marker::PhantomData; +use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; + +/// Parameters for an `Fp2` quadratic extension over base field `F`. +pub trait Fp2Config { + /// Whether the non-residue is -1. + /// + /// When `true`, multiplication by the non-residue is a free negation and + /// the Karatsuba/squaring routines can avoid a base-field multiply. + const IS_NEG_ONE: bool = false; + + /// Non-residue `NR` such that `u^2 = NR`. + fn non_residue() -> F; +} + +/// Quadratic extension element `c0 + c1 * u` with `u^2 = NR`. +pub struct Fp2> { + /// Constant term. + pub c0: F, + /// Coefficient of `u`. + pub c1: F, + _cfg: PhantomData C>, +} + +impl> Fp2 { + /// Construct `c0 + c1 * u`. + #[inline] + pub fn new(c0: F, c1: F) -> Self { + Self { + c0, + c1, + _cfg: PhantomData, + } + } + + /// Multiply a base-field element by the non-residue. + /// + /// When `IS_NEG_ONE` is true this is just a negation (no multiply). + #[inline(always)] + fn mul_nr(x: F) -> F { + if C::IS_NEG_ONE { + -x + } else { + C::non_residue() * x + } + } + + /// Return the conjugate `c0 - c1 * u`. + #[inline] + pub fn conjugate(self) -> Self { + Self::new(self.c0, -self.c1) + } + + /// Return the norm in the base field: `c0^2 - NR * c1^2`. + #[inline] + pub fn norm(self) -> F { + (self.c0 * self.c0) - Self::mul_nr(self.c1 * self.c1) + } +} + +impl> std::fmt::Debug for Fp2 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Fp2") + .field("c0", &self.c0) + .field("c1", &self.c1) + .finish() + } +} + +impl> Clone for Fp2 { + fn clone(&self) -> Self { + *self + } +} + +impl> Copy for Fp2 {} + +impl> Default for Fp2 { + fn default() -> Self { + Self::new(F::zero(), F::zero()) + } +} + +impl> PartialEq for Fp2 { + fn eq(&self, other: &Self) -> bool { + self.c0 == other.c0 && self.c1 == other.c1 + } +} + +impl> Eq for Fp2 {} + +impl> Add for Fp2 { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + Self::new(self.c0 + rhs.c0, self.c1 + rhs.c1) + } +} +impl> Sub for Fp2 { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + Self::new(self.c0 - rhs.c0, self.c1 - rhs.c1) + } +} +impl> Neg for Fp2 { + type Output = Self; + fn neg(self) -> Self::Output { + Self::new(-self.c0, -self.c1) + } +} +impl> AddAssign for Fp2 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} +impl> SubAssign for Fp2 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} +impl> Mul for Fp2 { + type Output = Self; + fn mul(self, rhs: Self) -> Self::Output { + let v0 = self.c0 * rhs.c0; + let v1 = self.c1 * rhs.c1; + Self::new( + v0 + Self::mul_nr(v1), + (self.c0 + self.c1) * (rhs.c0 + rhs.c1) - v0 - v1, + ) + } +} + +impl<'a, F: FieldCore, C: Fp2Config> Add<&'a Self> for Fp2 { + type Output = Self; + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} +impl<'a, F: FieldCore, C: Fp2Config> Sub<&'a Self> for Fp2 { + type Output = Self; + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} +impl<'a, F: FieldCore, C: Fp2Config> Mul<&'a Self> for Fp2 { + type Output = Self; + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl> Valid for Fp2 { + fn check(&self) -> Result<(), SerializationError> { + self.c0.check()?; + self.c1.check()?; + Ok(()) + } +} + +impl> HachiSerialize for Fp2 { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.c0.serialize_with_mode(&mut writer, compress)?; + self.c1.serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.c0.serialized_size(compress) + self.c1.serialized_size(compress) + } +} + +impl> HachiDeserialize for Fp2 { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let c0 = F::deserialize_with_mode(&mut reader, compress, validate)?; + let c1 = F::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self::new(c0, c1); + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl> AdditiveGroup for Fp2 { + const ZERO: Self = Self { + c0: F::ZERO, + c1: F::ZERO, + _cfg: PhantomData, + }; +} + +impl> FieldCore for Fp2 { + fn one() -> Self { + Self::new(F::one(), F::zero()) + } + + fn is_zero(&self) -> bool { + self.c0.is_zero() && self.c1.is_zero() + } + + /// Specialized squaring: 2 base-field multiplications instead of 3. + /// + /// `(c0 + c1·u)^2 = (c0^2 + NR·c1^2) + (2·c0·c1)·u` + fn square(&self) -> Self { + let v0 = self.c0 * self.c0; + let v1 = self.c1 * self.c1; + Self::new(v0 + Self::mul_nr(v1), (self.c0 + self.c0) * self.c1) + } + + fn inv(self) -> Option { + if self.is_zero() { + return None; + } + let inv_n = self.norm().inv()?; + Some(Self::new(self.c0 * inv_n, (-self.c1) * inv_n)) + } + + const TWO_INV: Self = Self { + c0: F::TWO_INV, + c1: F::ZERO, + _cfg: PhantomData, + }; +} + +impl> FieldSampling for Fp2 { + fn sample(rng: &mut R) -> Self { + Self::new(F::sample(rng), F::sample(rng)) + } +} + +impl> FromSmallInt for Fp2 { + fn from_u64(val: u64) -> Self { + Self::new(F::from_u64(val), F::zero()) + } + + fn from_i64(val: i64) -> Self { + Self::new(F::from_i64(val), F::zero()) + } +} + +impl> HasUnreducedOps for Fp2 { + type MulU64Accum = AccumPair; + type ProductAccum = AccumPair; + + #[inline] + fn mul_u64_unreduced(self, small: u64) -> AccumPair { + AccumPair( + self.c0.mul_u64_unreduced(small), + self.c1.mul_u64_unreduced(small), + ) + } + + #[inline] + fn mul_to_product_accum(self, other: Self) -> AccumPair { + // Karatsuba: (c0 + c1·u)(d0 + d1·u) = (c0·d0 + NR·c1·d1) + (c0·d1 + c1·d0)·u + let v0 = self.c0.mul_to_product_accum(other.c0); + let v1 = self.c1.mul_to_product_accum(other.c1); + let cross = (self.c0 + self.c1).mul_to_product_accum(other.c0 + other.c1); + + let nr_v1 = if C::IS_NEG_ONE { -v1 } else { v1 + v1 }; + AccumPair(v0 + nr_v1, cross - v0 - v1) + } + + #[inline] + fn reduce_mul_u64_accum(accum: AccumPair) -> Self { + Self::new( + F::reduce_mul_u64_accum(accum.0), + F::reduce_mul_u64_accum(accum.1), + ) + } + + #[inline] + fn reduce_product_accum(accum: AccumPair) -> Self { + Self::new( + F::reduce_product_accum(accum.0), + F::reduce_product_accum(accum.1), + ) + } +} + +/// Parameters for an `Fp4` quadratic extension over `Fp2`. +pub trait Fp4Config> { + /// Non-residue `NR2` in `Fp2` such that `v^2 = NR2`. + fn non_residue() -> Fp2; +} + +/// `Fp4Config` with non-residue `u ∈ Fp2` (the element `(0, 1)`). +/// +/// This is the standard tower choice: `Fp4 = Fp2[v] / (v^2 - u)`. +pub struct UnitNr; + +impl> Fp4Config for UnitNr { + fn non_residue() -> Fp2 { + Fp2::new(F::zero(), F::one()) + } +} + +/// Quartic extension element `c0 + c1 * v` over `Fp2`, where `v^2 = NR2`. +pub struct Fp4, C4: Fp4Config> { + /// Constant term. + pub c0: Fp2, + /// Coefficient of `v`. + pub c1: Fp2, + _cfg: PhantomData C4>, +} + +impl, C4: Fp4Config> Fp4 { + /// Construct `c0 + c1 * v`. + #[inline] + pub fn new(c0: Fp2, c1: Fp2) -> Self { + Self { + c0, + c1, + _cfg: PhantomData, + } + } + + /// Return the norm in `Fp2`: `c0^2 - NR2 * c1^2`. + #[inline] + pub fn norm(self) -> Fp2 { + let nr2 = C4::non_residue(); + (self.c0 * self.c0) - (nr2 * (self.c1 * self.c1)) + } +} + +impl, C4: Fp4Config> std::fmt::Debug + for Fp4 +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Fp4") + .field("c0", &self.c0) + .field("c1", &self.c1) + .finish() + } +} + +impl, C4: Fp4Config> Clone for Fp4 { + fn clone(&self) -> Self { + *self + } +} + +impl, C4: Fp4Config> Copy for Fp4 {} + +impl, C4: Fp4Config> Default for Fp4 { + fn default() -> Self { + Self::new( + Fp2::new(F::zero(), F::zero()), + Fp2::new(F::zero(), F::zero()), + ) + } +} + +impl, C4: Fp4Config> PartialEq for Fp4 { + fn eq(&self, other: &Self) -> bool { + self.c0 == other.c0 && self.c1 == other.c1 + } +} + +impl, C4: Fp4Config> Eq for Fp4 {} + +impl, C4: Fp4Config> Add for Fp4 { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + Self::new(self.c0 + rhs.c0, self.c1 + rhs.c1) + } +} +impl, C4: Fp4Config> Sub for Fp4 { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + Self::new(self.c0 - rhs.c0, self.c1 - rhs.c1) + } +} +impl, C4: Fp4Config> Neg for Fp4 { + type Output = Self; + fn neg(self) -> Self::Output { + Self::new(-self.c0, -self.c1) + } +} +impl, C4: Fp4Config> AddAssign for Fp4 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} +impl, C4: Fp4Config> SubAssign for Fp4 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} +impl, C4: Fp4Config> Mul for Fp4 { + type Output = Self; + fn mul(self, rhs: Self) -> Self::Output { + let nr2 = C4::non_residue(); + let v0 = self.c0 * rhs.c0; + let v1 = self.c1 * rhs.c1; + Self::new( + v0 + (nr2 * v1), + (self.c0 + self.c1) * (rhs.c0 + rhs.c1) - v0 - v1, + ) + } +} + +impl<'a, F: FieldCore, C2: Fp2Config, C4: Fp4Config> Add<&'a Self> for Fp4 { + type Output = Self; + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} +impl<'a, F: FieldCore, C2: Fp2Config, C4: Fp4Config> Sub<&'a Self> for Fp4 { + type Output = Self; + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} +impl<'a, F: FieldCore, C2: Fp2Config, C4: Fp4Config> Mul<&'a Self> for Fp4 { + type Output = Self; + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl, C4: Fp4Config> Valid for Fp4 { + fn check(&self) -> Result<(), SerializationError> { + self.c0.check()?; + self.c1.check()?; + Ok(()) + } +} + +impl, C4: Fp4Config> HachiSerialize for Fp4 { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.c0.serialize_with_mode(&mut writer, compress)?; + self.c1.serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.c0.serialized_size(compress) + self.c1.serialized_size(compress) + } +} + +impl, C4: Fp4Config> HachiDeserialize + for Fp4 +{ + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let c0 = Fp2::::deserialize_with_mode(&mut reader, compress, validate)?; + let c1 = Fp2::::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self::new(c0, c1); + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl, C4: Fp4Config> AdditiveGroup for Fp4 { + const ZERO: Self = Self { + c0: Fp2::ZERO, + c1: Fp2::ZERO, + _cfg: PhantomData, + }; +} + +impl, C4: Fp4Config> FieldCore for Fp4 { + fn one() -> Self { + Self::new(Fp2::one(), Fp2::zero()) + } + + fn is_zero(&self) -> bool { + self.c0.is_zero() && self.c1.is_zero() + } + + fn square(&self) -> Self { + let nr2 = C4::non_residue(); + let v0 = self.c0.square(); + let v1 = self.c1.square(); + Self::new(v0 + nr2 * v1, (self.c0 + self.c0) * self.c1) + } + + fn inv(self) -> Option { + if self.is_zero() { + return None; + } + let inv_n = self.norm().inv()?; + Some(Self::new(self.c0 * inv_n, (-self.c1) * inv_n)) + } + + const TWO_INV: Self = Self { + c0: Fp2::TWO_INV, + c1: Fp2::ZERO, + _cfg: PhantomData, + }; +} + +impl, C4: Fp4Config> FieldSampling + for Fp4 +{ + fn sample(rng: &mut R) -> Self { + Self::new(Fp2::sample(rng), Fp2::sample(rng)) + } +} + +impl, C4: Fp4Config> FromSmallInt + for Fp4 +{ + fn from_u64(val: u64) -> Self { + Self::new(Fp2::from_u64(val), Fp2::zero()) + } + + fn from_i64(val: i64) -> Self { + Self::new(Fp2::from_i64(val), Fp2::zero()) + } +} + +// Scalar * VectorModule impls for extension scalars. + +impl Mul, N>> for Fp2 +where + F: FieldCore + Valid, + C: Fp2Config, +{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, F, C, const N: usize> Mul<&'a VectorModule, N>> for Fp2 +where + F: FieldCore + Valid, + C: Fp2Config, +{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} + +impl Mul, N>> for Fp4 +where + F: FieldCore + Valid, + C2: Fp2Config, + C4: Fp4Config, +{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, F, C2, C4, const N: usize> Mul<&'a VectorModule, N>> for Fp4 +where + F: FieldCore + Valid, + C2: Fp2Config, + C4: Fp4Config, +{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} + +// Convenience aliases for extension fields with NR = 2 (valid for all Hachi +// pseudo-Mersenne primes where p ≡ 5 mod 8). + +/// Quadratic extension over any `F` with non-residue 2. +pub type Ext2 = Fp2; + +/// Quartic extension as tower `Ext2[v]/(v^2 - u)`. +pub type Ext4 = Fp4; + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::lift::ExtField; + use crate::algebra::Fp64; + use crate::{FieldCore, FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp64<4294967197>; + type E2 = Ext2; + type E4 = Ext4; + + #[test] + fn fp2_add_sub_identity() { + let a = E2::new(F::from_u64(3), F::from_u64(5)); + let b = E2::new(F::from_u64(7), F::from_u64(11)); + let c = a + b; + assert_eq!(c - b, a); + assert_eq!(c - a, b); + } + + #[test] + fn fp2_mul_one() { + let a = E2::new(F::from_u64(42), F::from_u64(13)); + assert_eq!(a * E2::one(), a); + assert_eq!(E2::one() * a, a); + } + + #[test] + fn fp2_mul_commutativity() { + let mut rng = StdRng::seed_from_u64(1234); + let a = E2::sample(&mut rng); + let b = E2::sample(&mut rng); + assert_eq!(a * b, b * a); + } + + #[test] + fn fp2_karatsuba_matches_schoolbook() { + let mut rng = StdRng::seed_from_u64(5678); + for _ in 0..100 { + let a = E2::sample(&mut rng); + let b = E2::sample(&mut rng); + let nr = >::non_residue(); + let expected = E2::new( + (a.c0 * b.c0) + (nr * (a.c1 * b.c1)), + (a.c0 * b.c1) + (a.c1 * b.c0), + ); + assert_eq!(a * b, expected); + } + } + + #[test] + fn fp2_square_matches_mul() { + let mut rng = StdRng::seed_from_u64(9012); + for _ in 0..100 { + let a = E2::sample(&mut rng); + assert_eq!(a.square(), a * a, "square mismatch for {a:?}"); + } + } + + #[test] + fn fp2_inv() { + let mut rng = StdRng::seed_from_u64(3456); + for _ in 0..50 { + let a = E2::sample(&mut rng); + if !a.is_zero() { + let inv = a.inv().unwrap(); + assert_eq!(a * inv, E2::one()); + } + } + } + + #[test] + fn fp4_mul_commutativity() { + let mut rng = StdRng::seed_from_u64(7890); + let a = E4::sample(&mut rng); + let b = E4::sample(&mut rng); + assert_eq!(a * b, b * a); + } + + #[test] + fn fp4_square_matches_mul() { + let mut rng = StdRng::seed_from_u64(1111); + for _ in 0..50 { + let a = E4::sample(&mut rng); + assert_eq!(a.square(), a * a); + } + } + + #[test] + fn fp4_inv() { + let mut rng = StdRng::seed_from_u64(2222); + for _ in 0..50 { + let a = E4::sample(&mut rng); + if !a.is_zero() { + let inv = a.inv().unwrap(); + assert_eq!(a * inv, E4::one()); + } + } + } + + #[test] + fn from_small_int_fp2() { + let a = E2::from_u64(42); + assert_eq!(a, E2::new(F::from_u64(42), F::zero())); + + let b = E2::from_i64(-3); + assert_eq!(b, E2::new(F::from_i64(-3), F::zero())); + + let c = E2::from_u8(7); + assert_eq!(c, E2::from_u64(7)); + + let d = E2::from_u32(100_000); + assert_eq!(d, E2::from_u64(100_000)); + } + + #[test] + fn from_small_int_fp4() { + let a = E4::from_u64(42); + assert_eq!(a, E4::new(E2::from_u64(42), E2::zero())); + + let b = E4::from_i64(-7); + assert_eq!(b, E4::new(E2::from_i64(-7), E2::zero())); + } + + #[test] + fn ext_field_degree() { + assert_eq!(>::EXT_DEGREE, 1); + assert_eq!(>::EXT_DEGREE, 2); + assert_eq!(>::EXT_DEGREE, 4); + } + + #[test] + fn ext_field_from_base_slice() { + let c0 = F::from_u64(3); + let c1 = F::from_u64(5); + let e2 = E2::from_base_slice(&[c0, c1]); + assert_eq!(e2, E2::new(c0, c1)); + + let c2 = F::from_u64(7); + let c3 = F::from_u64(11); + let e4 = E4::from_base_slice(&[c0, c1, c2, c3]); + assert_eq!(e4, E4::new(E2::new(c0, c1), E2::new(c2, c3))); + } + + #[test] + fn eq_impl() { + let a = E2::new(F::from_u64(1), F::from_u64(2)); + let b = E2::new(F::from_u64(1), F::from_u64(2)); + let c = E2::new(F::from_u64(1), F::from_u64(3)); + assert_eq!(a, b); + assert_ne!(a, c); + } +} diff --git a/src/algebra/fields/fp128.rs b/src/algebra/fields/fp128.rs new file mode 100644 index 00000000..78531bdc --- /dev/null +++ b/src/algebra/fields/fp128.rs @@ -0,0 +1,1065 @@ +//! 128-bit prime field for primes of the form `p = 2^128 − c` with `c < 2^64`. +//! +//! Uses Solinas-style two-fold reduction: no Montgomery form, ~23 cycles/mul +//! on both AArch64 and x86-64. The offset `c` is computed at compile time +//! from the const-generic modulus `P`. +//! +//! ## Naming convention for built-in primes +//! +//! The built-in type names encode the **signed terms as they appear in the +//! modulus `p`** (excluding the leading `+2^128` term). For example, +//! `Prime128M13M4P0` denotes `p = 2^128 − 2^13 − 2^4 + 2^0`. + +use std::io::{Read, Write}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use rand_core::RngCore; + +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{ + AdditiveGroup, CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible, + PseudoMersenneField, +}; + +/// Pack two u64 limbs into `[lo, hi]`. +#[inline(always)] +const fn pack(lo: u64, hi: u64) -> [u64; 2] { + [lo, hi] +} + +/// Convert `u128` → `[u64; 2]`. +#[inline(always)] +const fn from_u128(x: u128) -> [u64; 2] { + [x as u64, (x >> 64) as u64] +} + +/// Convert `[u64; 2]` → `u128`. +#[inline(always)] +const fn to_u128(x: [u64; 2]) -> u128 { + x[0] as u128 | (x[1] as u128) << 64 +} + +use super::util::{is_pow2_u64, log2_pow2_u64, mul64_wide}; + +/// 128-bit prime field element for primes `p = 2^128 − c` with `c < 2^64`. +/// +/// Stored as `[u64; 2]` (lo, hi) for 8-byte alignment and direct limb access. +/// +/// The offset `c = 2^128 − p` and all derived constants are computed at +/// compile time from the const-generic `P`. Instantiating `Fp128` with a +/// modulus that is not of this form is a compile-time error. +#[derive(Debug, Clone, Copy, Default)] +pub struct Fp128(pub(crate) [u64; 2]); + +impl PartialEq for Fp128

{ + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for Fp128

{} + +impl Fp128

{ + /// Offset `c = 2^128 − p`. Validated at compile time. + pub const C: u128 = { + let c = 0u128.wrapping_sub(P); + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!(c < (1u128 << 64), "P must be 2^128 - c with c < 2^64"); + // Fused overflow+canonicalize requires C(C+1) < P. + assert!( + c * (c + 1) < P, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + /// Low 64 bits of `C` (always equals `C` since `C < 2^64`). + pub const C_LO: u64 = Self::C as u64; + /// +1 means `C = 2^a + 1`, -1 means `C = 2^a - 1`, 0 means generic. + const C_SHIFT_KIND: i8 = { + let c = Self::C_LO; + if c > 1 && is_pow2_u64(c - 1) { + 1 + } else if c == u64::MAX || is_pow2_u64(c + 1) { + -1 + } else { + 0 + } + }; + const C_SHIFT: u32 = { + let c = Self::C_LO; + if Self::C_SHIFT_KIND == 1 { + log2_pow2_u64(c - 1) + } else if Self::C_SHIFT_KIND == -1 { + if c == u64::MAX { + 64 + } else { + log2_pow2_u64(c + 1) + } + } else { + 0 + } + }; + + /// Multiply by `C = 2^128 - P`. For `C = 2^a ± 1`, this is shift/add or + /// shift/sub only; otherwise it falls back to generic widening multiply. + #[inline(always)] + fn mul_c_wide(x: u64) -> (u64, u64) { + if Self::C_SHIFT_KIND == 1 { + let v = ((x as u128) << Self::C_SHIFT) + x as u128; + (v as u64, (v >> 64) as u64) + } else if Self::C_SHIFT_KIND == -1 { + let v = ((x as u128) << Self::C_SHIFT) - x as u128; + (v as u64, (v >> 64) as u64) + } else { + mul64_wide(Self::C_LO, x) + } + } + + /// Create from a canonical representative in `[0, p)`. + #[inline] + pub fn from_canonical_u128(x: u128) -> Self { + debug_assert!(x < P); + Self(from_u128(x)) + } + + /// Return the canonical representative in `[0, p)`. + #[inline] + pub fn to_canonical_u128(self) -> u128 { + to_u128(self.0) + } + + /// Const-evaluable `from_i64`. Embeds a small signed integer into `Fp`. + pub const fn from_i64_const(val: i64) -> Self { + if val >= 0 { + Self(from_u128(val as u128)) + } else { + Self(Self::sub_raw( + pack(0, 0), + from_u128(val.unsigned_abs() as u128), + )) + } + } + + /// Const-evaluable lookup table for balanced digits in `[-b/2, b/2)` + /// where `b = 2^log_basis`. Requires `log_basis <= 4`. + pub const fn digit_lut(log_basis: u32) -> [Self; 16] { + assert!(log_basis > 0 && log_basis <= 4); + let b = 1u32 << log_basis; + let half_b = (b / 2) as i64; + let mut lut = [Self(pack(0, 0)); 16]; + let mut i = 0u32; + while i < b { + lut[i as usize] = Self::from_i64_const(i as i64 - half_b); + i += 1; + } + lut + } + + #[inline(always)] + fn add_raw(a: [u64; 2], b: [u64; 2]) -> [u64; 2] { + let (s, carry) = to_u128(a).overflowing_add(to_u128(b)); + let (reduced, borrow) = s.overflowing_sub(P); + from_u128(if carry | !borrow { + reduced + } else { + reduced.wrapping_add(P) + }) + } + + #[inline(always)] + const fn sub_raw(a: [u64; 2], b: [u64; 2]) -> [u64; 2] { + let (diff, borrow) = to_u128(a).overflowing_sub(to_u128(b)); + from_u128(if borrow { diff.wrapping_add(P) } else { diff }) + } + + /// Fold 2 + canonicalize: reduce `[t0, t1] + t2·2^128` into `[0, p)`. + /// + /// Correctness argument for the fused overflow+canonicalize: + /// + /// Let `v = base + C·t2` (mathematical, not mod 2^128). + /// From the fold-1 mac chain, `t2 ≤ C`, so `C·t2 ≤ C²`. + /// + /// - **No overflow** (`v < 2^128`): `s = v`, and the standard + /// canonicalize applies — `s + C` carries iff `s ≥ P`. + /// - **Overflow** (`v ≥ 2^128`): `s = v − 2^128`, so `s < C·t2 ≤ C²`. + /// The correct reduced value is `s + C` (since `2^128 ≡ C mod P`). + /// Because `s + C < C² + C = C(C+1)` and `C(C+1) < P` for all + /// `C < 2^64`, the value `s + C` is already in `[0, P)` — no + /// further canonicalization is needed, and `s + C < 2^128` so the + /// add does NOT carry. + /// + /// Therefore `if (overflow | carry) { s + C } else { s }` is correct + /// in both cases, fusing the overflow correction with canonicalization. + #[inline(always)] + fn fold2_canonicalize(t0: u64, t1: u64, t2: u64) -> [u64; 2] { + let (ct2_lo, ct2_hi) = Self::mul_c_wide(t2); + + let (s0, carry0) = t0.overflowing_add(ct2_lo); + let (s1a, carry1a) = t1.overflowing_add(ct2_hi); + let (s1, carry1b) = s1a.overflowing_add(carry0 as u64); + let overflow = carry1a | carry1b; + + let (r0, carry2) = s0.overflowing_add(Self::C_LO); + let (r1, carry3) = s1.overflowing_add(carry2 as u64); + + pack( + if overflow | carry3 { r0 } else { s0 }, + if overflow | carry3 { r1 } else { s1 }, + ) + } + + /// Solinas fold for exactly 4 limbs: `[r0,r1] + C·[r2,r3]` → 3 limbs, + /// then `fold2_canonicalize`. + #[inline(always)] + fn reduce_4(r0: u64, r1: u64, r2: u64, r3: u64) -> [u64; 2] { + let (cr2_lo, cr2_hi) = Self::mul_c_wide(r2); + let (cr3_lo, cr3_hi) = Self::mul_c_wide(r3); + + let t0_sum = r0 as u128 + cr2_lo as u128; + let t0 = t0_sum as u64; + let carryf = (t0_sum >> 64) as u64; + + let t1_sum = r1 as u128 + cr2_hi as u128 + cr3_lo as u128 + carryf as u128; + let t1 = t1_sum as u64; + + let t2_sum = cr3_hi as u128 + (t1_sum >> 64); + let t2 = t2_sum as u64; + debug_assert_eq!(t2_sum >> 64, 0); + + Self::fold2_canonicalize(t0, t1, t2) + } + + #[inline(always)] + fn mul_raw(a: [u64; 2], b: [u64; 2]) -> [u64; 2] { + let [r0, r1, r2, r3] = Self(a).mul_wide(Self(b)); + Self::reduce_4(r0, r1, r2, r3) + } + + #[inline(always)] + fn sqr_wide(self) -> [u64; 4] { + let (a0, a1) = (self.0[0], self.0[1]); + let (p00_lo, p00_hi) = mul64_wide(a0, a0); + let (p01_lo, p01_hi) = mul64_wide(a0, a1); + let (p11_lo, p11_hi) = mul64_wide(a1, a1); + + let row1 = p00_hi as u128 + (p01_lo as u128) * 2; + let r0 = p00_lo; + let r1 = row1 as u64; + let carry1 = (row1 >> 64) as u64; + + let row2 = (p01_hi as u128) * 2 + p11_lo as u128 + carry1 as u128; + let r2 = row2 as u64; + let carry2 = (row2 >> 64) as u64; + + let row3 = p11_hi as u128 + carry2 as u128; + let r3 = row3 as u64; + debug_assert_eq!(row3 >> 64, 0); + + [r0, r1, r2, r3] + } + + #[inline(always)] + fn sqr_raw(a: [u64; 2]) -> [u64; 2] { + let [r0, r1, r2, r3] = Self(a).sqr_wide(); + Self::reduce_4(r0, r1, r2, r3) + } + + /// Squaring, equivalent to `self * self`. + #[inline(always)] + pub fn square(self) -> Self { + Self(Self::sqr_raw(self.0)) + } + + fn pow_u128(self, mut exp: u128) -> Self { + let mut base = self; + let mut acc = Self::one(); + while exp > 0 { + if (exp & 1) == 1 { + acc *= base; + } + base = Self(Self::sqr_raw(base.0)); + exp >>= 1; + } + acc + } + + /// Extract the canonical `[lo, hi]` limb representation. + #[inline(always)] + pub fn to_limbs(self) -> [u64; 2] { + self.0 + } + + /// 128×64 → 192-bit widening multiply, **no reduction**. + /// + /// Returns `[lo, mid, hi]` representing `self · other` as a 192-bit + /// integer. Cost: 2 widening `mul64`. + #[inline(always)] + pub fn mul_wide_u64(self, other: u64) -> [u64; 3] { + let (a0, a1) = (self.0[0], self.0[1]); + let (p0_lo, p0_hi) = mul64_wide(a0, other); + let (p1_lo, p1_hi) = mul64_wide(a1, other); + let mid = p0_hi as u128 + p1_lo as u128; + let hi = p1_hi + (mid >> 64) as u64; + [p0_lo, mid as u64, hi] + } + + /// 128×128 → 256-bit widening multiply, **no reduction**. + /// + /// Returns `[r0, r1, r2, r3]` representing `self · other` as a 256-bit + /// integer. This is the schoolbook 2×2 portion of the Solinas multiply, + /// without the reduction fold. Cost: 4 widening `mul64`. + #[inline(always)] + pub fn mul_wide(self, other: Self) -> [u64; 4] { + let (a0, a1) = (self.0[0], self.0[1]); + let (b0, b1) = (other.0[0], other.0[1]); + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r0 = p00_lo; + let r1 = row1 as u64; + let carry1 = (row1 >> 64) as u64; + + let row2 = p01_hi as u128 + p10_hi as u128 + p11_lo as u128 + carry1 as u128; + let r2 = row2 as u64; + let carry2 = (row2 >> 64) as u64; + + let row3 = p11_hi as u128 + carry2 as u128; + let r3 = row3 as u64; + debug_assert_eq!(row3 >> 64, 0); + + [r0, r1, r2, r3] + } + + /// 128×128 → 256-bit widening multiply with a raw `u128` operand, + /// **no reduction**. + #[inline(always)] + pub fn mul_wide_u128(self, other: u128) -> [u64; 4] { + self.mul_wide(Self(from_u128(other))) + } + + /// 128×(64*M) → (64*OUT) widening multiply, **no reduction**. + /// + /// Multiplies a canonical Fp128 value (`[u64; 2]`) by an arbitrary + /// little-endian limb array and returns the little-endian product + /// truncated/extended to `OUT` limbs. + #[inline(always)] + pub fn mul_wide_limbs(self, other: [u64; M]) -> [u64; OUT] { + let (a0, a1) = (self.0[0], self.0[1]); + + // Hot-path specializations used by Jolt (M in {3,4}, OUT in {4,5}). + // These avoid loop/control-flow overhead in tight sumcheck FMAs. + if M == 3 && OUT == 5 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let (p12_lo, p12_hi) = mul64_wide(a1, b2); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + let carry3 = row3 >> 64; + + let row4 = p12_hi as u128 + carry3; + let r4 = row4 as u64; + debug_assert_eq!(row4 >> 64, 0); + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + out[4] = r4; + return out; + } + if M == 3 && OUT == 4 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let p12_lo = a1.wrapping_mul(b2); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + return out; + } + if M == 4 && OUT == 6 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + let b3 = other[3]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let (p03_lo, p03_hi) = mul64_wide(a0, b3); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let (p12_lo, p12_hi) = mul64_wide(a1, b2); + let (p13_lo, p13_hi) = mul64_wide(a1, b3); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p03_lo as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + let carry3 = row3 >> 64; + + let row4 = p03_hi as u128 + p12_hi as u128 + p13_lo as u128 + carry3; + let r4 = row4 as u64; + let carry4 = row4 >> 64; + + let row5 = p13_hi as u128 + carry4; + let r5 = row5 as u64; + debug_assert_eq!(row5 >> 64, 0); + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + out[4] = r4; + out[5] = r5; + return out; + } + if M == 4 && OUT == 5 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + let b3 = other[3]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let (p03_lo, p03_hi) = mul64_wide(a0, b3); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let (p12_lo, p12_hi) = mul64_wide(a1, b2); + let p13_lo = a1.wrapping_mul(b3); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p03_lo as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + let carry3 = row3 >> 64; + + let row4 = p03_hi as u128 + p12_hi as u128 + p13_lo as u128 + carry3; + let r4 = row4 as u64; + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + out[4] = r4; + return out; + } + if M == 4 && OUT == 4 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + let b3 = other[3]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let p03_lo = a0.wrapping_mul(b3); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let p12_lo = a1.wrapping_mul(b2); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p03_lo as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + return out; + } + + let mut out = [0u64; OUT]; + + for (i, &b) in other.iter().enumerate() { + if i >= OUT { + break; + } + + let (p0_lo, p0_hi) = mul64_wide(a0, b); + let (p1_lo, p1_hi) = mul64_wide(a1, b); + + let s0 = out[i] as u128 + p0_lo as u128; + out[i] = s0 as u64; + let mut carry = s0 >> 64; + + if i + 1 >= OUT { + continue; + } + let s1 = out[i + 1] as u128 + p0_hi as u128 + p1_lo as u128 + carry; + out[i + 1] = s1 as u64; + carry = s1 >> 64; + + if i + 2 >= OUT { + continue; + } + let s2 = out[i + 2] as u128 + p1_hi as u128 + carry; + out[i + 2] = s2 as u64; + + let mut carry_hi = s2 >> 64; + let mut j = i + 3; + while carry_hi != 0 && j < OUT { + let sj = out[j] as u128 + carry_hi; + out[j] = sj as u64; + carry_hi = sj >> 64; + j += 1; + } + } + + out + } + + /// Reduce an arbitrary-width little-endian limb array to a canonical + /// field element via iterated Solinas folding. + /// + /// Each fold splits at the 128-bit boundary and replaces + /// `hi · 2^128` with `hi · C`, reducing width by one limb per + /// iteration. Supports 0–10 input limbs (up to 640 bits). + /// + /// # Panics + /// + /// Panics if `limbs.len() > 10`. + #[inline(always)] + pub fn solinas_reduce(limbs: &[u64]) -> Self { + match limbs.len() { + 0 => Self::zero(), + 1 => Self(pack(limbs[0], 0)), + 2 => Self::from_canonical_u128_reduced(to_u128([limbs[0], limbs[1]])), + 3 => Self(Self::fold2_canonicalize(limbs[0], limbs[1], limbs[2])), + 4 => Self(Self::reduce_4(limbs[0], limbs[1], limbs[2], limbs[3])), + 5 => { + let (l0, l1, l2, l3, l4) = (limbs[0], limbs[1], limbs[2], limbs[3], limbs[4]); + let (c2_lo, c2_hi) = Self::mul_c_wide(l2); + let (c3_lo, c3_hi) = Self::mul_c_wide(l3); + let (c4_lo, c4_hi) = Self::mul_c_wide(l4); + + let s0 = l0 as u128 + c2_lo as u128; + let s1 = l1 as u128 + c2_hi as u128 + c3_lo as u128 + (s0 >> 64); + let s2 = c3_hi as u128 + c4_lo as u128 + (s1 >> 64); + let s3 = c4_hi as u128 + (s2 >> 64); + debug_assert_eq!(s3 >> 64, 0); + + Self(Self::reduce_4(s0 as u64, s1 as u64, s2 as u64, s3 as u64)) + } + n => { + assert!(n <= 10, "solinas_reduce supports at most 10 limbs"); + let mut buf = [0u64; 11]; + buf[..n].copy_from_slice(limbs); + let mut len = n; + let c = Self::C_LO; + + while len > 5 { + let high_len = len - 2; + let mut next = [0u64; 11]; + + let mut carry: u64 = 0; + for i in 0..high_len { + let wide = c as u128 * buf[i + 2] as u128 + carry as u128; + next[i] = wide as u64; + carry = (wide >> 64) as u64; + } + next[high_len] = carry; + + let s0 = next[0] as u128 + buf[0] as u128; + next[0] = s0 as u64; + let s1 = next[1] as u128 + buf[1] as u128 + (s0 >> 64); + next[1] = s1 as u64; + let mut c_out = (s1 >> 64) as u64; + for limb in &mut next[2..=high_len] { + if c_out == 0 { + break; + } + let s = *limb as u128 + c_out as u128; + *limb = s as u64; + c_out = (s >> 64) as u64; + } + debug_assert_eq!(c_out, 0); + + buf = next; + len -= 1; + while len > 5 && buf[len - 1] == 0 { + len -= 1; + } + } + + Self::solinas_reduce(&buf[..len]) + } + } + } +} + +impl Add for Fp128

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Self(Self::add_raw(self.0, rhs.0)) + } +} + +impl Sub for Fp128

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Self(Self::sub_raw(self.0, rhs.0)) + } +} + +impl Mul for Fp128

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Self(Self::mul_raw(self.0, rhs.0)) + } +} + +impl Neg for Fp128

{ + type Output = Self; + #[inline] + fn neg(self) -> Self::Output { + Self(Self::sub_raw(pack(0, 0), self.0)) + } +} + +impl AddAssign for Fp128

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for Fp128

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for Fp128

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl<'a, const P: u128> Add<&'a Self> for Fp128

{ + type Output = Self; + #[inline] + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} + +impl<'a, const P: u128> Sub<&'a Self> for Fp128

{ + type Output = Self; + #[inline] + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} + +impl<'a, const P: u128> Mul<&'a Self> for Fp128

{ + type Output = Self; + #[inline] + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl Valid for Fp128

{ + fn check(&self) -> Result<(), SerializationError> { + if to_u128(self.0) < P { + Ok(()) + } else { + Err(SerializationError::InvalidData("Fp128 out of range".into())) + } + } +} + +impl HachiSerialize for Fp128

{ + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + to_u128(self.0).serialize_with_mode(&mut writer, Compress::No)?; + Ok(()) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 16 + } +} + +impl HachiDeserialize for Fp128

{ + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let x = u128::deserialize_with_mode(&mut reader, Compress::No, validate)?; + if matches!(validate, Validate::Yes) && x >= P { + return Err(SerializationError::InvalidData( + "Fp128 out of range".to_string(), + )); + } + + // Without validation, reduce without division. + // For `p = 2^128 − c` with `c < 2^64` we have `p > 2^127`, + // so any `u128` is in `[0, 2p)` and one conditional subtract suffices. + let out = if matches!(validate, Validate::Yes) { + x + } else { + let (sub, borrow) = x.overflowing_sub(P); + if borrow { + x + } else { + sub + } + }; + Ok(Self(from_u128(out))) + } +} + +impl AdditiveGroup for Fp128

{ + const ZERO: Self = Self(pack(0, 0)); +} + +impl FieldCore for Fp128

{ + fn one() -> Self { + Self(pack(1, 0)) + } + + fn is_zero(&self) -> bool { + self.0 == [0, 0] + } + + fn inv(self) -> Option { + let inv = self.inv_or_zero(); + if self.is_zero() { + None + } else { + Some(inv) + } + } + + const TWO_INV: Self = { + let v = (P >> 1) + 1; + Self(pack(v as u64, (v >> 64) as u64)) + }; +} + +impl Invertible for Fp128

{ + fn inv_or_zero(self) -> Self { + let candidate = self.pow_u128(P.wrapping_sub(2)); + let v = to_u128(self.0); + let nz = ((v | v.wrapping_neg()) >> 127) & 1; + let mask = 0u128.wrapping_sub(nz); + let masked = to_u128(candidate.0) & mask; + Self(from_u128(masked)) + } +} + +impl FieldSampling for Fp128

{ + fn sample(rng: &mut R) -> Self { + loop { + let lo = rng.next_u64(); + let hi = rng.next_u64(); + let x = lo as u128 | (hi as u128) << 64; + if x < P { + return Self(pack(lo, hi)); + } + } + } +} + +impl FromSmallInt for Fp128

{ + fn from_u64(val: u64) -> Self { + // For Fp128 pseudo-Mersenne primes, p = 2^128 - c with c < 2^64. + // Therefore any u64 is always canonical (< p), so this can be a + // direct limb construction with no reduction path. + Self(from_u128(val as u128)) + } + + fn from_i64(val: i64) -> Self { + Self::from_i64_const(val) + } + + fn digit_lut(log_basis: u32) -> [Self; 16] { + Self::digit_lut(log_basis) + } +} + +impl CanonicalField for Fp128

{ + fn to_canonical_u128(self) -> u128 { + to_u128(self.0) + } + + fn from_canonical_u128_checked(val: u128) -> Option { + if val < P { + Some(Self(from_u128(val))) + } else { + None + } + } + + fn from_canonical_u128_reduced(val: u128) -> Self { + let (sub, borrow) = val.overflowing_sub(P); + Self(from_u128(if borrow { val } else { sub })) + } +} + +impl PseudoMersenneField for Fp128

{ + const MODULUS_BITS: u32 = 128; + const MODULUS_OFFSET: u128 = Self::C; +} + +/// `p = 2^128 − 2^13 − 2^4 + 2^0` (C = 8207). +pub type Prime128M13M4P0 = Fp128<0xffffffffffffffffffffffffffffdff1>; +/// `p = 2^128 − 2^37 + 2^3 + 2^0` (C = 137438953463). +pub type Prime128M37P3P0 = Fp128<0xffffffffffffffffffffffe000000009>; +/// `p = 2^128 − 2^52 − 2^3 + 2^0` (C = 4503599627370487). +pub type Prime128M52M3P0 = Fp128<0xffffffffffffffffffeffffffffffff9>; +/// `p = 2^128 − 2^54 + 2^4 + 2^0` (C = 18014398509481967). +pub type Prime128M54P4P0 = Fp128<0xffffffffffffffffffc0000000000011>; +/// `p = 2^128 − 2^8 − 2^4 − 2^1 − 2^0` (C = 275). +pub type Prime128M8M4M1M0 = Fp128<0xfffffffffffffffffffffffffffffeed>; +/// `p = 2^128 − 2^18 − 2^0` (C = 2^18 + 1). +pub type Prime128M18M0 = Fp128<0xfffffffffffffffffffffffffffbffff>; +/// `p = 2^128 − 2^54 + 2^0` (C = 2^54 − 1). +pub type Prime128M54P0 = Fp128<0xffffffffffffffffffc0000000000001>; + +#[cfg(test)] +mod tests { + use super::*; + use crate::{FieldSampling, PseudoMersenneField}; + use rand::rngs::StdRng; + use rand::SeedableRng; + use rand_core::RngCore; + + type F = Prime128M8M4M1M0; + + #[test] + fn to_limbs_roundtrip() { + let mut rng = StdRng::seed_from_u64(0xdead_beef_cafe_1234); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + assert_eq!(Fp128(a.to_limbs()), a); + } + } + + #[test] + fn mul_wide_u64_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0x1122_3344_5566_7788); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b = rng.next_u64(); + let expected = a * F::from_u64(b); + let reduced = F::solinas_reduce(&a.mul_wide_u64(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0xaabb_ccdd_eeff_0011); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b: F = FieldSampling::sample(&mut rng); + let expected = a * b; + let reduced = F::solinas_reduce(&a.mul_wide(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_u128_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0x9988_7766_5544_3322); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b = rng.next_u64() as u128 | ((rng.next_u64() as u128) << 64); + let expected = a * F::from_canonical_u128_reduced(b); + let reduced = F::solinas_reduce(&a.mul_wide_u128(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_limbs_roundtrips_through_reduction() { + let mut rng = StdRng::seed_from_u64(0x1bad_f00d_0ddc_afe1); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b3 = [rng.next_u64(), rng.next_u64(), rng.next_u64()]; + let b4 = [ + rng.next_u64(), + rng.next_u64(), + rng.next_u64(), + rng.next_u64(), + ]; + + let got3_full = a.mul_wide_limbs::<3, 5>(b3); + let got3_trunc = a.mul_wide_limbs::<3, 4>(b3); + assert_eq!( + got3_trunc, + [got3_full[0], got3_full[1], got3_full[2], got3_full[3]] + ); + let exp3 = a * F::solinas_reduce(&b3); + assert_eq!(F::solinas_reduce(&got3_full), exp3); + + let got4_full = a.mul_wide_limbs::<4, 6>(b4); + let got4_trunc = a.mul_wide_limbs::<4, 4>(b4); + assert_eq!( + got4_trunc, + [got4_full[0], got4_full[1], got4_full[2], got4_full[3]] + ); + let exp4 = a * F::solinas_reduce(&b4); + assert_eq!(F::solinas_reduce(&got4_full), exp4); + } + } + + #[test] + fn solinas_reduce_small_inputs() { + assert_eq!(F::solinas_reduce(&[]), F::zero()); + assert_eq!(F::solinas_reduce(&[42]), F::from_u64(42)); + let one_shifted = F::from_canonical_u128_reduced(1u128 << 64); + assert_eq!(F::solinas_reduce(&[0, 1]), one_shifted); + } + + #[test] + fn solinas_reduce_4_limbs_max() { + // 2^256 - 1 ≡ C² - 1 (mod P), since 2^128 ≡ C + let c = F::from_canonical_u128_reduced(::MODULUS_OFFSET); + let expected = c * c - F::one(); + assert_eq!(F::solinas_reduce(&[u64::MAX; 4]), expected); + } + + #[test] + fn solinas_reduce_9_limbs() { + // 1 + 2^512 = 1 + (2^128)^4 ≡ 1 + C^4 + let c = F::from_canonical_u128_reduced(::MODULUS_OFFSET); + let expected = F::one() + c * c * c * c; + assert_eq!(F::solinas_reduce(&[1, 0, 0, 0, 0, 0, 0, 0, 1]), expected); + } + + #[test] + fn solinas_reduce_accumulated_products() { + let mut rng = StdRng::seed_from_u64(0xfeed_face_0bad_c0de); + let mut acc = [0u64; 5]; + let mut expected = F::zero(); + + for _ in 0..200 { + let a: F = FieldSampling::sample(&mut rng); + let b = rng.next_u64(); + let wide = a.mul_wide_u64(b); + + let mut carry: u64 = 0; + for j in 0..5 { + let addend = if j < 3 { wide[j] } else { 0 }; + let sum = acc[j] as u128 + addend as u128 + carry as u128; + acc[j] = sum as u64; + carry = (sum >> 64) as u64; + } + assert_eq!(carry, 0); + expected += a * F::from_u64(b); + } + + assert_eq!(F::solinas_reduce(&acc), expected); + } + + #[test] + fn solinas_reduce_cross_prime() { + // Verify with Prime128M18M0 (C = 2^18 + 1, shift+add path) + type G = Prime128M18M0; + let c = G::from_canonical_u128_reduced(::MODULUS_OFFSET); + let expected = c * c - G::one(); + assert_eq!(G::solinas_reduce(&[u64::MAX; 4]), expected); + + // Verify with Prime128M54P0 (C = 2^54 - 1, shift-sub path) + type H = Prime128M54P0; + let c = H::from_canonical_u128_reduced(::MODULUS_OFFSET); + let expected = c * c - H::one(); + assert_eq!(H::solinas_reduce(&[u64::MAX; 4]), expected); + } + + #[test] + fn from_i64_handles_min_without_overflow() { + let x = F::from_i64(i64::MIN); + let y = F::from_u64(i64::MIN.unsigned_abs()); + assert_eq!(x + y, F::zero()); + } +} diff --git a/src/algebra/fields/fp32.rs b/src/algebra/fields/fp32.rs new file mode 100644 index 00000000..77419476 --- /dev/null +++ b/src/algebra/fields/fp32.rs @@ -0,0 +1,471 @@ +//! Prime field for primes of the form `p = 2^k − c` with `c` small, backed +//! by `u32` storage. +//! +//! Uses Solinas-style two-fold reduction: the offset `c` and fold point `k` +//! are computed at compile time from the const-generic modulus `P`. + +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use rand_core::RngCore; + +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{ + AdditiveGroup, CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible, + PseudoMersenneField, +}; +use std::io::{Read, Write}; + +/// Prime field element for primes `p = 2^k − c` stored as `u32`. +/// +/// The fold point `k` and offset `c = 2^k − p` are computed at compile time +/// from the const-generic `P`. Instantiating with a modulus that does not +/// satisfy the Solinas conditions is a compile-time error. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct Fp32(pub(crate) u32); + +impl Fp32

{ + /// Fold point: smallest `k` such that `P ≤ 2^k`. + const BITS: u32 = 32 - P.leading_zeros(); + + /// Offset `c = 2^k − P`. + pub const C: u32 = { + let c = if Self::BITS == 32 { + 0u32.wrapping_sub(P) + } else { + (1u32 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u64) * (c as u64 + 1) < P as u64, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + /// Mask for extracting the low `BITS` bits from a u64. + const MASK: u64 = if Self::BITS == 32 { + u32::MAX as u64 + } else { + (1u64 << Self::BITS) - 1 + }; + + /// Create from a canonical representative in `[0, P)`. + #[inline] + pub fn from_canonical_u32(x: u32) -> Self { + debug_assert!(x < P); + Self(x) + } + + /// Return the canonical representative in `[0, P)`. + #[inline] + pub fn to_canonical_u32(self) -> u32 { + self.0 + } + + /// Solinas reduction: fold a u64 at bit `BITS` until the value fits, + /// then conditionally subtract `P`. + /// + /// For multiplication products (< 2^{2·BITS}) exactly 2 folds suffice; + /// for arbitrary u64 inputs (e.g. `from_u64`) the loop runs at most + /// `ceil(64 / BITS)` iterations. + #[inline(always)] + fn reduce_u64(x: u64) -> u32 { + let c = Self::C as u64; + let mut v = x; + while v >> Self::BITS != 0 { + v = (v & Self::MASK) + c * (v >> Self::BITS); + } + let reduced = v.wrapping_sub(P as u64); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + /// Reduce a `u128` to canonical form (for `from_canonical_u128_reduced`). + #[inline(always)] + fn reduce_u128(x: u128) -> u32 { + let c = Self::C as u128; + let bits = Self::BITS; + let mask = if bits == 32 { + u32::MAX as u128 + } else { + (1u128 << bits) - 1 + }; + let mut v = x; + while v >> bits != 0 { + v = (v & mask) + c * (v >> bits); + } + let f = v as u64; + let reduced = f.wrapping_sub(P as u64); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + /// Two-fold Solinas reduction for multiplication products. + /// + /// Input must be < 2^{2·BITS} (guaranteed for `a*b` where `a,b < P`). + /// Exactly 2 folds + conditional subtract, no loop. + #[inline(always)] + fn reduce_product(x: u64) -> u32 { + let c = Self::C as u64; + let f1 = (x & Self::MASK) + c * (x >> Self::BITS); + let f2 = (f1 & Self::MASK) + c * (f1 >> Self::BITS); + let reduced = f2.wrapping_sub(P as u64); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + #[inline(always)] + fn add_raw(a: u32, b: u32) -> u32 { + let s = (a as u64) + (b as u64); + let reduced = s.wrapping_sub(P as u64); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + #[inline(always)] + fn sub_raw(a: u32, b: u32) -> u32 { + let diff = (a as u64).wrapping_sub(b as u64); + let borrow = diff >> 63; + diff.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + #[inline(always)] + fn mul_raw(a: u32, b: u32) -> u32 { + Self::reduce_product((a as u64) * (b as u64)) + } + + #[inline(always)] + fn sqr_raw(a: u32) -> u32 { + Self::mul_raw(a, a) + } + + /// Squaring, equivalent to `self * self`. + #[inline(always)] + pub fn square(self) -> Self { + Self(Self::sqr_raw(self.0)) + } + + fn pow(self, mut exp: u64) -> Self { + let mut base = self; + let mut acc = Self::one(); + while exp > 0 { + if (exp & 1) == 1 { + acc *= base; + } + base = base.square(); + exp >>= 1; + } + acc + } + + /// Extract the canonical value. + #[inline(always)] + pub fn to_limbs(self) -> u32 { + self.0 + } + + /// 32×32 → 64-bit widening multiply, **no reduction**. + #[inline(always)] + pub fn mul_wide(self, other: Self) -> u64 { + (self.0 as u64) * (other.0 as u64) + } + + /// 32×32 → 64-bit widening multiply with a raw `u32` operand, + /// **no reduction**. + #[inline(always)] + pub fn mul_wide_u32(self, other: u32) -> u64 { + (self.0 as u64) * (other as u64) + } + + /// Reduce a u64 value via Solinas folding to a canonical field element. + #[inline(always)] + pub fn solinas_reduce(x: u64) -> Self { + Self(Self::reduce_u64(x)) + } +} + +impl Add for Fp32

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Self(Self::add_raw(self.0, rhs.0)) + } +} + +impl Sub for Fp32

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Self(Self::sub_raw(self.0, rhs.0)) + } +} + +impl Mul for Fp32

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Self(Self::mul_raw(self.0, rhs.0)) + } +} + +impl Neg for Fp32

{ + type Output = Self; + #[inline] + fn neg(self) -> Self::Output { + Self(Self::sub_raw(0, self.0)) + } +} + +impl AddAssign for Fp32

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for Fp32

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for Fp32

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl<'a, const P: u32> Add<&'a Self> for Fp32

{ + type Output = Self; + #[inline] + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} + +impl<'a, const P: u32> Sub<&'a Self> for Fp32

{ + type Output = Self; + #[inline] + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} + +impl<'a, const P: u32> Mul<&'a Self> for Fp32

{ + type Output = Self; + #[inline] + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl Valid for Fp32

{ + fn check(&self) -> Result<(), SerializationError> { + if self.0 < P { + Ok(()) + } else { + Err(SerializationError::InvalidData("Fp32 out of range".into())) + } + } +} + +impl HachiSerialize for Fp32

{ + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(&mut writer, Compress::No)?; + Ok(()) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 4 + } +} + +impl HachiDeserialize for Fp32

{ + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let x = u32::deserialize_with_mode(&mut reader, Compress::No, validate)?; + if matches!(validate, Validate::Yes) && x >= P { + return Err(SerializationError::InvalidData( + "Fp32 out of range".to_string(), + )); + } + let out = if matches!(validate, Validate::Yes) { + Self(x) + } else { + Self(Self::reduce_u64(x as u64)) + }; + Ok(out) + } +} + +impl AdditiveGroup for Fp32

{ + const ZERO: Self = Self(0); +} + +impl FieldCore for Fp32

{ + fn one() -> Self { + Self(if P > 1 { 1 } else { 0 }) + } + + fn is_zero(&self) -> bool { + self.0 == 0 + } + + fn inv(self) -> Option { + let inv = self.inv_or_zero(); + if self.is_zero() { + None + } else { + Some(inv) + } + } + + const TWO_INV: Self = Self((P as u64).div_ceil(2) as u32); +} + +impl Invertible for Fp32

{ + fn inv_or_zero(self) -> Self { + let candidate = self.pow((P as u64).wrapping_sub(2)); + let nz = ((self.0 | self.0.wrapping_neg()) >> 31) & 1; + let mask = 0u32.wrapping_sub(nz); + Self(candidate.0 & mask) + } +} + +impl FieldSampling for Fp32

{ + fn sample(rng: &mut R) -> Self { + Self(Self::reduce_u64(rng.next_u64())) + } +} + +impl FromSmallInt for Fp32

{ + fn from_u64(val: u64) -> Self { + Self(Self::reduce_u64(val)) + } + + fn from_i64(val: i64) -> Self { + if val >= 0 { + Self::from_u64(val as u64) + } else { + -Self::from_u64(val.unsigned_abs()) + } + } +} + +impl CanonicalField for Fp32

{ + fn to_canonical_u128(self) -> u128 { + self.0 as u128 + } + + fn from_canonical_u128_checked(val: u128) -> Option { + if val < P as u128 { + Some(Self(val as u32)) + } else { + None + } + } + + fn from_canonical_u128_reduced(val: u128) -> Self { + Self(Self::reduce_u128(val)) + } +} + +impl PseudoMersenneField for Fp32

{ + const MODULUS_BITS: u32 = Self::BITS; + const MODULUS_OFFSET: u128 = Self::C as u128; +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp32<251>; // 2^8 - 5 + + #[test] + fn solinas_constants() { + assert_eq!(F::BITS, 8); + assert_eq!(F::C, 5); + assert_eq!(F::MASK, 255); + + type G = Fp32<{ (1u32 << 24) - 3 }>; // 2^24 - 3 + assert_eq!(G::BITS, 24); + assert_eq!(G::C, 3); + } + + #[test] + fn basic_arithmetic() { + let a = F::from_u64(100); + let b = F::from_u64(200); + assert_eq!((a + b).to_canonical_u32(), (100 + 200) % 251); + assert_eq!((a * b).to_canonical_u32(), (100 * 200) % 251); + assert_eq!((b - a).to_canonical_u32(), 100); + assert_eq!((-a).to_canonical_u32(), 251 - 100); + } + + #[test] + fn mul_wide_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0x1234_5678); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b: F = FieldSampling::sample(&mut rng); + let expected = a * b; + let reduced = F::solinas_reduce(a.mul_wide(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_u32_matches() { + let mut rng = StdRng::seed_from_u64(0xabcd_ef01); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b = rng.next_u32() % 251; + let expected = a * F::from_canonical_u32(b); + let reduced = F::solinas_reduce(a.mul_wide_u32(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn reduce_large_values() { + assert_eq!( + F::from_u64(u64::MAX).to_canonical_u32(), + (u64::MAX % 251) as u32 + ); + assert_eq!(F::from_u64(0).to_canonical_u32(), 0); + assert_eq!(F::from_u64(251).to_canonical_u32(), 0); + assert_eq!(F::from_u64(252).to_canonical_u32(), 1); + } + + #[test] + fn pseudo_mersenne_trait() { + assert_eq!(::MODULUS_BITS, 8); + assert_eq!(::MODULUS_OFFSET, 5); + } + + #[test] + fn cross_prime_32bit() { + type G = Fp32<{ u32::MAX - 98 }>; // 2^32 - 99 + assert_eq!(G::BITS, 32); + assert_eq!(G::C, 99); + + let a = G::from_u64(1_000_000); + let b = G::from_u64(2_000_000); + let product = (1_000_000u64 * 2_000_000u64) % ((1u64 << 32) - 99); + assert_eq!((a * b).to_canonical_u32(), product as u32); + } +} diff --git a/src/algebra/fields/fp64.rs b/src/algebra/fields/fp64.rs new file mode 100644 index 00000000..5e88ec5b --- /dev/null +++ b/src/algebra/fields/fp64.rs @@ -0,0 +1,585 @@ +//! Prime field for primes of the form `p = 2^k − c` with `c` small, backed +//! by `u64` storage. +//! +//! Uses Solinas-style two-fold reduction. For `c = 2^a ± 1` the fold +//! multiply is replaced by shift+add/sub, saving a u128 widening multiply. + +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use rand_core::RngCore; + +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{ + AdditiveGroup, CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible, + PseudoMersenneField, +}; +use std::io::{Read, Write}; + +use super::util::{is_pow2_u64, log2_pow2_u64, mul64_wide}; + +/// Prime field element for primes `p = 2^k − c` stored as `u64`. +/// +/// The fold point `k` and offset `c = 2^k − p` are computed at compile time +/// from the const-generic `P`. For `c = 2^a ± 1`, the fold multiply is +/// replaced by shift+add/sub. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct Fp64(pub(crate) u64); + +impl Fp64

{ + /// Fold point: smallest `k` such that `P ≤ 2^k`. + const BITS: u32 = 64 - P.leading_zeros(); + + /// Offset `c = 2^k − P`. + pub const C: u64 = { + let c = if Self::BITS == 64 { + 0u64.wrapping_sub(P) + } else { + (1u64 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u128) * (c as u128 + 1) < P as u128, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + /// +1 means `C = 2^a + 1`, -1 means `C = 2^a - 1`, 0 means generic. + const C_SHIFT_KIND: i8 = { + let c = Self::C; + if c > 1 && is_pow2_u64(c - 1) { + 1 + } else if c == u64::MAX || is_pow2_u64(c + 1) { + -1 + } else { + 0 + } + }; + + const C_SHIFT: u32 = { + let c = Self::C; + if Self::C_SHIFT_KIND == 1 { + log2_pow2_u64(c - 1) + } else if Self::C_SHIFT_KIND == -1 { + if c == u64::MAX { + 64 + } else { + log2_pow2_u64(c + 1) + } + } else { + 0 + } + }; + + /// Mask for extracting the low `BITS` bits from a u128. + const MASK: u128 = if Self::BITS == 64 { + u64::MAX as u128 + } else { + (1u128 << Self::BITS) - 1 + }; + + /// u64-width mask (only valid when BITS < 64). + const MASK64: u64 = if Self::BITS < 64 { + (1u64 << Self::BITS) - 1 + } else { + u64::MAX + }; + + /// Whether Solinas folding of a multiplication product can stay + /// entirely in u64. True when BITS < 64 and C·2^BITS < 2^64. + const FOLD_IN_U64: bool = Self::BITS < 64 && (Self::C as u128) < (1u128 << (64 - Self::BITS)); + + /// u64 multiply by C, split into u32-wide halves so LLVM emits + /// `umull` (32×32→64) instead of promoting to u128. + /// Only valid when C fits in u32 (always true: C < sqrt(P) < 2^32). + #[inline(always)] + fn mul_c_narrow(x: u64) -> u64 { + #[cfg(target_arch = "x86_64")] + { + // x86_64 has fast scalar 64-bit multiply; use one multiply instead + // of two widened 32-bit multiplies in the fold hot path. + Self::C.wrapping_mul(x) + } + #[cfg(not(target_arch = "x86_64"))] + { + let c = Self::C as u32; + let x_lo = x as u32; + let x_hi = (x >> 32) as u32; + (c as u64 * x_lo as u64).wrapping_add((c as u64 * x_hi as u64) << 32) + } + } + + /// Multiply `x` by `C`. For `C = 2^a ± 1` uses shift+add/sub. + #[inline(always)] + fn mul_c(x: u64) -> u128 { + if Self::C_SHIFT_KIND == 1 { + ((x as u128) << Self::C_SHIFT) + x as u128 + } else if Self::C_SHIFT_KIND == -1 { + ((x as u128) << Self::C_SHIFT) - x as u128 + } else { + (Self::C as u128) * (x as u128) + } + } + + /// Create from a canonical representative in `[0, P)`. + #[inline] + pub fn from_canonical_u64(x: u64) -> Self { + debug_assert!(x < P); + Self(x) + } + + /// Return the canonical representative in `[0, P)`. + #[inline] + pub fn to_canonical_u64(self) -> u64 { + self.0 + } + + /// Solinas reduction: fold a u128 at bit `BITS` until the value fits, + /// then conditionally subtract `P`. + /// + /// For multiplication products (< 2^{2·BITS}) exactly 2 folds suffice; + /// for arbitrary u128 inputs the loop runs at most `ceil(128 / BITS)` + /// iterations. + #[inline(always)] + fn reduce_u128(x: u128) -> u64 { + let mut v = x; + while v >> Self::BITS != 0 { + v = (v & Self::MASK) + Self::mul_c((v >> Self::BITS) as u64); + } + let reduced = v.wrapping_sub(P as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + + /// Two-fold Solinas reduction for multiplication products. + /// + /// Input must be < 2^{2·BITS} (guaranteed for `a*b` where `a,b < P`). + /// Exactly 2 folds + conditional subtract, no loop. + /// + /// When `FOLD_IN_U64` is true the entire reduction stays in u64, + /// avoiding expensive u128 mask/shift on sub-word primes. + #[inline(always)] + fn reduce_product(x: u128) -> u64 { + if Self::FOLD_IN_U64 { + let lo = x as u64; + let hi = (x >> 64) as u64; + let high = (lo >> Self::BITS) | (hi << (64 - Self::BITS)); + let f1 = (lo & Self::MASK64) + Self::mul_c_narrow(high); + let f2 = (f1 & Self::MASK64) + Self::mul_c_narrow(f1 >> Self::BITS); + let reduced = f2.wrapping_sub(P); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P) + } else { + let f1 = (x & Self::MASK) + Self::mul_c((x >> Self::BITS) as u64); + let f2 = (f1 & Self::MASK) + Self::mul_c((f1 >> Self::BITS) as u64); + let reduced = f2.wrapping_sub(P as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + } + + /// BMI2 fast path: avoid re-materializing `u128` product in the common + /// sub-word configuration where reduction stays in `u64`. + #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] + #[inline(always)] + fn reduce_product_wide(lo: u64, hi: u64) -> u64 { + if Self::FOLD_IN_U64 { + let high = (lo >> Self::BITS) | (hi << (64 - Self::BITS)); + let f1 = (lo & Self::MASK64) + Self::mul_c_narrow(high); + let f2 = (f1 & Self::MASK64) + Self::mul_c_narrow(f1 >> Self::BITS); + let reduced = f2.wrapping_sub(P); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P) + } else { + Self::reduce_product(lo as u128 | ((hi as u128) << 64)) + } + } + + #[inline(always)] + fn add_raw(a: u64, b: u64) -> u64 { + if Self::BITS <= 62 { + let s = a + b; + let reduced = s.wrapping_sub(P); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P) + } else { + let s = (a as u128) + (b as u128); + let reduced = s.wrapping_sub(P as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + } + + #[inline(always)] + fn sub_raw(a: u64, b: u64) -> u64 { + if Self::BITS <= 62 { + let diff = a.wrapping_sub(b); + let borrow = diff >> 63; + diff.wrapping_add(borrow.wrapping_neg() & P) + } else { + let diff = (a as u128).wrapping_sub(b as u128); + let borrow = diff >> 127; + diff.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + } + + #[inline(always)] + fn mul_raw(a: u64, b: u64) -> u64 { + #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] + { + let (lo, hi) = mul64_wide(a, b); + Self::reduce_product_wide(lo, hi) + } + #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))] + { + Self::reduce_product((a as u128) * (b as u128)) + } + } + + #[inline(always)] + fn sqr_raw(a: u64) -> u64 { + Self::mul_raw(a, a) + } + + /// Squaring, equivalent to `self * self`. + #[inline(always)] + pub fn square(self) -> Self { + Self(Self::sqr_raw(self.0)) + } + + fn pow(self, mut exp: u64) -> Self { + let mut base = self; + let mut acc = Self::one(); + while exp > 0 { + if (exp & 1) == 1 { + acc *= base; + } + base = base.square(); + exp >>= 1; + } + acc + } + + /// Extract the canonical value. + #[inline(always)] + pub fn to_limbs(self) -> u64 { + self.0 + } + + /// 64×64 → 128-bit widening multiply, **no reduction**. + #[inline(always)] + pub fn mul_wide(self, other: Self) -> u128 { + let (lo, hi) = mul64_wide(self.0, other.0); + lo as u128 | ((hi as u128) << 64) + } + + /// 64×64 → 128-bit widening multiply with a raw `u64` operand, + /// **no reduction**. + #[inline(always)] + pub fn mul_wide_u64(self, other: u64) -> u128 { + let (lo, hi) = mul64_wide(self.0, other); + lo as u128 | ((hi as u128) << 64) + } + + /// Reduce a u128 value via Solinas folding to a canonical field element. + #[inline(always)] + pub fn solinas_reduce(x: u128) -> Self { + Self(Self::reduce_u128(x)) + } +} + +impl Add for Fp64

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Self(Self::add_raw(self.0, rhs.0)) + } +} + +impl Sub for Fp64

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Self(Self::sub_raw(self.0, rhs.0)) + } +} + +impl Mul for Fp64

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Self(Self::mul_raw(self.0, rhs.0)) + } +} + +impl Neg for Fp64

{ + type Output = Self; + #[inline] + fn neg(self) -> Self::Output { + Self(Self::sub_raw(0, self.0)) + } +} + +impl AddAssign for Fp64

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for Fp64

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for Fp64

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl<'a, const P: u64> Add<&'a Self> for Fp64

{ + type Output = Self; + #[inline] + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} + +impl<'a, const P: u64> Sub<&'a Self> for Fp64

{ + type Output = Self; + #[inline] + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} + +impl<'a, const P: u64> Mul<&'a Self> for Fp64

{ + type Output = Self; + #[inline] + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl Valid for Fp64

{ + fn check(&self) -> Result<(), SerializationError> { + if self.0 < P { + Ok(()) + } else { + Err(SerializationError::InvalidData("Fp64 out of range".into())) + } + } +} + +impl HachiSerialize for Fp64

{ + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(&mut writer, Compress::No)?; + Ok(()) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 8 + } +} + +impl HachiDeserialize for Fp64

{ + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let x = u64::deserialize_with_mode(&mut reader, Compress::No, validate)?; + if matches!(validate, Validate::Yes) && x >= P { + return Err(SerializationError::InvalidData( + "Fp64 out of range".to_string(), + )); + } + let out = if matches!(validate, Validate::Yes) { + Self(x) + } else { + Self(Self::reduce_u128(x as u128)) + }; + Ok(out) + } +} + +impl AdditiveGroup for Fp64

{ + const ZERO: Self = Self(0); +} + +impl FieldCore for Fp64

{ + fn one() -> Self { + Self(if P > 1 { 1 } else { 0 }) + } + + fn is_zero(&self) -> bool { + self.0 == 0 + } + + fn inv(self) -> Option { + let inv = self.inv_or_zero(); + if self.is_zero() { + None + } else { + Some(inv) + } + } + + const TWO_INV: Self = Self((P as u128).div_ceil(2) as u64); +} + +impl Invertible for Fp64

{ + fn inv_or_zero(self) -> Self { + let candidate = self.pow(P.wrapping_sub(2)); + let nz = ((self.0 | self.0.wrapping_neg()) >> 63) & 1; + let mask = 0u64.wrapping_sub(nz); + Self(candidate.0 & mask) + } +} + +impl FieldSampling for Fp64

{ + fn sample(rng: &mut R) -> Self { + let lo = rng.next_u64() as u128; + let hi = rng.next_u64() as u128; + Self(Self::reduce_u128(lo | (hi << 64))) + } +} + +impl FromSmallInt for Fp64

{ + fn from_u64(val: u64) -> Self { + Self(Self::reduce_u128(val as u128)) + } + + fn from_i64(val: i64) -> Self { + if val >= 0 { + Self::from_u64(val as u64) + } else { + -Self::from_u64(val.unsigned_abs()) + } + } +} + +impl CanonicalField for Fp64

{ + fn to_canonical_u128(self) -> u128 { + self.0 as u128 + } + + fn from_canonical_u128_checked(val: u128) -> Option { + if val < P as u128 { + Some(Self(val as u64)) + } else { + None + } + } + + fn from_canonical_u128_reduced(val: u128) -> Self { + Self(Self::reduce_u128(val)) + } +} + +impl PseudoMersenneField for Fp64

{ + const MODULUS_BITS: u32 = Self::BITS; + const MODULUS_OFFSET: u128 = Self::C as u128; +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F40 = Fp64<{ (1u64 << 40) - 195 }>; // 2^40 - 195 + type F64 = Fp64<{ u64::MAX - 58 }>; // 2^64 - 59 + + #[test] + fn solinas_constants() { + assert_eq!(F40::BITS, 40); + assert_eq!(F40::C, 195); + + assert_eq!(F64::BITS, 64); + assert_eq!(F64::C, 59); + } + + #[test] + fn basic_arithmetic_sub_word() { + let a = F40::from_u64(1_000_000); + let b = F40::from_u64(2_000_000); + let p = (1u64 << 40) - 195; + assert_eq!((a + b).to_canonical_u64(), 3_000_000); + assert_eq!( + (a * b).to_canonical_u64(), + (1_000_000u128 * 2_000_000u128 % p as u128) as u64 + ); + } + + #[test] + fn basic_arithmetic_full_word() { + let a = F64::from_u64(1_000_000_000); + let b = F64::from_u64(2_000_000_000); + let p = u64::MAX - 58; + assert_eq!( + (a * b).to_canonical_u64(), + (1_000_000_000u128 * 2_000_000_000u128 % p as u128) as u64 + ); + } + + #[test] + fn mul_wide_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0xdead_beef); + for _ in 0..1000 { + let a: F40 = FieldSampling::sample(&mut rng); + let b: F40 = FieldSampling::sample(&mut rng); + let expected = a * b; + let reduced = F40::solinas_reduce(a.mul_wide(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_u64_matches() { + let mut rng = StdRng::seed_from_u64(0xcafe_d00d); + for _ in 0..1000 { + let a: F40 = FieldSampling::sample(&mut rng); + let b = rng.next_u64() % ((1u64 << 40) - 195); + let expected = a * F40::from_canonical_u64(b); + let reduced = F40::solinas_reduce(a.mul_wide_u64(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn pseudo_mersenne_trait() { + assert_eq!(::MODULUS_BITS, 40); + assert_eq!(::MODULUS_OFFSET, 195); + assert_eq!(::MODULUS_BITS, 64); + assert_eq!(::MODULUS_OFFSET, 59); + } + + #[test] + fn shift_optimization_detected() { + type G = Fp64<{ (1u64 << 56) - 27 }>; // C = 27, not 2^a±1 + assert_eq!(G::C_SHIFT_KIND, 0); + + type H = Fp64<{ u64::MAX - 58 }>; // C = 59, not 2^a±1 + assert_eq!(H::C_SHIFT_KIND, 0); + } + + #[test] + fn reduce_u128_large() { + assert_eq!(F64::from_canonical_u128_reduced(u128::MAX), { + let p = u64::MAX as u128 - 58; + F64::from_canonical_u64((u128::MAX % p) as u64) + }); + } +} diff --git a/src/algebra/fields/lift.rs b/src/algebra/fields/lift.rs new file mode 100644 index 00000000..e05a8aa7 --- /dev/null +++ b/src/algebra/fields/lift.rs @@ -0,0 +1,100 @@ +//! Helpers for embedding base fields into extension fields. + +use crate::algebra::fields::ext::{Fp2, Fp2Config, Fp4, Fp4Config}; +use crate::primitives::serialization::Valid; +use crate::{FieldCore, FromSmallInt}; + +/// Lift a base-field element into an extension field. +/// +/// This is intentionally small: for extension towers we embed into the constant term. +pub trait LiftBase: FieldCore { + /// Embed `x ∈ F` as a constant in `Self`. + fn lift_base(x: F) -> Self; +} + +/// An algebraic extension of base field `F`. +/// +/// Provides the extension degree and a constructor from a slice of base-field +/// coefficients (in the canonical basis `{1, u, u^2, ...}`). +pub trait ExtField: FieldCore + LiftBase + FromSmallInt { + /// Extension degree: `[Self : F]`. + const EXT_DEGREE: usize; + + /// Construct from a coefficient slice `[c0, c1, ..., c_{d-1}]`. + /// + /// # Panics + /// Panics if `coeffs.len() != Self::EXT_DEGREE`. + fn from_base_slice(coeffs: &[F]) -> Self; +} + +impl ExtField for F { + const EXT_DEGREE: usize = 1; + + #[inline] + fn from_base_slice(coeffs: &[F]) -> Self { + assert_eq!(coeffs.len(), 1); + coeffs[0] + } +} + +impl ExtField for Fp2 +where + F: FieldCore + FromSmallInt + Valid, + C: Fp2Config, +{ + const EXT_DEGREE: usize = 2; + + #[inline] + fn from_base_slice(coeffs: &[F]) -> Self { + assert_eq!(coeffs.len(), 2); + Self::new(coeffs[0], coeffs[1]) + } +} + +impl ExtField for Fp4 +where + F: FieldCore + FromSmallInt + Valid, + C2: Fp2Config, + C4: Fp4Config, +{ + const EXT_DEGREE: usize = 4; + + #[inline] + fn from_base_slice(coeffs: &[F]) -> Self { + assert_eq!(coeffs.len(), 4); + Self::new( + Fp2::new(coeffs[0], coeffs[1]), + Fp2::new(coeffs[2], coeffs[3]), + ) + } +} + +impl LiftBase for F { + #[inline] + fn lift_base(x: F) -> Self { + x + } +} + +impl LiftBase for Fp2 +where + F: FieldCore + Valid, + C: Fp2Config, +{ + #[inline] + fn lift_base(x: F) -> Self { + Self::new(x, F::zero()) + } +} + +impl LiftBase for Fp4 +where + F: FieldCore + Valid, + C2: Fp2Config, + C4: Fp4Config, +{ + #[inline] + fn lift_base(x: F) -> Self { + Self::new(Fp2::new(x, F::zero()), Fp2::new(F::zero(), F::zero())) + } +} diff --git a/src/algebra/fields/mod.rs b/src/algebra/fields/mod.rs new file mode 100644 index 00000000..90ff6c91 --- /dev/null +++ b/src/algebra/fields/mod.rs @@ -0,0 +1,47 @@ +//! Prime fields and extension field towers. + +pub mod ext; +pub mod fp128; +pub mod fp32; +pub mod fp64; +pub mod lift; +pub mod packed; +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(target_feature = "avx512f", target_feature = "avx512dq")) +))] +pub mod packed_avx2; +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx512f", + target_feature = "avx512dq" +))] +pub mod packed_avx512; +pub mod packed_ext; +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +pub mod packed_neon; +pub mod pseudo_mersenne; +pub(crate) mod util; +pub mod wide; + +pub use ext::{Ext2, Ext4, Fp2, Fp2Config, Fp4, Fp4Config, NegOneNr, TwoNr, UnitNr}; +pub use fp128::{ + Fp128, Prime128M13M4P0, Prime128M37P3P0, Prime128M52M3P0, Prime128M54P4P0, Prime128M8M4M1M0, +}; +pub use fp32::Fp32; +pub use fp64::Fp64; +pub use lift::{ExtField, LiftBase}; +pub use packed::{ + Fp128Packing, Fp32Packing, Fp64Packing, HasPacking, NoPacking, PackedField, PackedValue, +}; +pub use pseudo_mersenne::{ + is_pow2_offset, pow2_offset, pseudo_mersenne_modulus, Pow2Offset128Field, Pow2Offset24Field, + Pow2Offset30Field, Pow2Offset31Field, Pow2Offset32Field, Pow2Offset40Field, Pow2Offset48Field, + Pow2Offset56Field, Pow2Offset64Field, Pow2OffsetPrimeSpec, POW2_OFFSET_IMPLEMENTED_MAX_BITS, + POW2_OFFSET_MAX, POW2_OFFSET_PRIMES, POW2_OFFSET_TABLE, +}; +pub use wide::{ + AccumPair, Fp128MulU64Accum, Fp128ProductAccum, Fp128x8i32, Fp32x2i32, Fp64ProductAccum, + Fp64x4i32, HasUnreducedOps, HasWide, ReduceTo, +}; diff --git a/src/algebra/fields/packed.rs b/src/algebra/fields/packed.rs new file mode 100644 index 00000000..6da7917b --- /dev/null +++ b/src/algebra/fields/packed.rs @@ -0,0 +1,455 @@ +//! Packed field abstractions and architecture-specific SIMD backends. + +use crate::algebra::fields::{Fp128, Fp32, Fp64}; +use crate::FieldCore; +use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; + +/// Array-like packed values over a scalar type. +pub trait PackedValue: 'static + Copy + Send + Sync { + /// Scalar value type carried by each lane. + type Value: 'static + Copy + Send + Sync; + + /// Number of scalar lanes. + const WIDTH: usize; + + /// Build from a lane generator. + fn from_fn(f: F) -> Self + where + F: FnMut(usize) -> Self::Value; + + /// Extract one lane. + fn extract(&self, lane: usize) -> Self::Value; + + /// Pack a scalar slice into packed values. + /// + /// # Panics + /// + /// Panics if the length is not divisible by `WIDTH`. + #[inline] + fn pack_slice(buf: &[Self::Value]) -> Vec { + assert!( + buf.len() % Self::WIDTH == 0, + "slice length {} must be divisible by WIDTH {}", + buf.len(), + Self::WIDTH + ); + buf.chunks_exact(Self::WIDTH) + .map(|chunk| Self::from_fn(|i| chunk[i])) + .collect() + } + + /// Packed prefix + scalar suffix split. + #[inline] + fn pack_slice_with_suffix(buf: &[Self::Value]) -> (Vec, &[Self::Value]) { + let split = buf.len() - (buf.len() % Self::WIDTH); + let (packed, suffix) = buf.split_at(split); + (Self::pack_slice(packed), suffix) + } + + /// Unpack packed values into a flat scalar vector. + #[inline] + fn unpack_slice(buf: &[Self]) -> Vec { + let mut out = Vec::with_capacity(buf.len() * Self::WIDTH); + for packed in buf { + for lane in 0..Self::WIDTH { + out.push(packed.extract(lane)); + } + } + out + } +} + +/// Packed arithmetic over a scalar field. +pub trait PackedField: + PackedValue + Add + Sub + Mul +{ + /// Scalar field type. + type Scalar: FieldCore; + + /// Broadcast one scalar across all lanes. + fn broadcast(value: Self::Scalar) -> Self; +} + +/// Scalar fallback packed type with one lane. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +#[repr(transparent)] +pub struct NoPacking(pub [T; 1]); + +impl PackedValue for NoPacking +where + T: 'static + Copy + Send + Sync, +{ + type Value = T; + const WIDTH: usize = 1; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([f(0)]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert_eq!(lane, 0); + self.0[0] + } +} + +impl Add for NoPacking { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([self.0[0] + rhs.0[0]]) + } +} + +impl Sub for NoPacking { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([self.0[0] - rhs.0[0]]) + } +} + +impl Mul for NoPacking { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + Self([self.0[0] * rhs.0[0]]) + } +} + +impl AddAssign for NoPacking { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for NoPacking { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for NoPacking { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for NoPacking { + type Scalar = T; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value]) + } +} + +/// Scalar field -> packed field association. +pub trait HasPacking: FieldCore { + /// Packed representation for this scalar field. + type Packing: PackedField; +} + +/// Selected packed backend for `Fp128`. +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +pub type Fp128Packing = super::packed_neon::PackedFp128Neon

; + +/// Selected packed backend for `Fp128`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx512f", + target_feature = "avx512dq" +))] +pub type Fp128Packing = super::packed_avx512::PackedFp128Avx512

; + +/// Selected packed backend for `Fp128`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(target_feature = "avx512f", target_feature = "avx512dq")) +))] +pub type Fp128Packing = super::packed_avx2::PackedFp128Avx2

; + +/// Selected packed backend for `Fp128`. +#[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "x86_64", target_feature = "avx2") +)))] +pub type Fp128Packing = NoPacking>; + +impl HasPacking for Fp128

{ + type Packing = Fp128Packing

; +} + +/// Selected packed backend for `Fp32`. +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +pub type Fp32Packing = super::packed_neon::PackedFp32Neon

; + +/// Selected packed backend for `Fp32`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx512f", + target_feature = "avx512dq" +))] +pub type Fp32Packing = super::packed_avx512::PackedFp32Avx512

; + +/// Selected packed backend for `Fp32`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(target_feature = "avx512f", target_feature = "avx512dq")) +))] +pub type Fp32Packing = super::packed_avx2::PackedFp32Avx2

; + +/// Selected packed backend for `Fp32`. +#[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "x86_64", target_feature = "avx2") +)))] +pub type Fp32Packing = NoPacking>; + +impl HasPacking for Fp32

{ + type Packing = Fp32Packing

; +} + +/// Selected packed backend for `Fp64`. +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +pub type Fp64Packing = super::packed_neon::PackedFp64Neon

; + +/// Selected packed backend for `Fp64`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx512f", + target_feature = "avx512dq" +))] +pub type Fp64Packing = super::packed_avx512::PackedFp64Avx512

; + +/// Selected packed backend for `Fp64`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(target_feature = "avx512f", target_feature = "avx512dq")) +))] +pub type Fp64Packing = super::packed_avx2::PackedFp64Avx2

; + +/// Selected packed backend for `Fp64`. +#[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "x86_64", target_feature = "avx2") +)))] +pub type Fp64Packing = NoPacking>; + +impl HasPacking for Fp64

{ + type Packing = Fp64Packing

; +} + +#[cfg(test)] +mod tests { + use super::{HasPacking, PackedField, PackedValue}; + use crate::algebra::fields::{ + Pow2Offset24Field, Pow2Offset31Field, Pow2Offset32Field, Pow2Offset40Field, + Pow2Offset64Field, Prime128M13M4P0, + }; + use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt}; + use rand::{rngs::StdRng, RngCore, SeedableRng}; + + fn rand_u128(rng: &mut R) -> u128 { + let lo = rng.next_u64() as u128; + let hi = rng.next_u64() as u128; + lo | (hi << 64) + } + + fn check_packed_add_sub_mul(seed: u64) + where + F: FieldCore + FieldSampling + PartialEq + std::fmt::Debug, + PF: PackedField + PackedValue, + { + let mut rng = StdRng::seed_from_u64(seed); + let len = PF::WIDTH * 17 + 3; + let lhs: Vec = (0..len).map(|_| FieldSampling::sample(&mut rng)).collect(); + let rhs: Vec = (0..len).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let (lhs_p, lhs_s) = PF::pack_slice_with_suffix(&lhs); + let (rhs_p, rhs_s) = PF::pack_slice_with_suffix(&rhs); + + let add_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a + b) + .collect(); + let sub_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a - b) + .collect(); + let mul_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a * b) + .collect(); + + let mut add_out = PF::unpack_slice(&add_p); + let mut sub_out = PF::unpack_slice(&sub_p); + let mut mul_out = PF::unpack_slice(&mul_p); + + for (&a, &b) in lhs_s.iter().zip(rhs_s.iter()) { + add_out.push(a + b); + sub_out.push(a - b); + mul_out.push(a * b); + } + + for i in 0..len { + assert_eq!( + add_out[i], + lhs[i] + rhs[i], + "packed add mismatch at lane {i}" + ); + assert_eq!( + sub_out[i], + lhs[i] - rhs[i], + "packed sub mismatch at lane {i}" + ); + assert_eq!( + mul_out[i], + lhs[i] * rhs[i], + "packed mul mismatch at lane {i}" + ); + } + } + + fn check_broadcast_roundtrip(val: F) + where + F: FieldCore + PartialEq + std::fmt::Debug, + PF: PackedField + PackedValue, + { + let p = PF::broadcast(val); + for lane in 0..PF::WIDTH { + assert_eq!(p.extract(lane), val); + } + } + + #[test] + fn packed_fp128_add_sub_mul_match_scalar() { + type F = Prime128M13M4P0; + type PF = ::Packing; + + let mut rng = StdRng::seed_from_u64(0x55aa_4422_1177_0033); + let len = PF::WIDTH * 17 + 3; + let lhs: Vec = (0..len) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let rhs: Vec = (0..len) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let (lhs_p, lhs_s) = PF::pack_slice_with_suffix(&lhs); + let (rhs_p, rhs_s) = PF::pack_slice_with_suffix(&rhs); + + let add_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a + b) + .collect(); + let sub_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a - b) + .collect(); + let mul_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a * b) + .collect(); + + let mut add_out = PF::unpack_slice(&add_p); + let mut sub_out = PF::unpack_slice(&sub_p); + let mut mul_out = PF::unpack_slice(&mul_p); + + for (&a, &b) in lhs_s.iter().zip(rhs_s.iter()) { + add_out.push(a + b); + sub_out.push(a - b); + mul_out.push(a * b); + } + + for i in 0..len { + assert_eq!( + add_out[i], + lhs[i] + rhs[i], + "packed add mismatch at lane {i}" + ); + assert_eq!( + sub_out[i], + lhs[i] - rhs[i], + "packed sub mismatch at lane {i}" + ); + assert_eq!( + mul_out[i], + lhs[i] * rhs[i], + "packed mul mismatch at lane {i}" + ); + } + } + + #[test] + fn fp128_broadcast_and_extract_roundtrip() { + type F = Prime128M13M4P0; + type PF = ::Packing; + check_broadcast_roundtrip::(F::from_u64(42)); + } + + #[test] + fn packed_fp32_24b_add_sub_mul() { + type F = Pow2Offset24Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa24_bb24_cc24_dd24); + } + + #[test] + fn packed_fp32_31b_add_sub_mul() { + type F = Pow2Offset31Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa31_bb31_cc31_dd31); + } + + #[test] + fn packed_fp32_32b_add_sub_mul() { + type F = Pow2Offset32Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa32_bb32_cc32_dd32); + } + + #[test] + fn fp32_broadcast_and_extract_roundtrip() { + type F = Pow2Offset24Field; + type PF = ::Packing; + check_broadcast_roundtrip::(F::from_u64(42)); + } + + #[test] + fn packed_fp64_40b_add_sub_mul() { + type F = Pow2Offset40Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa40_bb40_cc40_dd40); + } + + #[test] + fn packed_fp64_64b_add_sub_mul() { + type F = Pow2Offset64Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa64_bb64_cc64_dd64); + } + + #[test] + fn fp64_broadcast_and_extract_roundtrip() { + type F = Pow2Offset40Field; + type PF = ::Packing; + check_broadcast_roundtrip::(F::from_u64(42)); + } +} diff --git a/src/algebra/fields/packed_avx2.rs b/src/algebra/fields/packed_avx2.rs new file mode 100644 index 00000000..c29685e2 --- /dev/null +++ b/src/algebra/fields/packed_avx2.rs @@ -0,0 +1,767 @@ +//! AVX2 packed backends for Fp32, Fp64, Fp128. +//! +//! Techniques adapted from plonky2 (Goldilocks) and plonky3 (Mersenne-31). + +use super::packed::{PackedField, PackedValue}; +use crate::algebra::fields::{Fp128, Fp32, Fp64}; +use core::arch::x86_64::*; +use core::fmt; +use core::mem::transmute; +use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; + +/// Duplicate high 32 bits of each 64-bit lane into the low 32 bits. +/// Uses the float `movehdup` instruction which runs on port 5 (doesn't compete +/// with multiply on ports 0/1). +#[inline(always)] +unsafe fn movehdup_epi32(x: __m256i) -> __m256i { + _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x))) +} + +#[inline(always)] +unsafe fn moveldup_epi32(x: __m256i) -> __m256i { + _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(x))) +} + +/// 64×64→128 schoolbook multiply using 32×32→64 partial products. +/// Returns (hi, lo) representing the 128-bit product. +#[inline] +unsafe fn mul64_64_256(x: __m256i, y: __m256i) -> (__m256i, __m256i) { + let x_hi = movehdup_epi32(x); + let y_hi = movehdup_epi32(y); + + let mul_ll = _mm256_mul_epu32(x, y); + let mul_lh = _mm256_mul_epu32(x, y_hi); + let mul_hl = _mm256_mul_epu32(x_hi, y); + let mul_hh = _mm256_mul_epu32(x_hi, y_hi); + + let mul_ll_hi = _mm256_srli_epi64::<32>(mul_ll); + let t0 = _mm256_add_epi64(mul_hl, mul_ll_hi); + let mask32 = _mm256_set1_epi64x(0xFFFF_FFFF_i64); + let t0_lo = _mm256_and_si256(t0, mask32); + let t0_hi = _mm256_srli_epi64::<32>(t0); + let t1 = _mm256_add_epi64(mul_lh, t0_lo); + let t2 = _mm256_add_epi64(mul_hh, t0_hi); + let t1_hi = _mm256_srli_epi64::<32>(t1); + let res_hi = _mm256_add_epi64(t2, t1_hi); + + let t1_lo = moveldup_epi32(t1); + let res_lo = _mm256_blend_epi32::<0b10101010>(mul_ll, t1_lo); + + (res_hi, res_lo) +} + +/// Number of `Fp32` lanes in an AVX2 packed vector. +pub const FP32_WIDTH: usize = 8; + +/// AVX2 packed arithmetic for `Fp32

`, processing 8 lanes. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PackedFp32Avx2(pub [Fp32

; FP32_WIDTH]); + +impl PackedFp32Avx2

{ + const BITS: u32 = 32 - P.leading_zeros(); + + const C: u32 = { + let c = if Self::BITS == 32 { + 0u32.wrapping_sub(P) + } else { + (1u32 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u64) * (c as u64 + 1) < P as u64, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + const MASK_U64: u64 = if Self::BITS == 32 { + u32::MAX as u64 + } else { + (1u64 << Self::BITS) - 1 + }; + + #[inline(always)] + fn to_vec(self) -> __m256i { + unsafe { transmute(self) } + } + + #[inline(always)] + unsafe fn from_vec(v: __m256i) -> Self { + unsafe { transmute(v) } + } +} + +impl Default for PackedFp32Avx2

{ + #[inline] + fn default() -> Self { + Self([Fp32(0); FP32_WIDTH]) + } +} + +impl fmt::Debug for PackedFp32Avx2

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp32Avx2").field(&self.0).finish() + } +} + +impl PartialEq for PackedFp32Avx2

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for PackedFp32Avx2

{} + +impl Add for PackedFp32Avx2

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm256_set1_epi32(P as i32); + + let result = if Self::BITS <= 31 { + let t = _mm256_add_epi32(a, b); + let u = _mm256_sub_epi32(t, p); + _mm256_min_epu32(t, u) + } else { + let c = _mm256_set1_epi32(Self::C as i32); + let t = _mm256_add_epi32(a, b); + // Emulate unsigned compare: XOR with 0x80000000 converts u32 compare to i32 + let sign32 = _mm256_set1_epi32(i32::MIN); + let overflow = + _mm256_cmpgt_epi32(_mm256_xor_si256(a, sign32), _mm256_xor_si256(t, sign32)); + // Step 1: correct overflow (2^32 ≡ C mod P) + let t2 = _mm256_add_epi32(t, _mm256_and_si256(overflow, c)); + // Step 2: subtract P if t2 >= P + let r = _mm256_sub_epi32(t2, p); + _mm256_min_epu32(t2, r) + }; + + Self::from_vec(result) + } + } +} + +impl Sub for PackedFp32Avx2

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm256_set1_epi32(P as i32); + + let result = if Self::BITS <= 31 { + let t = _mm256_sub_epi32(a, b); + let u = _mm256_add_epi32(t, p); + _mm256_min_epu32(t, u) + } else { + let t = _mm256_sub_epi32(a, b); + let sign32 = _mm256_set1_epi32(i32::MIN); + let underflow = + _mm256_cmpgt_epi32(_mm256_xor_si256(b, sign32), _mm256_xor_si256(a, sign32)); + let corrected = _mm256_add_epi32(t, p); + _mm256_blendv_epi8(t, corrected, underflow) + }; + + Self::from_vec(result) + } + } +} + +impl Mul for PackedFp32Avx2

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + + let prod_evn = _mm256_mul_epu32(a, b); + let a_odd = movehdup_epi32(a); + let b_odd = movehdup_epi32(b); + let prod_odd = _mm256_mul_epu32(a_odd, b_odd); + + let mask = _mm256_set1_epi64x(Self::MASK_U64 as i64); + let c_vec = _mm256_set1_epi64x(Self::C as i64); + let shift = _mm_set_epi64x(0, Self::BITS as i64); + + // Fold 1 + let evn_lo = _mm256_and_si256(prod_evn, mask); + let evn_hi = _mm256_srl_epi64(prod_evn, shift); + let evn_f1 = _mm256_add_epi64(evn_lo, _mm256_mul_epu32(evn_hi, c_vec)); + + let odd_lo = _mm256_and_si256(prod_odd, mask); + let odd_hi = _mm256_srl_epi64(prod_odd, shift); + let odd_f1 = _mm256_add_epi64(odd_lo, _mm256_mul_epu32(odd_hi, c_vec)); + + // Fold 2 + let evn_f1_lo = _mm256_and_si256(evn_f1, mask); + let evn_f1_hi = _mm256_srl_epi64(evn_f1, shift); + let evn_f2 = _mm256_add_epi64(evn_f1_lo, _mm256_mul_epu32(evn_f1_hi, c_vec)); + + let odd_f1_lo = _mm256_and_si256(odd_f1, mask); + let odd_f1_hi = _mm256_srl_epi64(odd_f1, shift); + let odd_f2 = _mm256_add_epi64(odd_f1_lo, _mm256_mul_epu32(odd_f1_hi, c_vec)); + + // Recombine even/odd: shift odd results into high 32-bit positions, blend. + let odd_shifted = _mm256_slli_epi64::<32>(odd_f2); + let combined = _mm256_blend_epi32::<0b10101010>(evn_f2, odd_shifted); + + // Conditional subtract P + let p_vec = _mm256_set1_epi32(P as i32); + let reduced = _mm256_sub_epi32(combined, p_vec); + Self::from_vec(_mm256_min_epu32(combined, reduced)) + } + } +} + +impl PackedValue for PackedFp32Avx2

{ + type Value = Fp32

; + const WIDTH: usize = FP32_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([f(0), f(1), f(2), f(3), f(4), f(5), f(6), f(7)]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP32_WIDTH); + self.0[lane] + } +} + +impl AddAssign for PackedFp32Avx2

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp32Avx2

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp32Avx2

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp32Avx2

{ + type Scalar = Fp32

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value; FP32_WIDTH]) + } +} + +/// Number of `Fp64` lanes in an AVX2 packed vector. +pub const FP64_WIDTH: usize = 4; + +/// AVX2 packed arithmetic for `Fp64

`, processing 4 lanes. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PackedFp64Avx2(pub [Fp64

; FP64_WIDTH]); + +impl PackedFp64Avx2

{ + const BITS: u32 = 64 - P.leading_zeros(); + + const C_LO: u64 = { + let c = if Self::BITS == 64 { + 0u64.wrapping_sub(P) + } else { + (1u64 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + c + }; + + const MASK64: u64 = if Self::BITS < 64 { + (1u64 << Self::BITS) - 1 + } else { + u64::MAX + }; + + #[inline(always)] + fn to_vec(self) -> __m256i { + unsafe { transmute(self) } + } + + #[inline(always)] + unsafe fn from_vec(v: __m256i) -> Self { + unsafe { transmute(v) } + } + + #[inline] + unsafe fn reduce128_vec(hi: __m256i, lo: __m256i) -> __m256i { + if Self::BITS < 64 { + Self::reduce128_small_k(hi, lo) + } else { + Self::reduce128_full_k(hi, lo) + } + } + + /// Reduction for BITS < 64. All intermediates fit in u64 — no overflow. + #[inline] + unsafe fn reduce128_small_k(hi: __m256i, lo: __m256i) -> __m256i { + let mask_k = _mm256_set1_epi64x(Self::MASK64 as i64); + let c_vec = _mm256_set1_epi64x(Self::C_LO as i64); + let p_vec = _mm256_set1_epi64x(P as i64); + let shift_k = _mm_set_epi64x(0, Self::BITS as i64); + let shift_64mk = _mm_set_epi64x(0, (64 - Self::BITS) as i64); + + let lo_k = _mm256_and_si256(lo, mask_k); + let lo_upper = _mm256_srl_epi64(lo, shift_k); + let hi_shifted = _mm256_sll_epi64(hi, shift_64mk); + let hi_k = _mm256_or_si256(lo_upper, hi_shifted); + + let c_hi_lo = _mm256_mul_epu32(c_vec, hi_k); + let hi_k_top = _mm256_srli_epi64::<32>(hi_k); + let c_hi_top = _mm256_mul_epu32(c_vec, hi_k_top); + let c_hi_top_shifted = _mm256_slli_epi64::<32>(c_hi_top); + let c_hi_full = _mm256_add_epi64(c_hi_lo, c_hi_top_shifted); + + let fold1 = _mm256_add_epi64(lo_k, c_hi_full); + + let fold1_lo_k = _mm256_and_si256(fold1, mask_k); + let fold1_hi = _mm256_srl_epi64(fold1, shift_k); + let c_fold1_hi = _mm256_mul_epu32(c_vec, fold1_hi); + let fold2 = _mm256_add_epi64(fold1_lo_k, c_fold1_hi); + + let reduced = _mm256_sub_epi64(fold2, p_vec); + let sign = _mm256_set1_epi64x(i64::MIN); + let fold2_s = _mm256_xor_si256(fold2, sign); + let reduced_s = _mm256_xor_si256(reduced, sign); + let fold2_lt = _mm256_cmpgt_epi64(reduced_s, fold2_s); + _mm256_blendv_epi8(reduced, fold2, fold2_lt) + } + + /// Reduction for BITS == 64. Uses XOR-with-SIGN_BIT trick for unsigned + /// overflow detection. + #[inline] + unsafe fn reduce128_full_k(hi: __m256i, lo: __m256i) -> __m256i { + let c_vec = _mm256_set1_epi64x(Self::C_LO as i64); + let p_vec = _mm256_set1_epi64x(P as i64); + let sign = _mm256_set1_epi64x(i64::MIN); + let c_hi_lo = _mm256_mul_epu32(c_vec, hi); + let hi_hi = _mm256_srli_epi64::<32>(hi); + let c_hi_hi = _mm256_mul_epu32(c_vec, hi_hi); + + let c_hi_hi_lo32 = _mm256_slli_epi64::<32>(c_hi_hi); + let c_hi_carry = _mm256_srli_epi64::<32>(c_hi_hi); + + let sum_lo = _mm256_add_epi64(c_hi_lo, c_hi_hi_lo32); + let c_hi_lo_s = _mm256_xor_si256(c_hi_lo, sign); + let sum_lo_s = _mm256_xor_si256(sum_lo, sign); + let carry0 = _mm256_cmpgt_epi64(c_hi_lo_s, sum_lo_s); + let overflow = _mm256_sub_epi64(c_hi_carry, carry0); + + let s = _mm256_add_epi64(lo, sum_lo); + let lo_s = _mm256_xor_si256(lo, sign); + let s_s = _mm256_xor_si256(s, sign); + let carry1 = _mm256_cmpgt_epi64(lo_s, s_s); + let total_overflow = _mm256_sub_epi64(overflow, carry1); + + let final_corr = _mm256_mul_epu32(c_vec, total_overflow); + let result = _mm256_add_epi64(s, final_corr); + let s2_s = _mm256_xor_si256(s, sign); + let result_s = _mm256_xor_si256(result, sign); + let carry_f = _mm256_cmpgt_epi64(s2_s, result_s); + let corr_f = _mm256_and_si256(carry_f, c_vec); + let result = _mm256_add_epi64(result, corr_f); + + let result_s2 = _mm256_xor_si256(result, sign); + let p_s = _mm256_xor_si256(p_vec, sign); + let lt_p = _mm256_cmpgt_epi64(p_s, result_s2); + let sub_amt = _mm256_andnot_si256(lt_p, p_vec); + _mm256_sub_epi64(result, sub_amt) + } +} + +impl Default for PackedFp64Avx2

{ + #[inline] + fn default() -> Self { + Self([Fp64(0); FP64_WIDTH]) + } +} + +impl fmt::Debug for PackedFp64Avx2

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp64Avx2").field(&self.0).finish() + } +} + +impl PartialEq for PackedFp64Avx2

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for PackedFp64Avx2

{} + +impl Add for PackedFp64Avx2

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm256_set1_epi64x(P as i64); + + let result = if Self::BITS <= 62 { + // a + b < 2P < 2^63: no overflow. + let s = _mm256_add_epi64(a, b); + let r = _mm256_sub_epi64(s, p); + // s < P? Use signed compare after shift trick. + let sign = _mm256_set1_epi64x(i64::MIN); + let s_s = _mm256_xor_si256(s, sign); + let p_s = _mm256_xor_si256(p, sign); + let borrow = _mm256_cmpgt_epi64(p_s, s_s); + _mm256_blendv_epi8(r, s, borrow) + } else { + // a + b can overflow u64. + let s = _mm256_add_epi64(a, b); + let sign = _mm256_set1_epi64x(i64::MIN); + let a_s = _mm256_xor_si256(a, sign); + let s_s = _mm256_xor_si256(s, sign); + let overflow = _mm256_cmpgt_epi64(a_s, s_s); + let c = _mm256_set1_epi64x(Self::C_LO as i64); + let s_plus_c = _mm256_add_epi64(s, c); + let s_minus_p = _mm256_sub_epi64(s, p); + let p_s = _mm256_xor_si256(p, sign); + let lt_p = _mm256_cmpgt_epi64(p_s, s_s); + let no_of = _mm256_blendv_epi8(s_minus_p, s, lt_p); + _mm256_blendv_epi8(no_of, s_plus_c, overflow) + }; + + Self::from_vec(result) + } + } +} + +impl Sub for PackedFp64Avx2

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm256_set1_epi64x(P as i64); + let d = _mm256_sub_epi64(a, b); + + let sign = _mm256_set1_epi64x(i64::MIN); + let a_s = _mm256_xor_si256(a, sign); + let b_s = _mm256_xor_si256(b, sign); + let underflow = _mm256_cmpgt_epi64(b_s, a_s); + let corrected = _mm256_add_epi64(d, p); + Self::from_vec(_mm256_blendv_epi8(d, corrected, underflow)) + } + } +} + +impl Mul for PackedFp64Avx2

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + unsafe { + let (hi, lo) = mul64_64_256(self.to_vec(), rhs.to_vec()); + Self::from_vec(Self::reduce128_vec(hi, lo)) + } + } +} + +impl PackedValue for PackedFp64Avx2

{ + type Value = Fp64

; + const WIDTH: usize = FP64_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([f(0), f(1), f(2), f(3)]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP64_WIDTH); + self.0[lane] + } +} + +impl AddAssign for PackedFp64Avx2

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp64Avx2

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp64Avx2

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp64Avx2

{ + type Scalar = Fp64

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value; FP64_WIDTH]) + } +} + +/// Number of `Fp128` lanes in an AVX2 packed vector. +pub const FP128_WIDTH: usize = 4; + +/// AVX2 packed arithmetic for `Fp128

`, 4 lanes in SoA layout. +/// +/// Stores 4 elements as separate `lo` and `hi` `u64` arrays, enabling +/// vectorized add/sub via `__m256i`. Mul remains scalar per-lane. +#[derive(Clone, Copy)] +pub struct PackedFp128Avx2 { + lo: [u64; FP128_WIDTH], + hi: [u64; FP128_WIDTH], +} + +impl PackedFp128Avx2

{ + const P_LO: u64 = P as u64; + const P_HI: u64 = (P >> 64) as u64; +} + +impl Default for PackedFp128Avx2

{ + #[inline] + fn default() -> Self { + Self { + lo: [0; FP128_WIDTH], + hi: [0; FP128_WIDTH], + } + } +} + +impl fmt::Debug for PackedFp128Avx2

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let elems: Vec<_> = (0..FP128_WIDTH).map(|i| self.extract(i)).collect(); + f.debug_tuple("PackedFp128Avx2").field(&elems).finish() + } +} + +impl PartialEq for PackedFp128Avx2

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.lo == other.lo && self.hi == other.hi + } +} + +impl Eq for PackedFp128Avx2

{} + +impl Add for PackedFp128Avx2

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a_lo = _mm256_loadu_si256(self.lo.as_ptr().cast()); + let a_hi = _mm256_loadu_si256(self.hi.as_ptr().cast()); + let b_lo = _mm256_loadu_si256(rhs.lo.as_ptr().cast()); + let b_hi = _mm256_loadu_si256(rhs.hi.as_ptr().cast()); + let p_lo = _mm256_set1_epi64x(Self::P_LO as i64); + let p_hi = _mm256_set1_epi64x(Self::P_HI as i64); + let sign = _mm256_set1_epi64x(i64::MIN); + let one = _mm256_set1_epi64x(1); + + // 128-bit add with unsigned compare emulation (XOR sign bit) + let sum_lo = _mm256_add_epi64(a_lo, b_lo); + let carry_lo = + _mm256_cmpgt_epi64(_mm256_xor_si256(a_lo, sign), _mm256_xor_si256(sum_lo, sign)); + let carry_lo_bit = _mm256_and_si256(carry_lo, one); + + let hi_tmp = _mm256_add_epi64(a_hi, b_hi); + let ov1 = + _mm256_cmpgt_epi64(_mm256_xor_si256(a_hi, sign), _mm256_xor_si256(hi_tmp, sign)); + let sum_hi = _mm256_add_epi64(hi_tmp, carry_lo_bit); + let ov2 = _mm256_cmpgt_epi64( + _mm256_xor_si256(hi_tmp, sign), + _mm256_xor_si256(sum_hi, sign), + ); + let carry_128 = _mm256_or_si256(ov1, ov2); + + // 128-bit subtract P + let red_lo = _mm256_sub_epi64(sum_lo, p_lo); + let borrow_lo = + _mm256_cmpgt_epi64(_mm256_xor_si256(p_lo, sign), _mm256_xor_si256(sum_lo, sign)); + let borrow_lo_bit = _mm256_and_si256(borrow_lo, one); + + let red_hi_tmp = _mm256_sub_epi64(sum_hi, p_hi); + let bw1 = + _mm256_cmpgt_epi64(_mm256_xor_si256(p_hi, sign), _mm256_xor_si256(sum_hi, sign)); + let red_hi = _mm256_sub_epi64(red_hi_tmp, borrow_lo_bit); + let bw2 = _mm256_cmpgt_epi64( + _mm256_xor_si256(borrow_lo_bit, sign), + _mm256_xor_si256(red_hi_tmp, sign), + ); + let borrow = _mm256_or_si256(bw1, bw2); + + // use_reduced = carry_128 | !borrow + let not_borrow = _mm256_xor_si256(borrow, _mm256_set1_epi64x(-1)); + let use_reduced = _mm256_or_si256(carry_128, not_borrow); + let out_lo = _mm256_blendv_epi8(sum_lo, red_lo, use_reduced); + let out_hi = _mm256_blendv_epi8(sum_hi, red_hi, use_reduced); + + let mut result = Self::default(); + _mm256_storeu_si256(result.lo.as_mut_ptr().cast(), out_lo); + _mm256_storeu_si256(result.hi.as_mut_ptr().cast(), out_hi); + result + } + } +} + +impl Sub for PackedFp128Avx2

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a_lo = _mm256_loadu_si256(self.lo.as_ptr().cast()); + let a_hi = _mm256_loadu_si256(self.hi.as_ptr().cast()); + let b_lo = _mm256_loadu_si256(rhs.lo.as_ptr().cast()); + let b_hi = _mm256_loadu_si256(rhs.hi.as_ptr().cast()); + let p_lo = _mm256_set1_epi64x(Self::P_LO as i64); + let p_hi = _mm256_set1_epi64x(Self::P_HI as i64); + let sign = _mm256_set1_epi64x(i64::MIN); + let one = _mm256_set1_epi64x(1); + + // 128-bit sub + let diff_lo = _mm256_sub_epi64(a_lo, b_lo); + let borrow_lo = + _mm256_cmpgt_epi64(_mm256_xor_si256(b_lo, sign), _mm256_xor_si256(a_lo, sign)); + let borrow_lo_bit = _mm256_and_si256(borrow_lo, one); + + let hi_tmp = _mm256_sub_epi64(a_hi, b_hi); + let bw1 = + _mm256_cmpgt_epi64(_mm256_xor_si256(b_hi, sign), _mm256_xor_si256(a_hi, sign)); + let diff_hi = _mm256_sub_epi64(hi_tmp, borrow_lo_bit); + let bw2 = _mm256_cmpgt_epi64( + _mm256_xor_si256(borrow_lo_bit, sign), + _mm256_xor_si256(hi_tmp, sign), + ); + let borrow_128 = _mm256_or_si256(bw1, bw2); + + // Correction: add P back where underflow occurred + let corr_lo = _mm256_add_epi64(diff_lo, p_lo); + let carry_lo = _mm256_cmpgt_epi64( + _mm256_xor_si256(diff_lo, sign), + _mm256_xor_si256(corr_lo, sign), + ); + let carry_lo_bit = _mm256_and_si256(carry_lo, one); + let corr_hi = _mm256_add_epi64(diff_hi, p_hi); + let corr_hi = _mm256_add_epi64(corr_hi, carry_lo_bit); + + let out_lo = _mm256_blendv_epi8(diff_lo, corr_lo, borrow_128); + let out_hi = _mm256_blendv_epi8(diff_hi, corr_hi, borrow_128); + + let mut result = Self::default(); + _mm256_storeu_si256(result.lo.as_mut_ptr().cast(), out_lo); + _mm256_storeu_si256(result.hi.as_mut_ptr().cast(), out_hi); + result + } + } +} + +impl Mul for PackedFp128Avx2

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let mut out = Self::default(); + for i in 0..FP128_WIDTH { + let a = Fp128::

([self.lo[i], self.hi[i]]); + let b = Fp128::

([rhs.lo[i], rhs.hi[i]]); + let r = a * b; + out.lo[i] = r.0[0]; + out.hi[i] = r.0[1]; + } + out + } +} + +impl PackedValue for PackedFp128Avx2

{ + type Value = Fp128

; + const WIDTH: usize = FP128_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + let mut lo = [0u64; FP128_WIDTH]; + let mut hi = [0u64; FP128_WIDTH]; + for i in 0..FP128_WIDTH { + let v = f(i); + lo[i] = v.0[0]; + hi[i] = v.0[1]; + } + Self { lo, hi } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP128_WIDTH); + Fp128([self.lo[lane], self.hi[lane]]) + } +} + +impl AddAssign for PackedFp128Avx2

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp128Avx2

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp128Avx2

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp128Avx2

{ + type Scalar = Fp128

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self { + lo: [value.0[0]; FP128_WIDTH], + hi: [value.0[1]; FP128_WIDTH], + } + } +} diff --git a/src/algebra/fields/packed_avx512.rs b/src/algebra/fields/packed_avx512.rs new file mode 100644 index 00000000..343c1dfc --- /dev/null +++ b/src/algebra/fields/packed_avx512.rs @@ -0,0 +1,729 @@ +//! AVX-512 packed backends for Fp32, Fp64, Fp128. +//! +//! Requires AVX-512F + AVX-512DQ. Uses native unsigned comparisons and mask +//! registers for branchless conditionals. + +use super::packed::{PackedField, PackedValue}; +use crate::algebra::fields::{Fp128, Fp32, Fp64}; +use crate::FieldCore; +use core::arch::x86_64::*; +use core::fmt; +use core::mem::transmute; +use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; + +#[inline(always)] +unsafe fn movehdup_epi32_512(x: __m512i) -> __m512i { + _mm512_castps_si512(_mm512_movehdup_ps(_mm512_castsi512_ps(x))) +} + +#[inline(always)] +unsafe fn moveldup_epi32_512(x: __m512i) -> __m512i { + _mm512_castps_si512(_mm512_moveldup_ps(_mm512_castsi512_ps(x))) +} + +/// 64×64→128 schoolbook multiply using 32×32→64 partial products. +/// Returns (hi, lo) representing the 128-bit product. +/// Adapted from plonky3's Goldilocks AVX-512 backend. +#[inline] +unsafe fn mul64_64_512(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + let x_hi = movehdup_epi32_512(x); + let y_hi = movehdup_epi32_512(y); + + let mul_ll = _mm512_mul_epu32(x, y); + let mul_lh = _mm512_mul_epu32(x, y_hi); + let mul_hl = _mm512_mul_epu32(x_hi, y); + let mul_hh = _mm512_mul_epu32(x_hi, y_hi); + + let mul_ll_hi = _mm512_srli_epi64::<32>(mul_ll); + let t0 = _mm512_add_epi64(mul_hl, mul_ll_hi); + let mask32 = _mm512_set1_epi64(0xFFFF_FFFF_i64); + let t0_lo = _mm512_and_si512(t0, mask32); + let t0_hi = _mm512_srli_epi64::<32>(t0); + let t1 = _mm512_add_epi64(mul_lh, t0_lo); + let t2 = _mm512_add_epi64(mul_hh, t0_hi); + let t1_hi = _mm512_srli_epi64::<32>(t1); + let res_hi = _mm512_add_epi64(t2, t1_hi); + + let t1_lo = moveldup_epi32_512(t1); + let res_lo = _mm512_mask_blend_epi32(0b0101_0101_0101_0101, t1_lo, mul_ll); + + (res_hi, res_lo) +} + +/// Number of `Fp32` lanes in an AVX-512 packed vector. +pub const FP32_WIDTH: usize = 16; + +/// AVX-512 packed arithmetic for `Fp32

`, processing 16 lanes. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PackedFp32Avx512(pub [Fp32

; FP32_WIDTH]); + +impl PackedFp32Avx512

{ + const BITS: u32 = 32 - P.leading_zeros(); + + const C: u32 = { + let c = if Self::BITS == 32 { + 0u32.wrapping_sub(P) + } else { + (1u32 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u64) * (c as u64 + 1) < P as u64, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + const MASK_U64: u64 = if Self::BITS == 32 { + u32::MAX as u64 + } else { + (1u64 << Self::BITS) - 1 + }; + + #[inline(always)] + fn to_vec(self) -> __m512i { + unsafe { transmute(self) } + } + + #[inline(always)] + unsafe fn from_vec(v: __m512i) -> Self { + unsafe { transmute(v) } + } +} + +impl Default for PackedFp32Avx512

{ + #[inline] + fn default() -> Self { + Self([Fp32(0); FP32_WIDTH]) + } +} + +impl fmt::Debug for PackedFp32Avx512

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp32Avx512").field(&self.0).finish() + } +} + +impl PartialEq for PackedFp32Avx512

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for PackedFp32Avx512

{} + +impl Add for PackedFp32Avx512

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm512_set1_epi32(P as i32); + + let result = if Self::BITS <= 31 { + let t = _mm512_add_epi32(a, b); + let u = _mm512_sub_epi32(t, p); + _mm512_min_epu32(t, u) + } else { + let c = _mm512_set1_epi32(Self::C as i32); + let t = _mm512_add_epi32(a, b); + // Step 1: correct overflow (2^32 ≡ C mod P) + let overflow = _mm512_cmplt_epu32_mask(t, a); + let t2 = _mm512_mask_add_epi32(t, overflow, t, c); + // Step 2: subtract P if t2 >= P + let geq_p = _mm512_cmpge_epu32_mask(t2, p); + _mm512_mask_sub_epi32(t2, geq_p, t2, p) + }; + + Self::from_vec(result) + } + } +} + +impl Sub for PackedFp32Avx512

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm512_set1_epi32(P as i32); + + let result = if Self::BITS <= 31 { + let t = _mm512_sub_epi32(a, b); + let u = _mm512_add_epi32(t, p); + _mm512_min_epu32(t, u) + } else { + let t = _mm512_sub_epi32(a, b); + let underflow = _mm512_cmplt_epu32_mask(a, b); + _mm512_mask_add_epi32(t, underflow, t, p) + }; + + Self::from_vec(result) + } + } +} + +impl Mul for PackedFp32Avx512

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + + let prod_evn = _mm512_mul_epu32(a, b); + let a_odd = movehdup_epi32_512(a); + let b_odd = movehdup_epi32_512(b); + let prod_odd = _mm512_mul_epu32(a_odd, b_odd); + + let mask = _mm512_set1_epi64(Self::MASK_U64 as i64); + let c_vec = _mm512_set1_epi64(Self::C as i64); + let shift = _mm_set_epi64x(0, Self::BITS as i64); + + // Fold 1 + let evn_lo = _mm512_and_si512(prod_evn, mask); + let evn_hi = _mm512_srl_epi64(prod_evn, shift); + let evn_f1 = _mm512_add_epi64(evn_lo, _mm512_mul_epu32(evn_hi, c_vec)); + + let odd_lo = _mm512_and_si512(prod_odd, mask); + let odd_hi = _mm512_srl_epi64(prod_odd, shift); + let odd_f1 = _mm512_add_epi64(odd_lo, _mm512_mul_epu32(odd_hi, c_vec)); + + // Fold 2 + let evn_f1_lo = _mm512_and_si512(evn_f1, mask); + let evn_f1_hi = _mm512_srl_epi64(evn_f1, shift); + let evn_f2 = _mm512_add_epi64(evn_f1_lo, _mm512_mul_epu32(evn_f1_hi, c_vec)); + + let odd_f1_lo = _mm512_and_si512(odd_f1, mask); + let odd_f1_hi = _mm512_srl_epi64(odd_f1, shift); + let odd_f2 = _mm512_add_epi64(odd_f1_lo, _mm512_mul_epu32(odd_f1_hi, c_vec)); + + // Recombine even/odd + let odd_shifted = _mm512_slli_epi64::<32>(odd_f2); + let combined = _mm512_mask_blend_epi32(0b1010101010101010, evn_f2, odd_shifted); + + // Conditional subtract P + let p_vec = _mm512_set1_epi32(P as i32); + let reduced = _mm512_sub_epi32(combined, p_vec); + Self::from_vec(_mm512_min_epu32(combined, reduced)) + } + } +} + +impl PackedValue for PackedFp32Avx512

{ + type Value = Fp32

; + const WIDTH: usize = FP32_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([ + f(0), + f(1), + f(2), + f(3), + f(4), + f(5), + f(6), + f(7), + f(8), + f(9), + f(10), + f(11), + f(12), + f(13), + f(14), + f(15), + ]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP32_WIDTH); + self.0[lane] + } +} + +impl AddAssign for PackedFp32Avx512

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp32Avx512

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp32Avx512

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp32Avx512

{ + type Scalar = Fp32

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value; FP32_WIDTH]) + } +} + +/// Number of `Fp64` lanes in an AVX-512 packed vector. +pub const FP64_WIDTH: usize = 8; + +/// AVX-512 packed arithmetic for `Fp64

`, processing 8 lanes. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PackedFp64Avx512(pub [Fp64

; FP64_WIDTH]); + +impl PackedFp64Avx512

{ + const BITS: u32 = 64 - P.leading_zeros(); + + const C_LO: u64 = { + let c = if Self::BITS == 64 { + 0u64.wrapping_sub(P) + } else { + (1u64 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + c + }; + + const MASK64: u64 = if Self::BITS < 64 { + (1u64 << Self::BITS) - 1 + } else { + u64::MAX + }; + + #[inline(always)] + fn to_vec(self) -> __m512i { + unsafe { transmute(self) } + } + + #[inline(always)] + unsafe fn from_vec(v: __m512i) -> Self { + unsafe { transmute(v) } + } + + /// Vectorized 128-bit Solinas reduction for p = 2^BITS - C. + /// Given (hi, lo) = 128-bit product, computes result ≡ (hi*2^64 + lo) mod p. + #[inline] + unsafe fn reduce128_vec(hi: __m512i, lo: __m512i) -> __m512i { + if Self::BITS < 64 { + Self::reduce128_small_k(hi, lo) + } else { + Self::reduce128_full_k(hi, lo) + } + } + + /// Reduction for BITS < 64 (e.g. 40-bit prime). No overflow issues: all + /// intermediates fit in u64. + #[inline] + unsafe fn reduce128_small_k(hi: __m512i, lo: __m512i) -> __m512i { + let mask_k = _mm512_set1_epi64(Self::MASK64 as i64); + let c_vec = _mm512_set1_epi64(Self::C_LO as i64); + let p_vec = _mm512_set1_epi64(P as i64); + let shift_k = _mm_set_epi64x(0, Self::BITS as i64); + let shift_64mk = _mm_set_epi64x(0, (64 - Self::BITS) as i64); + + let lo_k = _mm512_and_si512(lo, mask_k); + let lo_upper = _mm512_srl_epi64(lo, shift_k); + let hi_shifted = _mm512_sll_epi64(hi, shift_64mk); + let hi_k = _mm512_or_si512(lo_upper, hi_shifted); + + // c * hi_k: hi_k may exceed 32 bits, split into lo32 and top + let c_hi_lo = _mm512_mul_epu32(c_vec, hi_k); + let hi_k_top = _mm512_srli_epi64::<32>(hi_k); + let c_hi_top = _mm512_mul_epu32(c_vec, hi_k_top); + let c_hi_top_shifted = _mm512_slli_epi64::<32>(c_hi_top); + let c_hi_full = _mm512_add_epi64(c_hi_lo, c_hi_top_shifted); + + let fold1 = _mm512_add_epi64(lo_k, c_hi_full); + + let fold1_lo_k = _mm512_and_si512(fold1, mask_k); + let fold1_hi = _mm512_srl_epi64(fold1, shift_k); + let c_fold1_hi = _mm512_mul_epu32(c_vec, fold1_hi); + let fold2 = _mm512_add_epi64(fold1_lo_k, c_fold1_hi); + + let reduced = _mm512_sub_epi64(fold2, p_vec); + _mm512_min_epu64(fold2, reduced) + } + + /// Reduction for BITS == 64 (e.g. p = 2^64 - 87). Tracks overflow from + /// c*hi exceeding 64 bits, using native unsigned comparisons. + #[inline] + unsafe fn reduce128_full_k(hi: __m512i, lo: __m512i) -> __m512i { + let c_vec = _mm512_set1_epi64(Self::C_LO as i64); + let p_vec = _mm512_set1_epi64(P as i64); + let one = _mm512_set1_epi64(1); + + // c * hi_lo32 + let c_hi_lo = _mm512_mul_epu32(c_vec, hi); + // c * hi_hi32 + let hi_hi = _mm512_srli_epi64::<32>(hi); + let c_hi_hi = _mm512_mul_epu32(c_vec, hi_hi); + + let c_hi_hi_lo32 = _mm512_slli_epi64::<32>(c_hi_hi); + let c_hi_carry = _mm512_srli_epi64::<32>(c_hi_hi); + + // Lower 64 bits of c * hi + let sum_lo = _mm512_add_epi64(c_hi_lo, c_hi_hi_lo32); + let carry0 = _mm512_cmplt_epu64_mask(sum_lo, c_hi_lo); + let overflow = _mm512_mask_add_epi64(c_hi_carry, carry0, c_hi_carry, one); + + // lo + sum_lo + let s = _mm512_add_epi64(lo, sum_lo); + let carry1 = _mm512_cmplt_epu64_mask(s, lo); + let total_overflow = _mm512_mask_add_epi64(overflow, carry1, overflow, one); + + // Fold overflow: total_overflow * c (at most ~2^15) + let final_corr = _mm512_mul_epu32(c_vec, total_overflow); + let result = _mm512_add_epi64(s, final_corr); + let carry_f = _mm512_cmplt_epu64_mask(result, s); + let result = _mm512_mask_add_epi64(result, carry_f, result, c_vec); + + let ge_mask = _mm512_cmpge_epu64_mask(result, p_vec); + _mm512_mask_sub_epi64(result, ge_mask, result, p_vec) + } +} + +impl Default for PackedFp64Avx512

{ + #[inline] + fn default() -> Self { + Self([Fp64(0); FP64_WIDTH]) + } +} + +impl fmt::Debug for PackedFp64Avx512

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp64Avx512").field(&self.0).finish() + } +} + +impl PartialEq for PackedFp64Avx512

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for PackedFp64Avx512

{} + +impl Add for PackedFp64Avx512

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm512_set1_epi64(P as i64); + + let result = if Self::BITS <= 62 { + let s = _mm512_add_epi64(a, b); + let geq_p = _mm512_cmpge_epu64_mask(s, p); + _mm512_mask_sub_epi64(s, geq_p, s, p) + } else { + let s = _mm512_add_epi64(a, b); + let overflow = _mm512_cmplt_epu64_mask(s, a); + let c = _mm512_set1_epi64(Self::C_LO as i64); + let geq_p = _mm512_cmpge_epu64_mask(s, p); + let no_of = _mm512_mask_sub_epi64(s, geq_p, s, p); + let s_plus_c = _mm512_add_epi64(s, c); + _mm512_mask_blend_epi64(overflow, no_of, s_plus_c) + }; + + Self::from_vec(result) + } + } +} + +impl Sub for PackedFp64Avx512

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm512_set1_epi64(P as i64); + let d = _mm512_sub_epi64(a, b); + let underflow = _mm512_cmplt_epu64_mask(a, b); + Self::from_vec(_mm512_mask_add_epi64(d, underflow, d, p)) + } + } +} + +impl Mul for PackedFp64Avx512

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + unsafe { + let (hi, lo) = mul64_64_512(self.to_vec(), rhs.to_vec()); + Self::from_vec(Self::reduce128_vec(hi, lo)) + } + } +} + +impl PackedValue for PackedFp64Avx512

{ + type Value = Fp64

; + const WIDTH: usize = FP64_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([f(0), f(1), f(2), f(3), f(4), f(5), f(6), f(7)]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP64_WIDTH); + self.0[lane] + } +} + +impl AddAssign for PackedFp64Avx512

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp64Avx512

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp64Avx512

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp64Avx512

{ + type Scalar = Fp64

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value; FP64_WIDTH]) + } +} + +/// Number of `Fp128` lanes in an AVX-512 packed vector. +pub const FP128_WIDTH: usize = 8; + +/// AVX-512 packed arithmetic for `Fp128

`, 8 lanes in SoA layout. +/// +/// Stores 8 elements as separate `lo` and `hi` `u64` arrays, enabling +/// vectorized add/sub via `__m512i`. Mul remains scalar per-lane. +#[derive(Clone, Copy)] +pub struct PackedFp128Avx512 { + lo: [u64; FP128_WIDTH], + hi: [u64; FP128_WIDTH], +} + +impl PackedFp128Avx512

{ + const P_LO: u64 = P as u64; + const P_HI: u64 = (P >> 64) as u64; +} + +impl Default for PackedFp128Avx512

{ + #[inline] + fn default() -> Self { + Self { + lo: [0; FP128_WIDTH], + hi: [0; FP128_WIDTH], + } + } +} + +impl fmt::Debug for PackedFp128Avx512

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let elems: Vec<_> = (0..FP128_WIDTH).map(|i| self.extract(i)).collect(); + f.debug_tuple("PackedFp128Avx512").field(&elems).finish() + } +} + +impl PartialEq for PackedFp128Avx512

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.lo == other.lo && self.hi == other.hi + } +} + +impl Eq for PackedFp128Avx512

{} + +impl Add for PackedFp128Avx512

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a_lo = _mm512_loadu_si512(self.lo.as_ptr().cast()); + let a_hi = _mm512_loadu_si512(self.hi.as_ptr().cast()); + let b_lo = _mm512_loadu_si512(rhs.lo.as_ptr().cast()); + let b_hi = _mm512_loadu_si512(rhs.hi.as_ptr().cast()); + let p_lo = _mm512_set1_epi64(Self::P_LO as i64); + let p_hi = _mm512_set1_epi64(Self::P_HI as i64); + let one = _mm512_set1_epi64(1); + + // 128-bit add: (sum_hi, sum_lo) = (a_hi, a_lo) + (b_hi, b_lo) + let sum_lo = _mm512_add_epi64(a_lo, b_lo); + let carry_lo = _mm512_cmplt_epu64_mask(sum_lo, a_lo); + let hi_tmp = _mm512_add_epi64(a_hi, b_hi); + let ov1 = _mm512_cmplt_epu64_mask(hi_tmp, a_hi); + let sum_hi = _mm512_mask_add_epi64(hi_tmp, carry_lo, hi_tmp, one); + let ov2 = _mm512_cmplt_epu64_mask(sum_hi, hi_tmp); + let carry_128 = ov1 | ov2; + + // 128-bit subtract P: (red_hi, red_lo) = (sum_hi, sum_lo) - P + let red_lo = _mm512_sub_epi64(sum_lo, p_lo); + let borrow_lo = _mm512_cmplt_epu64_mask(sum_lo, p_lo); + let red_hi_tmp = _mm512_sub_epi64(sum_hi, p_hi); + let bw1 = _mm512_cmplt_epu64_mask(sum_hi, p_hi); + let red_hi = _mm512_mask_sub_epi64(red_hi_tmp, borrow_lo, red_hi_tmp, one); + let bw2 = _mm512_cmplt_epu64_mask(red_hi_tmp, _mm512_maskz_mov_epi64(borrow_lo, one)); + let borrow = bw1 | bw2; + + // Use reduced if: overflow happened OR subtraction didn't borrow + let use_reduced = carry_128 | !borrow; + let out_lo = _mm512_mask_blend_epi64(use_reduced, sum_lo, red_lo); + let out_hi = _mm512_mask_blend_epi64(use_reduced, sum_hi, red_hi); + + let mut result = Self::default(); + _mm512_storeu_si512(result.lo.as_mut_ptr().cast(), out_lo); + _mm512_storeu_si512(result.hi.as_mut_ptr().cast(), out_hi); + result + } + } +} + +impl Sub for PackedFp128Avx512

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a_lo = _mm512_loadu_si512(self.lo.as_ptr().cast()); + let a_hi = _mm512_loadu_si512(self.hi.as_ptr().cast()); + let b_lo = _mm512_loadu_si512(rhs.lo.as_ptr().cast()); + let b_hi = _mm512_loadu_si512(rhs.hi.as_ptr().cast()); + let p_lo = _mm512_set1_epi64(Self::P_LO as i64); + let p_hi = _mm512_set1_epi64(Self::P_HI as i64); + let one = _mm512_set1_epi64(1); + + // 128-bit sub: (diff_hi, diff_lo) = (a_hi, a_lo) - (b_hi, b_lo) + let diff_lo = _mm512_sub_epi64(a_lo, b_lo); + let borrow_lo = _mm512_cmplt_epu64_mask(a_lo, b_lo); + let hi_tmp = _mm512_sub_epi64(a_hi, b_hi); + let bw1 = _mm512_cmplt_epu64_mask(a_hi, b_hi); + let diff_hi = _mm512_mask_sub_epi64(hi_tmp, borrow_lo, hi_tmp, one); + let bw2 = _mm512_cmplt_epu64_mask(hi_tmp, _mm512_maskz_mov_epi64(borrow_lo, one)); + let borrow_128 = bw1 | bw2; + + // Correction: add P back where underflow occurred + let corr_lo = _mm512_add_epi64(diff_lo, p_lo); + let carry_lo = _mm512_cmplt_epu64_mask(corr_lo, diff_lo); + let corr_hi = _mm512_add_epi64(diff_hi, p_hi); + let corr_hi = _mm512_mask_add_epi64(corr_hi, carry_lo, corr_hi, one); + + let out_lo = _mm512_mask_blend_epi64(borrow_128, diff_lo, corr_lo); + let out_hi = _mm512_mask_blend_epi64(borrow_128, diff_hi, corr_hi); + + let mut result = Self::default(); + _mm512_storeu_si512(result.lo.as_mut_ptr().cast(), out_lo); + _mm512_storeu_si512(result.hi.as_mut_ptr().cast(), out_hi); + result + } + } +} + +impl Mul for PackedFp128Avx512

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let mut out = Self::default(); + for i in 0..FP128_WIDTH { + let a = Fp128::

([self.lo[i], self.hi[i]]); + let b = Fp128::

([rhs.lo[i], rhs.hi[i]]); + let r = a * b; + out.lo[i] = r.0[0]; + out.hi[i] = r.0[1]; + } + out + } +} + +impl PackedValue for PackedFp128Avx512

{ + type Value = Fp128

; + const WIDTH: usize = FP128_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + let mut lo = [0u64; FP128_WIDTH]; + let mut hi = [0u64; FP128_WIDTH]; + for i in 0..FP128_WIDTH { + let v = f(i); + lo[i] = v.0[0]; + hi[i] = v.0[1]; + } + Self { lo, hi } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP128_WIDTH); + Fp128([self.lo[lane], self.hi[lane]]) + } +} + +impl AddAssign for PackedFp128Avx512

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp128Avx512

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp128Avx512

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp128Avx512

{ + type Scalar = Fp128

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self { + lo: [value.0[0]; FP128_WIDTH], + hi: [value.0[1]; FP128_WIDTH], + } + } +} diff --git a/src/algebra/fields/packed_ext.rs b/src/algebra/fields/packed_ext.rs new file mode 100644 index 00000000..9725b598 --- /dev/null +++ b/src/algebra/fields/packed_ext.rs @@ -0,0 +1,421 @@ +//! Packed extension field types using transpose-based packing. +//! +//! A `PackedFp2` stores `[PF; 2]` where `PF` is the packed base field. +//! Each `PF` lane contains the corresponding coefficient of an `Fp2` element. +//! This enables WIDTH-fold parallel arithmetic over `Fp2` using existing SIMD +//! base-field operations. + +use crate::algebra::fields::ext::{Fp2, Fp2Config, Fp4, Fp4Config}; +use crate::algebra::fields::packed::{HasPacking, PackedField, PackedValue}; +use crate::primitives::serialization::Valid; +use crate::FieldCore; +use core::ops::{Add, Mul, Sub}; + +/// Packed `Fp2` elements stored in transpose layout: `[PF; 2]`. +/// +/// If `PF` has width `W`, this represents `W` parallel `Fp2` values. +pub struct PackedFp2, PF: PackedField> { + /// Degree-0 coefficient (packed across SIMD lanes). + pub c0: PF, + /// Degree-1 coefficient (packed across SIMD lanes). + pub c1: PF, + _marker: std::marker::PhantomData (F, C)>, +} + +impl, PF: PackedField> Clone for PackedFp2 { + fn clone(&self) -> Self { + *self + } +} + +impl, PF: PackedField> Copy for PackedFp2 {} + +impl, PF: PackedField> std::fmt::Debug + for PackedFp2 +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PackedFp2").finish_non_exhaustive() + } +} + +impl, PF: PackedField> PackedFp2 { + /// Create a `PackedFp2` from its two packed coefficients. + #[inline] + pub fn new(c0: PF, c1: PF) -> Self { + Self { + c0, + c1, + _marker: std::marker::PhantomData, + } + } + + #[inline] + fn mul_nr(x: PF) -> PF { + if C::IS_NEG_ONE { + let zero = PF::broadcast(F::zero()); + zero - x + } else { + PF::broadcast(C::non_residue()) * x + } + } +} + +impl PackedValue for PackedFp2 +where + F: FieldCore + Valid + 'static, + C: Fp2Config + 'static, + PF: PackedField, +{ + type Value = Fp2; + const WIDTH: usize = PF::WIDTH; + + fn from_fn(mut f: G) -> Self + where + G: FnMut(usize) -> Self::Value, + { + let mut c0s = Vec::with_capacity(PF::WIDTH); + let mut c1s = Vec::with_capacity(PF::WIDTH); + for i in 0..PF::WIDTH { + let val = f(i); + c0s.push(val.c0); + c1s.push(val.c1); + } + Self::new(PF::from_fn(|i| c0s[i]), PF::from_fn(|i| c1s[i])) + } + + fn extract(&self, lane: usize) -> Self::Value { + Fp2::new(self.c0.extract(lane), self.c1.extract(lane)) + } +} + +impl Add for PackedFp2 +where + F: FieldCore, + C: Fp2Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self::new(self.c0 + rhs.c0, self.c1 + rhs.c1) + } +} + +impl Sub for PackedFp2 +where + F: FieldCore, + C: Fp2Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::new(self.c0 - rhs.c0, self.c1 - rhs.c1) + } +} + +impl Mul for PackedFp2 +where + F: FieldCore, + C: Fp2Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let v0 = self.c0 * rhs.c0; + let v1 = self.c1 * rhs.c1; + Self::new( + v0 + Self::mul_nr(v1), + (self.c0 + self.c1) * (rhs.c0 + rhs.c1) - v0 - v1, + ) + } +} + +impl PackedField for PackedFp2 +where + F: FieldCore + Valid + 'static, + C: Fp2Config + 'static, + PF: PackedField, +{ + type Scalar = Fp2; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self::new(PF::broadcast(value.c0), PF::broadcast(value.c1)) + } +} + +impl HasPacking for Fp2 +where + F: FieldCore + Valid + HasPacking + 'static, + C: Fp2Config + 'static, +{ + type Packing = PackedFp2; +} + +/// Packed `Fp4` elements stored in transpose layout: `[PackedFp2; 2]`. +pub struct PackedFp4< + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +> { + /// Low half (`Fp2` coefficient 0). + pub c0: PackedFp2, + /// High half (`Fp2` coefficient 1). + pub c1: PackedFp2, + _marker: std::marker::PhantomData C4>, +} + +impl Clone for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + fn clone(&self) -> Self { + *self + } +} + +impl Copy for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ +} + +impl std::fmt::Debug for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PackedFp4").finish_non_exhaustive() + } +} + +impl PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + /// Create a `PackedFp4` from its two `PackedFp2` halves. + #[inline] + pub fn new(c0: PackedFp2, c1: PackedFp2) -> Self { + Self { + c0, + c1, + _marker: std::marker::PhantomData, + } + } +} + +impl PackedValue for PackedFp4 +where + F: FieldCore + Valid + 'static, + C2: Fp2Config + 'static, + C4: Fp4Config + 'static, + PF: PackedField, +{ + type Value = Fp4; + const WIDTH: usize = PF::WIDTH; + + fn from_fn(mut f: G) -> Self + where + G: FnMut(usize) -> Self::Value, + { + let mut c0s: Vec> = Vec::with_capacity(PF::WIDTH); + let mut c1s: Vec> = Vec::with_capacity(PF::WIDTH); + for i in 0..PF::WIDTH { + let val = f(i); + c0s.push(val.c0); + c1s.push(val.c1); + } + Self::new( + PackedFp2::from_fn(|i| c0s[i]), + PackedFp2::from_fn(|i| c1s[i]), + ) + } + + fn extract(&self, lane: usize) -> Self::Value { + Fp4::new(self.c0.extract(lane), self.c1.extract(lane)) + } +} + +impl Add for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self::new(self.c0 + rhs.c0, self.c1 + rhs.c1) + } +} + +impl Sub for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::new(self.c0 - rhs.c0, self.c1 - rhs.c1) + } +} + +impl Mul for PackedFp4 +where + F: FieldCore + Valid + 'static, + C2: Fp2Config + 'static, + C4: Fp4Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let nr2 = PackedFp2::broadcast(C4::non_residue()); + let v0 = self.c0 * rhs.c0; + let v1 = self.c1 * rhs.c1; + Self::new( + v0 + nr2 * v1, + (self.c0 + self.c1) * (rhs.c0 + rhs.c1) - v0 - v1, + ) + } +} + +impl PackedField for PackedFp4 +where + F: FieldCore + Valid + 'static, + C2: Fp2Config + 'static, + C4: Fp4Config + 'static, + PF: PackedField, +{ + type Scalar = Fp4; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self::new( + PackedFp2::broadcast(value.c0), + PackedFp2::broadcast(value.c1), + ) + } +} + +impl HasPacking for Fp4 +where + F: FieldCore + Valid + HasPacking + 'static, + C2: Fp2Config + 'static, + C4: Fp4Config + 'static, +{ + type Packing = PackedFp4; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::ext::{Ext2, Ext4, TwoNr, UnitNr}; + use crate::algebra::Fp64; + use crate::{FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp64<4294967197>; + type E2 = Ext2; + type E4 = Ext4; + type PE2 = PackedFp2::Packing>; + type PE4 = PackedFp4::Packing>; + + #[test] + fn packed_fp2_add() { + let mut rng = StdRng::seed_from_u64(100); + let width = ::WIDTH; + let a_elems: Vec = (0..width).map(|_| E2::sample(&mut rng)).collect(); + let b_elems: Vec = (0..width).map(|_| E2::sample(&mut rng)).collect(); + + let pa = PE2::from_fn(|i| a_elems[i]); + let pb = PE2::from_fn(|i| b_elems[i]); + let pc = pa + pb; + + for i in 0..width { + assert_eq!(pc.extract(i), a_elems[i] + b_elems[i]); + } + } + + #[test] + fn packed_fp2_mul() { + let mut rng = StdRng::seed_from_u64(200); + let width = ::WIDTH; + let a_elems: Vec = (0..width).map(|_| E2::sample(&mut rng)).collect(); + let b_elems: Vec = (0..width).map(|_| E2::sample(&mut rng)).collect(); + + let pa = PE2::from_fn(|i| a_elems[i]); + let pb = PE2::from_fn(|i| b_elems[i]); + let pc = pa * pb; + + for i in 0..width { + assert_eq!( + pc.extract(i), + a_elems[i] * b_elems[i], + "packed Fp2 mul mismatch at lane {i}" + ); + } + } + + #[test] + fn packed_fp2_broadcast() { + let val = E2::new(F::from_u64(7), F::from_u64(11)); + let packed = PE2::broadcast(val); + let width = ::WIDTH; + for i in 0..width { + assert_eq!(packed.extract(i), val); + } + } + + #[test] + fn packed_fp4_mul() { + let mut rng = StdRng::seed_from_u64(300); + let width = ::WIDTH; + let a_elems: Vec = (0..width).map(|_| E4::sample(&mut rng)).collect(); + let b_elems: Vec = (0..width).map(|_| E4::sample(&mut rng)).collect(); + + let pa = PE4::from_fn(|i| a_elems[i]); + let pb = PE4::from_fn(|i| b_elems[i]); + let pc = pa * pb; + + for i in 0..width { + assert_eq!( + pc.extract(i), + a_elems[i] * b_elems[i], + "packed Fp4 mul mismatch at lane {i}" + ); + } + } + + #[test] + fn pack_unpack_roundtrip_fp2() { + let mut rng = StdRng::seed_from_u64(400); + let width = ::WIDTH; + let elems: Vec = (0..width * 3).map(|_| E2::sample(&mut rng)).collect(); + + let packed = PE2::pack_slice(&elems); + let unpacked = PE2::unpack_slice(&packed); + + assert_eq!(elems, unpacked); + } +} diff --git a/src/algebra/fields/packed_neon.rs b/src/algebra/fields/packed_neon.rs new file mode 100644 index 00000000..e76c1ed2 --- /dev/null +++ b/src/algebra/fields/packed_neon.rs @@ -0,0 +1,781 @@ +//! AArch64 NEON packed backends for Fp32, Fp64, Fp128. + +use super::packed::{PackedField, PackedValue}; +use crate::algebra::fields::{Fp128, Fp32, Fp64}; +use crate::FieldCore; +use core::arch::aarch64::{ + uint32x2_t, uint32x4_t, uint64x2_t, vaddq_u32, vaddq_u64, vandq_u64, vbslq_u32, vbslq_u64, + vcgtq_u64, vcltq_u32, vcltq_u64, vcombine_u32, vdup_n_u32, vdupq_n_s64, vdupq_n_u32, + vdupq_n_u64, veorq_u64, vget_low_u32, vminq_u32, vmovn_u64, vmull_high_u32, vmull_u32, + vorrq_u64, vshlq_u64, vsubq_u32, vsubq_u64, +}; +use core::fmt; +use core::mem::transmute; +use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; + +/// Number of packed `Fp128` lanes in this backend. +pub const WIDTH: usize = 2; + +/// True SoA layout for two packed `Fp128` lanes. +/// +/// `lo = [lane0.lo, lane1.lo]` +/// `hi = [lane0.hi, lane1.hi]` +#[derive(Clone, Copy)] +pub struct PackedFp128Neon { + lo: [u64; 2], + hi: [u64; 2], +} + +#[inline(always)] +fn to_vec(x: [u64; 2]) -> uint64x2_t { + unsafe { transmute::<[u64; 2], uint64x2_t>(x) } +} + +#[inline(always)] +fn from_vec(v: uint64x2_t) -> [u64; 2] { + unsafe { transmute::(v) } +} + +#[inline(always)] +fn mask_to_bit(mask: uint64x2_t) -> uint64x2_t { + // SAFETY: NEON intrinsics are available under this cfg. + unsafe { vandq_u64(mask, vdupq_n_u64(1)) } +} + +#[inline(always)] +const fn modulus_lo() -> u64 { + P as u64 +} + +#[inline(always)] +const fn modulus_hi() -> u64 { + (P >> 64) as u64 +} + +use super::util::{is_pow2_u64, log2_pow2_u64}; + +impl PackedFp128Neon

{ + const C: u128 = { + let c = 0u128.wrapping_sub(P); + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!(c < (1u128 << 64), "P must be 2^128 - c with c < 2^64"); + assert!( + c * (c + 1) < P, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + const C_LO: u64 = Self::C as u64; + const C_SHIFT_KIND: i8 = { + let c = Self::C_LO; + if c > 1 && is_pow2_u64(c - 1) { + 1 + } else if c == u64::MAX || is_pow2_u64(c + 1) { + -1 + } else { + 0 + } + }; + const C_SHIFT: u32 = { + let c = Self::C_LO; + if Self::C_SHIFT_KIND == 1 { + log2_pow2_u64(c - 1) + } else if Self::C_SHIFT_KIND == -1 { + if c == u64::MAX { + 64 + } else { + log2_pow2_u64(c + 1) + } + } else { + 0 + } + }; + + #[inline(always)] + fn mul_wide_u64(a: u64, b: u64) -> (u64, u64) { + let prod = (a as u128) * (b as u128); + (prod as u64, (prod >> 64) as u64) + } + + #[inline(always)] + fn mul_c_wide(x: u64) -> (u64, u64) { + if Self::C_SHIFT_KIND == 1 { + let v = ((x as u128) << Self::C_SHIFT) + x as u128; + (v as u64, (v >> 64) as u64) + } else if Self::C_SHIFT_KIND == -1 { + let v = ((x as u128) << Self::C_SHIFT) - x as u128; + (v as u64, (v >> 64) as u64) + } else { + Self::mul_wide_u64(Self::C_LO, x) + } + } + + #[inline(always)] + fn fold2_canonicalize(t0: u64, t1: u64, t2: u64) -> (u64, u64) { + let (ct2_lo, ct2_hi) = Self::mul_c_wide(t2); + + let (s0, carry0) = t0.overflowing_add(ct2_lo); + let (s1a, carry1a) = t1.overflowing_add(ct2_hi); + let (s1, carry1b) = s1a.overflowing_add(carry0 as u64); + let overflow = carry1a | carry1b; + + let (r0, carry2) = s0.overflowing_add(Self::C_LO); + let (r1, carry3) = s1.overflowing_add(carry2 as u64); + + if overflow | carry3 { + (r0, r1) + } else { + (s0, s1) + } + } + + #[inline(always)] + fn mul_raw_lane(a0: u64, a1: u64, b0: u64, b1: u64) -> (u64, u64) { + let (p00_lo, p00_hi) = Self::mul_wide_u64(a0, b0); + let (p01_lo, p01_hi) = Self::mul_wide_u64(a0, b1); + let (p10_lo, p10_hi) = Self::mul_wide_u64(a1, b0); + let (p11_lo, p11_hi) = Self::mul_wide_u64(a1, b1); + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r0 = p00_lo; + let r1 = row1 as u64; + let carry1 = (row1 >> 64) as u64; + + let row2 = p01_hi as u128 + p10_hi as u128 + p11_lo as u128 + carry1 as u128; + let r2 = row2 as u64; + let carry2 = (row2 >> 64) as u64; + + let row3 = p11_hi as u128 + carry2 as u128; + let r3 = row3 as u64; + debug_assert_eq!(row3 >> 64, 0); + + let (cr2_lo, cr2_hi) = Self::mul_c_wide(r2); + let (cr3_lo, cr3_hi) = Self::mul_c_wide(r3); + + let t0_sum = r0 as u128 + cr2_lo as u128; + let t0 = t0_sum as u64; + let carryf = (t0_sum >> 64) as u64; + + let t1_sum = r1 as u128 + cr2_hi as u128 + cr3_lo as u128 + carryf as u128; + let t1 = t1_sum as u64; + + let t2_sum = cr3_hi as u128 + (t1_sum >> 64); + let t2 = t2_sum as u64; + debug_assert_eq!(t2_sum >> 64, 0); + + Self::fold2_canonicalize(t0, t1, t2) + } +} + +impl Default for PackedFp128Neon

{ + #[inline] + fn default() -> Self { + Self::broadcast(Fp128::zero()) + } +} + +impl fmt::Debug for PackedFp128Neon

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp128Neon") + .field(&[self.extract(0), self.extract(1)]) + .finish() + } +} + +impl PartialEq for PackedFp128Neon

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.extract(0) == other.extract(0) && self.extract(1) == other.extract(1) + } +} + +impl Eq for PackedFp128Neon

{} + +impl PackedValue for PackedFp128Neon

{ + type Value = Fp128

; + const WIDTH: usize = WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + let x0 = f(0); + let x1 = f(1); + Self { + lo: [x0.0[0], x1.0[0]], + hi: [x0.0[1], x1.0[1]], + } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < WIDTH); + Fp128([self.lo[lane], self.hi[lane]]) + } +} + +impl Add for PackedFp128Neon

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + let lo_a = to_vec(self.lo); + let hi_a = to_vec(self.hi); + let lo_b = to_vec(rhs.lo); + let hi_b = to_vec(rhs.hi); + + let (out_lo, out_hi) = unsafe { + let p_lo = vdupq_n_u64(modulus_lo::

()); + let p_hi = vdupq_n_u64(modulus_hi::

()); + + let sum_lo = vaddq_u64(lo_a, lo_b); + let carry_lo = mask_to_bit(vcltq_u64(sum_lo, lo_a)); + + let hi_tmp = vaddq_u64(hi_a, hi_b); + let carry_hi1 = vcltq_u64(hi_tmp, hi_a); + let sum_hi = vaddq_u64(hi_tmp, carry_lo); + let carry_hi2 = vcltq_u64(sum_hi, hi_tmp); + let carry_128 = vorrq_u64(carry_hi1, carry_hi2); + + let red_lo = vsubq_u64(sum_lo, p_lo); + let borrow_lo = mask_to_bit(vcgtq_u64(p_lo, sum_lo)); + + let red_hi_tmp = vsubq_u64(sum_hi, p_hi); + let borrow_hi1 = vcgtq_u64(p_hi, sum_hi); + let red_hi = vsubq_u64(red_hi_tmp, borrow_lo); + let borrow_hi2 = vcltq_u64(red_hi_tmp, borrow_lo); + let borrow = vorrq_u64(borrow_hi1, borrow_hi2); + + let not_borrow = veorq_u64(borrow, vdupq_n_u64(u64::MAX)); + let use_reduced = vorrq_u64(carry_128, not_borrow); + let out_lo = vbslq_u64(use_reduced, red_lo, sum_lo); + let out_hi = vbslq_u64(use_reduced, red_hi, sum_hi); + (out_lo, out_hi) + }; + + Self { + lo: from_vec(out_lo), + hi: from_vec(out_hi), + } + } +} + +impl Sub for PackedFp128Neon

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + let lo_a = to_vec(self.lo); + let hi_a = to_vec(self.hi); + let lo_b = to_vec(rhs.lo); + let hi_b = to_vec(rhs.hi); + + let (out_lo, out_hi) = unsafe { + let p_lo = vdupq_n_u64(modulus_lo::

()); + let p_hi = vdupq_n_u64(modulus_hi::

()); + + let diff_lo = vsubq_u64(lo_a, lo_b); + let borrow_lo = mask_to_bit(vcltq_u64(lo_a, lo_b)); + + let diff_hi_tmp = vsubq_u64(hi_a, hi_b); + let borrow_hi1 = vcltq_u64(hi_a, hi_b); + let diff_hi = vsubq_u64(diff_hi_tmp, borrow_lo); + let borrow_hi2 = vcltq_u64(diff_hi_tmp, borrow_lo); + let borrow_128 = vorrq_u64(borrow_hi1, borrow_hi2); + + let corr_lo = vaddq_u64(diff_lo, p_lo); + let carry_lo = mask_to_bit(vcltq_u64(corr_lo, diff_lo)); + + let corr_hi_tmp = vaddq_u64(diff_hi, p_hi); + let corr_hi = vaddq_u64(corr_hi_tmp, carry_lo); + + let out_lo = vbslq_u64(borrow_128, corr_lo, diff_lo); + let out_hi = vbslq_u64(borrow_128, corr_hi, diff_hi); + (out_lo, out_hi) + }; + + Self { + lo: from_vec(out_lo), + hi: from_vec(out_hi), + } + } +} + +impl Mul for PackedFp128Neon

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let (o0_lo, o0_hi) = Self::mul_raw_lane(self.lo[0], self.hi[0], rhs.lo[0], rhs.hi[0]); + let (o1_lo, o1_hi) = Self::mul_raw_lane(self.lo[1], self.hi[1], rhs.lo[1], rhs.hi[1]); + + Self { + lo: [o0_lo, o1_lo], + hi: [o0_hi, o1_hi], + } + } +} + +impl AddAssign for PackedFp128Neon

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp128Neon

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp128Neon

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp128Neon

{ + type Scalar = Fp128

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self::from_fn(|_| value) + } +} + +/// Number of packed `Fp32` lanes. +pub const FP32_WIDTH: usize = 4; + +/// NEON packed `Fp32` backend: 4 lanes in `uint32x4_t`. +#[derive(Clone, Copy)] +pub struct PackedFp32Neon { + vals: [u32; 4], +} + +#[inline(always)] +fn to_vec32(x: [u32; 4]) -> uint32x4_t { + unsafe { transmute::<[u32; 4], uint32x4_t>(x) } +} + +#[inline(always)] +fn from_vec32(v: uint32x4_t) -> [u32; 4] { + unsafe { transmute::(v) } +} + +impl PackedFp32Neon

{ + const BITS: u32 = 32 - P.leading_zeros(); + + const C: u32 = { + let c = if Self::BITS == 32 { + 0u32.wrapping_sub(P) + } else { + (1u32 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u64) * (c as u64 + 1) < P as u64, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + const MASK_U64: u64 = if Self::BITS == 32 { + u32::MAX as u64 + } else { + (1u64 << Self::BITS) - 1 + }; + + #[inline(always)] + fn mul_c_u64(hi: uint64x2_t, c: uint32x2_t) -> uint64x2_t { + unsafe { + let hi_narrow = vmovn_u64(hi); + vmull_u32(hi_narrow, c) + } + } + + #[inline(always)] + fn solinas_reduce(prod_lo: uint64x2_t, prod_hi: uint64x2_t) -> uint32x4_t { + unsafe { + let mask = vdupq_n_u64(Self::MASK_U64); + let neg_bits = vdupq_n_s64(-(Self::BITS as i64)); + let c = vdup_n_u32(Self::C); + + let f1_lo = vaddq_u64( + vandq_u64(prod_lo, mask), + Self::mul_c_u64(vshlq_u64(prod_lo, neg_bits), c), + ); + let f1_hi = vaddq_u64( + vandq_u64(prod_hi, mask), + Self::mul_c_u64(vshlq_u64(prod_hi, neg_bits), c), + ); + + let f2_lo = vaddq_u64( + vandq_u64(f1_lo, mask), + Self::mul_c_u64(vshlq_u64(f1_lo, neg_bits), c), + ); + let f2_hi = vaddq_u64( + vandq_u64(f1_hi, mask), + Self::mul_c_u64(vshlq_u64(f1_hi, neg_bits), c), + ); + + if Self::BITS < 32 { + let result = vcombine_u32(vmovn_u64(f2_lo), vmovn_u64(f2_hi)); + let p = vdupq_n_u32(P); + vminq_u32(result, vsubq_u32(result, p)) + } else { + let p_u64 = vdupq_n_u64(P as u64); + + let red_lo = vsubq_u64(f2_lo, p_u64); + let keep_lo = vcltq_u64(f2_lo, p_u64); + let out_lo = vbslq_u64(keep_lo, f2_lo, red_lo); + + let red_hi = vsubq_u64(f2_hi, p_u64); + let keep_hi = vcltq_u64(f2_hi, p_u64); + let out_hi = vbslq_u64(keep_hi, f2_hi, red_hi); + + vcombine_u32(vmovn_u64(out_lo), vmovn_u64(out_hi)) + } + } + } +} + +impl Default for PackedFp32Neon

{ + #[inline] + fn default() -> Self { + Self { vals: [0; 4] } + } +} + +impl fmt::Debug for PackedFp32Neon

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp32Neon").field(&self.vals).finish() + } +} + +impl PartialEq for PackedFp32Neon

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.vals == other.vals + } +} + +impl Eq for PackedFp32Neon

{} + +impl Add for PackedFp32Neon

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + let a = to_vec32(self.vals); + let b = to_vec32(rhs.vals); + let result = unsafe { + let p = vdupq_n_u32(P); + if Self::BITS <= 31 { + let t = vaddq_u32(a, b); + vminq_u32(t, vsubq_u32(t, p)) + } else { + let c = vdupq_n_u32(Self::C); + let t = vaddq_u32(a, b); + let overflow = vcltq_u32(t, a); + let t_plus_c = vaddq_u32(t, c); + let no_of = vminq_u32(t, vsubq_u32(t, p)); + vbslq_u32(overflow, t_plus_c, no_of) + } + }; + Self { + vals: from_vec32(result), + } + } +} + +impl Sub for PackedFp32Neon

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + let a = to_vec32(self.vals); + let b = to_vec32(rhs.vals); + let result = unsafe { + let p = vdupq_n_u32(P); + if Self::BITS <= 31 { + let t = vsubq_u32(a, b); + vminq_u32(t, vaddq_u32(t, p)) + } else { + let t = vsubq_u32(a, b); + let underflow = vcltq_u32(a, b); + vbslq_u32(underflow, vaddq_u32(t, p), t) + } + }; + Self { + vals: from_vec32(result), + } + } +} + +impl Mul for PackedFp32Neon

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let a = to_vec32(self.vals); + let b = to_vec32(rhs.vals); + let result = unsafe { + let prod_lo = vmull_u32(vget_low_u32(a), vget_low_u32(b)); + let prod_hi = vmull_high_u32(a, b); + Self::solinas_reduce(prod_lo, prod_hi) + }; + Self { + vals: from_vec32(result), + } + } +} + +impl PackedValue for PackedFp32Neon

{ + type Value = Fp32

; + const WIDTH: usize = FP32_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self { + vals: [f(0).0, f(1).0, f(2).0, f(3).0], + } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP32_WIDTH); + Fp32(self.vals[lane]) + } +} + +impl AddAssign for PackedFp32Neon

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp32Neon

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp32Neon

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp32Neon

{ + type Scalar = Fp32

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self { vals: [value.0; 4] } + } +} + +/// Number of packed `Fp64` lanes. +pub const FP64_WIDTH: usize = 2; + +/// NEON packed `Fp64` backend: 2 lanes in `uint64x2_t`. +#[derive(Clone, Copy)] +pub struct PackedFp64Neon { + vals: [u64; 2], +} + +impl PackedFp64Neon

{ + const BITS: u32 = 64 - P.leading_zeros(); + + const C_LO: u64 = { + let c = if Self::BITS == 64 { + 0u64.wrapping_sub(P) + } else { + (1u64 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + c + }; + + const MASK64: u64 = if Self::BITS < 64 { + (1u64 << Self::BITS) - 1 + } else { + u64::MAX + }; + + const MASK_U128: u128 = if Self::BITS == 64 { + u64::MAX as u128 + } else { + (1u128 << Self::BITS) - 1 + }; + + const FOLD_IN_U64: bool = + Self::BITS < 64 && (Self::C_LO as u128) < (1u128 << (64 - Self::BITS)); + + #[inline(always)] + fn mul_c_narrow(x: u64) -> u64 { + Self::C_LO.wrapping_mul(x) + } + + #[inline(always)] + fn reduce_product(x: u128) -> u64 { + if Self::FOLD_IN_U64 { + let lo = x as u64; + let hi = (x >> 64) as u64; + let high = (lo >> Self::BITS) | (hi << (64 - Self::BITS)); + let f1 = (lo & Self::MASK64).wrapping_add(Self::mul_c_narrow(high)); + let f2 = (f1 & Self::MASK64).wrapping_add(Self::mul_c_narrow(f1 >> Self::BITS)); + let reduced = f2.wrapping_sub(P); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P) + } else { + let f1 = + (x & Self::MASK_U128) + (Self::C_LO as u128) * ((x >> Self::BITS) as u64 as u128); + let f2 = + (f1 & Self::MASK_U128) + (Self::C_LO as u128) * ((f1 >> Self::BITS) as u64 as u128); + let reduced = f2.wrapping_sub(P as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + } +} + +impl Default for PackedFp64Neon

{ + #[inline] + fn default() -> Self { + Self { vals: [0; 2] } + } +} + +impl fmt::Debug for PackedFp64Neon

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp64Neon").field(&self.vals).finish() + } +} + +impl PartialEq for PackedFp64Neon

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.vals == other.vals + } +} + +impl Eq for PackedFp64Neon

{} + +impl Add for PackedFp64Neon

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + let a = to_vec(self.vals); + let b = to_vec(rhs.vals); + let result = unsafe { + let p = vdupq_n_u64(P); + if Self::BITS <= 62 { + let s = vaddq_u64(a, b); + let r = vsubq_u64(s, p); + let borrow = vcltq_u64(s, p); + vbslq_u64(borrow, s, r) + } else { + let s = vaddq_u64(a, b); + let overflow = vcltq_u64(s, a); + let c = vdupq_n_u64(Self::C_LO); + let s_plus_c = vaddq_u64(s, c); + let s_minus_p = vsubq_u64(s, p); + let borrow = vcltq_u64(s, p); + let no_of = vbslq_u64(borrow, s, s_minus_p); + vbslq_u64(overflow, s_plus_c, no_of) + } + }; + Self { + vals: from_vec(result), + } + } +} + +impl Sub for PackedFp64Neon

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + let a = to_vec(self.vals); + let b = to_vec(rhs.vals); + let result = unsafe { + let p = vdupq_n_u64(P); + let d = vsubq_u64(a, b); + let underflow = vcltq_u64(a, b); + vbslq_u64(underflow, vaddq_u64(d, p), d) + }; + Self { + vals: from_vec(result), + } + } +} + +impl Mul for PackedFp64Neon

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let x0 = (self.vals[0] as u128) * (rhs.vals[0] as u128); + let x1 = (self.vals[1] as u128) * (rhs.vals[1] as u128); + let r0 = Self::reduce_product(x0); + let r1 = Self::reduce_product(x1); + Self { vals: [r0, r1] } + } +} + +impl PackedValue for PackedFp64Neon

{ + type Value = Fp64

; + const WIDTH: usize = FP64_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self { + vals: [f(0).0, f(1).0], + } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP64_WIDTH); + Fp64(self.vals[lane]) + } +} + +impl AddAssign for PackedFp64Neon

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp64Neon

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp64Neon

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp64Neon

{ + type Scalar = Fp64

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self { vals: [value.0; 2] } + } +} diff --git a/src/algebra/fields/pseudo_mersenne.rs b/src/algebra/fields/pseudo_mersenne.rs new file mode 100644 index 00000000..17920899 --- /dev/null +++ b/src/algebra/fields/pseudo_mersenne.rs @@ -0,0 +1,177 @@ +//! `2^k - offset` pseudo-Mersenne registry and field aliases. +//! +//! This module models the specific flavor where each modulus is the smallest +//! prime below `2^k` with `q % 8 == 5`, written as `q = 2^k - offset`. + +use super::{Fp128, Fp32, Fp64}; + +/// Offset table (`q = 2^k - offset[k]`) imported from `labrador/data.py`. +pub const POW2_OFFSET_TABLE: [i16; 256] = [ + -1, -1, -1, 3, 3, 3, 3, 19, 27, 3, 3, 19, 3, 75, 3, 19, 99, 91, 11, 19, 3, 19, 3, 27, 3, 91, + 27, 115, 299, 3, 35, 19, 99, 355, 131, 451, 243, 123, 107, 19, 195, 75, 11, 67, 539, 139, 635, + 115, 59, 123, 27, 139, 395, 315, 131, 67, 27, 195, 27, 99, 107, 259, 171, 259, 59, 115, 203, + 19, 83, 19, 35, 411, 107, 475, 35, 427, 123, 43, 11, 67, 1307, 51, 315, 139, 35, 19, 35, 67, + 299, 99, 75, 315, 83, 51, 3, 211, 147, 595, 51, 115, 99, 99, 483, 339, 395, 139, 1187, 171, 59, + 91, 195, 835, 75, 211, 11, 67, 3, 451, 563, 867, 395, 531, 3, 67, 59, 579, 203, 507, 275, 315, + 27, 315, 347, 99, 603, 795, 243, 339, 203, 187, 27, 171, 1491, 355, 83, 355, 1371, 387, 347, + 99, 3, 195, 539, 171, 243, 499, 195, 19, 155, 91, 75, 1011, 627, 867, 155, 115, 1811, 771, + 1467, 643, 195, 19, 155, 531, 3, 267, 563, 339, 563, 507, 107, 283, 267, 147, 59, 339, 371, + 1411, 363, 819, 11, 19, 915, 123, 75, 915, 459, 75, 627, 459, 75, 1035, 195, 187, 1515, 1219, + 1443, 91, 299, 451, 171, 1099, 99, 3, 395, 1147, 683, 675, 243, 355, 395, 3, 875, 235, 363, + 1131, 155, 835, 723, 91, 27, 235, 875, 3, 83, 259, 875, 1515, 731, 531, 467, 819, 267, 475, + 1923, 163, 107, 411, 387, 75, 2331, 355, 1515, 1723, 1427, 19, +]; + +/// Maximum supported offset in this `2^k - offset` specialization. +pub const POW2_OFFSET_MAX: u128 = 1u128 << 16; + +/// Current active bit-size bound for concrete field aliases in this phase. +pub const POW2_OFFSET_IMPLEMENTED_MAX_BITS: u32 = 128; + +/// Metadata describing a `2^k - offset` pseudo-Mersenne modulus. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Pow2OffsetPrimeSpec { + /// `k` in `2^k - offset`. + pub bits: u32, + /// `offset` in `2^k - offset`. + pub offset: u16, + /// Modulus value. + pub modulus: u128, +} + +/// Return table offset for `q = 2^k - offset` when available and positive. +pub const fn pow2_offset(bits: u32) -> Option { + if bits as usize >= POW2_OFFSET_TABLE.len() { + return None; + } + let offset = POW2_OFFSET_TABLE[bits as usize]; + if offset <= 0 { + None + } else { + Some(offset as u16) + } +} + +/// Compute `2^k - offset` for `k <= 128`. +pub const fn pseudo_mersenne_modulus(bits: u32, offset: u128) -> Option { + if bits == 0 || bits > 128 || offset == 0 { + return None; + } + if bits == 128 { + Some(u128::MAX - (offset - 1)) + } else { + Some((1u128 << bits) - offset) + } +} + +/// Check whether `(k, offset)` is accepted by the `2^k - offset` policy. +pub const fn is_pow2_offset(bits: u32, offset: u128) -> bool { + if bits > POW2_OFFSET_IMPLEMENTED_MAX_BITS || offset > POW2_OFFSET_MAX { + return false; + } + match pow2_offset(bits) { + Some(qoff) => (qoff as u128) == offset, + None => false, + } +} + +/// `offset` for `k = 24`. +pub const POW2_OFFSET_24: u16 = 3; +/// `offset` for `k = 30`. +pub const POW2_OFFSET_30: u16 = 35; +/// `offset` for `k = 31`. +pub const POW2_OFFSET_31: u16 = 19; +/// `offset` for `k = 32`. +pub const POW2_OFFSET_32: u16 = 99; +/// `offset` for `k = 40`. +pub const POW2_OFFSET_40: u16 = 195; +/// `offset` for `k = 48`. +pub const POW2_OFFSET_48: u16 = 59; +/// `offset` for `k = 56`. +pub const POW2_OFFSET_56: u16 = 27; +/// `offset` for `k = 64`. +pub const POW2_OFFSET_64: u16 = 59; +/// `offset` for `k = 128`. +pub const POW2_OFFSET_128: u16 = 275; + +/// `2^24 - 3`. +pub const POW2_OFFSET_MODULUS_24: u32 = ((1u128 << 24) - (POW2_OFFSET_24 as u128)) as u32; +/// `2^30 - 35`. +pub const POW2_OFFSET_MODULUS_30: u32 = ((1u128 << 30) - (POW2_OFFSET_30 as u128)) as u32; +/// `2^31 - 19`. +pub const POW2_OFFSET_MODULUS_31: u32 = ((1u128 << 31) - (POW2_OFFSET_31 as u128)) as u32; +/// `2^32 - 99`. +pub const POW2_OFFSET_MODULUS_32: u32 = ((1u128 << 32) - (POW2_OFFSET_32 as u128)) as u32; +/// `2^40 - 195`. +pub const POW2_OFFSET_MODULUS_40: u64 = ((1u128 << 40) - (POW2_OFFSET_40 as u128)) as u64; +/// `2^48 - 59`. +pub const POW2_OFFSET_MODULUS_48: u64 = ((1u128 << 48) - (POW2_OFFSET_48 as u128)) as u64; +/// `2^56 - 27`. +pub const POW2_OFFSET_MODULUS_56: u64 = ((1u128 << 56) - (POW2_OFFSET_56 as u128)) as u64; +/// `2^64 - 59`. +pub const POW2_OFFSET_MODULUS_64: u64 = u64::MAX - ((POW2_OFFSET_64 as u64) - 1); +/// `2^128 - 275`. +pub const POW2_OFFSET_MODULUS_128: u128 = u128::MAX - (POW2_OFFSET_128 as u128 - 1); + +/// Alias for `2^24 - offset`. +pub type Pow2Offset24Field = Fp32; +/// Alias for `2^30 - offset`. +pub type Pow2Offset30Field = Fp32; +/// Alias for `2^31 - offset`. +pub type Pow2Offset31Field = Fp32; +/// Alias for `2^32 - offset`. +pub type Pow2Offset32Field = Fp32; +/// Alias for `2^40 - offset`. +pub type Pow2Offset40Field = Fp64; +/// Alias for `2^48 - offset`. +pub type Pow2Offset48Field = Fp64; +/// Alias for `2^56 - offset`. +pub type Pow2Offset56Field = Fp64; +/// Alias for `2^64 - offset`. +pub type Pow2Offset64Field = Fp64; +/// Alias for `2^128 - offset`. +pub type Pow2Offset128Field = Fp128; + +/// `2^k - offset` profiles currently enabled in-code. +/// +/// Each listed modulus satisfies `q % 8 == 5`. +pub const POW2_OFFSET_PRIMES: [Pow2OffsetPrimeSpec; 7] = [ + Pow2OffsetPrimeSpec { + bits: 24, + offset: POW2_OFFSET_24, + modulus: POW2_OFFSET_MODULUS_24 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 32, + offset: POW2_OFFSET_32, + modulus: POW2_OFFSET_MODULUS_32 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 40, + offset: POW2_OFFSET_40, + modulus: POW2_OFFSET_MODULUS_40 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 48, + offset: POW2_OFFSET_48, + modulus: POW2_OFFSET_MODULUS_48 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 56, + offset: POW2_OFFSET_56, + modulus: POW2_OFFSET_MODULUS_56 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 64, + offset: POW2_OFFSET_64, + modulus: POW2_OFFSET_MODULUS_64 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 128, + offset: POW2_OFFSET_128, + modulus: POW2_OFFSET_MODULUS_128, + }, +]; + +// All PseudoMersenneField impls for Fp32/Fp64/Fp128 are blanket impls in +// their respective modules (fp32.rs, fp64.rs, fp128.rs). diff --git a/src/algebra/fields/util.rs b/src/algebra/fields/util.rs new file mode 100644 index 00000000..9496f1ce --- /dev/null +++ b/src/algebra/fields/util.rs @@ -0,0 +1,38 @@ +//! Shared helpers for field arithmetic backends. + +#[inline(always)] +pub(crate) const fn is_pow2_u64(x: u64) -> bool { + x != 0 && (x & (x - 1)) == 0 +} + +#[inline(always)] +pub(crate) const fn log2_pow2_u64(mut x: u64) -> u32 { + let mut k = 0u32; + while x > 1 { + x >>= 1; + k += 1; + } + k +} + +/// `a * b` widening to 128 bits; returns `(lo64, hi64)`. +#[inline(always)] +pub(crate) fn mul64_wide(a: u64, b: u64) -> (u64, u64) { + #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] + { + unsafe { mul64_wide_bmi2(a, b) } + } + #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))] + { + let prod = (a as u128) * (b as u128); + (prod as u64, (prod >> 64) as u64) + } +} + +#[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] +#[inline(always)] +unsafe fn mul64_wide_bmi2(a: u64, b: u64) -> (u64, u64) { + let mut hi = 0; + let lo = unsafe { std::arch::x86_64::_mulx_u64(a, b, &mut hi) }; + (lo, hi) +} diff --git a/src/algebra/fields/wide.rs b/src/algebra/fields/wide.rs new file mode 100644 index 00000000..e160f3d1 --- /dev/null +++ b/src/algebra/fields/wide.rs @@ -0,0 +1,1238 @@ +//! Wide unreduced field accumulators for carry-free signed addition. +//! +//! Each type splits a canonical field element into 16-bit limbs stored in +//! `i32` slots. Addition and negation are element-wise i32 ops — no carry +//! propagation, no modular reduction. Reduction back to canonical form +//! happens once after accumulation via [`reduce`](Fp128x8i32::reduce). +//! +//! The i32 overflow budget is `i32::MAX / u16::MAX ≈ 32,769` signed +//! additions before any limb can overflow. + +use std::ops::{Add, AddAssign, Neg, Sub, SubAssign}; + +use crate::{AdditiveGroup, CanonicalField, FieldCore}; + +use super::fp128::Fp128; +use super::fp32::Fp32; +use super::fp64::Fp64; + +/// Wide unreduced accumulator for `Fp32`: 2 × i32 limbs (16-bit data each). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(C)] +pub struct Fp32x2i32(pub [i32; 2]); + +impl Fp32x2i32 { + /// Returns the zero accumulator. + #[inline] + pub fn zero() -> Self { + ::ZERO + } +} + +impl From> for Fp32x2i32 { + #[inline] + fn from(x: Fp32

) -> Self { + let v = x.0; + Self([(v & 0xFFFF) as i32, (v >> 16) as i32]) + } +} + +impl Fp32x2i32 { + /// Multiply every limb by a small signed scalar. + /// + /// Safe when `|small| * max_limb_magnitude` fits in i32. After `From`, + /// limbs are in `[0, 0xFFFF]`, so `|small| ≤ 32_767` is safe for a single + /// product. For accumulation of `k` scaled values, require + /// `k * |small| * 0xFFFF < i32::MAX`, i.e. roughly `k * |small| < 32_768`. + #[inline] + pub fn scale_i32(self, small: i32) -> Self { + Self([self.0[0] * small, self.0[1] * small]) + } + + /// Reduce back to canonical `Fp32

`. + /// + /// Carry-propagates the i32 limbs into a signed value, normalizes to + /// `[0, p)`, and returns the canonical field element. + #[inline] + pub fn reduce(self) -> Fp32

{ + let [l0, l1] = self.0; + // Carry-propagate: value = l0 + l1 * 2^16 + let wide = l0 as i64 + (l1 as i64) * (1i64 << 16); + // Normalize to [0, p) + let p = P as i64; + let normalized = ((wide % p) + p) % p; + Fp32::from_canonical_u32(normalized as u32) + } +} + +impl Add for Fp32x2i32 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([self.0[0] + rhs.0[0], self.0[1] + rhs.0[1]]) + } +} + +impl AddAssign for Fp32x2i32 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0[0] += rhs.0[0]; + self.0[1] += rhs.0[1]; + } +} + +impl Sub for Fp32x2i32 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([self.0[0] - rhs.0[0], self.0[1] - rhs.0[1]]) + } +} + +impl SubAssign for Fp32x2i32 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0[0] -= rhs.0[0]; + self.0[1] -= rhs.0[1]; + } +} + +impl Neg for Fp32x2i32 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self([-self.0[0], -self.0[1]]) + } +} + +/// Wide unreduced accumulator for `Fp64`: 4 × i32 limbs (16-bit data each). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(C)] +pub struct Fp64x4i32(pub [i32; 4]); + +impl Fp64x4i32 { + /// Returns the zero accumulator. + #[inline] + pub fn zero() -> Self { + ::ZERO + } +} + +impl From> for Fp64x4i32 { + #[inline] + fn from(x: Fp64

) -> Self { + let v = x.0; + Self([ + (v & 0xFFFF) as i32, + ((v >> 16) & 0xFFFF) as i32, + ((v >> 32) & 0xFFFF) as i32, + ((v >> 48) & 0xFFFF) as i32, + ]) + } +} + +impl Fp64x4i32 { + /// Multiply every limb by a small signed scalar. See [`Fp32x2i32::scale_i32`]. + #[inline] + pub fn scale_i32(self, small: i32) -> Self { + Self([ + self.0[0] * small, + self.0[1] * small, + self.0[2] * small, + self.0[3] * small, + ]) + } + + /// Reduce back to canonical `Fp64

`. + #[inline] + pub fn reduce(self) -> Fp64

{ + let [l0, l1, l2, l3] = self.0; + // Carry-propagate: value = l0 + l1*2^16 + l2*2^32 + l3*2^48 + let wide = l0 as i128 + + (l1 as i128) * (1i128 << 16) + + (l2 as i128) * (1i128 << 32) + + (l3 as i128) * (1i128 << 48); + let p = P as i128; + let normalized = ((wide % p) + p) % p; + Fp64::

::from_canonical_u64(normalized as u64) + } +} + +#[cfg(target_arch = "aarch64")] +impl Add for Fp64x4i32 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + use std::arch::aarch64::*; + let a = vld1q_s32(self.0.as_ptr()); + let b = vld1q_s32(rhs.0.as_ptr()); + let mut out = [0i32; 4]; + vst1q_s32(out.as_mut_ptr(), vaddq_s32(a, b)); + Self(out) + } + } +} + +#[cfg(target_arch = "aarch64")] +impl AddAssign for Fp64x4i32 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +#[cfg(target_arch = "aarch64")] +impl Sub for Fp64x4i32 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + use std::arch::aarch64::*; + let a = vld1q_s32(self.0.as_ptr()); + let b = vld1q_s32(rhs.0.as_ptr()); + let mut out = [0i32; 4]; + vst1q_s32(out.as_mut_ptr(), vsubq_s32(a, b)); + Self(out) + } + } +} + +#[cfg(target_arch = "aarch64")] +impl SubAssign for Fp64x4i32 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +#[cfg(target_arch = "aarch64")] +impl Neg for Fp64x4i32 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + unsafe { + use std::arch::aarch64::*; + let a = vld1q_s32(self.0.as_ptr()); + let mut out = [0i32; 4]; + vst1q_s32(out.as_mut_ptr(), vnegq_s32(a)); + Self(out) + } + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl Add for Fp64x4i32 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([ + self.0[0] + rhs.0[0], + self.0[1] + rhs.0[1], + self.0[2] + rhs.0[2], + self.0[3] + rhs.0[3], + ]) + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl AddAssign for Fp64x4i32 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0[0] += rhs.0[0]; + self.0[1] += rhs.0[1]; + self.0[2] += rhs.0[2]; + self.0[3] += rhs.0[3]; + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl Sub for Fp64x4i32 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([ + self.0[0] - rhs.0[0], + self.0[1] - rhs.0[1], + self.0[2] - rhs.0[2], + self.0[3] - rhs.0[3], + ]) + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl SubAssign for Fp64x4i32 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0[0] -= rhs.0[0]; + self.0[1] -= rhs.0[1]; + self.0[2] -= rhs.0[2]; + self.0[3] -= rhs.0[3]; + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl Neg for Fp64x4i32 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self([-self.0[0], -self.0[1], -self.0[2], -self.0[3]]) + } +} + +/// Wide unreduced accumulator for `Fp128`: 8 × i32 limbs (16-bit data each). +/// +/// On AVX2, one element fits a single 256-bit YMM register. On NEON, it +/// spans two 128-bit Q registers. All arithmetic is carry-free element-wise +/// i32 operations. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(C)] +pub struct Fp128x8i32(pub [i32; 8]); + +impl Fp128x8i32 { + /// Returns the zero accumulator. + #[inline] + pub fn zero() -> Self { + ::ZERO + } +} + +impl From> for Fp128x8i32 { + #[inline] + fn from(x: Fp128

) -> Self { + let lo = x.0[0]; + let hi = x.0[1]; + Self([ + (lo & 0xFFFF) as i32, + ((lo >> 16) & 0xFFFF) as i32, + ((lo >> 32) & 0xFFFF) as i32, + ((lo >> 48) & 0xFFFF) as i32, + (hi & 0xFFFF) as i32, + ((hi >> 16) & 0xFFFF) as i32, + ((hi >> 32) & 0xFFFF) as i32, + ((hi >> 48) & 0xFFFF) as i32, + ]) + } +} + +impl Fp128x8i32 { + /// Multiply every limb by a small signed scalar. See [`Fp32x2i32::scale_i32`]. + #[inline] + pub fn scale_i32(self, small: i32) -> Self { + Self([ + self.0[0] * small, + self.0[1] * small, + self.0[2] * small, + self.0[3] * small, + self.0[4] * small, + self.0[5] * small, + self.0[6] * small, + self.0[7] * small, + ]) + } + + /// Reduce back to canonical `Fp128

`. + /// + /// Carry-propagates the 8 × i32 limbs into unsigned u64 limbs, then + /// applies Solinas reduction. + #[inline] + pub fn reduce(self) -> Fp128

{ + let limbs = self.0; + + // Carry-propagate from low to high, accumulating into i64 slots. + // Each i32 limb can be in [-32769*65535, 32769*65535] ≈ ±2^31. + // After propagation, each 16-bit "digit" is in [0, 65535] and we + // may have a signed residual in the top that overflows 128 bits. + let mut carry: i64 = 0; + let mut digits = [0u16; 8]; + for i in 0..8 { + let v = limbs[i] as i64 + carry; + // Arithmetic right-shift to propagate sign correctly + digits[i] = (v & 0xFFFF) as u16; + carry = v >> 16; + } + + // Reassemble into u64 limbs + let lo = digits[0] as u64 + | (digits[1] as u64) << 16 + | (digits[2] as u64) << 32 + | (digits[3] as u64) << 48; + let hi = digits[4] as u64 + | (digits[5] as u64) << 16 + | (digits[6] as u64) << 32 + | (digits[7] as u64) << 48; + + // p = 2^128 - c, so 2^128 ≡ c (mod p). + // value = lo + hi*2^64 + carry*2^128 ≡ lo + hi*2^64 + carry*c (mod p). + let c = Fp128::

::C_LO; + if carry == 0 { + Fp128::

::from_canonical_u128_reduced(lo as u128 | (hi as u128) << 64) + } else if carry > 0 { + Fp128::

::solinas_reduce(&[lo, hi, carry as u64]) + } else { + // carry < 0: value = base - |carry|*c. + let neg_carry = (-carry) as u64; + let sub = neg_carry as u128 * c as u128; + let base = lo as u128 | (hi as u128) << 64; + if base >= sub { + Fp128::

::from_canonical_u128_reduced(base - sub) + } else { + let diff = sub - base; + Fp128::

::from_canonical_u128_reduced(P - diff) + } + } + } +} + +#[cfg(target_arch = "aarch64")] +impl Add for Fp128x8i32 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + use std::arch::aarch64::*; + let a0 = vld1q_s32(self.0.as_ptr()); + let a1 = vld1q_s32(self.0.as_ptr().add(4)); + let b0 = vld1q_s32(rhs.0.as_ptr()); + let b1 = vld1q_s32(rhs.0.as_ptr().add(4)); + let mut out = [0i32; 8]; + vst1q_s32(out.as_mut_ptr(), vaddq_s32(a0, b0)); + vst1q_s32(out.as_mut_ptr().add(4), vaddq_s32(a1, b1)); + Self(out) + } + } +} + +#[cfg(target_arch = "aarch64")] +impl AddAssign for Fp128x8i32 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +#[cfg(target_arch = "aarch64")] +impl Sub for Fp128x8i32 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + use std::arch::aarch64::*; + let a0 = vld1q_s32(self.0.as_ptr()); + let a1 = vld1q_s32(self.0.as_ptr().add(4)); + let b0 = vld1q_s32(rhs.0.as_ptr()); + let b1 = vld1q_s32(rhs.0.as_ptr().add(4)); + let mut out = [0i32; 8]; + vst1q_s32(out.as_mut_ptr(), vsubq_s32(a0, b0)); + vst1q_s32(out.as_mut_ptr().add(4), vsubq_s32(a1, b1)); + Self(out) + } + } +} + +#[cfg(target_arch = "aarch64")] +impl SubAssign for Fp128x8i32 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +#[cfg(target_arch = "aarch64")] +impl Neg for Fp128x8i32 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + unsafe { + use std::arch::aarch64::*; + let a0 = vld1q_s32(self.0.as_ptr()); + let a1 = vld1q_s32(self.0.as_ptr().add(4)); + let mut out = [0i32; 8]; + vst1q_s32(out.as_mut_ptr(), vnegq_s32(a0)); + vst1q_s32(out.as_mut_ptr().add(4), vnegq_s32(a1)); + Self(out) + } + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl Add for Fp128x8i32 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([ + self.0[0] + rhs.0[0], + self.0[1] + rhs.0[1], + self.0[2] + rhs.0[2], + self.0[3] + rhs.0[3], + self.0[4] + rhs.0[4], + self.0[5] + rhs.0[5], + self.0[6] + rhs.0[6], + self.0[7] + rhs.0[7], + ]) + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl AddAssign for Fp128x8i32 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0[0] += rhs.0[0]; + self.0[1] += rhs.0[1]; + self.0[2] += rhs.0[2]; + self.0[3] += rhs.0[3]; + self.0[4] += rhs.0[4]; + self.0[5] += rhs.0[5]; + self.0[6] += rhs.0[6]; + self.0[7] += rhs.0[7]; + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl Sub for Fp128x8i32 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([ + self.0[0] - rhs.0[0], + self.0[1] - rhs.0[1], + self.0[2] - rhs.0[2], + self.0[3] - rhs.0[3], + self.0[4] - rhs.0[4], + self.0[5] - rhs.0[5], + self.0[6] - rhs.0[6], + self.0[7] - rhs.0[7], + ]) + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl SubAssign for Fp128x8i32 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0[0] -= rhs.0[0]; + self.0[1] -= rhs.0[1]; + self.0[2] -= rhs.0[2]; + self.0[3] -= rhs.0[3]; + self.0[4] -= rhs.0[4]; + self.0[5] -= rhs.0[5]; + self.0[6] -= rhs.0[6]; + self.0[7] -= rhs.0[7]; + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl Neg for Fp128x8i32 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self([ + -self.0[0], -self.0[1], -self.0[2], -self.0[3], -self.0[4], -self.0[5], -self.0[6], + -self.0[7], + ]) + } +} + +impl AdditiveGroup for Fp32x2i32 { + const ZERO: Self = Self([0; 2]); +} + +impl AdditiveGroup for Fp64x4i32 { + const ZERO: Self = Self([0; 4]); +} + +impl AdditiveGroup for Fp128x8i32 { + const ZERO: Self = Self([0; 8]); +} + +/// Accumulator for `Fp64 × u64` products (also used for `Fp64 × Fp64`). +/// +/// Each product is ≤ 128 bits, split into two u64 halves stored as u128 slots. +/// Headroom: 2^64 additions per slot before overflow. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Fp64ProductAccum(pub [u128; 2]); + +impl Fp64ProductAccum { + /// Reduce accumulated products to a canonical `Fp64

`. + #[inline] + pub fn reduce(self) -> Fp64

{ + let [s0, s1] = self.0; + // s0 = Σ lo_i, s1 = Σ hi_i; value = s0 + s1 * 2^64 + let a = Fp64::

::solinas_reduce(s0); + let b = Fp64::

::solinas_reduce(s1); + let shift = Fp64::

::solinas_reduce(1u128 << 64); + let b_shifted = Fp64::

::solinas_reduce(b.mul_wide_u64(shift.to_limbs())); + a + b_shifted + } +} + +impl From> for Fp64ProductAccum { + #[inline] + fn from(x: Fp64

) -> Self { + Self([x.to_limbs() as u128, 0]) + } +} + +impl AdditiveGroup for Fp64ProductAccum { + const ZERO: Self = Self([0; 2]); +} + +impl Add for Fp64ProductAccum { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([self.0[0] + rhs.0[0], self.0[1] + rhs.0[1]]) + } +} +impl AddAssign for Fp64ProductAccum { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0[0] += rhs.0[0]; + self.0[1] += rhs.0[1]; + } +} +impl Sub for Fp64ProductAccum { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([ + self.0[0].wrapping_sub(rhs.0[0]), + self.0[1].wrapping_sub(rhs.0[1]), + ]) + } +} +impl SubAssign for Fp64ProductAccum { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0[0] = self.0[0].wrapping_sub(rhs.0[0]); + self.0[1] = self.0[1].wrapping_sub(rhs.0[1]); + } +} +impl Neg for Fp64ProductAccum { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self([self.0[0].wrapping_neg(), self.0[1].wrapping_neg()]) + } +} + +/// Accumulator for `Fp128 × u64` products. +/// +/// Each `mul_wide_u64` produces 3 u64 limbs; stored as `[u128; 3]`. +/// Headroom: 2^64 additions per slot. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Fp128MulU64Accum(pub [u128; 3]); + +impl Fp128MulU64Accum { + /// Reduce to canonical `Fp128

`. + #[inline] + pub fn reduce(self) -> Fp128

{ + let [s0, s1, s2] = self.0; + let c0 = s0 >> 64; + let r0 = s0 as u64; + let t1 = s1 + c0; + let r1 = t1 as u64; + let c1 = t1 >> 64; + let t2 = s2 + c1; + let r2 = t2 as u64; + let r3 = (t2 >> 64) as u64; + Fp128::

::solinas_reduce(&[r0, r1, r2, r3]) + } +} + +impl From> for Fp128MulU64Accum { + #[inline] + fn from(x: Fp128

) -> Self { + let [lo, hi] = x.to_limbs(); + Self([lo as u128, hi as u128, 0]) + } +} + +impl AdditiveGroup for Fp128MulU64Accum { + const ZERO: Self = Self([0; 3]); +} + +impl Add for Fp128MulU64Accum { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([ + self.0[0] + rhs.0[0], + self.0[1] + rhs.0[1], + self.0[2] + rhs.0[2], + ]) + } +} +impl AddAssign for Fp128MulU64Accum { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0[0] += rhs.0[0]; + self.0[1] += rhs.0[1]; + self.0[2] += rhs.0[2]; + } +} +impl Sub for Fp128MulU64Accum { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([ + self.0[0].wrapping_sub(rhs.0[0]), + self.0[1].wrapping_sub(rhs.0[1]), + self.0[2].wrapping_sub(rhs.0[2]), + ]) + } +} +impl SubAssign for Fp128MulU64Accum { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0[0] = self.0[0].wrapping_sub(rhs.0[0]); + self.0[1] = self.0[1].wrapping_sub(rhs.0[1]); + self.0[2] = self.0[2].wrapping_sub(rhs.0[2]); + } +} +impl Neg for Fp128MulU64Accum { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self([ + self.0[0].wrapping_neg(), + self.0[1].wrapping_neg(), + self.0[2].wrapping_neg(), + ]) + } +} + +/// Accumulator for `Fp128 × Fp128` products. +/// +/// Each `mul_wide` produces 4 u64 limbs; stored as `[u128; 4]`. +/// Headroom: 2^64 additions per slot. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Fp128ProductAccum(pub [u128; 4]); + +impl Fp128ProductAccum { + /// Reduce to canonical `Fp128

`. + #[inline] + pub fn reduce(self) -> Fp128

{ + let [s0, s1, s2, s3] = self.0; + let c0 = s0 >> 64; + let r0 = s0 as u64; + let t1 = s1 + c0; + let r1 = t1 as u64; + let c1 = t1 >> 64; + let t2 = s2 + c1; + let r2 = t2 as u64; + let c2 = t2 >> 64; + let t3 = s3 + c2; + let r3 = t3 as u64; + let r4 = (t3 >> 64) as u64; + Fp128::

::solinas_reduce(&[r0, r1, r2, r3, r4]) + } +} + +impl From> for Fp128ProductAccum { + #[inline] + fn from(x: Fp128

) -> Self { + let [lo, hi] = x.to_limbs(); + Self([lo as u128, hi as u128, 0, 0]) + } +} + +impl AdditiveGroup for Fp128ProductAccum { + const ZERO: Self = Self([0; 4]); +} + +impl Add for Fp128ProductAccum { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([ + self.0[0] + rhs.0[0], + self.0[1] + rhs.0[1], + self.0[2] + rhs.0[2], + self.0[3] + rhs.0[3], + ]) + } +} +impl AddAssign for Fp128ProductAccum { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0[0] += rhs.0[0]; + self.0[1] += rhs.0[1]; + self.0[2] += rhs.0[2]; + self.0[3] += rhs.0[3]; + } +} +impl Sub for Fp128ProductAccum { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([ + self.0[0].wrapping_sub(rhs.0[0]), + self.0[1].wrapping_sub(rhs.0[1]), + self.0[2].wrapping_sub(rhs.0[2]), + self.0[3].wrapping_sub(rhs.0[3]), + ]) + } +} +impl SubAssign for Fp128ProductAccum { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0[0] = self.0[0].wrapping_sub(rhs.0[0]); + self.0[1] = self.0[1].wrapping_sub(rhs.0[1]); + self.0[2] = self.0[2].wrapping_sub(rhs.0[2]); + self.0[3] = self.0[3].wrapping_sub(rhs.0[3]); + } +} +impl Neg for Fp128ProductAccum { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self([ + self.0[0].wrapping_neg(), + self.0[1].wrapping_neg(), + self.0[2].wrapping_neg(), + self.0[3].wrapping_neg(), + ]) + } +} + +/// Pair accumulator for extension fields. +/// +/// Wraps two base-field accumulators `(c0, c1)` component-wise. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AccumPair(pub A, pub A); + +impl AdditiveGroup for AccumPair { + const ZERO: Self = Self(A::ZERO, A::ZERO); +} + +impl Add for AccumPair { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self(self.0 + rhs.0, self.1 + rhs.1) + } +} +impl AddAssign for AccumPair { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0 += rhs.0; + self.1 += rhs.1; + } +} +impl Sub for AccumPair { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self(self.0 - rhs.0, self.1 - rhs.1) + } +} +impl SubAssign for AccumPair { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0 -= rhs.0; + self.1 -= rhs.1; + } +} +impl Neg for AccumPair { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self(-self.0, -self.1) + } +} + +/// Reduce a wide unreduced accumulator back to a canonical field element. +pub trait ReduceTo { + /// Carry-propagate and reduce to a canonical field element. + fn reduce(self) -> F; +} + +impl ReduceTo> for Fp32x2i32 { + #[inline] + fn reduce(self) -> Fp32

{ + Fp32x2i32::reduce::

(self) + } +} + +impl ReduceTo> for Fp64x4i32 { + #[inline] + fn reduce(self) -> Fp64

{ + Fp64x4i32::reduce::

(self) + } +} + +impl ReduceTo> for Fp128x8i32 { + #[inline] + fn reduce(self) -> Fp128

{ + Fp128x8i32::reduce::

(self) + } +} + +/// Multi-level unreduced multiplication hierarchy. +/// +/// Provides `field × u64` and `field × field` widening multiplies that return +/// accumulator types supporting carry-free addition. Reduction back to a +/// canonical field element happens once after accumulation. +pub trait HasUnreducedOps: FieldCore { + /// Accumulator for `self × u64` products (narrower than full product). + type MulU64Accum: AdditiveGroup; + /// Accumulator for `self × self` products. + type ProductAccum: AdditiveGroup; + + /// Widening `self × small` with no reduction. + fn mul_u64_unreduced(self, small: u64) -> Self::MulU64Accum; + /// Widening `self × other` with no reduction. + fn mul_to_product_accum(self, other: Self) -> Self::ProductAccum; + + /// Reduce a narrow-mul accumulator to a canonical field element. + fn reduce_mul_u64_accum(accum: Self::MulU64Accum) -> Self; + /// Reduce a full-product accumulator to a canonical field element. + fn reduce_product_accum(accum: Self::ProductAccum) -> Self; +} + +impl HasUnreducedOps for Fp64

{ + type MulU64Accum = Fp64ProductAccum; + type ProductAccum = Fp64ProductAccum; + + #[inline] + fn mul_u64_unreduced(self, small: u64) -> Fp64ProductAccum { + let wide = self.mul_wide_u64(small); + Fp64ProductAccum([wide & u64::MAX as u128, wide >> 64]) + } + + #[inline] + fn mul_to_product_accum(self, other: Self) -> Fp64ProductAccum { + let wide = self.mul_wide(other); + Fp64ProductAccum([wide & u64::MAX as u128, wide >> 64]) + } + + #[inline] + fn reduce_mul_u64_accum(accum: Fp64ProductAccum) -> Self { + accum.reduce::

() + } + + #[inline] + fn reduce_product_accum(accum: Fp64ProductAccum) -> Self { + accum.reduce::

() + } +} + +impl HasUnreducedOps for Fp128

{ + type MulU64Accum = Fp128MulU64Accum; + type ProductAccum = Fp128ProductAccum; + + #[inline] + fn mul_u64_unreduced(self, small: u64) -> Fp128MulU64Accum { + let [lo, mid, hi] = self.mul_wide_u64(small); + Fp128MulU64Accum([lo as u128, mid as u128, hi as u128]) + } + + #[inline] + fn mul_to_product_accum(self, other: Self) -> Fp128ProductAccum { + let [r0, r1, r2, r3] = self.mul_wide(other); + Fp128ProductAccum([r0 as u128, r1 as u128, r2 as u128, r3 as u128]) + } + + #[inline] + fn reduce_mul_u64_accum(accum: Fp128MulU64Accum) -> Self { + accum.reduce::

() + } + + #[inline] + fn reduce_product_accum(accum: Fp128ProductAccum) -> Self { + accum.reduce::

() + } +} + +/// Element-wise scaling of a wide accumulator by a small signed integer. +pub trait ScaleI32 { + /// Scale each element by `small`. + fn scale_i32(self, small: i32) -> Self; +} + +impl ScaleI32 for Fp32x2i32 { + #[inline] + fn scale_i32(self, small: i32) -> Self { + self.scale_i32(small) + } +} + +impl ScaleI32 for Fp64x4i32 { + #[inline] + fn scale_i32(self, small: i32) -> Self { + self.scale_i32(small) + } +} + +impl ScaleI32 for Fp128x8i32 { + #[inline] + fn scale_i32(self, small: i32) -> Self { + self.scale_i32(small) + } +} + +/// Associates a field type with its wide unreduced accumulator. +pub trait HasWide: FieldCore { + /// The wide accumulator type. + type Wide: AdditiveGroup + From + ReduceTo + ScaleI32; + + /// Convert `self` to wide form and scale every limb by `small`. + /// + /// Equivalent to `Self::Wide::from(self).scale_i32(small)` but avoids + /// the trait-method ambiguity at call sites. + #[inline] + fn mul_small_to_wide(self, small: i32) -> Self::Wide { + Self::Wide::from(self).scale_i32(small) + } +} + +impl HasWide for Fp32

{ + type Wide = Fp32x2i32; +} + +impl HasWide for Fp64

{ + type Wide = Fp64x4i32; +} + +impl HasWide for Fp128

{ + type Wide = Fp128x8i32; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::{Pow2Offset24Field, Pow2Offset40Field, Prime128M8M4M1M0}; + use crate::{FieldCore, FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + use rand_core::RngCore; + + type F128 = Prime128M8M4M1M0; + type F32 = Pow2Offset24Field; + type F64 = Pow2Offset40Field; + + const P128: u128 = 0xfffffffffffffffffffffffffffffeed; + const P32: u32 = (1 << 24) - 3; + const P64: u64 = (1 << 40) - 195; + + #[test] + fn fp128_roundtrip() { + let mut rng = StdRng::seed_from_u64(0xdead_1234); + for _ in 0..1000 { + let a: F128 = FieldSampling::sample(&mut rng); + let wide = Fp128x8i32::from(a); + let back = wide.reduce::(); + assert_eq!(a, back, "roundtrip failed for {a:?}"); + } + } + + #[test] + fn fp128_accumulate_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0xbeef_cafe_4321); + let n = 1000; + let vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let scalar_sum = vals.iter().fold(F128::zero(), |acc, &x| acc + x); + + let wide_sum = vals + .iter() + .fold(Fp128x8i32::zero(), |acc, &x| acc + Fp128x8i32::from(x)); + let reduced = wide_sum.reduce::(); + + assert_eq!(scalar_sum, reduced); + } + + #[test] + fn fp128_add_sub_neg_match_scalar() { + let mut rng = StdRng::seed_from_u64(0x1122_3344_5566); + for _ in 0..500 { + let a: F128 = FieldSampling::sample(&mut rng); + let b: F128 = FieldSampling::sample(&mut rng); + + let wa = Fp128x8i32::from(a); + let wb = Fp128x8i32::from(b); + + assert_eq!((wa + wb).reduce::(), a + b); + assert_eq!((wa - wb).reduce::(), a - b); + assert_eq!((-wa).reduce::(), -a); + } + } + + #[test] + fn fp128_mixed_add_sub_stress() { + let mut rng = StdRng::seed_from_u64(0xaaaa_bbbb_cccc); + let n = 500; + let vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let mut scalar = F128::zero(); + let mut wide = Fp128x8i32::zero(); + for (i, &v) in vals.iter().enumerate() { + let wv = Fp128x8i32::from(v); + if i % 3 == 0 { + scalar -= v; + wide -= wv; + } else { + scalar += v; + wide += wv; + } + } + assert_eq!(wide.reduce::(), scalar); + } + + #[test] + fn fp32_roundtrip() { + let mut rng = StdRng::seed_from_u64(0x3232_3232); + for _ in 0..1000 { + let a: F32 = FieldSampling::sample(&mut rng); + let wide = Fp32x2i32::from(a); + let back = wide.reduce::(); + assert_eq!(a, back); + } + } + + #[test] + fn fp32_accumulate_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0x3232_abcd); + let n = 1000; + let vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let scalar_sum = vals.iter().fold(F32::zero(), |acc, &x| acc + x); + let wide_sum = vals + .iter() + .fold(Fp32x2i32::zero(), |acc, &x| acc + Fp32x2i32::from(x)); + assert_eq!(wide_sum.reduce::(), scalar_sum); + } + + #[test] + fn fp64_roundtrip() { + let mut rng = StdRng::seed_from_u64(0x6464_6464); + for _ in 0..1000 { + let a: F64 = FieldSampling::sample(&mut rng); + let wide = Fp64x4i32::from(a); + let back = wide.reduce::(); + assert_eq!(a, back); + } + } + + #[test] + fn fp64_accumulate_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0x6464_beef); + let n = 1000; + let vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let scalar_sum = vals.iter().fold(F64::zero(), |acc, &x| acc + x); + let wide_sum = vals + .iter() + .fold(Fp64x4i32::zero(), |acc, &x| acc + Fp64x4i32::from(x)); + assert_eq!(wide_sum.reduce::(), scalar_sum); + } + + #[test] + fn fp64_product_accum_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0x6464_4444); + let n = 500; + let a_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let scalar_sum: F64 = a_vals + .iter() + .zip(b_vals.iter()) + .fold(F64::zero(), |acc, (&a, &b)| acc + a * b); + + let accum_sum = a_vals + .iter() + .zip(b_vals.iter()) + .fold(Fp64ProductAccum::ZERO, |acc, (&a, &b)| { + acc + a.mul_to_product_accum(b) + }); + assert_eq!(F64::reduce_product_accum(accum_sum), scalar_sum); + } + + #[test] + fn fp64_mul_u64_accum_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0x6464_5555); + let n = 500; + let a_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b_vals: Vec = (0..n).map(|_| rng.next_u64() >> 32).collect(); + + let scalar_sum: F64 = a_vals + .iter() + .zip(b_vals.iter()) + .fold(F64::zero(), |acc, (&a, &b)| acc + a * F64::from_u64(b)); + + let accum_sum = a_vals + .iter() + .zip(b_vals.iter()) + .fold(Fp64ProductAccum::ZERO, |acc, (&a, &b)| { + acc + a.mul_u64_unreduced(b) + }); + assert_eq!(F64::reduce_mul_u64_accum(accum_sum), scalar_sum); + } + + #[test] + fn fp128_product_accum_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0x0128_6666); + let n = 500; + let a_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let scalar_sum: F128 = a_vals + .iter() + .zip(b_vals.iter()) + .fold(F128::zero(), |acc, (&a, &b)| acc + a * b); + + let accum_sum = a_vals + .iter() + .zip(b_vals.iter()) + .fold(Fp128ProductAccum::ZERO, |acc, (&a, &b)| { + acc + a.mul_to_product_accum(b) + }); + assert_eq!(F128::reduce_product_accum(accum_sum), scalar_sum); + } + + #[test] + fn fp128_mul_u64_accum_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0x0128_7777); + let n = 500; + let a_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b_vals: Vec = (0..n).map(|_| rng.next_u64()).collect(); + + let scalar_sum: F128 = a_vals + .iter() + .zip(b_vals.iter()) + .fold(F128::zero(), |acc, (&a, &b)| acc + a * F128::from_u64(b)); + + let accum_sum = a_vals + .iter() + .zip(b_vals.iter()) + .fold(Fp128MulU64Accum::ZERO, |acc, (&a, &b)| { + acc + a.mul_u64_unreduced(b) + }); + assert_eq!(F128::reduce_mul_u64_accum(accum_sum), scalar_sum); + } + + #[test] + fn fp128_product_accum_sub_neg() { + let mut rng = StdRng::seed_from_u64(0x0128_8888); + let n = 500; + let a_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let mut scalar_sum = F128::zero(); + let mut accum_pos = Fp128ProductAccum::ZERO; + let mut accum_neg = Fp128ProductAccum::ZERO; + for (i, (&a, &b)) in a_vals.iter().zip(b_vals.iter()).enumerate() { + let prod = a.mul_to_product_accum(b); + if i % 2 == 0 { + scalar_sum += a * b; + accum_pos += prod; + } else { + scalar_sum -= a * b; + accum_neg += prod; + } + } + let result = F128::reduce_product_accum(accum_pos) - F128::reduce_product_accum(accum_neg); + assert_eq!(result, scalar_sum); + } +} diff --git a/src/algebra/mod.rs b/src/algebra/mod.rs new file mode 100644 index 00000000..a3e0c839 --- /dev/null +++ b/src/algebra/mod.rs @@ -0,0 +1,32 @@ +//! Concrete algebra backends and arithmetic building blocks. +//! +//! This module includes: +//! - Generic prime fields and extensions (`fields`) +//! - Module and polynomial containers (`module`, `poly`) +//! - Low-level NTT and CRT+NTT arithmetic scaffolding (`ntt`) + +pub mod backend; +pub mod fields; +pub mod module; +pub mod ntt; +pub mod poly; +pub mod ring; + +// Flat re-exports for convenience. +pub use backend::{CrtReconstruct, NttPrimeOps, NttTransform, RingBackend, ScalarBackend}; +pub use fields::{ + is_pow2_offset, pow2_offset, pseudo_mersenne_modulus, ExtField, Fp128, Fp128Packing, Fp2, + Fp2Config, Fp32, Fp32Packing, Fp4, Fp4Config, Fp64, Fp64Packing, HasPacking, LiftBase, + NoPacking, PackedField, PackedValue, Pow2Offset128Field, Pow2Offset24Field, Pow2Offset30Field, + Pow2Offset31Field, Pow2Offset32Field, Pow2Offset40Field, Pow2Offset48Field, Pow2Offset56Field, + Pow2Offset64Field, Pow2OffsetPrimeSpec, Prime128M13M4P0, Prime128M37P3P0, Prime128M52M3P0, + Prime128M54P4P0, Prime128M8M4M1M0, POW2_OFFSET_IMPLEMENTED_MAX_BITS, POW2_OFFSET_MAX, + POW2_OFFSET_PRIMES, POW2_OFFSET_TABLE, +}; +pub use module::VectorModule; +pub use ntt::tables; +pub use ntt::{GarnerData, LimbQ, MontCoeff, NttPrime, PrimeWidth, RADIX_BITS}; +pub use ring::{ + CrtNttConvertibleField, CrtNttParamSet, CyclotomicCrtNtt, CyclotomicRing, DigitMontLut, + SparseChallenge, SparseChallengeConfig, +}; diff --git a/src/algebra/module.rs b/src/algebra/module.rs new file mode 100644 index 00000000..bf41c567 --- /dev/null +++ b/src/algebra/module.rs @@ -0,0 +1,194 @@ +//! Simple module implementations. + +use super::fields::{Fp128, Fp32, Fp64}; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{CanonicalField, FieldCore, FieldSampling, Module}; +use rand_core::RngCore; +use std::io::{Read, Write}; +use std::ops::{Add, Mul, Neg, Sub}; + +/// Fixed-length vector module over a scalar type `F`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct VectorModule(pub [F; N]); + +impl VectorModule { + /// Construct the zero vector. + #[inline] + pub fn zero_vec() -> Self { + Self([F::zero(); N]) + } +} + +impl Add for VectorModule { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + let mut out = self.0; + for (dst, src) in out.iter_mut().zip(rhs.0.iter()) { + *dst += *src; + } + Self(out) + } +} + +impl Sub for VectorModule { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + let mut out = self.0; + for (dst, src) in out.iter_mut().zip(rhs.0.iter()) { + *dst -= *src; + } + Self(out) + } +} + +impl Neg for VectorModule { + type Output = Self; + fn neg(self) -> Self::Output { + let mut out = self.0; + for coeff in &mut out { + *coeff = -*coeff; + } + Self(out) + } +} + +impl<'a, F: FieldCore, const N: usize> Add<&'a Self> for VectorModule { + type Output = Self; + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} + +impl<'a, F: FieldCore, const N: usize> Sub<&'a Self> for VectorModule { + type Output = Self; + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} + +impl Valid for VectorModule { + fn check(&self) -> Result<(), SerializationError> { + for x in self.0.iter() { + x.check()?; + } + Ok(()) + } +} + +impl HachiSerialize for VectorModule { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + for x in self.0.iter() { + x.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.0.iter().map(|x| x.serialized_size(compress)).sum() + } +} + +impl HachiDeserialize for VectorModule { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let mut arr = [F::zero(); N]; + for coeff in &mut arr { + *coeff = F::deserialize_with_mode(&mut reader, compress, validate)?; + } + let out = Self(arr); + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl Module for VectorModule +where + F: FieldCore + + CanonicalField + + FieldSampling + + Valid + + Mul, Output = VectorModule> + + for<'a> Mul<&'a VectorModule, Output = VectorModule>, +{ + type Scalar = F; + + fn zero() -> Self { + Self::zero_vec() + } + + fn scale(&self, k: &Self::Scalar) -> Self { + // Delegate to Scalar * VectorModule to satisfy the Module trait’s scalar bounds. + *k * *self + } + + fn random(rng: &mut R) -> Self { + Self(std::array::from_fn(|_| F::sample(rng))) + } +} + +// Scalar * VectorModule impls for our local prime field types. + +impl Mul, N>> for Fp32

{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, const P: u32, const N: usize> Mul<&'a VectorModule, N>> for Fp32

{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} + +impl Mul, N>> for Fp64

{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, const P: u64, const N: usize> Mul<&'a VectorModule, N>> for Fp64

{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} + +impl Mul, N>> for Fp128

{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, const P: u128, const N: usize> Mul<&'a VectorModule, N>> for Fp128

{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} diff --git a/src/algebra/ntt/butterfly.rs b/src/algebra/ntt/butterfly.rs new file mode 100644 index 00000000..fb942750 --- /dev/null +++ b/src/algebra/ntt/butterfly.rs @@ -0,0 +1,382 @@ +//! NTT butterfly transforms for negacyclic rings `Z_p[X]/(X^D + 1)`. +//! +//! Implements a negacyclic NTT via the standard **twist + cyclic NTT** method. +//! +//! Let `psi` be a primitive `2D`-th root of unity (`psi^D = -1 mod p`) and +//! `omega = psi^2`, a primitive `D`-th root of unity. For polynomials modulo +//! `X^D + 1`, we: +//! - pre-twist coefficients by `psi^i` +//! - run a cyclic size-`D` NTT using `omega` +//! - inverse-cyclic NTT using `omega^{-1}` +//! - post-untwist by `psi^{-i}` + +use super::prime::{MontCoeff, NttPrime, PrimeWidth}; + +/// Precomputed twiddle factors for a specific prime and degree `D`. +/// +/// `D` must be a power of two. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NttTwiddles { + /// Stage roots for iterative forward cyclic NTT in Montgomery form. + pub(crate) fwd_wlen: [MontCoeff; D], + /// Stage roots for iterative inverse cyclic NTT in Montgomery form. + pub(crate) inv_wlen: [MontCoeff; D], + /// Number of active stages in the twiddle arrays (`log2(D)`). + pub(crate) num_stages: usize, + /// Twist factors `psi^i` for negacyclic embedding, in Montgomery form. + pub(crate) psi_pows: [MontCoeff; D], + /// Untwist factors `psi^{-i}`, in Montgomery form. + pub(crate) psi_inv_pows: [MontCoeff; D], + /// `D^{-1} mod p` in Montgomery form, used for inverse NTT final scaling. + pub(crate) d_inv: MontCoeff, + /// Fused `D^{-1} * psi^{-i}` for each index, in Montgomery form. + pub(crate) d_inv_psi_inv: [MontCoeff; D], + /// Per-position forward twiddles, packed across stages. + /// Stage s (with butterfly half-length 2^s) occupies `[2^s - 1 .. 2^(s+1) - 2]`. + /// Breaks the serial `w = mul(w, wlen)` dependency chain in butterfly loops. + pub(crate) fwd_twiddles: [MontCoeff; D], + /// Per-position inverse twiddles, same layout as `fwd_twiddles`. + pub(crate) inv_twiddles: [MontCoeff; D], +} + +impl NttTwiddles { + /// Compute twiddle factors for the given prime. + /// + /// Finds a primitive `2D`-th root `psi` and derives `omega = psi^2`. + /// Fills cyclic forward/inverse twiddles for `omega` and twist/untwist + /// tables for `psi`. All values are stored in Montgomery form. + /// + /// # Panics + /// + /// Panics if `D` is not a power of two, or if `2D` does not divide `p - 1`. + pub fn compute(prime: NttPrime) -> Self { + assert!(D.is_power_of_two(), "D must be a power of two"); + let p = prime.p.to_i64(); + assert!( + (p - 1) % (2 * D as i64) == 0, + "2D must divide p - 1 for NTT roots to exist" + ); + + let psi = find_primitive_root_2d(p, D); + let omega = (psi * psi) % p; + let omega_inv = pow_mod(omega, p - 2, p); + + let psi_inv = pow_mod(psi, p - 2, p); + let mut psi_pows = [MontCoeff::from_raw(W::default()); D]; + let mut psi_inv_pows = [MontCoeff::from_raw(W::default()); D]; + let mut cur = 1i64; + let mut cur_inv = 1i64; + for i in 0..D { + psi_pows[i] = prime.from_canonical(W::from_i64(cur)); + psi_inv_pows[i] = prime.from_canonical(W::from_i64(cur_inv)); + cur = (cur * psi) % p; + cur_inv = (cur_inv * psi_inv) % p; + } + + let mut fwd_wlen = [MontCoeff::from_raw(W::default()); D]; + let mut inv_wlen = [MontCoeff::from_raw(W::default()); D]; + let mut len = 1usize; + let mut stage = 0usize; + while len < D { + let exp = (D / (2 * len)) as i64; + fwd_wlen[stage] = prime.from_canonical(W::from_i64(pow_mod(omega, exp, p))); + inv_wlen[stage] = prime.from_canonical(W::from_i64(pow_mod(omega_inv, exp, p))); + len *= 2; + stage += 1; + } + + let d_inv_canonical = pow_mod(D as i64, p - 2, p); + let d_inv = prime.from_canonical(W::from_i64(d_inv_canonical)); + + let mut d_inv_psi_inv = [MontCoeff::from_raw(W::default()); D]; + for i in 0..D { + d_inv_psi_inv[i] = prime.mul(d_inv, psi_inv_pows[i]); + } + + let num_stages = stage; + let mut fwd_twiddles = [MontCoeff::from_raw(W::default()); D]; + let mut inv_twiddles = [MontCoeff::from_raw(W::default()); D]; + let one = prime.from_canonical(W::from_i64(1)); + for s in 0..num_stages { + let len = 1usize << s; + let base = len - 1; + let mut w_fwd = one; + let mut w_inv = one; + for j in 0..len { + fwd_twiddles[base + j] = w_fwd; + inv_twiddles[base + j] = w_inv; + w_fwd = prime.mul(w_fwd, fwd_wlen[s]); + w_inv = prime.mul(w_inv, inv_wlen[s]); + } + } + + Self { + fwd_wlen, + inv_wlen, + num_stages, + psi_pows, + psi_inv_pows, + d_inv, + d_inv_psi_inv, + fwd_twiddles, + inv_twiddles, + } + } +} + +/// Forward negacyclic NTT (twist + cyclic Gentleman-Sande DIF). +/// +/// Transforms `D` coefficients in-place from coefficient form to NTT +/// evaluation form. Both outputs of each butterfly are range-reduced +/// to prevent overflow. +pub fn forward_ntt( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + #[cfg(target_arch = "aarch64")] + if super::neon::use_neon_ntt() { + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::forward_ntt_i32( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::forward_ntt_i16( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + } + + for (ai, psi) in a.iter_mut().zip(tw.psi_pows.iter()) { + *ai = prime.mul(*ai, *psi); + } + + let mut len = D / 2; + while len > 0 { + let twiddle_base = len - 1; + let mut start = 0usize; + while start < D { + for j in 0..len { + let w = tw.fwd_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = a[start + j + len]; + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.mul(MontCoeff::from_raw(diff), w); + } + start += 2 * len; + } + len /= 2; + } + + prime.reduce_range_in_place(a); +} + +/// Inverse negacyclic NTT (cyclic Cooley-Tukey DIT + untwist). +/// +/// Transforms `D` evaluations in-place back to coefficient form. +/// Includes the final `D^{-1}` scaling. +pub fn inverse_ntt( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + #[cfg(target_arch = "aarch64")] + if super::neon::use_neon_ntt() { + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::inverse_ntt_i32( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::inverse_ntt_i16( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + } + + let mut len = 1usize; + while len < D { + let twiddle_base = len - 1; + let mut start = 0usize; + while start < D { + for j in 0..len { + let w = tw.inv_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = prime.mul(a[start + j + len], w); + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.reduce_range(MontCoeff::from_raw(diff)); + } + start += 2 * len; + } + len *= 2; + } + + for (ai, fused) in a.iter_mut().zip(tw.d_inv_psi_inv.iter()) { + *ai = prime.mul(*ai, *fused); + } +} + +/// Forward cyclic NTT (Gentleman-Sande DIF, **no** negacyclic twist). +/// +/// Evaluates a polynomial at the D-th roots of *unity* (roots of X^D - 1) +/// rather than X^D + 1. Used with `inverse_ntt_cyclic` to compute unreduced +/// polynomial products via CRT over (X^D - 1)(X^D + 1). +pub fn forward_ntt_cyclic( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + #[cfg(target_arch = "aarch64")] + if super::neon::use_neon_ntt() { + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::forward_ntt_cyclic_i32( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::forward_ntt_cyclic_i16( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + } + + let mut len = D / 2; + while len > 0 { + let twiddle_base = len - 1; + let mut start = 0usize; + while start < D { + for j in 0..len { + let w = tw.fwd_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = a[start + j + len]; + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.mul(MontCoeff::from_raw(diff), w); + } + start += 2 * len; + } + len /= 2; + } + prime.reduce_range_in_place(a); +} + +/// Inverse cyclic NTT (Cooley-Tukey DIT, **no** negacyclic untwist). +/// +/// Recovers coefficients of a polynomial from evaluations at D-th roots of unity. +/// Includes the `D^{-1}` scaling factor. +pub fn inverse_ntt_cyclic( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + #[cfg(target_arch = "aarch64")] + if super::neon::use_neon_ntt() { + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::inverse_ntt_cyclic_i32( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::inverse_ntt_cyclic_i16( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + } + + let mut len = 1usize; + while len < D { + let twiddle_base = len - 1; + let mut start = 0usize; + while start < D { + for j in 0..len { + let w = tw.inv_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = prime.mul(a[start + j + len], w); + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.reduce_range(MontCoeff::from_raw(diff)); + } + start += 2 * len; + } + len *= 2; + } + + for c in a.iter_mut() { + *c = prime.mul(*c, tw.d_inv); + } +} + +/// Find a primitive `2D`-th root of unity mod `p`. +fn find_primitive_root_2d(p: i64, d: usize) -> i64 { + let half = (p - 1) / 2; + let exp = (p - 1) / (2 * d as i64); + for a in 2..p { + if pow_mod(a, half, p) == p - 1 { + let psi = pow_mod(a, exp, p); + debug_assert_eq!(pow_mod(psi, d as i64, p), p - 1, "psi^D != -1"); + return psi; + } + } + panic!("no primitive root found for p={p}"); +} + +/// Modular exponentiation: `base^exp mod modulus`. +fn pow_mod(mut base: i64, mut exp: i64, modulus: i64) -> i64 { + let mut result = 1i64; + base %= modulus; + while exp > 0 { + if exp & 1 == 1 { + result = result * base % modulus; + } + base = base * base % modulus; + exp >>= 1; + } + result +} diff --git a/src/algebra/ntt/crt.rs b/src/algebra/ntt/crt.rs new file mode 100644 index 00000000..e0897937 --- /dev/null +++ b/src/algebra/ntt/crt.rs @@ -0,0 +1,202 @@ +//! CRT helpers: Garner reconstruction and limb-based modular arithmetic. + +use std::cmp::Ordering; +use std::fmt; +use std::ops::{Add, Sub}; + +use super::prime::{NttPrime, PrimeWidth}; + +/// Limb radix bit-width (`2^14`). +pub const RADIX_BITS: u32 = 14; +const RADIX: i32 = 1 << RADIX_BITS; +const RADIX_MASK: i32 = RADIX - 1; + +/// Precomputed Garner inverse table for CRT reconstruction. +/// +/// `gamma[i][j]` = `p_j^{-1} mod p_i` for `j < i`. Upper triangle and +/// diagonal entries are zero (unused). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GarnerData { + /// `gamma[i][j]` = `p_j^{-1} mod p_i` for `j < i`. + pub gamma: [[W; K]; K], +} + +impl GarnerData { + /// Compute Garner constants from a set of NTT primes. + pub fn compute(primes: &[NttPrime; K]) -> Self { + let mut gamma = [[W::default(); K]; K]; + for i in 1..K { + let pi = primes[i].p.to_i64(); + #[allow(clippy::needless_range_loop)] + for j in 0..i { + let pj = primes[j].p.to_i64(); + let inv = mod_inverse_i64(pj, pi); + gamma[i][j] = W::from_i64(inv); + } + } + Self { gamma } + } +} + +/// Modular inverse via extended GCD, operating in `i64`. +fn mod_inverse_i64(a: i64, modulus: i64) -> i64 { + let (mut t, mut new_t) = (0i64, 1i64); + let (mut r, mut new_r) = (modulus, ((a % modulus) + modulus) % modulus); + while new_r != 0 { + let q = r / new_r; + (t, new_t) = (new_t, t - q * new_t); + (r, new_r) = (new_r, r - q * new_r); + } + assert_eq!(r, 1, "modular inverse does not exist"); + ((t % modulus) + modulus) % modulus +} + +/// Fixed-width radix-`2^14` integer. +/// +/// Limbs are little-endian: `limbs[0]` is least significant. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct LimbQ { + /// Little-endian limbs. + pub limbs: [u16; L], +} + +impl Default for LimbQ { + #[inline] + fn default() -> Self { + Self::zero() + } +} + +impl LimbQ { + /// Zero value. + #[inline] + pub const fn zero() -> Self { + Self { limbs: [0; L] } + } + + /// Construct directly from limbs. + #[inline] + pub const fn from_limbs(limbs: [u16; L]) -> Self { + Self { limbs } + } + + /// Conditional subtraction: if `self >= modulus`, return `self - modulus` (branchless). + #[inline] + pub fn csub_mod(self, modulus: Self) -> Self { + let mut diff = [0u16; L]; + let mut borrow = 0i32; + for (i, df) in diff.iter_mut().enumerate() { + let d = self.limbs[i] as i32 - modulus.limbs[i] as i32 + borrow; + borrow = d >> 31; + if i + 1 < L { + *df = (d - borrow * RADIX) as u16; + } else { + *df = d as u16; + } + } + let mask = borrow as u16; + let mut result = [0u16; L]; + for (i, r) in result.iter_mut().enumerate() { + *r = (self.limbs[i] & mask) | (diff[i] & !mask); + } + Self { limbs: result } + } +} + +impl From for LimbQ { + fn from(mut x: u128) -> Self { + let mut out = [0u16; L]; + for (i, limb) in out.iter_mut().enumerate() { + if i + 1 < L { + *limb = (x & (RADIX_MASK as u128)) as u16; + x >>= RADIX_BITS; + } else { + *limb = x as u16; + } + } + Self { limbs: out } + } +} + +impl TryFrom> for u128 { + type Error = &'static str; + + fn try_from(limb: LimbQ) -> Result { + if (L as u32) * RADIX_BITS > 128 { + return Err("LimbQ too wide for u128"); + } + let mut acc = 0u128; + for i in (0..L).rev() { + acc <<= RADIX_BITS; + acc |= limb.limbs[i] as u128; + } + Ok(acc) + } +} + +impl PartialOrd for LimbQ { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for LimbQ { + fn cmp(&self, other: &Self) -> Ordering { + for i in (0..L).rev() { + match self.limbs[i].cmp(&other.limbs[i]) { + Ordering::Equal => continue, + ord => return ord, + } + } + Ordering::Equal + } +} + +impl Add for LimbQ { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + let mut out = [0u16; L]; + let mut carry = 0i32; + for (i, out_limb) in out.iter_mut().enumerate() { + let s = self.limbs[i] as i32 + rhs.limbs[i] as i32 + carry; + if i + 1 < L { + carry = s >> RADIX_BITS; + *out_limb = (s & RADIX_MASK) as u16; + } else { + *out_limb = s as u16; + } + } + Self { limbs: out } + } +} + +impl Sub for LimbQ { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + let mut out = [0u16; L]; + let mut borrow = 0i32; + for (i, out_limb) in out.iter_mut().enumerate() { + let d = self.limbs[i] as i32 - rhs.limbs[i] as i32 + borrow; + if i + 1 < L { + borrow = d >> 31; + *out_limb = (d - borrow * RADIX) as u16; + } else { + *out_limb = d as u16; + } + } + Self { limbs: out } + } +} + +impl fmt::Display for LimbQ { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Ok(val) = u128::try_from(*self) { + write!(f, "{val}") + } else { + write!(f, "LimbQ{:?}", self.limbs) + } + } +} diff --git a/src/algebra/ntt/mod.rs b/src/algebra/ntt/mod.rs new file mode 100644 index 00000000..b059e9b6 --- /dev/null +++ b/src/algebra/ntt/mod.rs @@ -0,0 +1,12 @@ +//! NTT-friendly small-prime arithmetic and CRT helpers. + +pub mod butterfly; +pub mod crt; +#[cfg(target_arch = "aarch64")] +pub(crate) mod neon; +pub mod prime; +pub mod tables; + +pub use butterfly::NttTwiddles; +pub use crt::{GarnerData, LimbQ, RADIX_BITS}; +pub use prime::{MontCoeff, NttPrime, PrimeWidth}; diff --git a/src/algebra/ntt/neon.rs b/src/algebra/ntt/neon.rs new file mode 100644 index 00000000..10ed1817 --- /dev/null +++ b/src/algebra/ntt/neon.rs @@ -0,0 +1,1001 @@ +//! AArch64 NEON SIMD kernels for NTT butterfly, Montgomery multiply, +//! and pointwise operations. +//! +//! Provides vectorized i32 (for Q64/Q128) and i16 (for Q32) paths. +//! Dispatch is controlled by [`use_neon_ntt`]: set `HACHI_SCALAR_NTT=1` +//! to force the scalar fallback for A/B performance comparison. + +use std::arch::aarch64::*; +use std::sync::OnceLock; + +use super::butterfly::NttTwiddles; +use super::prime::{MontCoeff, NttPrime}; + +/// Whether the NEON NTT path is active. Cached on first call. +/// Set `HACHI_SCALAR_NTT=1` to force scalar fallback. +pub(crate) fn use_neon_ntt() -> bool { + static ENABLED: OnceLock = OnceLock::new(); + *ENABLED.get_or_init(|| std::env::var("HACHI_SCALAR_NTT").map_or(true, |v| v != "1")) +} + +/// 4-wide Montgomery multiply for i32 primes. +/// +/// Uses two 2-wide `vmull_s32` chains (since i32×i32→i64 only fills 2 lanes +/// of a 128-bit register) and combines the results. +#[inline(always)] +unsafe fn mont_mul_4x_i32(a: int32x4_t, b: int32x4_t, p: int32x2_t, pinv: int32x2_t) -> int32x4_t { + let a_lo = vget_low_s32(a); + let a_hi = vget_high_s32(a); + let b_lo = vget_low_s32(b); + let b_hi = vget_high_s32(b); + + // Low pair + let c_lo = vmull_s32(a_lo, b_lo); + let t_lo = vmul_s32(vmovn_s64(c_lo), pinv); + let tp_lo = vmull_s32(t_lo, p); + let r_lo = vmovn_s64(vshrq_n_s64::<32>(vsubq_s64(c_lo, tp_lo))); + + // High pair + let c_hi = vmull_s32(a_hi, b_hi); + let t_hi = vmul_s32(vmovn_s64(c_hi), pinv); + let tp_hi = vmull_s32(t_hi, p); + let r_hi = vmovn_s64(vshrq_n_s64::<32>(vsubq_s64(c_hi, tp_hi))); + + vcombine_s32(r_lo, r_hi) +} + +/// 4-wide range reduction for i32: maps `(-2p, 2p)` → `(-p, p)`. +/// +/// Uses comparison-first approach to avoid the i64 widening that the +/// scalar `csubp`/`caddp` path requires (since `a - p` can overflow i32). +#[inline(always)] +unsafe fn reduce_range_4x_i32(a: int32x4_t, p: int32x4_t) -> int32x4_t { + let zero = vdupq_n_s32(0); + + // csubp: subtract p where a >= p + let ge_mask = vcgeq_s32(a, p); + let after_sub = vsubq_s32( + a, + vreinterpretq_s32_u32(vandq_u32(vreinterpretq_u32_s32(p), ge_mask)), + ); + + // caddp: add p where result < 0 + let lt_mask = vcltq_s32(after_sub, zero); + vaddq_s32( + after_sub, + vreinterpretq_s32_u32(vandq_u32(vreinterpretq_u32_s32(p), lt_mask)), + ) +} + +/// NEON-accelerated forward negacyclic NTT for i32 primes. +/// +/// Processes 4 butterfly pairs per iteration when `len >= 4`; +/// falls back to scalar for the final 2 stages (`len = 2, 1`). +pub(crate) unsafe fn forward_ntt_i32( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s32(prime.p); + let pinv_d = vdup_n_s32(prime.pinv); + let p_q = vdupq_n_s32(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i32; + + // Pre-twist by psi^i + { + let psi_ptr = tw.psi_pows.as_ptr() as *const i32; + let mut i = 0; + while i + 4 <= D { + let ai = vld1q_s32(a_ptr.add(i)); + let psi = vld1q_s32(psi_ptr.add(i)); + vst1q_s32(a_ptr.add(i), mont_mul_4x_i32(ai, psi, p_d, pinv_d)); + i += 4; + } + } + + // DIF butterfly stages + let mut len = D / 2; + while len > 0 { + let twiddle_base = len - 1; + let tw_ptr = tw.fwd_twiddles.as_ptr() as *const i32; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let u = vld1q_s32(a_ptr.add(start + j)); + let v = vld1q_s32(a_ptr.add(start + j + len)); + let w = vld1q_s32(tw_ptr.add(twiddle_base + j)); + + let sum = vaddq_s32(u, v); + let diff = vsubq_s32(u, v); + + vst1q_s32(a_ptr.add(start + j), reduce_range_4x_i32(sum, p_q)); + vst1q_s32( + a_ptr.add(start + j + len), + mont_mul_4x_i32(diff, w, p_d, pinv_d), + ); + j += 4; + } + } else { + for j in 0..len { + let w = tw.fwd_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = a[start + j + len]; + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.mul(MontCoeff::from_raw(diff), w); + } + } + start += 2 * len; + } + len /= 2; + } + + // Final reduce_range pass + reduce_range_in_place_i32(a, p_q); +} + +/// NEON-accelerated inverse negacyclic NTT for i32 primes. +pub(crate) unsafe fn inverse_ntt_i32( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s32(prime.p); + let pinv_d = vdup_n_s32(prime.pinv); + let p_q = vdupq_n_s32(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i32; + + // DIT butterfly stages + let mut len = 1usize; + while len < D { + let twiddle_base = len - 1; + let tw_ptr = tw.inv_twiddles.as_ptr() as *const i32; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let w = vld1q_s32(tw_ptr.add(twiddle_base + j)); + let u = vld1q_s32(a_ptr.add(start + j)); + let v_raw = vld1q_s32(a_ptr.add(start + j + len)); + let v = mont_mul_4x_i32(v_raw, w, p_d, pinv_d); + + let sum = vaddq_s32(u, v); + let diff = vsubq_s32(u, v); + + vst1q_s32(a_ptr.add(start + j), reduce_range_4x_i32(sum, p_q)); + vst1q_s32(a_ptr.add(start + j + len), reduce_range_4x_i32(diff, p_q)); + j += 4; + } + } else { + for j in 0..len { + let w = tw.inv_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = prime.mul(a[start + j + len], w); + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.reduce_range(MontCoeff::from_raw(diff)); + } + } + start += 2 * len; + } + len *= 2; + } + + // Fused D^{-1} * psi^{-i} untwist + { + let fused_ptr = tw.d_inv_psi_inv.as_ptr() as *const i32; + let mut i = 0; + while i + 4 <= D { + let ai = vld1q_s32(a_ptr.add(i)); + let f = vld1q_s32(fused_ptr.add(i)); + vst1q_s32(a_ptr.add(i), mont_mul_4x_i32(ai, f, p_d, pinv_d)); + i += 4; + } + } +} + +/// NEON-accelerated forward cyclic NTT for i32 (no negacyclic twist). +pub(crate) unsafe fn forward_ntt_cyclic_i32( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s32(prime.p); + let pinv_d = vdup_n_s32(prime.pinv); + let p_q = vdupq_n_s32(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i32; + + let mut len = D / 2; + while len > 0 { + let twiddle_base = len - 1; + let tw_ptr = tw.fwd_twiddles.as_ptr() as *const i32; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let u = vld1q_s32(a_ptr.add(start + j)); + let v = vld1q_s32(a_ptr.add(start + j + len)); + let w = vld1q_s32(tw_ptr.add(twiddle_base + j)); + let sum = vaddq_s32(u, v); + let diff = vsubq_s32(u, v); + vst1q_s32(a_ptr.add(start + j), reduce_range_4x_i32(sum, p_q)); + vst1q_s32( + a_ptr.add(start + j + len), + mont_mul_4x_i32(diff, w, p_d, pinv_d), + ); + j += 4; + } + } else { + for j in 0..len { + let w = tw.fwd_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = a[start + j + len]; + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.mul(MontCoeff::from_raw(diff), w); + } + } + start += 2 * len; + } + len /= 2; + } + reduce_range_in_place_i32(a, p_q); +} + +/// NEON-accelerated inverse cyclic NTT for i32 (no negacyclic untwist). +pub(crate) unsafe fn inverse_ntt_cyclic_i32( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s32(prime.p); + let pinv_d = vdup_n_s32(prime.pinv); + let p_q = vdupq_n_s32(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i32; + + let mut len = 1usize; + while len < D { + let twiddle_base = len - 1; + let tw_ptr = tw.inv_twiddles.as_ptr() as *const i32; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let w = vld1q_s32(tw_ptr.add(twiddle_base + j)); + let u = vld1q_s32(a_ptr.add(start + j)); + let v_raw = vld1q_s32(a_ptr.add(start + j + len)); + let v = mont_mul_4x_i32(v_raw, w, p_d, pinv_d); + let sum = vaddq_s32(u, v); + let diff = vsubq_s32(u, v); + vst1q_s32(a_ptr.add(start + j), reduce_range_4x_i32(sum, p_q)); + vst1q_s32(a_ptr.add(start + j + len), reduce_range_4x_i32(diff, p_q)); + j += 4; + } + } else { + for j in 0..len { + let w = tw.inv_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = prime.mul(a[start + j + len], w); + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.reduce_range(MontCoeff::from_raw(diff)); + } + } + start += 2 * len; + } + len *= 2; + } + + // D^{-1} scaling + { + let d_inv = tw.d_inv; + let d_inv_q = vdupq_n_s32(d_inv.raw()); + let mut i = 0; + while i + 4 <= D { + let ai = vld1q_s32(a_ptr.add(i)); + vst1q_s32(a_ptr.add(i), mont_mul_4x_i32(ai, d_inv_q, p_d, pinv_d)); + i += 4; + } + } +} + +/// 4-wide pointwise multiply-accumulate for a single CRT limb (i32). +/// +/// `acc[i] = reduce_range(acc[i] + mont_mul(lhs[i], rhs[i]))` for `i in 0..d`. +pub(crate) unsafe fn pointwise_mul_acc_i32( + acc: *mut i32, + lhs: *const i32, + rhs: *const i32, + d: usize, + p: i32, + pinv: i32, +) { + let p_d = vdup_n_s32(p); + let pinv_d = vdup_n_s32(pinv); + let p_q = vdupq_n_s32(p); + let mut i = 0; + while i + 4 <= d { + let a = vld1q_s32(acc.add(i)); + let l = vld1q_s32(lhs.add(i)); + let r = vld1q_s32(rhs.add(i)); + let prod = mont_mul_4x_i32(l, r, p_d, pinv_d); + let sum = vaddq_s32(a, prod); + vst1q_s32(acc.add(i), reduce_range_4x_i32(sum, p_q)); + i += 4; + } +} + +/// 4-wide add-and-reduce for a single CRT limb (i32). +/// +/// `acc[i] = reduce_range(acc[i] + other[i])` for `i in 0..d`. +pub(crate) unsafe fn add_reduce_i32(acc: *mut i32, other: *const i32, d: usize, p: i32) { + let p_q = vdupq_n_s32(p); + let mut i = 0; + while i + 4 <= d { + let a = vld1q_s32(acc.add(i)); + let b = vld1q_s32(other.add(i)); + vst1q_s32(acc.add(i), reduce_range_4x_i32(vaddq_s32(a, b), p_q)); + i += 4; + } +} + +/// In-place reduce_range over a full array. +unsafe fn reduce_range_in_place_i32(a: &mut [MontCoeff; D], p_q: int32x4_t) { + let ptr = a.as_mut_ptr() as *mut i32; + let mut i = 0; + while i + 4 <= D { + let val = vld1q_s32(ptr.add(i)); + vst1q_s32(ptr.add(i), reduce_range_4x_i32(val, p_q)); + i += 4; + } +} + +/// 4-wide Montgomery multiply for i16 primes. +/// +/// Natural 4-wide: `vmull_s16` produces `int32x4_t`. +#[inline(always)] +unsafe fn mont_mul_4x_i16(a: int16x4_t, b: int16x4_t, p: int16x4_t, pinv: int16x4_t) -> int16x4_t { + let c = vmull_s16(a, b); + let t = vmul_s16(vmovn_s32(c), pinv); + let tp = vmull_s16(t, p); + vmovn_s32(vshrq_n_s32::<16>(vsubq_s32(c, tp))) +} + +/// 8-wide Montgomery multiply for i16 primes (two 4-wide chains). +#[inline(always)] +unsafe fn mont_mul_8x_i16(a: int16x8_t, b: int16x8_t, p: int16x4_t, pinv: int16x4_t) -> int16x8_t { + let r_lo = mont_mul_4x_i16(vget_low_s16(a), vget_low_s16(b), p, pinv); + let r_hi = mont_mul_4x_i16(vget_high_s16(a), vget_high_s16(b), p, pinv); + vcombine_s16(r_lo, r_hi) +} + +/// 8-wide range reduction for i16: `(-2p, 2p)` → `(-p, p)`. +/// +/// Same comparison-first approach as i32 but on `int16x8_t`. +#[inline(always)] +unsafe fn reduce_range_8x_i16(a: int16x8_t, p: int16x8_t) -> int16x8_t { + let zero = vdupq_n_s16(0); + let ge_mask = vcgeq_s16(a, p); + let after_sub = vsubq_s16( + a, + vreinterpretq_s16_u16(vandq_u16(vreinterpretq_u16_s16(p), ge_mask)), + ); + let lt_mask = vcltq_s16(after_sub, zero); + vaddq_s16( + after_sub, + vreinterpretq_s16_u16(vandq_u16(vreinterpretq_u16_s16(p), lt_mask)), + ) +} + +/// NEON-accelerated forward negacyclic NTT for i16 primes. +/// +/// Processes 4 butterflies per iteration when `len >= 4`; +/// scalar fallback for `len < 4`. +pub(crate) unsafe fn forward_ntt_i16( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s16(prime.p); + let pinv_d = vdup_n_s16(prime.pinv); + let p_q = vdupq_n_s16(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i16; + + // Pre-twist by psi^i + { + let psi_ptr = tw.psi_pows.as_ptr() as *const i16; + let mut i = 0; + while i + 4 <= D { + let ai = vld1_s16(a_ptr.add(i)); + let psi = vld1_s16(psi_ptr.add(i)); + vst1_s16(a_ptr.add(i), mont_mul_4x_i16(ai, psi, p_d, pinv_d)); + i += 4; + } + } + + // DIF butterfly stages + let mut len = D / 2; + while len > 0 { + let twiddle_base = len - 1; + let tw_ptr = tw.fwd_twiddles.as_ptr() as *const i16; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let u = vld1_s16(a_ptr.add(start + j)); + let v = vld1_s16(a_ptr.add(start + j + len)); + let w = vld1_s16(tw_ptr.add(twiddle_base + j)); + let sum = vadd_s16(u, v); + let diff = vsub_s16(u, v); + + // reduce_range on 4-wide i16 (use 8-wide by padding) + let sum_q = vcombine_s16(sum, vdup_n_s16(0)); + let sum_reduced = vget_low_s16(reduce_range_8x_i16(sum_q, p_q)); + + let diff_mul_w = mont_mul_4x_i16(diff, w, p_d, pinv_d); + vst1_s16(a_ptr.add(start + j), sum_reduced); + vst1_s16(a_ptr.add(start + j + len), diff_mul_w); + j += 4; + } + } else { + for j in 0..len { + let w = tw.fwd_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = a[start + j + len]; + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.mul(MontCoeff::from_raw(diff), w); + } + } + start += 2 * len; + } + len /= 2; + } + + // Final reduce_range pass + reduce_range_in_place_i16(a, p_q); +} + +/// NEON-accelerated inverse negacyclic NTT for i16 primes. +pub(crate) unsafe fn inverse_ntt_i16( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s16(prime.p); + let pinv_d = vdup_n_s16(prime.pinv); + let p_q = vdupq_n_s16(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i16; + + let mut len = 1usize; + while len < D { + let twiddle_base = len - 1; + let tw_ptr = tw.inv_twiddles.as_ptr() as *const i16; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let w = vld1_s16(tw_ptr.add(twiddle_base + j)); + let u = vld1_s16(a_ptr.add(start + j)); + let v_raw = vld1_s16(a_ptr.add(start + j + len)); + let v = mont_mul_4x_i16(v_raw, w, p_d, pinv_d); + let sum = vadd_s16(u, v); + let diff = vsub_s16(u, v); + let reduced = reduce_range_8x_i16(vcombine_s16(sum, diff), p_q); + vst1_s16(a_ptr.add(start + j), vget_low_s16(reduced)); + vst1_s16(a_ptr.add(start + j + len), vget_high_s16(reduced)); + j += 4; + } + } else { + for j in 0..len { + let w = tw.inv_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = prime.mul(a[start + j + len], w); + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.reduce_range(MontCoeff::from_raw(diff)); + } + } + start += 2 * len; + } + len *= 2; + } + + // Fused D^{-1} * psi^{-i} untwist + { + let fused_ptr = tw.d_inv_psi_inv.as_ptr() as *const i16; + let mut i = 0; + while i + 4 <= D { + let ai = vld1_s16(a_ptr.add(i)); + let f = vld1_s16(fused_ptr.add(i)); + vst1_s16(a_ptr.add(i), mont_mul_4x_i16(ai, f, p_d, pinv_d)); + i += 4; + } + } +} + +/// NEON-accelerated forward cyclic NTT for i16. +pub(crate) unsafe fn forward_ntt_cyclic_i16( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s16(prime.p); + let pinv_d = vdup_n_s16(prime.pinv); + let p_q = vdupq_n_s16(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i16; + + let mut len = D / 2; + while len > 0 { + let twiddle_base = len - 1; + let tw_ptr = tw.fwd_twiddles.as_ptr() as *const i16; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let u = vld1_s16(a_ptr.add(start + j)); + let v = vld1_s16(a_ptr.add(start + j + len)); + let w = vld1_s16(tw_ptr.add(twiddle_base + j)); + let sum = vadd_s16(u, v); + let diff = vsub_s16(u, v); + let sum_q = vcombine_s16(sum, vdup_n_s16(0)); + vst1_s16( + a_ptr.add(start + j), + vget_low_s16(reduce_range_8x_i16(sum_q, p_q)), + ); + vst1_s16( + a_ptr.add(start + j + len), + mont_mul_4x_i16(diff, w, p_d, pinv_d), + ); + j += 4; + } + } else { + for j in 0..len { + let w = tw.fwd_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = a[start + j + len]; + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.mul(MontCoeff::from_raw(diff), w); + } + } + start += 2 * len; + } + len /= 2; + } + reduce_range_in_place_i16(a, p_q); +} + +/// NEON-accelerated inverse cyclic NTT for i16. +pub(crate) unsafe fn inverse_ntt_cyclic_i16( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s16(prime.p); + let pinv_d = vdup_n_s16(prime.pinv); + let p_q = vdupq_n_s16(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i16; + + let mut len = 1usize; + while len < D { + let twiddle_base = len - 1; + let tw_ptr = tw.inv_twiddles.as_ptr() as *const i16; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let w = vld1_s16(tw_ptr.add(twiddle_base + j)); + let u = vld1_s16(a_ptr.add(start + j)); + let v_raw = vld1_s16(a_ptr.add(start + j + len)); + let v = mont_mul_4x_i16(v_raw, w, p_d, pinv_d); + let sum = vadd_s16(u, v); + let diff = vsub_s16(u, v); + let reduced = reduce_range_8x_i16(vcombine_s16(sum, diff), p_q); + vst1_s16(a_ptr.add(start + j), vget_low_s16(reduced)); + vst1_s16(a_ptr.add(start + j + len), vget_high_s16(reduced)); + j += 4; + } + } else { + for j in 0..len { + let w = tw.inv_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = prime.mul(a[start + j + len], w); + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.reduce_range(MontCoeff::from_raw(diff)); + } + } + start += 2 * len; + } + len *= 2; + } + + // D^{-1} scaling + { + let d_inv = tw.d_inv; + let d_inv_d = vdup_n_s16(d_inv.raw()); + let mut i = 0; + while i + 4 <= D { + let ai = vld1_s16(a_ptr.add(i)); + vst1_s16(a_ptr.add(i), mont_mul_4x_i16(ai, d_inv_d, p_d, pinv_d)); + i += 4; + } + } +} + +/// 8-wide pointwise multiply-accumulate for a single CRT limb (i16). +pub(crate) unsafe fn pointwise_mul_acc_i16( + acc: *mut i16, + lhs: *const i16, + rhs: *const i16, + d: usize, + p: i16, + pinv: i16, +) { + let p_d = vdup_n_s16(p); + let pinv_d = vdup_n_s16(pinv); + let p_q = vdupq_n_s16(p); + let mut i = 0; + while i + 8 <= d { + let a = vld1q_s16(acc.add(i)); + let l = vld1q_s16(lhs.add(i)); + let r = vld1q_s16(rhs.add(i)); + let prod = mont_mul_8x_i16(l, r, p_d, pinv_d); + let sum = vaddq_s16(a, prod); + vst1q_s16(acc.add(i), reduce_range_8x_i16(sum, p_q)); + i += 8; + } + while i + 4 <= d { + let a = vld1_s16(acc.add(i)); + let l = vld1_s16(lhs.add(i)); + let r = vld1_s16(rhs.add(i)); + let prod = mont_mul_4x_i16(l, r, p_d, pinv_d); + let sum = vadd_s16(a, prod); + let sum_q = vcombine_s16(sum, vdup_n_s16(0)); + vst1_s16(acc.add(i), vget_low_s16(reduce_range_8x_i16(sum_q, p_q))); + i += 4; + } +} + +/// 8-wide add-and-reduce for a single CRT limb (i16). +pub(crate) unsafe fn add_reduce_i16(acc: *mut i16, other: *const i16, d: usize, p: i16) { + let p_q = vdupq_n_s16(p); + let mut i = 0; + while i + 8 <= d { + let a = vld1q_s16(acc.add(i)); + let b = vld1q_s16(other.add(i)); + vst1q_s16(acc.add(i), reduce_range_8x_i16(vaddq_s16(a, b), p_q)); + i += 8; + } + while i + 4 <= d { + let a = vld1_s16(acc.add(i)); + let b = vld1_s16(other.add(i)); + let sum_q = vcombine_s16(vadd_s16(a, b), vdup_n_s16(0)); + vst1_s16(acc.add(i), vget_low_s16(reduce_range_8x_i16(sum_q, p_q))); + i += 4; + } +} + +/// In-place reduce_range over a full i16 array. +unsafe fn reduce_range_in_place_i16(a: &mut [MontCoeff; D], p_q: int16x8_t) { + let ptr = a.as_mut_ptr() as *mut i16; + let mut i = 0; + while i + 8 <= D { + let val = vld1q_s16(ptr.add(i)); + vst1q_s16(ptr.add(i), reduce_range_8x_i16(val, p_q)); + i += 8; + } + while i + 4 <= D { + let val = vld1_s16(ptr.add(i)); + let padded = vcombine_s16(val, vdup_n_s16(0)); + vst1_s16(ptr.add(i), vget_low_s16(reduce_range_8x_i16(padded, p_q))); + i += 4; + } +} + +#[cfg(test)] +mod tests { + use super::super::butterfly::{ + forward_ntt as scalar_forward_ntt, forward_ntt_cyclic as scalar_forward_ntt_cyclic, + inverse_ntt as scalar_inverse_ntt, inverse_ntt_cyclic as scalar_inverse_ntt_cyclic, + NttTwiddles, + }; + use super::super::prime::{MontCoeff, NttPrime}; + use super::*; + + fn random_mont_array_i32( + prime: NttPrime, + seed: u64, + ) -> [MontCoeff; D] { + let mut state = seed; + std::array::from_fn(|_| { + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + let val = ((state >> 33) as i64 % prime.p as i64) as i32; + prime.from_canonical(val) + }) + } + + fn random_mont_array_i16( + prime: NttPrime, + seed: u64, + ) -> [MontCoeff; D] { + let mut state = seed; + std::array::from_fn(|_| { + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + let val = ((state >> 33) as i64 % prime.p as i64) as i16; + prime.from_canonical(val) + }) + } + + const TEST_PRIME_I32: i32 = 1073707009; + const TEST_PRIME_I16: i16 = 13697; + + #[test] + fn neon_forward_ntt_i32_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I32); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i32::<512>(prime, 0xCAFE); + + let mut neon_result = input; + unsafe { forward_ntt_i32(&mut neon_result, prime, &tw) }; + + let mut scalar_result = input; + scalar_forward_ntt(&mut scalar_result, prime, &tw); + + for i in 0..512 { + let n = prime.to_canonical(neon_result[i]); + let s = prime.to_canonical(scalar_result[i]); + assert_eq!(n, s, "mismatch at index {i}: neon={n}, scalar={s}"); + } + } + + #[test] + fn neon_inverse_ntt_i32_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I32); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i32::<512>(prime, 0xBEEF); + + let mut neon_result = input; + unsafe { inverse_ntt_i32(&mut neon_result, prime, &tw) }; + + let mut scalar_result = input; + scalar_inverse_ntt(&mut scalar_result, prime, &tw); + + for i in 0..512 { + let n = prime.to_canonical(neon_result[i]); + let s = prime.to_canonical(scalar_result[i]); + assert_eq!(n, s, "mismatch at index {i}: neon={n}, scalar={s}"); + } + } + + #[test] + fn neon_forward_inverse_roundtrip_i32() { + let prime = NttPrime::compute(TEST_PRIME_I32); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i32::<512>(prime, 0xDEAD); + let canonical_input: Vec = input.iter().map(|c| prime.to_canonical(*c)).collect(); + + let mut a = input; + unsafe { + forward_ntt_i32(&mut a, prime, &tw); + inverse_ntt_i32(&mut a, prime, &tw); + } + + for i in 0..512 { + let result = prime.to_canonical(a[i]); + assert_eq!( + result, canonical_input[i], + "roundtrip mismatch at index {i}" + ); + } + } + + #[test] + fn neon_cyclic_ntt_i32_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I32); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i32::<512>(prime, 0xFACE); + + let mut neon_fwd = input; + unsafe { forward_ntt_cyclic_i32(&mut neon_fwd, prime, &tw) }; + + let mut scalar_fwd = input; + scalar_forward_ntt_cyclic(&mut scalar_fwd, prime, &tw); + + for i in 0..512 { + let n = prime.to_canonical(neon_fwd[i]); + let s = prime.to_canonical(scalar_fwd[i]); + assert_eq!(n, s, "forward cyclic mismatch at {i}: neon={n}, scalar={s}"); + } + + let mut neon_inv = neon_fwd; + unsafe { inverse_ntt_cyclic_i32(&mut neon_inv, prime, &tw) }; + + let mut scalar_inv = scalar_fwd; + scalar_inverse_ntt_cyclic(&mut scalar_inv, prime, &tw); + + for i in 0..512 { + let n = prime.to_canonical(neon_inv[i]); + let s = prime.to_canonical(scalar_inv[i]); + assert_eq!(n, s, "inverse cyclic mismatch at {i}: neon={n}, scalar={s}"); + } + } + + #[test] + fn neon_pointwise_mul_acc_i32_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I32); + const D: usize = 512; + let acc_init = random_mont_array_i32::(prime, 0x1111); + let lhs = random_mont_array_i32::(prime, 0x2222); + let rhs = random_mont_array_i32::(prime, 0x3333); + + let mut neon_acc = acc_init; + unsafe { + pointwise_mul_acc_i32( + neon_acc.as_mut_ptr() as *mut i32, + lhs.as_ptr() as *const i32, + rhs.as_ptr() as *const i32, + D, + prime.p, + prime.pinv, + ); + } + + let mut scalar_acc = acc_init; + for i in 0..D { + let prod = prime.mul(lhs[i], rhs[i]); + let sum = MontCoeff::from_raw(scalar_acc[i].raw().wrapping_add(prod.raw())); + scalar_acc[i] = prime.reduce_range(sum); + } + + for i in 0..D { + let n = prime.to_canonical(neon_acc[i]); + let s = prime.to_canonical(scalar_acc[i]); + assert_eq!(n, s, "pointwise mul acc mismatch at {i}"); + } + } + + #[test] + fn neon_forward_ntt_i16_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I16); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i16::<64>(prime, 0xABCD); + + let mut neon_result = input; + unsafe { forward_ntt_i16(&mut neon_result, prime, &tw) }; + + let mut scalar_result = input; + scalar_forward_ntt(&mut scalar_result, prime, &tw); + + for i in 0..64 { + let n = prime.to_canonical(neon_result[i]); + let s = prime.to_canonical(scalar_result[i]); + assert_eq!(n, s, "i16 forward mismatch at {i}: neon={n}, scalar={s}"); + } + } + + #[test] + fn neon_inverse_ntt_i16_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I16); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i16::<64>(prime, 0xFEED); + + let mut neon_result = input; + unsafe { inverse_ntt_i16(&mut neon_result, prime, &tw) }; + + let mut scalar_result = input; + scalar_inverse_ntt(&mut scalar_result, prime, &tw); + + for i in 0..64 { + let n = prime.to_canonical(neon_result[i]); + let s = prime.to_canonical(scalar_result[i]); + assert_eq!(n, s, "i16 inverse mismatch at {i}: neon={n}, scalar={s}"); + } + } + + #[test] + fn neon_forward_inverse_roundtrip_i16() { + let prime = NttPrime::compute(TEST_PRIME_I16); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i16::<64>(prime, 0x7777); + let canonical_input: Vec = input.iter().map(|c| prime.to_canonical(*c)).collect(); + + let mut a = input; + unsafe { + forward_ntt_i16(&mut a, prime, &tw); + inverse_ntt_i16(&mut a, prime, &tw); + } + + for i in 0..64 { + let result = prime.to_canonical(a[i]); + assert_eq!(result, canonical_input[i], "i16 roundtrip mismatch at {i}"); + } + } + + #[test] + fn neon_cyclic_i16_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I16); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i16::<64>(prime, 0x9999); + + let mut neon_fwd = input; + unsafe { forward_ntt_cyclic_i16(&mut neon_fwd, prime, &tw) }; + + let mut scalar_fwd = input; + scalar_forward_ntt_cyclic(&mut scalar_fwd, prime, &tw); + + for i in 0..64 { + let n = prime.to_canonical(neon_fwd[i]); + let s = prime.to_canonical(scalar_fwd[i]); + assert_eq!(n, s, "i16 fwd cyclic mismatch at {i}"); + } + + let mut neon_inv = neon_fwd; + unsafe { inverse_ntt_cyclic_i16(&mut neon_inv, prime, &tw) }; + + let mut scalar_inv = scalar_fwd; + scalar_inverse_ntt_cyclic(&mut scalar_inv, prime, &tw); + + for i in 0..64 { + let n = prime.to_canonical(neon_inv[i]); + let s = prime.to_canonical(scalar_inv[i]); + assert_eq!(n, s, "i16 inv cyclic mismatch at {i}"); + } + } + + #[test] + fn neon_pointwise_mul_acc_i16_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I16); + const D: usize = 64; + let acc_init = random_mont_array_i16::(prime, 0xAAAA); + let lhs = random_mont_array_i16::(prime, 0xBBBB); + let rhs = random_mont_array_i16::(prime, 0xCCCC); + + let mut neon_acc = acc_init; + unsafe { + pointwise_mul_acc_i16( + neon_acc.as_mut_ptr() as *mut i16, + lhs.as_ptr() as *const i16, + rhs.as_ptr() as *const i16, + D, + prime.p, + prime.pinv, + ); + } + + let mut scalar_acc = acc_init; + for i in 0..D { + let prod = prime.mul(lhs[i], rhs[i]); + let sum = MontCoeff::from_raw(scalar_acc[i].raw().wrapping_add(prod.raw())); + scalar_acc[i] = prime.reduce_range(sum); + } + + for i in 0..D { + let n = prime.to_canonical(neon_acc[i]); + let s = prime.to_canonical(scalar_acc[i]); + assert_eq!(n, s, "i16 pointwise mul acc mismatch at {i}"); + } + } +} diff --git a/src/algebra/ntt/prime.rs b/src/algebra/ntt/prime.rs new file mode 100644 index 00000000..a5646234 --- /dev/null +++ b/src/algebra/ntt/prime.rs @@ -0,0 +1,382 @@ +//! NTT prime arithmetic kernels generic over coefficient width. +//! +//! Per-prime scalar operations: +//! - Montgomery multiplication ([`NttPrime::mul`]) +//! - Branchless conditional add/sub and range reduction +//! +//! Coefficients in Montgomery domain are wrapped in [`MontCoeff`] to prevent +//! accidental mixing with canonical values. +//! +//! The [`PrimeWidth`] trait abstracts over `i16` (R = 2^16, for primes < 2^14) +//! and `i32` (R = 2^32, for primes < 2^30). All NTT types are generic over +//! `W: PrimeWidth`; monomorphization produces optimal code for each width. + +use std::fmt; + +mod sealed { + pub trait Sealed {} + impl Sealed for i16 {} + impl Sealed for i32 {} +} + +/// Integer width abstraction for NTT prime arithmetic. +/// +/// Sealed with exactly two implementations: `i16` and `i32`. +pub trait PrimeWidth: + sealed::Sealed + Copy + Clone + Eq + Default + fmt::Debug + Send + Sync + 'static +{ + /// Double-width type for intermediate Montgomery products. + type Wide: Copy + Clone; + + /// log2(R) for Montgomery reduction: 16 for `i16`, 32 for `i32`. + const R_LOG: u32; + + /// Widening multiply: `a * b` as `Wide`. + fn wide_mul(a: Self, b: Self) -> Self::Wide; + + /// Truncate wide value to narrow (low half, i.e. mod R). + fn truncate(w: Self::Wide) -> Self; + + /// Arithmetic right shift of wide value by `R_LOG` bits. + fn wide_shift(w: Self::Wide) -> Self; + + /// Wide subtraction (wrapping). + fn wide_sub(a: Self::Wide, b: Self::Wide) -> Self::Wide; + + /// Wrapping addition. + fn wrapping_add(self, rhs: Self) -> Self; + /// Wrapping subtraction. + fn wrapping_sub(self, rhs: Self) -> Self; + /// Wrapping multiplication. + fn wrapping_mul(self, rhs: Self) -> Self; + /// Wrapping negation. + fn wrapping_neg(self) -> Self; + + /// Arithmetic right shift by `BITS - 1`: all-1s if negative, all-0s otherwise. + fn sign_mask(self) -> Self; + + /// Bitwise AND. + fn bitand(self, rhs: Self) -> Self; + + /// Convert from `i64` (truncating). + fn from_i64(v: i64) -> Self; + /// Convert to `i64` (sign-extending). + fn to_i64(self) -> i64; +} + +impl PrimeWidth for i16 { + type Wide = i32; + const R_LOG: u32 = 16; + + #[inline] + fn wide_mul(a: Self, b: Self) -> i32 { + (a as i32) * (b as i32) + } + #[inline] + fn truncate(w: i32) -> Self { + w as i16 + } + #[inline] + fn wide_shift(w: i32) -> Self { + (w >> 16) as i16 + } + #[inline] + fn wide_sub(a: i32, b: i32) -> i32 { + a.wrapping_sub(b) + } + #[inline] + fn wrapping_add(self, rhs: Self) -> Self { + i16::wrapping_add(self, rhs) + } + #[inline] + fn wrapping_sub(self, rhs: Self) -> Self { + i16::wrapping_sub(self, rhs) + } + #[inline] + fn wrapping_mul(self, rhs: Self) -> Self { + i16::wrapping_mul(self, rhs) + } + #[inline] + fn wrapping_neg(self) -> Self { + i16::wrapping_neg(self) + } + #[inline] + fn sign_mask(self) -> Self { + self >> 15 + } + #[inline] + fn bitand(self, rhs: Self) -> Self { + self & rhs + } + #[inline] + fn from_i64(v: i64) -> Self { + v as i16 + } + #[inline] + fn to_i64(self) -> i64 { + self as i64 + } +} + +impl PrimeWidth for i32 { + type Wide = i64; + const R_LOG: u32 = 32; + + #[inline] + fn wide_mul(a: Self, b: Self) -> i64 { + (a as i64) * (b as i64) + } + #[inline] + fn truncate(w: i64) -> Self { + w as i32 + } + #[inline] + fn wide_shift(w: i64) -> Self { + (w >> 32) as i32 + } + #[inline] + fn wide_sub(a: i64, b: i64) -> i64 { + a.wrapping_sub(b) + } + #[inline] + fn wrapping_add(self, rhs: Self) -> Self { + i32::wrapping_add(self, rhs) + } + #[inline] + fn wrapping_sub(self, rhs: Self) -> Self { + i32::wrapping_sub(self, rhs) + } + #[inline] + fn wrapping_mul(self, rhs: Self) -> Self { + i32::wrapping_mul(self, rhs) + } + #[inline] + fn wrapping_neg(self) -> Self { + i32::wrapping_neg(self) + } + #[inline] + fn sign_mask(self) -> Self { + self >> 31 + } + #[inline] + fn bitand(self, rhs: Self) -> Self { + self & rhs + } + #[inline] + fn from_i64(v: i64) -> Self { + v as i32 + } + #[inline] + fn to_i64(self) -> i64 { + self as i64 + } +} + +/// A coefficient in Montgomery domain for an NTT prime. +/// +/// Wraps a `W` representing `a * R mod p` (where `R = 2^{W::R_LOG}`). +/// Use [`NttPrime::from_canonical`] to enter and [`NttPrime::to_canonical`] +/// to leave Montgomery domain. +#[derive(Clone, Copy, PartialEq, Eq, Default)] +#[repr(transparent)] +pub struct MontCoeff(W); + +impl MontCoeff { + /// Wrap a raw Montgomery-domain value. + #[inline] + pub fn from_raw(val: W) -> Self { + Self(val) + } + + /// Extract the raw value (still in Montgomery domain). + #[inline] + pub fn raw(self) -> W { + self.0 + } +} + +impl fmt::Debug for MontCoeff { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Mont({:?})", self.0) + } +} + +/// Per-prime constants for NTT arithmetic. +/// +/// Generic over `W: PrimeWidth` — use `i16` for primes below 2^14 (R = 2^16), +/// or `i32` for primes below 2^30 (R = 2^32). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NttPrime { + /// Prime modulus. + pub p: W, + /// `p^{-1} mod R` (centered signed). Used in Montgomery reduction. + pub pinv: W, + /// `R mod p` (centered signed). Montgomery form of 1. + pub mont: W, + /// `R^2 mod p` (centered signed). Used for canonical → Montgomery conversion. + pub montsq: W, +} + +impl NttPrime { + /// Derive all Montgomery constants from a raw prime value. + pub fn compute(p: W) -> Self { + let p_i64 = p.to_i64(); + debug_assert!(p_i64 > 1 && p_i64 % 2 == 1, "NTT prime must be odd and > 1"); + + // pinv via Newton's method: x_{n+1} = x_n * (2 - p * x_n). + // 5 iterations gives correctness mod 2^32 (sufficient for both i16 and i32). + let mut pinv: i64 = 1; + for _ in 0..5 { + pinv = pinv.wrapping_mul(2i64.wrapping_sub(p_i64.wrapping_mul(pinv))); + } + let pinv = W::from_i64(pinv); + + let half = p_i64 / 2; + let center = |x: i64| -> W { W::from_i64(if x > half { x - p_i64 } else { x }) }; + + let r_mod_p = ((1i128 << W::R_LOG) % (p_i64 as i128)) as i64; + let mont = center(r_mod_p); + + let rsq_mod_p = ((1i128 << (2 * W::R_LOG)) % (p_i64 as i128)) as i64; + let montsq = center(rsq_mod_p); + + Self { + p, + pinv, + mont, + montsq, + } + } + + /// Montgomery product: `a * b * R^{-1} mod p`. + #[inline] + pub fn mul(self, a: MontCoeff, b: MontCoeff) -> MontCoeff { + MontCoeff(self.mont_mul_raw(a.0, b.0)) + } + + /// Raw Montgomery multiply on bare `W` values. + #[inline] + pub(crate) fn mont_mul_raw(self, a: W, b: W) -> W { + let c = W::wide_mul(a, b); + let t = W::truncate(c).wrapping_mul(self.pinv); + let tp = W::wide_mul(t, self.p); + W::wide_shift(W::wide_sub(c, tp)) + } + + /// Conditionally subtract `p` if `a >= p` (branchless). + /// + /// Input may be in `(-2p, 2p)` during butterflies. Both i16 and i32 paths + /// widen to avoid signed overflow: i16→i32, i32→i64. + #[inline] + pub fn csubp(self, a: MontCoeff) -> MontCoeff { + if W::R_LOG == 16 { + let ai = a.0.to_i64() as i32; + let pi = self.p.to_i64() as i32; + let diff = ai - pi; + let mask = diff >> 31; + MontCoeff(W::from_i64((diff + (mask & pi)) as i64)) + } else { + let ai = a.0.to_i64(); + let pi = self.p.to_i64(); + let diff = ai - pi; + let mask = diff >> 63; + MontCoeff(W::from_i64(diff + (mask & pi))) + } + } + + /// Conditionally add `p` if `a < 0` (branchless). + /// + /// Widened arithmetic mirrors `csubp` to avoid i16 overflow edge cases. + #[inline] + pub fn caddp(self, a: MontCoeff) -> MontCoeff { + if W::R_LOG == 16 { + let ai = a.0.to_i64() as i32; + let pi = self.p.to_i64() as i32; + let mask = ai >> 31; + MontCoeff(W::from_i64((ai + (mask & pi)) as i64)) + } else { + let ai = a.0.to_i64(); + let pi = self.p.to_i64(); + let mask = ai >> 63; + MontCoeff(W::from_i64(ai + (mask & pi))) + } + } + + /// Range-reduce from `(-2p, 2p)` to `(-p, p)`. + #[inline] + pub fn reduce_range(self, a: MontCoeff) -> MontCoeff { + self.caddp(self.csubp(a)) + } + + /// Fully normalize a Montgomery coefficient to `[0, p)`. + #[inline] + pub fn normalize(self, a: MontCoeff) -> MontCoeff { + self.csubp(self.caddp(a)) + } + + /// Convert a canonical value into Montgomery domain: `a ↦ a * R mod p`. + #[inline] + pub fn from_canonical(self, a: W) -> MontCoeff { + MontCoeff(self.mont_mul_raw(a, self.montsq)) + } + + /// Convert from Montgomery domain to canonical `[0, p)`. + #[inline] + pub fn to_canonical(self, a: MontCoeff) -> W { + let raw = MontCoeff(self.mont_mul_raw(a.0, W::from_i64(1))); + self.normalize(raw).0 + } + + /// Center a canonical value from approximately `(-p, p)` into `[-p/2, p/2)`. + #[inline] + pub fn center(self, a: W) -> W { + let mask_neg = a.sign_mask(); + let canonical = a.wrapping_add(mask_neg.bitand(self.p)); + let half = W::from_i64(self.p.to_i64() / 2); + let needs_sub = half.wrapping_sub(canonical).sign_mask(); + canonical.wrapping_add(needs_sub.bitand(self.p.wrapping_neg())) + } + + /// Pointwise Montgomery multiplication of two coefficient slices. + /// + /// # Panics + /// + /// Panics if slices have different lengths. + #[inline] + pub fn pointwise_mul( + self, + out: &mut [MontCoeff], + lhs: &[MontCoeff], + rhs: &[MontCoeff], + ) { + assert_eq!(out.len(), lhs.len()); + assert_eq!(lhs.len(), rhs.len()); + for ((o, a), b) in out.iter_mut().zip(lhs.iter()).zip(rhs.iter()) { + *o = self.mul(*a, *b); + } + } + + /// In-place Montgomery scaling by a constant. + #[inline] + pub fn scale_in_place(self, coeffs: &mut [MontCoeff], scalar: MontCoeff) { + for c in coeffs { + *c = self.mul(*c, scalar); + } + } + + /// In-place range reduction on a coefficient slice. + #[inline] + pub fn reduce_range_in_place(self, coeffs: &mut [MontCoeff]) { + for c in coeffs { + *c = self.reduce_range(*c); + } + } + + /// In-place centering of canonical values to `[-p/2, p/2)`. + #[inline] + pub fn center_slice(self, coeffs: &mut [W]) { + for c in coeffs { + *c = self.center(*c); + } + } +} diff --git a/src/algebra/ntt/tables.rs b/src/algebra/ntt/tables.rs new file mode 100644 index 00000000..70e35bd1 --- /dev/null +++ b/src/algebra/ntt/tables.rs @@ -0,0 +1,215 @@ +//! Deterministic parameter presets for small-prime CRT arithmetic. +//! +//! Q32: `logq = 32` with six `i16` NTT-friendly primes (D ≤ 64). +//! Q64: `logq = 64` with `i32` NTT-friendly primes (D ≤ 1024). +//! Q128: `logq = 128` with five `i32` NTT-friendly primes (D ≤ 1024). + +use super::crt::GarnerData; +use super::prime::NttPrime; + +/// Polynomial degree for the base ring `Z_q[X]/(X^d + 1)`. +pub const RING_DEGREE: usize = 64; +/// Maximum ring degree covered by the i32 CRT parameter sets. +pub const MAX_CRT_RING_DEGREE: usize = 1024; + +/// Number of CRT primes for the `logq = 32` parameter set. +pub const Q32_NUM_PRIMES: usize = 6; + +/// The modulus `q = 2^32 - 99`. +pub const Q32_MODULUS: u64 = (1u64 << 32) - 99; + +/// CRT primes and per-prime Montgomery constants for `logq = 32`. +/// +/// All constants are for `R = 2^16` (i16 width). +pub const Q32_PRIMES: [NttPrime; Q32_NUM_PRIMES] = [ + NttPrime { + p: 13697, + pinv: 2689, + mont: -2949, + montsq: -994, + }, + NttPrime { + p: 13441, + pinv: 2945, + mont: -1669, + montsq: 3274, + }, + NttPrime { + p: 13313, + pinv: -13311, + mont: -1029, + montsq: -6199, + }, + NttPrime { + p: 12289, + pinv: -12287, + mont: 4091, + montsq: -1337, + }, + NttPrime { + p: 12161, + pinv: 4225, + mont: 4731, + montsq: -6040, + }, + NttPrime { + p: 11777, + pinv: -11775, + mont: -5126, + montsq: 1389, + }, +]; + +/// Garner CRT reconstruction constants for Q32. +pub fn q32_garner() -> GarnerData { + GarnerData::compute(&Q32_PRIMES) +} + +/// Number of CRT primes for the `logq = 64` fast profile (`P > q`). +pub const Q64_NUM_PRIMES_FAST: usize = 3; +/// Number of CRT primes for the `logq = 64` conservative profile (`P > 128*q^2`). +pub const Q64_NUM_PRIMES: usize = 5; + +/// The modulus `q = 2^64 - 59`. +pub const Q64_MODULUS: u64 = u64::MAX - 58; + +/// Number of CRT primes for the `logq = 128` parameter set. +pub const Q128_NUM_PRIMES: usize = 5; + +/// The modulus `q = 2^128 - 275`. +pub const Q128_MODULUS: u128 = u128::MAX - 274; + +/// Raw 30-bit primes for D≤1024, each satisfying `2048 | (p - 1)`. +/// +/// They are ordered descending by value. +pub const D1024_RAW_PRIMES: [i32; Q128_NUM_PRIMES] = + [1073707009, 1073698817, 1073692673, 1073682433, 1073668097]; + +/// Raw 30-bit primes for Q64 fast profile (`K=3`, `P > q`). +pub const Q64_RAW_PRIMES_FAST: [i32; Q64_NUM_PRIMES_FAST] = [ + D1024_RAW_PRIMES[0], + D1024_RAW_PRIMES[1], + D1024_RAW_PRIMES[2], +]; + +/// Raw 30-bit primes for Q64 conservative profile (`K=5`, `P > 128*q^2`). +pub const Q64_RAW_PRIMES: [i32; Q64_NUM_PRIMES] = D1024_RAW_PRIMES; + +/// Raw 30-bit primes for Q128, each satisfying `2048 | (p - 1)`. +pub const Q128_RAW_PRIMES: [i32; Q128_NUM_PRIMES] = D1024_RAW_PRIMES; + +/// CRT primes and per-prime Montgomery constants for `logq = 64` fast profile. +pub fn q64_primes_fast() -> [NttPrime; Q64_NUM_PRIMES_FAST] { + std::array::from_fn(|k| NttPrime::compute(Q64_RAW_PRIMES_FAST[k])) +} + +/// Garner CRT reconstruction constants for Q64 fast profile. +pub fn q64_garner_fast() -> GarnerData { + let primes = q64_primes_fast(); + GarnerData::compute(&primes) +} + +/// CRT primes and per-prime Montgomery constants for `logq = 64` conservative profile. +pub fn q64_primes() -> [NttPrime; Q64_NUM_PRIMES] { + std::array::from_fn(|k| NttPrime::compute(Q64_RAW_PRIMES[k])) +} + +/// Garner CRT reconstruction constants for Q64 conservative profile. +pub fn q64_garner() -> GarnerData { + let primes = q64_primes(); + GarnerData::compute(&primes) +} + +/// CRT primes and per-prime Montgomery constants for `logq = 128`. +pub fn q128_primes() -> [NttPrime; Q128_NUM_PRIMES] { + std::array::from_fn(|k| NttPrime::compute(Q128_RAW_PRIMES[k])) +} + +/// Garner CRT reconstruction constants for Q128. +pub fn q128_garner() -> GarnerData { + let primes = q128_primes(); + GarnerData::compute(&primes) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn verify_q32_prime_derived_constants() { + for prime in &Q32_PRIMES { + let recomputed = NttPrime::compute(prime.p); + assert_eq!( + prime.pinv, recomputed.pinv, + "pinv mismatch for p={}", + prime.p + ); + assert_eq!( + prime.mont, recomputed.mont, + "mont mismatch for p={}", + prime.p + ); + assert_eq!( + prime.montsq, recomputed.montsq, + "montsq mismatch for p={}", + prime.p + ); + } + } + + #[test] + fn verify_q128_primes_are_valid() { + let primes = q128_primes(); + for np in &primes { + let p = np.p as i64; + assert!(p > 1 && p % 2 == 1, "prime must be odd and > 1"); + assert_eq!( + (p - 1) % 2048, + 0, + "2048 must divide p-1 for D=1024 NTT (p={p})" + ); + // Verify pinv: p * pinv ≡ 1 (mod 2^32) + assert_eq!( + np.p.wrapping_mul(np.pinv), + 1, + "pinv verification failed for p={p}" + ); + } + } + + #[test] + fn verify_q64_primes_are_valid() { + let primes = q64_primes(); + for np in &primes { + let p = np.p as i64; + assert!(p > 1 && p % 2 == 1, "prime must be odd and > 1"); + assert_eq!( + (p - 1) % 2048, + 0, + "2048 must divide p-1 for D=1024 NTT (p={p})" + ); + assert_eq!( + np.p.wrapping_mul(np.pinv), + 1, + "pinv verification failed for p={p}" + ); + } + } + + #[test] + fn garner_data_is_consistent() { + let garner = q32_garner(); + for (i, &prime_i) in Q32_PRIMES.iter().enumerate().skip(1) { + let pi = prime_i.p as i64; + for (j, &prime_j) in Q32_PRIMES[..i].iter().enumerate() { + let pj = prime_j.p as i64; + let g = garner.gamma[i][j] as i64; + assert_eq!( + (pj * g) % pi, + 1, + "garner gamma[{i}][{j}] not inverse of p_{j} mod p_{i}" + ); + } + } + } +} diff --git a/src/algebra/poly.rs b/src/algebra/poly.rs new file mode 100644 index 00000000..2680c1a6 --- /dev/null +++ b/src/algebra/poly.rs @@ -0,0 +1,233 @@ +//! Polynomial containers and evaluation utilities. + +use crate::algebra::fields::wide::{HasWide, ReduceTo}; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::protocol::sumcheck::eq_poly::EqPolynomial; +use crate::{cfg_fold_reduce, AdditiveGroup, FieldCore, FromSmallInt}; +use std::io::{Read, Write}; +use std::ops::{Add, Neg, Sub}; + +/// A degree-(pub [F; D]); + +impl Poly { + /// Construct the zero polynomial. + pub fn zero() -> Self { + Self([F::zero(); D]) + } +} + +impl Add for Poly { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + let mut out = self.0; + for (dst, src) in out.iter_mut().zip(rhs.0.iter()) { + *dst += *src; + } + Self(out) + } +} + +impl Sub for Poly { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + let mut out = self.0; + for (dst, src) in out.iter_mut().zip(rhs.0.iter()) { + *dst -= *src; + } + Self(out) + } +} + +impl Neg for Poly { + type Output = Self; + fn neg(self) -> Self::Output { + let mut out = self.0; + for coeff in &mut out { + *coeff = -*coeff; + } + Self(out) + } +} + +impl Valid for Poly { + fn check(&self) -> Result<(), SerializationError> { + for x in self.0.iter() { + x.check()?; + } + Ok(()) + } +} + +impl HachiSerialize for Poly { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + for x in self.0.iter() { + x.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.0.iter().map(|x| x.serialized_size(compress)).sum() + } +} + +impl HachiDeserialize for Poly { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let mut arr = [F::zero(); D]; + for coeff in &mut arr { + *coeff = F::deserialize_with_mode(&mut reader, compress, validate)?; + } + let out = Self(arr); + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Evaluate the range-check polynomial `w · Π_{k=1}^{b−1} (w − k)(w + k)`. +/// +/// This polynomial vanishes exactly when `w ∈ {−(b−1), …, b−1}`. +/// Total degree in `w` is `2b − 1`. +pub fn range_check_eval(w: E, b: usize) -> E { + let s = w * w; + let mut acc = w; + for k in 1..b { + let k_e = E::from_u64(k as u64); + acc = acc * (s - k_e * k_e); + } + acc +} + +/// Evaluate a multilinear polynomial (given by boolean-hypercube evaluations in +/// little-endian bit order) at an arbitrary point via iterated folding. +/// +/// # Errors +/// +/// Returns an error if the evaluation table length is not a power of two or +/// does not match `2^point.len()`. +pub fn multilinear_eval(evals: &[E], point: &[E]) -> Result { + if !evals.len().is_power_of_two() { + return Err(HachiError::InvalidSize { + expected: 1 << point.len(), + actual: evals.len(), + }); + } + if evals.len() != 1 << point.len() { + return Err(HachiError::InvalidSize { + expected: 1 << point.len(), + actual: evals.len(), + }); + } + let mut current = evals.to_vec(); + for &r in point { + let half = current.len() / 2; + let mut next = Vec::with_capacity(half); + for i in 0..half { + next.push(current[2 * i] + r * (current[2 * i + 1] - current[2 * i])); + } + current = next; + } + Ok(current[0]) +} + +/// Fold an evaluation table in place by binding its first variable to `r`, +/// halving the table size. +/// +/// # Panics +/// +/// Panics if the evaluation table length is not a power of two or has fewer +/// than 2 elements. This is a prover-only helper where the caller guarantees +/// well-formed input. +#[tracing::instrument(skip_all, name = "fold_evals_in_place")] +pub fn fold_evals_in_place(evals: &mut Vec, r: E) { + assert!( + evals.len().is_power_of_two(), + "evals length must be a power of two" + ); + assert!(evals.len() >= 2, "evals must have at least 2 elements"); + let half = evals.len() / 2; + for i in 0..half { + evals[i] = evals[2 * i] + r * (evals[2 * i + 1] - evals[2 * i]); + } + evals.truncate(half); +} + +/// Evaluate a multilinear polynomial with small integer evaluations at a +/// field point, using the split-eq structure with unreduced accumulation. +/// +/// Uses `HasWide::mul_small_to_wide` in the inner loop: each eq table entry +/// is widened, scaled by the small witness value, and accumulated without +/// reduction. The inner sum is reduced once per outer iteration, then +/// multiplied by the outer eq factor and accumulated again in wide form. +/// +/// Overflow budget: each inner accumulation adds at most `0xFFFF * |small|` +/// to each i32 limb. For `|small| ≤ 128` (b ≤ 256), we can safely +/// accumulate 256 products before an i32 limb overflows. +/// +/// # Errors +/// +/// Returns an error if the table length does not match `2^point.len()`. +#[tracing::instrument(skip_all, name = "multilinear_eval_small")] +pub fn multilinear_eval_small( + evals_small: &[i8], + point: &[E], +) -> Result { + let n = point.len(); + if evals_small.len() != 1 << n { + return Err(HachiError::InvalidSize { + expected: 1 << n, + actual: evals_small.len(), + }); + } + if n == 0 { + return Ok(E::from_i64(evals_small[0] as i64)); + } + + let m = n / 2; + let (r_first, r_second) = point.split_at(m); + let eq_first = EqPolynomial::evals(r_first); + let eq_second = EqPolynomial::evals(r_second); + let in_len = eq_first.len(); + + // Max safe accumulations per chunk before i32 overflow. + // Limbs are 16-bit (0..0xFFFF), scaled by |small| ≤ 128 → 23-bit products. + // i32::MAX / (0xFFFF * 128) ≈ 256. + const CHUNK: usize = 256; + + let outer_accum = cfg_fold_reduce!( + 0..eq_second.len(), + || E::Wide::ZERO, + |acc, x_out| { + let base = x_out * in_len; + let mut inner_field = E::zero(); + for chunk_start in (0..in_len).step_by(CHUNK) { + let chunk_end = (chunk_start + CHUNK).min(in_len); + let mut chunk_acc = E::Wide::ZERO; + for x_in in chunk_start..chunk_end { + chunk_acc += eq_first[x_in].mul_small_to_wide(evals_small[base + x_in] as i32); + } + inner_field += >::reduce(chunk_acc); + } + + acc + E::Wide::from(eq_second[x_out] * inner_field) + }, + |a, b| a + b + ); + Ok(>::reduce(outer_accum)) +} diff --git a/src/algebra/ring/crt_ntt_repr.rs b/src/algebra/ring/crt_ntt_repr.rs new file mode 100644 index 00000000..b936f713 --- /dev/null +++ b/src/algebra/ring/crt_ntt_repr.rs @@ -0,0 +1,527 @@ +//! CRT+NTT-domain representation of cyclotomic ring elements. + +use std::array::from_fn; + +use crate::algebra::backend::{CrtReconstruct, NttPrimeOps, NttTransform, ScalarBackend}; +use crate::algebra::ntt::butterfly::{ + forward_ntt, forward_ntt_cyclic, inverse_ntt_cyclic, NttTwiddles, +}; +use crate::algebra::ntt::crt::GarnerData; +use crate::algebra::ntt::prime::{MontCoeff, NttPrime, PrimeWidth}; +use crate::{CanonicalField, FieldCore}; + +use super::cyclotomic::CyclotomicRing; + +/// CRT+NTT-domain representation of a cyclotomic ring element. +/// +/// Stores `K` arrays of `D` [`MontCoeff`] values, one per CRT prime. +/// Multiplication is pointwise per prime — O(K*D) vs O(D^2) for coefficient form. +/// +/// Generic over: +/// - `W: PrimeWidth` — integer width (`i16` or `i32`) +/// - `K` — number of CRT primes +/// - `D` — polynomial degree +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CyclotomicCrtNtt { + pub(crate) limbs: [[MontCoeff; D]; K], +} + +/// Field types that can convert to/from the CRT+NTT representation. +/// +/// Blanket-implemented for all `FieldCore + CanonicalField` types. +pub trait CrtNttConvertibleField: FieldCore + CanonicalField {} + +impl CrtNttConvertibleField for F {} + +/// Bundled CRT+NTT parameters for a fixed width/prime-count/degree tuple. +/// +/// Keeps primes/twiddles/Garner constants consistent and avoids passing them +/// independently at every call site. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CrtNttParamSet { + /// CRT primes with Montgomery constants. + pub primes: [NttPrime; K], + /// Per-prime twiddle tables for forward/inverse NTT. + pub twiddles: [NttTwiddles; K], + /// Garner reconstruction constants for CRT lift-back. + pub garner: GarnerData, +} + +/// Precomputed Montgomery forms for small balanced digit values. +/// +/// Covers the full `{-8, ..., 7}` range (16 entries per CRT prime), +/// which is sufficient for any `log_basis <= 4`. Storing the Montgomery +/// representation eliminates one `from_canonical` (a Montgomery multiply) +/// per coefficient in the `from_i8` hot path. +#[derive(Debug, Clone)] +pub struct DigitMontLut { + vals: [[MontCoeff; 16]; K], +} + +const DIGIT_LUT_HALF_B: i16 = 8; + +impl DigitMontLut { + /// Build the lookup table from CRT primes. + /// + /// Covers digit values in `{-8, ..., 7}` (balanced representation for + /// `log_basis <= 4`). + pub fn new(params: &CrtNttParamSet) -> Self { + let mut vals = [[MontCoeff::from_raw(W::default()); 16]; K]; + for (k, prime) in params.primes.iter().enumerate() { + for v_idx in 0..16u8 { + let v = v_idx as i64 - DIGIT_LUT_HALF_B as i64; + vals[k][v_idx as usize] = prime.from_canonical(W::from_i64(v)); + } + } + Self { vals } + } + + /// Look up the Montgomery form of a balanced digit for CRT prime `k`. + #[inline(always)] + pub fn get(&self, k: usize, digit: i8) -> MontCoeff { + unsafe { + *self + .vals + .get_unchecked(k) + .get_unchecked((digit as i16 + DIGIT_LUT_HALF_B) as usize) + } + } +} + +impl CrtNttParamSet { + /// Build a full parameter set from CRT primes. + /// + /// Computes per-prime twiddles and Garner reconstruction constants. + pub fn new(primes: [NttPrime; K]) -> Self { + let twiddles = from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = GarnerData::compute(&primes); + Self { + primes, + twiddles, + garner, + } + } +} + +impl CyclotomicCrtNtt { + /// The additive identity (all zeros in every CRT limb). + pub fn zero() -> Self { + Self { + limbs: [[MontCoeff::from_raw(W::default()); D]; K], + } + } + + /// Convert a coefficient-form ring element into CRT+NTT domain + /// using the default scalar backend. + pub fn from_ring( + ring: &CyclotomicRing, + primes: &[NttPrime; K], + twiddles: &[NttTwiddles; K], + ) -> Self { + Self::from_ring_with_backend::(ring, primes, twiddles) + } + + /// Convert a coefficient-form ring element into CRT+NTT domain + /// using a bundled parameter set and the scalar backend. + pub fn from_ring_with_params( + ring: &CyclotomicRing, + params: &CrtNttParamSet, + ) -> Self { + Self::from_ring(ring, ¶ms.primes, ¶ms.twiddles) + } + + /// Convert a coefficient-form ring element into CRT+NTT domain + /// through an explicit backend implementation. + pub fn from_ring_with_backend< + F: CrtNttConvertibleField, + B: NttPrimeOps + NttTransform, + >( + ring: &CyclotomicRing, + primes: &[NttPrime; K], + twiddles: &[NttTwiddles; K], + ) -> Self { + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + let centered_coeffs: [i128; D] = from_fn(|i| { + let canonical = ring.coeffs[i].to_canonical_u128(); + if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + } + }); + + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for ((limb, prime), tw) in limbs.iter_mut().zip(primes.iter()).zip(twiddles.iter()) { + // Interpret coefficients in centered form (-q/2, q/2] before reducing + // into the CRT primes. This makes the reduction map consistent with + // negacyclic subtraction (which naturally produces negative values). + let p = prime.p.to_i64(); + let p_u64 = p as u64; + let r64 = ((1u128 << 64) % p_u64 as u128) as i64; + let half_p = p / 2; + for (dst, centered) in limb.iter_mut().zip(centered_coeffs.iter()) { + let c = *centered; + let lo = (c as u64 % p_u64) as i64; + let hi = ((c >> 64) as i64).rem_euclid(p); + let mut r = (lo + hi * r64) % p; + if r >= half_p { + r -= p; + } + *dst = B::from_canonical(*prime, W::from_i64(r)); + } + B::forward_ntt(limb, *prime, tw); + } + Self { limbs } + } + + /// Convert small integer coefficients (e.g. gadget digits) into + /// negacyclic CRT+NTT domain, bypassing Fp128 centering entirely. + pub fn from_i8_with_params(digits: &[i8; D], params: &CrtNttParamSet) -> Self { + Self::from_i8_negacyclic_backend::(digits, params) + } + + /// Like [`Self::from_i8_with_params`] but uses a precomputed + /// [`DigitMontLut`] to replace per-coefficient `from_canonical` + /// (Montgomery multiply) with a table lookup. + #[inline] + pub fn from_i8_with_lut( + digits: &[i8; D], + params: &CrtNttParamSet, + lut: &DigitMontLut, + ) -> Self { + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for (k, (limb, tw)) in limbs.iter_mut().zip(params.twiddles.iter()).enumerate() { + for (dst, &d) in limb.iter_mut().zip(digits.iter()) { + *dst = lut.get(k, d); + } + forward_ntt(limb, params.primes[k], tw); + } + Self { limbs } + } + + /// Like [`Self::from_i8_cyclic`] but uses a precomputed [`DigitMontLut`]. + #[inline] + pub fn from_i8_cyclic_with_lut( + digits: &[i8; D], + params: &CrtNttParamSet, + lut: &DigitMontLut, + ) -> Self { + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for (k, (limb, tw)) in limbs.iter_mut().zip(params.twiddles.iter()).enumerate() { + for (dst, &d) in limb.iter_mut().zip(digits.iter()) { + *dst = lut.get(k, d); + } + forward_ntt_cyclic(limb, params.primes[k], tw); + } + Self { limbs } + } + + fn from_i8_negacyclic_backend + NttTransform>( + digits: &[i8; D], + params: &CrtNttParamSet, + ) -> Self { + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for ((limb, prime), tw) in limbs + .iter_mut() + .zip(params.primes.iter()) + .zip(params.twiddles.iter()) + { + for (dst, &d) in limb.iter_mut().zip(digits.iter()) { + *dst = B::from_canonical(*prime, W::from_i64(d as i64)); + } + B::forward_ntt(limb, *prime, tw); + } + Self { limbs } + } + + /// Convert small integer coefficients into cyclic CRT+NTT domain, + /// bypassing Fp128 centering entirely. + pub fn from_i8_cyclic(digits: &[i8; D], params: &CrtNttParamSet) -> Self { + Self::from_i8_cyclic_backend::(digits, params) + } + + fn from_i8_cyclic_backend>( + digits: &[i8; D], + params: &CrtNttParamSet, + ) -> Self { + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for ((limb, prime), tw) in limbs + .iter_mut() + .zip(params.primes.iter()) + .zip(params.twiddles.iter()) + { + for (dst, &d) in limb.iter_mut().zip(digits.iter()) { + *dst = B::from_canonical(*prime, W::from_i64(d as i64)); + } + forward_ntt_cyclic(limb, *prime, tw); + } + Self { limbs } + } + + /// Convert from CRT+NTT domain back to coefficient form + /// using the default scalar backend. + pub fn to_ring( + &self, + primes: &[NttPrime; K], + twiddles: &[NttTwiddles; K], + garner: &GarnerData, + ) -> CyclotomicRing { + self.to_ring_with_backend::(primes, twiddles, garner) + } + + /// Convert from CRT+NTT domain back to coefficient form + /// using a bundled parameter set and the scalar backend. + pub fn to_ring_with_params( + &self, + params: &CrtNttParamSet, + ) -> CyclotomicRing { + self.to_ring(¶ms.primes, ¶ms.twiddles, ¶ms.garner) + } + + /// Convert from CRT+NTT domain back to coefficient form + /// through an explicit backend implementation. + pub fn to_ring_with_backend< + F: CrtNttConvertibleField, + B: NttPrimeOps + NttTransform + CrtReconstruct, + >( + &self, + primes: &[NttPrime; K], + twiddles: &[NttTwiddles; K], + garner: &GarnerData, + ) -> CyclotomicRing { + let mut canonical = [[W::default(); D]; K]; + for (k, ((can, prime), tw)) in canonical + .iter_mut() + .zip(primes.iter()) + .zip(twiddles.iter()) + .enumerate() + { + let mut limb = self.limbs[k]; + B::inverse_ntt(&mut limb, *prime, tw); + for (dst, src) in can.iter_mut().zip(limb.iter()) { + let canon = B::to_canonical(*prime, *src); + *dst = prime.center(canon); + } + } + + let coeffs = B::reconstruct::(primes, &canonical, garner); + + CyclotomicRing::from_coefficients(coeffs) + } + + /// Add another CRT+NTT element and reduce each coefficient with the matching + /// prime to maintain valid Montgomery ranges using the scalar backend. + pub fn add_reduced(&self, rhs: &Self, primes: &[NttPrime; K]) -> Self { + self.add_reduced_with_backend::(rhs, primes) + } + + /// Add another CRT+NTT element and reduce using a bundled parameter set. + pub fn add_reduced_with_params(&self, rhs: &Self, params: &CrtNttParamSet) -> Self { + self.add_reduced(rhs, ¶ms.primes) + } + + /// Add another CRT+NTT element and reduce each coefficient with the matching + /// prime through an explicit backend implementation. + pub fn add_reduced_with_backend>( + &self, + rhs: &Self, + primes: &[NttPrime; K], + ) -> Self { + let mut out = self.clone(); + for (k, (limb, rhs_limb)) in out.limbs.iter_mut().zip(rhs.limbs.iter()).enumerate() { + let prime = primes[k]; + for (a, b) in limb.iter_mut().zip(rhs_limb.iter()) { + let sum = MontCoeff::from_raw(a.raw().wrapping_add(b.raw())); + *a = B::reduce_range(prime, sum); + } + } + out + } + + /// Subtract another CRT+NTT element and reduce using the scalar backend. + pub fn sub_reduced(&self, rhs: &Self, primes: &[NttPrime; K]) -> Self { + self.sub_reduced_with_backend::(rhs, primes) + } + + /// Subtract another CRT+NTT element and reduce using a bundled parameter set. + pub fn sub_reduced_with_params(&self, rhs: &Self, params: &CrtNttParamSet) -> Self { + self.sub_reduced(rhs, ¶ms.primes) + } + + /// Subtract another CRT+NTT element and reduce through an explicit backend. + pub fn sub_reduced_with_backend>( + &self, + rhs: &Self, + primes: &[NttPrime; K], + ) -> Self { + let mut out = self.clone(); + for (k, (limb, rhs_limb)) in out.limbs.iter_mut().zip(rhs.limbs.iter()).enumerate() { + let prime = primes[k]; + for (a, b) in limb.iter_mut().zip(rhs_limb.iter()) { + let diff = MontCoeff::from_raw(a.raw().wrapping_sub(b.raw())); + *a = B::reduce_range(prime, diff); + } + } + out + } + + /// Negate each CRT+NTT coefficient and reduce using the scalar backend. + pub fn neg_reduced(&self, primes: &[NttPrime; K]) -> Self { + self.neg_reduced_with_backend::(primes) + } + + /// Negate each CRT+NTT coefficient and reduce using a bundled parameter set. + pub fn neg_reduced_with_params(&self, params: &CrtNttParamSet) -> Self { + self.neg_reduced(¶ms.primes) + } + + /// Negate each CRT+NTT coefficient and reduce through an explicit backend. + pub fn neg_reduced_with_backend>( + &self, + primes: &[NttPrime; K], + ) -> Self { + let mut out = self.clone(); + for (k, limb) in out.limbs.iter_mut().enumerate() { + let prime = primes[k]; + for a in limb.iter_mut() { + let neg = MontCoeff::from_raw(a.raw().wrapping_neg()); + *a = B::reduce_range(prime, neg); + } + } + out + } + + /// Convert a coefficient-form ring element into CRT+**cyclic** NTT domain. + /// + /// Evaluates at D-th roots of unity (X^D - 1) instead of X^D + 1. + /// Used together with `to_ring_cyclic` to compute unreduced polynomial products. + pub fn from_ring_cyclic( + ring: &CyclotomicRing, + params: &CrtNttParamSet, + ) -> Self { + Self::from_ring_cyclic_with_backend::(ring, params) + } + + /// Convert a coefficient-form ring element into CRT+**cyclic** NTT domain + /// through an explicit backend. + pub fn from_ring_cyclic_with_backend>( + ring: &CyclotomicRing, + params: &CrtNttParamSet, + ) -> Self { + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + let centered_coeffs: [i128; D] = from_fn(|i| { + let canonical = ring.coeffs[i].to_canonical_u128(); + if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + } + }); + + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for ((limb, prime), tw) in limbs + .iter_mut() + .zip(params.primes.iter()) + .zip(params.twiddles.iter()) + { + let p = prime.p.to_i64(); + let p_u64 = p as u64; + let r64 = ((1u128 << 64) % p_u64 as u128) as i64; + let half_p = p / 2; + for (dst, centered) in limb.iter_mut().zip(centered_coeffs.iter()) { + let c = *centered; + let lo = (c as u64 % p_u64) as i64; + let hi = ((c >> 64) as i64).rem_euclid(p); + let mut r = (lo + hi * r64) % p; + if r >= half_p { + r -= p; + } + *dst = B::from_canonical(*prime, W::from_i64(r)); + } + forward_ntt_cyclic(limb, *prime, tw); + } + Self { limbs } + } + + /// Convert from CRT+**cyclic** NTT domain back to coefficient form. + /// + /// Inverse of `from_ring_cyclic`: applies inverse cyclic NTT then CRT reconstruction. + pub fn to_ring_cyclic( + &self, + params: &CrtNttParamSet, + ) -> CyclotomicRing { + self.to_ring_cyclic_with_backend::(params) + } + + /// Convert from CRT+**cyclic** NTT domain back to coefficient form + /// through an explicit backend. + pub fn to_ring_cyclic_with_backend< + F: CrtNttConvertibleField, + B: NttPrimeOps + CrtReconstruct, + >( + &self, + params: &CrtNttParamSet, + ) -> CyclotomicRing { + let mut canonical = [[W::default(); D]; K]; + for (k, ((can, prime), tw)) in canonical + .iter_mut() + .zip(params.primes.iter()) + .zip(params.twiddles.iter()) + .enumerate() + { + let mut limb = self.limbs[k]; + inverse_ntt_cyclic(&mut limb, *prime, tw); + for (dst, src) in can.iter_mut().zip(limb.iter()) { + let canon = B::to_canonical(*prime, *src); + *dst = prime.center(canon); + } + } + let coeffs = B::reconstruct::(¶ms.primes, &canonical, ¶ms.garner); + CyclotomicRing::from_coefficients(coeffs) + } + + /// Pointwise multiplication in CRT+NTT domain using the scalar backend. + pub fn pointwise_mul(&self, rhs: &Self, primes: &[NttPrime; K]) -> Self { + self.pointwise_mul_with_backend::(rhs, primes) + } + + /// Pointwise multiplication in CRT+NTT domain using a bundled parameter set. + pub fn pointwise_mul_with_params(&self, rhs: &Self, params: &CrtNttParamSet) -> Self { + self.pointwise_mul(rhs, ¶ms.primes) + } + + /// Pointwise multiplication in CRT+NTT domain through an explicit backend. + pub fn pointwise_mul_with_backend>( + &self, + rhs: &Self, + primes: &[NttPrime; K], + ) -> Self { + let mut out = [[MontCoeff::from_raw(W::default()); D]; K]; + for (k, ((o, a), b)) in out + .iter_mut() + .zip(self.limbs.iter()) + .zip(rhs.limbs.iter()) + .enumerate() + { + let prime = primes[k]; + B::pointwise_mul(prime, o, a, b); + // Keep coefficients in a bounded range for subsequent inverse NTT. + for c in o.iter_mut() { + *c = B::reduce_range(prime, *c); + } + } + Self { limbs: out } + } + + /// Apply `sigma_{-1}` directly in NTT domain (`slot[j] -> slot[D-1-j]`). + /// + /// This is a pure index permutation per CRT limb and does not negate values. + pub fn conjugation_automorphism_ntt(&self) -> Self { + let limbs = std::array::from_fn(|k| { + std::array::from_fn(|j| self.limbs[k][D.saturating_sub(1) - j]) + }); + Self { limbs } + } +} diff --git a/src/algebra/ring/cyclotomic.rs b/src/algebra/ring/cyclotomic.rs new file mode 100644 index 00000000..eca109cd --- /dev/null +++ b/src/algebra/ring/cyclotomic.rs @@ -0,0 +1,967 @@ +//! Cyclotomic ring `Z_q[X]/(X^D + 1)` in coefficient form. + +use super::sparse_challenge::SparseChallenge; +use crate::algebra::fields::wide::ReduceTo; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{AdditiveGroup, CanonicalField, FieldCore, FieldSampling}; +use rand_core::RngCore; +use std::array::from_fn; +use std::io::{Read, Write}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +/// Element of the cyclotomic ring `Z_q[X]/(X^D + 1)`. +/// +/// Stored as `D` coefficients in the base field `F`, representing +/// `a_0 + a_1*X + ... + a_{D-1}*X^{D-1}`. +/// +/// Multiplication is negacyclic convolution: `X^D = -1`, so a product +/// term at index `i + j >= D` wraps to index `(i + j) - D` with a sign flip. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(transparent)] +pub struct CyclotomicRing { + pub(crate) coeffs: [F; D], +} + +impl CyclotomicRing { + /// Construct from a coefficient array. + #[inline] + pub fn from_coefficients(coeffs: [F; D]) -> Self { + Self { coeffs } + } + + /// Construct from a slice, zero-padding if shorter than `D`. + /// + /// Avoids creating a `[F; D]` stack temporary when `D` is large. + #[inline] + pub fn from_slice(slice: &[F]) -> Self { + let mut coeffs = [F::zero(); D]; + let len = slice.len().min(D); + coeffs[..len].copy_from_slice(&slice[..len]); + Self { coeffs } + } + + /// Borrow the coefficient array. + #[inline] + pub fn coefficients(&self) -> &[F; D] { + &self.coeffs + } + + /// Mutably borrow the coefficient array. + #[inline] + pub fn coefficients_mut(&mut self) -> &mut [F; D] { + &mut self.coeffs + } + + /// The additive identity (all-zero polynomial). + #[inline] + pub fn zero() -> Self { + Self { + coeffs: [F::zero(); D], + } + } + + /// The multiplicative identity (`1 + 0*X + ... + 0*X^{D-1}`). + #[inline] + pub fn one() -> Self { + let mut coeffs = [F::zero(); D]; + coeffs[0] = F::one(); + Self { coeffs } + } + + /// The monomial `X` (i.e., `[0, 1, 0, ..., 0]`). + /// + /// # Panics + /// + /// Panics if `D < 2`. + #[inline] + pub fn x() -> Self { + assert!(D >= 2, "ring degree must be at least 2"); + let mut coeffs = [F::zero(); D]; + coeffs[1] = F::one(); + Self { coeffs } + } + + /// Scalar multiplication: multiply every coefficient by `k`. + #[inline] + pub fn scale(&self, k: &F) -> Self { + let mut out = self.coeffs; + for c in &mut out { + *c = *c * *k; + } + Self { coeffs: out } + } + + /// Apply the cyclotomic automorphism `sigma_k: X -> X^k` for odd `k`. + /// + /// In `Z_q[X]/(X^D + 1)`, this permutes/sign-flips coefficients using + /// exponent reduction modulo `2D`. + /// + /// # Panics + /// + /// Panics if `D == 0` or `k` is not odd modulo `2D`. + pub fn sigma(&self, k: usize) -> Self { + assert!(D > 0, "ring degree must be non-zero"); + let two_d = 2 * D; + let k_mod = k % two_d; + assert!(k_mod % 2 == 1, "sigma_k requires odd k in Z_q[X]/(X^D + 1)"); + + let mut out = [F::zero(); D]; + for (j, coeff) in self.coeffs.iter().copied().enumerate() { + let idx = (j * k_mod) % two_d; + if idx < D { + out[idx] += coeff; + } else { + out[idx - D] -= coeff; + } + } + Self { coeffs: out } + } + + /// Apply `sigma_{-1}` (`X -> X^{-1} = X^{2D-1}` in this ring). + /// + /// # Panics + /// + /// Panics if `D == 0`. + pub fn sigma_m1(&self) -> Self { + assert!(D > 0, "ring degree must be non-zero"); + self.sigma(2 * D - 1) + } + + /// Multiply by `X^k` in `Z_q[X]/(X^D + 1)` via O(D) coefficient rotation. + /// + /// Since `X^D = -1`, coefficients that wrap past index `D` get negated. + #[inline] + pub fn negacyclic_shift(&self, k: usize) -> Self { + let k = k % D; + if k == 0 { + return *self; + } + let mut out = [F::zero(); D]; + for i in 0..D { + let target = i + k; + if target < D { + out[target] = self.coeffs[i]; + } else { + out[target - D] = -self.coeffs[i]; + } + } + Self { coeffs: out } + } + + /// Multiply `self` by a sum of monomials `X^{k_1} + X^{k_2} + ...` + /// + /// Each term is a negacyclic shift, so the total cost is + /// `O(positions.len() * D)` field additions with zero multiplications. + pub fn mul_by_monomial_sum(&self, nonzero_positions: &[usize]) -> Self { + let mut result = Self::zero(); + for &k in nonzero_positions { + self.shift_accumulate_into(&mut result, k); + } + result + } + + /// Fused negacyclic shift + accumulate: `dst += self * X^k`. + /// + /// Equivalent to `*dst += self.negacyclic_shift(k)` but avoids + /// allocating a temporary ring element. + #[inline] + pub fn shift_accumulate_into(&self, dst: &mut Self, k: usize) { + let k = k % D; + if k == 0 { + for i in 0..D { + dst.coeffs[i] += self.coeffs[i]; + } + return; + } + for i in 0..D { + let target = i + k; + if target < D { + dst.coeffs[target] += self.coeffs[i]; + } else { + dst.coeffs[target - D] -= self.coeffs[i]; + } + } + } + + /// Fused negacyclic shift + subtract: `dst -= self * X^k`. + /// + /// Equivalent to `*dst -= self.negacyclic_shift(k)` but avoids + /// allocating a temporary ring element. + #[inline] + pub fn shift_sub_into(&self, dst: &mut Self, k: usize) { + let k = k % D; + if k == 0 { + for i in 0..D { + dst.coeffs[i] -= self.coeffs[i]; + } + return; + } + for i in 0..D { + let target = i + k; + if target < D { + dst.coeffs[target] -= self.coeffs[i]; + } else { + dst.coeffs[target - D] += self.coeffs[i]; + } + } + } + + /// Fused multiply-by-monomial-sum + accumulate: + /// `dst += self * (X^{k_1} + X^{k_2} + ...)`. + /// + /// Equivalent to `*dst += self.mul_by_monomial_sum(positions)` but avoids + /// all intermediate temporaries. + pub fn mul_by_monomial_sum_into(&self, dst: &mut Self, nonzero_positions: &[usize]) { + for &k in nonzero_positions { + self.shift_accumulate_into(dst, k); + } + } + + /// Multiply `self` by a sparse challenge element. + /// + /// Cost: `O(omega * D)` field additions instead of `O(D^2)` multiplications. + /// For `omega=23, D=256` this is 5,888 adds vs 65,536 muls. + pub fn mul_by_sparse(&self, challenge: &SparseChallenge) -> Self + where + F: CanonicalField, + { + let mut result = Self::zero(); + for (&pos, &coeff) in challenge.positions.iter().zip(challenge.coeffs.iter()) { + match coeff { + 1 => self.shift_accumulate_into(&mut result, pos as usize), + -1 => self.shift_sub_into(&mut result, pos as usize), + c => { + let shifted = self.negacyclic_shift(pos as usize); + result += shifted.scale(&F::from_i64(c as i64)); + } + } + } + result + } + + /// Check whether all coefficients are zero. + #[inline] + pub fn is_zero(&self) -> bool { + self.coeffs.iter().all(|c| c.is_zero()) + } + + /// Count non-zero coefficients. + #[inline] + pub fn hamming_weight(&self) -> usize { + self.coeffs.iter().filter(|c| !c.is_zero()).count() + } + + /// Sample a sparse challenge with exactly `omega` non-zeros in `{+1, -1}`. + /// + /// # Panics + /// + /// Panics if `omega > D` or `D == 0` with non-zero `omega`. + pub fn sample_sparse_pm1(rng: &mut R, omega: usize) -> Self { + assert!(omega <= D, "omega must be <= ring degree"); + assert!(D > 0 || omega == 0, "ring degree must be non-zero"); + + let mut coeffs = [F::zero(); D]; + let mut placed = 0usize; + while placed < omega { + let idx = (rng.next_u64() % (D as u64)) as usize; + if coeffs[idx].is_zero() { + coeffs[idx] = if (rng.next_u32() & 1) == 0 { + F::one() + } else { + -F::one() + }; + placed += 1; + } + } + Self { coeffs } + } +} + +impl CyclotomicRing { + /// Balanced decomposition writing directly into a pre-allocated output slice. + /// + /// `out` must have length exactly `levels`. Each element receives one digit plane. + /// + /// # Panics + /// + /// Panics if `log_basis == 0`, `log_basis >= 128`, or `out.len() * log_basis > 128 + log_basis`. + pub fn balanced_decompose_pow2_into(&self, out: &mut [Self], log_basis: u32) { + let levels = out.len(); + assert!(log_basis > 0 && log_basis < 128, "invalid log_basis"); + assert!( + (levels as u32).saturating_mul(log_basis) <= 128 + log_basis, + "levels * log_basis must be <= 128 + log_basis" + ); + + let half_b = 1i128 << (log_basis - 1); + let b = half_b << 1; + let mask = b - 1; + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + + for plane in out.iter_mut() { + *plane = Self::zero(); + } + + for i in 0..D { + let canonical = self.coeffs[i].to_canonical_u128(); + let mut c: i128 = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + + for plane in out.iter_mut() { + let d = c & mask; + let balanced = if d >= half_b { d - b } else { d }; + c = (c - balanced) >> log_basis; + + plane.coeffs[i] = if balanced >= 0 { + F::from_canonical_u128_reduced(balanced as u128) + } else { + F::from_canonical_u128_reduced(q - ((-balanced) as u128)) + }; + } + } + } + + /// Squared Euclidean norm of centered integer coefficients. + /// + /// Coefficients are centered into `(-q/2, q/2]` and accumulated as + /// `sum_i c_i^2`, using saturating arithmetic. + #[inline] + pub fn coeff_norm_sq(&self) -> u128 + where + F: CanonicalField, + { + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + self.coeffs.iter().fold(0u128, |acc, &coeff| { + let canonical = coeff.to_canonical_u128(); + let centered: i128 = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + let abs = centered.unsigned_abs(); + acc.saturating_add(abs.saturating_mul(abs)) + }) + } + + /// Functional gadget recomposition (`G * digits`) for base `2^log_basis`. + /// + /// Coefficients from each part are interpreted as one digit plane and + /// recombined back into canonical integers (then reduced into the field). + /// + /// # Panics + /// + /// Panics if `log_basis == 0`, `log_basis >= 128`, or `parts.len() * log_basis > 128`. + pub fn gadget_recompose_pow2(parts: &[Self], log_basis: u32) -> Self { + if parts.is_empty() { + return Self::zero(); + } + + assert!(log_basis > 0 && log_basis < 128, "invalid log_basis"); + + let b = F::from_canonical_u128_reduced(1u128 << log_basis); + let coeffs = from_fn(|i| { + let mut acc = F::zero(); + let mut power = F::one(); + for part in parts.iter() { + acc += part.coeffs[i] * power; + power = power * b; + } + acc + }); + Self { coeffs } + } + + /// Recompose from i8 digit planes (output of `balanced_decompose_pow2_i8`). + /// + /// # Panics + /// + /// Panics if `log_basis` is zero or >= 128. + pub fn gadget_recompose_pow2_i8(digits: &[[i8; D]], log_basis: u32) -> Self + where + F: CanonicalField, + { + if digits.is_empty() { + return Self::zero(); + } + assert!(log_basis > 0 && log_basis < 128, "invalid log_basis"); + + let b = F::from_canonical_u128_reduced(1u128 << log_basis); + let coeffs = from_fn(|i| { + let mut acc = F::zero(); + let mut power = F::one(); + for plane in digits { + acc += F::from_i64(plane[i] as i64) * power; + power = power * b; + } + acc + }); + Self { coeffs } + } + + /// Balanced (centered) base-`2^log_basis` gadget decomposition: `G^{-1}`. + /// + /// Each coefficient `c` (centered into `(-q/2, q/2]`) is decomposed into + /// `levels` balanced digits `d_k ∈ [-b/2, b/2)` satisfying + /// `c ≡ Σ_k d_k · b^k (mod q)`. + /// + /// Negative digits are stored as their field representation (`q + d`). + /// + /// # Panics + /// + /// Panics if `log_basis == 0`, `log_basis >= 128`, or `levels * log_basis > 128`. + pub fn balanced_decompose_pow2(&self, levels: usize, log_basis: u32) -> Vec { + assert!(log_basis > 0 && log_basis < 128, "invalid log_basis"); + assert!( + (levels as u32).saturating_mul(log_basis) <= 128 + log_basis, + "levels * log_basis must be <= 128 + log_basis" + ); + + let half_b = 1i128 << (log_basis - 1); + let b = half_b << 1; + let mask = b - 1; + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + + let mut digit_planes: Vec<[F; D]> = (0..levels).map(|_| [F::zero(); D]).collect(); + + for i in 0..D { + let canonical = self.coeffs[i].to_canonical_u128(); + let mut c: i128 = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + + for plane in digit_planes.iter_mut() { + let d = c & mask; + let balanced = if d >= half_b { d - b } else { d }; + c = (c - balanced) >> log_basis; + + plane[i] = if balanced >= 0 { + F::from_canonical_u128_reduced(balanced as u128) + } else { + F::from_canonical_u128_reduced(q - ((-balanced) as u128)) + }; + } + } + + digit_planes + .into_iter() + .map(Self::from_coefficients) + .collect() + } + + /// Balanced gadget decomposition into native `i8` digits. + /// + /// Same semantics as [`balanced_decompose_pow2`](Self::balanced_decompose_pow2) + /// but stores each digit as `i8` instead of a field element, avoiding + /// the cost of `F::from_canonical_u128_reduced`. + /// + /// Requires `log_basis <= 7` so digits fit in `[-64, 63]` (i8 range). + /// + /// # Panics + /// + /// Panics if `log_basis` is 0 or > 7, or if `levels * log_basis > 128 + log_basis`. + pub fn balanced_decompose_pow2_i8_into(&self, out: &mut [[i8; D]], log_basis: u32) + where + F: CanonicalField, + { + let levels = out.len(); + assert!( + log_basis > 0 && log_basis <= 7, + "log_basis must be in 1..=7 for i8 output" + ); + assert!( + (levels as u32).saturating_mul(log_basis) <= 128 + log_basis, + "levels * log_basis must be <= 128 + log_basis" + ); + + let half_b = 1i128 << (log_basis - 1); + let b = half_b << 1; + let mask = b - 1; + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + + for i in 0..D { + let canonical = self.coeffs[i].to_canonical_u128(); + let mut c: i128 = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + + for plane in out.iter_mut() { + let d = c & mask; + let balanced = if d >= half_b { d - b } else { d }; + c = (c - balanced) >> log_basis; + plane[i] = balanced as i8; + } + } + } + + /// Allocating variant of [`balanced_decompose_pow2_i8_into`](Self::balanced_decompose_pow2_i8_into). + pub fn balanced_decompose_pow2_i8(&self, levels: usize, log_basis: u32) -> Vec<[i8; D]> + where + F: CanonicalField, + { + let mut digit_planes: Vec<[i8; D]> = vec![[0i8; D]; levels]; + self.balanced_decompose_pow2_i8_into(&mut digit_planes, log_basis); + digit_planes + } + + /// Balanced decomposition where the last digit carries the remainder. + /// + /// The first `levels-1` digits are balanced in `[-b/2, b/2)`, while the + /// final digit is the remaining (possibly larger) centered value. + /// + /// # Panics + /// + /// Panics if `levels` is zero, `log_basis` is zero or >= 128, or + /// `levels * log_basis > 128`. + pub fn balanced_decompose_pow2_with_carry(&self, levels: usize, log_basis: u32) -> Vec + where + F: CanonicalField, + { + assert!(levels > 0, "levels must be positive"); + assert!(log_basis > 0 && log_basis < 128, "invalid log_basis"); + assert!( + (levels as u32).saturating_mul(log_basis) <= 128, + "levels * log_basis must be <= 128" + ); + + let b = 1i128 << log_basis; + let half_b = b / 2; + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + + let mut digit_planes: Vec<[F; D]> = (0..levels).map(|_| [F::zero(); D]).collect(); + + for i in 0..D { + let canonical = self.coeffs[i].to_canonical_u128(); + let mut c: i128 = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + + for (plane_idx, plane) in digit_planes.iter_mut().enumerate() { + let balanced = if plane_idx + 1 == levels { + c + } else { + let d = c.rem_euclid(b); + let digit = if d >= half_b { d - b } else { d }; + c = (c - digit) / b; + digit + }; + + plane[i] = if balanced >= 0 { + F::from_canonical_u128_reduced(balanced as u128) + } else { + F::from_canonical_u128_reduced(q - ((-balanced) as u128)) + }; + } + } + + digit_planes + .into_iter() + .map(Self::from_coefficients) + .collect() + } +} + +impl CyclotomicRing { + /// Generate a random ring element. + pub fn random(rng: &mut R) -> Self { + Self { + coeffs: from_fn(|_| F::sample(rng)), + } + } +} + +impl AddAssign for CyclotomicRing { + fn add_assign(&mut self, rhs: Self) { + for (dst, src) in self.coeffs.iter_mut().zip(rhs.coeffs.iter()) { + *dst = *dst + *src; + } + } +} + +impl SubAssign for CyclotomicRing { + fn sub_assign(&mut self, rhs: Self) { + for (dst, src) in self.coeffs.iter_mut().zip(rhs.coeffs.iter()) { + *dst = *dst - *src; + } + } +} + +impl Add for CyclotomicRing { + type Output = Self; + fn add(mut self, rhs: Self) -> Self { + self += rhs; + self + } +} + +impl Sub for CyclotomicRing { + type Output = Self; + fn sub(mut self, rhs: Self) -> Self { + self -= rhs; + self + } +} + +impl Neg for CyclotomicRing { + type Output = Self; + fn neg(self) -> Self { + let mut out = self.coeffs; + for c in &mut out { + *c = -*c; + } + Self { coeffs: out } + } +} + +impl MulAssign for CyclotomicRing { + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl<'a, F: FieldCore, const D: usize> Add<&'a Self> for CyclotomicRing { + type Output = Self; + fn add(self, rhs: &'a Self) -> Self { + self + *rhs + } +} + +impl<'a, F: FieldCore, const D: usize> Sub<&'a Self> for CyclotomicRing { + type Output = Self; + fn sub(self, rhs: &'a Self) -> Self { + self - *rhs + } +} + +impl<'a, F: FieldCore, const D: usize> Mul<&'a Self> for CyclotomicRing { + type Output = Self; + fn mul(self, rhs: &'a Self) -> Self { + self * *rhs + } +} + +/// Schoolbook negacyclic convolution: O(D^2). +/// +/// For each pair `(i, j)`: +/// - If `i + j < D`: accumulate `a_i * b_j` at index `i + j`. +/// - If `i + j >= D`: accumulate `-(a_i * b_j)` at index `(i + j) - D`. +impl Mul for CyclotomicRing { + type Output = Self; + fn mul(self, rhs: Self) -> Self { + let mut out = [F::zero(); D]; + for i in 0..D { + for j in 0..D { + let product = self.coeffs[i] * rhs.coeffs[j]; + let idx = i + j; + if idx < D { + out[idx] += product; + } else { + out[idx - D] -= product; + } + } + } + Self { coeffs: out } + } +} + +impl Valid for CyclotomicRing { + fn check(&self) -> Result<(), SerializationError> { + for x in self.coeffs.iter() { + x.check()?; + } + Ok(()) + } +} + +impl HachiSerialize for CyclotomicRing { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + for x in self.coeffs.iter() { + x.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.coeffs + .iter() + .map(|x| x.serialized_size(compress)) + .sum() + } +} + +impl HachiDeserialize for CyclotomicRing { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let mut coeffs = [F::zero(); D]; + for c in &mut coeffs { + *c = F::deserialize_with_mode(&mut reader, compress, validate)?; + } + let out = Self { coeffs }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl Default for CyclotomicRing { + fn default() -> Self { + Self::zero() + } +} + +/// Wide (unreduced) cyclotomic ring element for carry-free accumulation. +/// +/// Coefficients are wide accumulators (`W: AdditiveGroup`) that support +/// addition/subtraction without modular reduction. After accumulation, +/// call [`reduce`](Self::reduce) to convert back to `CyclotomicRing`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(transparent)] +pub struct WideCyclotomicRing { + pub(crate) coeffs: [W; D], +} + +impl WideCyclotomicRing { + /// The additive identity (all-zero coefficients). + pub const ZERO: Self = Self { + coeffs: [W::ZERO; D], + }; + + /// Returns the zero ring element. + #[inline] + pub fn zero() -> Self { + Self::ZERO + } + + /// Convert a reduced `CyclotomicRing` into wide form. + #[inline] + pub fn from_ring(ring: &CyclotomicRing) -> Self + where + W: From, + { + Self { + coeffs: from_fn(|i| W::from(ring.coeffs[i])), + } + } + + /// Reduce all coefficients back to canonical field form. + #[inline] + pub fn reduce(&self) -> CyclotomicRing + where + W: ReduceTo, + { + CyclotomicRing { + coeffs: from_fn(|i| self.coeffs[i].reduce()), + } + } + + /// Fused negacyclic shift + accumulate: `dst += self * X^k`. + #[inline] + pub fn shift_accumulate_into(&self, dst: &mut Self, k: usize) { + let k = k % D; + if k == 0 { + for i in 0..D { + dst.coeffs[i] += self.coeffs[i]; + } + return; + } + for i in 0..D { + let target = i + k; + if target < D { + dst.coeffs[target] += self.coeffs[i]; + } else { + dst.coeffs[target - D] -= self.coeffs[i]; + } + } + } + + /// Fused negacyclic shift + subtract: `dst -= self * X^k`. + #[inline] + pub fn shift_sub_into(&self, dst: &mut Self, k: usize) { + let k = k % D; + if k == 0 { + for i in 0..D { + dst.coeffs[i] -= self.coeffs[i]; + } + return; + } + for i in 0..D { + let target = i + k; + if target < D { + dst.coeffs[target] -= self.coeffs[i]; + } else { + dst.coeffs[target - D] += self.coeffs[i]; + } + } + } + + /// Fused multiply-by-monomial-sum + accumulate: + /// `dst += self * (X^{k_1} + X^{k_2} + ...)`. + pub fn mul_by_monomial_sum_into(&self, dst: &mut Self, nonzero_positions: &[usize]) { + for &k in nonzero_positions { + self.shift_accumulate_into(dst, k); + } + } +} + +impl Add for WideCyclotomicRing { + type Output = Self; + fn add(mut self, rhs: Self) -> Self { + for i in 0..D { + self.coeffs[i] += rhs.coeffs[i]; + } + self + } +} + +impl AddAssign for WideCyclotomicRing { + fn add_assign(&mut self, rhs: Self) { + for i in 0..D { + self.coeffs[i] += rhs.coeffs[i]; + } + } +} + +impl Sub for WideCyclotomicRing { + type Output = Self; + fn sub(mut self, rhs: Self) -> Self { + for i in 0..D { + self.coeffs[i] -= rhs.coeffs[i]; + } + self + } +} + +impl SubAssign for WideCyclotomicRing { + fn sub_assign(&mut self, rhs: Self) { + for i in 0..D { + self.coeffs[i] -= rhs.coeffs[i]; + } + } +} + +impl Neg for WideCyclotomicRing { + type Output = Self; + fn neg(self) -> Self { + Self { + coeffs: from_fn(|i| -self.coeffs[i]), + } + } +} + +impl Default for WideCyclotomicRing { + fn default() -> Self { + Self::zero() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::{Fp128x8i32, Fp64, Fp64x4i32, Prime128M8M4M1M0}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F64 = Fp64<4294967197>; + type F128 = Prime128M8M4M1M0; + const D: usize = 64; + + #[test] + fn wide_shift_accumulate_matches_narrow_fp64() { + let mut rng = StdRng::seed_from_u64(0x1234); + let src = CyclotomicRing::::random(&mut rng); + let initial = CyclotomicRing::::random(&mut rng); + + for k in [0, 1, 7, 31, 63] { + let mut narrow = initial; + src.shift_accumulate_into(&mut narrow, k); + + let wide_src = WideCyclotomicRing::::from_ring(&src); + let mut wide_dst = WideCyclotomicRing::::from_ring(&initial); + wide_src.shift_accumulate_into(&mut wide_dst, k); + let wide_reduced: CyclotomicRing = wide_dst.reduce(); + + assert_eq!(narrow, wide_reduced, "shift_accumulate k={k}"); + } + } + + #[test] + fn wide_shift_sub_matches_narrow_fp64() { + let mut rng = StdRng::seed_from_u64(0x5678); + let src = CyclotomicRing::::random(&mut rng); + let initial = CyclotomicRing::::random(&mut rng); + + for k in [0, 1, 15, 32, 63] { + let mut narrow = initial; + src.shift_sub_into(&mut narrow, k); + + let wide_src = WideCyclotomicRing::::from_ring(&src); + let mut wide_dst = WideCyclotomicRing::::from_ring(&initial); + wide_src.shift_sub_into(&mut wide_dst, k); + let wide_reduced: CyclotomicRing = wide_dst.reduce(); + + assert_eq!(narrow, wide_reduced, "shift_sub k={k}"); + } + } + + #[test] + fn wide_mul_by_monomial_sum_matches_narrow_fp64() { + let mut rng = StdRng::seed_from_u64(0xabcd); + let src = CyclotomicRing::::random(&mut rng); + let positions = vec![0, 5, 17, 42, 63]; + + let mut narrow = CyclotomicRing::::zero(); + src.mul_by_monomial_sum_into(&mut narrow, &positions); + + let wide_src = WideCyclotomicRing::::from_ring(&src); + let mut wide_dst = WideCyclotomicRing::::zero(); + wide_src.mul_by_monomial_sum_into(&mut wide_dst, &positions); + let wide_reduced: CyclotomicRing = wide_dst.reduce(); + + assert_eq!(narrow, wide_reduced); + } + + #[test] + fn wide_many_accumulations_fp128() { + let mut rng = StdRng::seed_from_u64(0xbeef); + let src = CyclotomicRing::::random(&mut rng); + + let mut narrow = CyclotomicRing::::zero(); + let wide_src = WideCyclotomicRing::::from_ring(&src); + let mut wide_dst = WideCyclotomicRing::::zero(); + + for k in 0..50 { + src.shift_accumulate_into(&mut narrow, k % D); + wide_src.shift_accumulate_into(&mut wide_dst, k % D); + } + for k in 0..30 { + src.shift_sub_into(&mut narrow, k % D); + wide_src.shift_sub_into(&mut wide_dst, k % D); + } + + let wide_reduced: CyclotomicRing = wide_dst.reduce(); + assert_eq!(narrow, wide_reduced); + } +} diff --git a/src/algebra/ring/mod.rs b/src/algebra/ring/mod.rs new file mode 100644 index 00000000..3a372a47 --- /dev/null +++ b/src/algebra/ring/mod.rs @@ -0,0 +1,11 @@ +//! Cyclotomic ring types and NTT representations. + +pub mod crt_ntt_repr; +pub mod cyclotomic; +pub mod sparse_challenge; + +pub use crt_ntt_repr::{CrtNttConvertibleField, CrtNttParamSet, CyclotomicCrtNtt, DigitMontLut}; +pub use cyclotomic::{CyclotomicRing, WideCyclotomicRing}; +pub use sparse_challenge::{ + sample_quaternary, sample_ternary, SparseChallenge, SparseChallengeConfig, +}; diff --git a/src/algebra/ring/sparse_challenge.rs b/src/algebra/ring/sparse_challenge.rs new file mode 100644 index 00000000..5cf14b9a --- /dev/null +++ b/src/algebra/ring/sparse_challenge.rs @@ -0,0 +1,212 @@ +//! Sparse ring challenges for cyclotomic protocols. +//! +//! Many lattice protocols sample "short/sparse" ring challenges whose coefficients +//! are mostly zero and whose non-zero coefficients come from a tiny integer alphabet +//! (e.g. `{±1}` or `{±1,±2}`), with a fixed Hamming weight `ω`. +//! +//! This module provides a minimal representation that is: +//! - independent of any specific protocol (Hachi/Greyhound/SuperNeo, etc.), +//! - easy to sample deterministically from Fiat–Shamir at the protocol layer, +//! - and efficient to evaluate at a point `α` using precomputed powers. + +use super::CyclotomicRing; +use crate::algebra::fields::LiftBase; +use crate::{CanonicalField, FieldCore}; +use rand_core::RngCore; + +/// Configuration for sampling a sparse challenge. +/// +/// This intentionally avoids redundant knobs: the distribution is determined by: +/// - exact `weight` (Hamming weight), +/// - and a list of allowed **non-zero** integer coefficients. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SparseChallengeConfig { + /// Exact Hamming weight ω. + pub weight: usize, + /// Allowed non-zero coefficients (small signed integers). + /// + /// Examples: + /// - `{±1}`: `vec![-1, 1]` + /// - `{±1,±2}`: `vec![-2, -1, 1, 2]` + pub nonzero_coeffs: Vec, +} + +impl SparseChallengeConfig { + /// Validate basic invariants for a given ring degree `D`. + /// + /// # Errors + /// + /// Returns an error if `weight > D`, if `nonzero_coeffs` is empty, or if it + /// contains `0`. + pub fn validate(&self) -> Result<(), &'static str> { + if self.weight > D { + return Err("weight must be <= ring degree D"); + } + if self.nonzero_coeffs.is_empty() { + return Err("nonzero_coeffs must be non-empty"); + } + if self.nonzero_coeffs.contains(&0) { + return Err("nonzero_coeffs must not contain 0"); + } + Ok(()) + } +} + +/// Sparse polynomial in `F[X]/(X^D+1)` represented by its non-zero terms. +/// +/// Invariants: +/// - `positions.len() == coeffs.len()` +/// - all positions are `< D` +/// - positions are unique +/// - all coeffs are non-zero +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SparseChallenge { + /// Coefficient indices (powers of `X`) where the polynomial is non-zero. + pub positions: Vec, + /// Small integer coefficients at the corresponding positions. + pub coeffs: Vec, +} + +impl SparseChallenge { + /// Construct an empty (all-zero) challenge. + #[inline] + pub fn zero() -> Self { + Self { + positions: Vec::new(), + coeffs: Vec::new(), + } + } + + /// Number of non-zero coefficients (Hamming weight). + #[inline] + pub fn hamming_weight(&self) -> usize { + debug_assert_eq!(self.positions.len(), self.coeffs.len()); + self.positions.len() + } + + /// ℓ₁ norm over integers: `Σ |coeff_i|`. + #[inline] + pub fn l1_norm(&self) -> u64 { + self.coeffs + .iter() + .map(|&c| (c as i32).unsigned_abs() as u64) + .sum() + } + + /// Validate structural invariants for a ring degree `D`. + /// + /// # Errors + /// + /// Returns an error if lengths mismatch, if any coefficient is zero, if any + /// position is out of range, or if positions contain duplicates. + pub fn validate(&self) -> Result<(), &'static str> { + if self.positions.len() != self.coeffs.len() { + return Err("positions and coeffs must have same length"); + } + // Check coeffs are non-zero and positions are in range + unique. + let mut seen = vec![false; D]; + for (&pos, &c) in self.positions.iter().zip(self.coeffs.iter()) { + if c == 0 { + return Err("coeffs must not contain 0"); + } + let p = pos as usize; + if p >= D { + return Err("position out of range"); + } + if seen[p] { + return Err("positions must be unique"); + } + seen[p] = true; + } + Ok(()) + } + + /// Convert to a dense ring element by placing coefficients in the canonical + /// coefficient basis. + /// + /// # Errors + /// + /// Returns an error if the sparse representation violates structural invariants. + pub fn to_dense( + &self, + ) -> Result, &'static str> { + self.validate::()?; + let mut out = [F::zero(); D]; + for (&pos, &c) in self.positions.iter().zip(self.coeffs.iter()) { + out[pos as usize] += F::from_i64(c as i64); + } + Ok(CyclotomicRing::from_coefficients(out)) + } + + /// Evaluate this sparse polynomial at `α` in `E`, given precomputed powers + /// `[α^0, α^1, ..., α^{D-1}]`. + /// + /// This is `O(weight)` and is intended to be used for verifier-side oracles + /// where `D` may be large but `weight` is small. + /// + /// # Errors + /// + /// Returns an error if structural invariants fail or if `alpha_pows.len() != D`. + pub fn eval_at_alpha(&self, alpha_pows: &[E]) -> Result + where + F: FieldCore + CanonicalField, + E: FieldCore + LiftBase, + { + self.validate::()?; + if alpha_pows.len() != D { + return Err("alpha_pows length mismatch"); + } + let mut acc = E::zero(); + for (&pos, &c) in self.positions.iter().zip(self.coeffs.iter()) { + let coeff_f = F::from_i64(c as i64); + acc += E::lift_base(coeff_f) * alpha_pows[pos as usize]; + } + Ok(acc) + } +} + +/// Sample a dense ternary ring element with coefficients in `{-1, 0, 1}`. +/// +/// Distribution matches Labrador C's ternary nibble LUT (`0xA815`), yielding +/// probabilities `5/16, 6/16, 5/16` for `-1, 0, 1` respectively. +pub fn sample_ternary( + rng: &mut R, +) -> CyclotomicRing { + const LUT: u16 = 0xA815; + let mut coeffs = [F::zero(); D]; + let mut i = 0usize; + while i < D { + let byte = (rng.next_u32() & 0xFF) as u8; + let lo = (((LUT >> (byte & 0x0F)) & 0x3) as i16) - 1; + coeffs[i] = F::from_i64(lo as i64); + i += 1; + if i < D { + let hi = (((LUT >> (byte >> 4)) & 0x3) as i16) - 1; + coeffs[i] = F::from_i64(hi as i64); + i += 1; + } + } + CyclotomicRing::from_coefficients(coeffs) +} + +/// Sample a dense quaternary ring element with coefficients in `{-2, -1, 0, 1}`. +/// +/// Coefficients are sampled uniformly from two-bit chunks and shifted by `-2`. +pub fn sample_quaternary( + rng: &mut R, +) -> CyclotomicRing { + let mut coeffs = [F::zero(); D]; + let mut i = 0usize; + while i < D { + let bits = rng.next_u32(); + for lane in 0..16 { + if i >= D { + break; + } + let val = (((bits >> (2 * lane)) & 0x3) as i16) - 2; + coeffs[i] = F::from_i64(val as i64); + i += 1; + } + } + CyclotomicRing::from_coefficients(coeffs) +} diff --git a/src/lib.rs b/src/lib.rs index e3859c54..29b68713 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,7 @@ +#![cfg_attr( + all(target_arch = "x86_64", target_feature = "avx512f"), + feature(stdarch_x86_avx512) +)] //! # hachi //! //! A high performance and modular implementation of the Hachi polynomial commitment scheme. @@ -17,7 +21,7 @@ //! ### Core Modules //! - [`primitives`] - Core traits and abstractions //! - [`primitives::arithmetic`] - Field and module traits for lattice arithmetic -//! - [`primitives::poly`] - Multilinear polynomial traits and operations +//! - [`primitives::poly`] - Multilinear polynomial utility functions //! - [`primitives::transcript`] - Fiat-Shamir transcript trait //! - [`primitives::serialization`] - Serialization abstractions //! - [`error`] - Error types @@ -35,7 +39,26 @@ pub mod error; /// Primitive traits and operations pub mod primitives; +/// Concrete algebra backends (prime fields, extensions, rings) +pub mod algebra; + +/// Protocol-layer transcript and commitment abstractions +pub mod protocol; + +/// Conditional parallelism utilities (`cfg_iter!`, `cfg_into_iter!`, etc.) +#[doc(hidden)] +pub mod parallel; + +/// Shared test configuration and helpers. +#[doc(hidden)] +pub mod test_utils; + pub use error::HachiError; -pub use primitives::arithmetic::{Field, HachiRoutines, Module}; -pub use primitives::poly::{MultilinearLagrange, Polynomial}; +pub use primitives::arithmetic::{ + AdditiveGroup, CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible, Module, + PseudoMersenneField, +}; pub use primitives::serialization::{HachiDeserialize, HachiSerialize}; +pub use protocol::{ + BasisMode, CommitmentScheme, DensePoly, HachiPolyOps, OneHotIndex, OneHotPoly, Transcript, +}; diff --git a/src/parallel.rs b/src/parallel.rs new file mode 100644 index 00000000..1be93577 --- /dev/null +++ b/src/parallel.rs @@ -0,0 +1,74 @@ +//! Conditional parallelism utilities. +//! +//! When the `parallel` feature is enabled, the `cfg_iter!` family of macros +//! expand to rayon's parallel iterators. Otherwise they fall back to standard +//! sequential iterators. + +#[cfg(feature = "parallel")] +pub use rayon::prelude::*; + +/// Returns `.par_iter()` when `parallel` is enabled, `.iter()` otherwise. +#[macro_export] +macro_rules! cfg_iter { + ($e:expr) => {{ + #[cfg(feature = "parallel")] + let it = $e.par_iter(); + #[cfg(not(feature = "parallel"))] + let it = $e.iter(); + it + }}; +} + +/// Returns `.par_iter_mut()` when `parallel` is enabled, `.iter_mut()` otherwise. +#[macro_export] +macro_rules! cfg_iter_mut { + ($e:expr) => {{ + #[cfg(feature = "parallel")] + let it = $e.par_iter_mut(); + #[cfg(not(feature = "parallel"))] + let it = $e.iter_mut(); + it + }}; +} + +/// Returns `.into_par_iter()` when `parallel` is enabled, `.into_iter()` otherwise. +#[macro_export] +macro_rules! cfg_into_iter { + ($e:expr) => {{ + #[cfg(feature = "parallel")] + let it = $e.into_par_iter(); + #[cfg(not(feature = "parallel"))] + let it = $e.into_iter(); + it + }}; +} + +/// Returns `.par_chunks(n)` when `parallel` is enabled, `.chunks(n)` otherwise. +#[macro_export] +macro_rules! cfg_chunks { + ($e:expr, $n:expr) => {{ + #[cfg(feature = "parallel")] + let it = $e.par_chunks($n); + #[cfg(not(feature = "parallel"))] + let it = $e.chunks($n); + it + }}; +} + +/// Parallel fold-reduce over a range. +/// +/// With `parallel`: `range.into_par_iter().fold(identity, fold_op).reduce(identity, reduce_op)`. +/// Without: `range.into_iter().fold(identity(), fold_op)`. +#[macro_export] +macro_rules! cfg_fold_reduce { + ($range:expr, $identity:expr, $fold_op:expr, $reduce_op:expr) => {{ + #[cfg(feature = "parallel")] + let result = $range + .into_par_iter() + .fold($identity, $fold_op) + .reduce($identity, $reduce_op); + #[cfg(not(feature = "parallel"))] + let result = $range.into_iter().fold(($identity)(), $fold_op); + result + }}; +} diff --git a/src/primitives/arithmetic.rs b/src/primitives/arithmetic.rs index ccaaa4ae..655e68ae 100644 --- a/src/primitives/arithmetic.rs +++ b/src/primitives/arithmetic.rs @@ -3,26 +3,42 @@ use super::{HachiDeserialize, HachiSerialize}; use rand_core::RngCore; -/// Field trait for lattice-based arithmetic -pub trait Field: +/// Minimal additive group: add, sub, neg, zero. +/// +/// Satisfied by both reduced field elements (`FieldCore`) and wide unreduced +/// accumulators (`Fp128x8i32`, etc.), enabling generic shift-accumulate +/// operations on `WideCyclotomicRing`. +pub trait AdditiveGroup: Sized + Clone + Copy - + PartialEq + Send + Sync - + HachiSerialize - + HachiDeserialize + std::ops::Add + std::ops::Sub - + std::ops::Mul + std::ops::Neg + + std::ops::AddAssign + + std::ops::SubAssign +{ + /// Additive identity. + const ZERO: Self; +} + +/// Core field operations required across algebra backends. +pub trait FieldCore: + AdditiveGroup + + PartialEq + + HachiSerialize + + HachiDeserialize + + std::ops::Mul + for<'a> std::ops::Add<&'a Self, Output = Self> + for<'a> std::ops::Sub<&'a Self, Output = Self> + for<'a> std::ops::Mul<&'a Self, Output = Self> { - /// Additive identity - fn zero() -> Self; + /// Additive identity. + fn zero() -> Self { + Self::ZERO + } /// Multiplicative identity fn one() -> Self; @@ -30,26 +46,149 @@ pub trait Field: /// Check if element is zero fn is_zero(&self) -> bool; - /// Field addition - fn add(&self, rhs: &Self) -> Self; - - /// Field subtraction - fn sub(&self, rhs: &Self) -> Self; - - /// Field multiplication - fn mul(&self, rhs: &Self) -> Self; - - /// Field inversion + /// Field squaring. + /// + /// Default is `self * self`; extension fields override with specialized + /// formulas that use fewer base-field multiplications. + fn square(&self) -> Self { + *self * *self + } + + /// Field inversion. + /// + /// This API may branch on zero-check and is intended for public/non-secret + /// values. For secret-bearing paths, use [`Invertible::inv_or_zero`]. fn inv(self) -> Option; - /// Generate random field element - fn random(rng: &mut R) -> Self; + /// Multiplicative inverse of 2: `(p + 1) / 2` for odd-characteristic fields. + const TWO_INV: Self; +} - /// Convert from u64 +/// Constant-time inversion helper for secret-bearing code paths. +/// +/// Implementations return `0` when the input is `0`, and `x^{-1}` otherwise, +/// without branching on the input value. +pub trait Invertible: FieldCore { + /// Constant-time inversion with zero-mapping behavior. + fn inv_or_zero(self) -> Self; +} + +/// Embed small integers into a field. +/// +/// Every field contains a copy of its prime subfield, and small integers embed +/// into it canonically via reduction modulo the characteristic. This trait is +/// implementable for ALL fields — base and extension alike. +/// +/// Only `from_u64` and `from_i64` need concrete implementations; the narrower +/// widths have default impls via lossless widening. +pub trait FromSmallInt: FieldCore { + /// Embed a `u8` into the field. + fn from_u8(val: u8) -> Self { + Self::from_u64(val as u64) + } + + /// Embed an `i8` into the field. + fn from_i8(val: i8) -> Self { + Self::from_i64(val as i64) + } + + /// Embed a `u16` into the field. + fn from_u16(val: u16) -> Self { + Self::from_u64(val as u64) + } + + /// Embed an `i16` into the field. + fn from_i16(val: i16) -> Self { + Self::from_i64(val as i64) + } + + /// Embed a `u32` into the field. + fn from_u32(val: u32) -> Self { + Self::from_u64(val as u64) + } + + /// Embed an `i32` into the field. + fn from_i32(val: i32) -> Self { + Self::from_i64(val as i64) + } + + /// Embed a `u64` into the field (reduce mod characteristic). fn from_u64(val: u64) -> Self; - /// Convert from i64 + /// Embed an `i64` into the field (reduce mod characteristic). fn from_i64(val: i64) -> Self; + + /// Embed an `i128` into the field. + /// + /// Default implementation splits into u64 limbs with field multiplication + /// by `2^64`. Override for base fields that have a direct path. + fn from_i128(val: i128) -> Self { + if val >= 0 { + let lo = val as u64; + let hi = (val >> 64) as u64; + if hi == 0 { + Self::from_u64(lo) + } else { + let two_64 = Self::from_u64(1u64 << 32) * Self::from_u64(1u64 << 32); + Self::from_u64(lo) + Self::from_u64(hi) * two_64 + } + } else { + -Self::from_i128(-val) + } + } + + /// Lookup table mapping balanced digit index → field element. + /// + /// For `log_basis` in `1..=4`, returns a 16-entry table where + /// `table[i]` = `from_i64(i - b/2)` for `i < b = 2^log_basis`, + /// and zero for `i >= b`. + /// + /// Index a digit `d ∈ [-b/2, b/2)` as `table[(d + b/2) as usize]`. + fn digit_lut(log_basis: u32) -> [Self; 16] { + debug_assert!(log_basis > 0 && log_basis <= 4); + let b = 1usize << log_basis; + let half_b = (b >> 1) as i64; + std::array::from_fn(|i| { + if i < b { + Self::from_i64(i as i64 - half_b) + } else { + Self::zero() + } + }) + } +} + +/// Canonical integer representation for prime (base) field elements. +/// +/// Provides a bijection between field elements and integers in `[0, p)`. +/// Only meaningful for base prime fields where elements ARE residues mod p. +/// Extension fields should NOT implement this trait. +pub trait CanonicalField: FromSmallInt { + /// Return canonical integer representation as `u128`. + fn to_canonical_u128(self) -> u128; + + /// Construct from canonical value if it is in range. + fn from_canonical_u128_checked(val: u128) -> Option; + + /// Construct from canonical value reduced modulo the field modulus. + fn from_canonical_u128_reduced(val: u128) -> Self; +} + +/// Optional sampling support for field elements. +/// +/// This is intentionally separate from core field algebra and may evolve. +pub trait FieldSampling: FieldCore { + /// Generate a sampled field element. + fn sample(rng: &mut R) -> Self; +} + +/// Metadata for pseudo-Mersenne style moduli (`2^k - c`). +pub trait PseudoMersenneField: CanonicalField { + /// Exponent `k` in `2^k - c`. + const MODULUS_BITS: u32; + + /// Offset `c` in `2^k - c`. + const MODULUS_OFFSET: u128; } /// Module trait for lattice-based algebraic structures @@ -73,24 +212,18 @@ pub trait Module: + for<'a> std::ops::Sub<&'a Self, Output = Self> { /// Scalar type (field/ring elements) - type Scalar: Field + type Scalar: FieldCore + + CanonicalField + + FieldSampling + std::ops::Mul + for<'a> std::ops::Mul<&'a Self, Output = Self>; /// Zero element fn zero() -> Self; - /// Addition - fn add(&self, rhs: &Self) -> Self; - - /// Negation - fn neg(&self) -> Self; - /// Scalar multiplication fn scale(&self, k: &Self::Scalar) -> Self; /// Generate random module element fn random(rng: &mut R) -> Self; } - -pub trait HachiRoutines {} diff --git a/src/primitives/poly.rs b/src/primitives/poly.rs index 1e9a8b63..6a4185b9 100644 --- a/src/primitives/poly.rs +++ b/src/primitives/poly.rs @@ -1,58 +1,6 @@ -//! Polynomial trait for multilinear polynomials +//! Multilinear polynomial utility functions. -use super::arithmetic::Field; - -/// Trait for multilinear Lagrange polynomial operations -pub trait MultilinearLagrange: Polynomial { - /// Compute multilinear Lagrange basis evaluations at a point - /// - /// For variables (r₀, r₁, ..., r_{n-1}), computes all 2^n basis polynomial evaluations. - /// The i-th basis polynomial evaluates to 1 at the i-th hypercube vertex and 0 elsewhere. - fn lagrange_basis(&self, output: &mut [F], point: &[F]) { - multilinear_lagrange_basis(output, point) - } - - /// Compute vector-matrix product: v = L^T * M - /// - /// Treats coefficients as a 2^nu × 2^sigma matrix. - /// For each column j: v\[j\] = Σ_i left_vec\[i\] * coefficients\[i\]\[j\] - fn vector_matrix_product(&self, left_vec: &[F], nu: usize, sigma: usize) -> Vec; - - /// Compute left and right vectors from evaluation point - /// - /// Given a point arranged for matrix evaluation, computes L and R such that: - /// polynomial_evaluation(point) = L^T × M × R - fn compute_evaluation_vectors(&self, point: &[F], nu: usize, sigma: usize) -> (Vec, Vec) { - compute_left_right_vectors(point, nu, sigma) - } -} - -/// Trait for multilinear polynomials -/// -/// Represents a polynomial in evaluation form (coefficients at hypercube points). -pub trait Polynomial { - /// Number of variables - fn num_vars(&self) -> usize; - - /// Total number of coefficients (2^num_vars) - fn len(&self) -> usize { - 1 << self.num_vars() - } - - /// Check if polynomial is empty - fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Evaluate polynomial at a point - /// - /// # Parameters - /// - `point`: Evaluation point (length must equal num_vars) - /// - /// # Returns - /// Polynomial evaluation result - fn evaluate(&self, point: &[F]) -> F; -} +use super::arithmetic::FieldCore; /// Compute multilinear Lagrange basis evaluations at a point /// @@ -62,7 +10,7 @@ pub trait Polynomial { /// Uses an iterative doubling approach: /// - Start with [1-r₀, r₀] /// - For each variable rᵢ, split each value v into [v*(1-rᵢ), v*rᵢ] -pub(crate) fn multilinear_lagrange_basis(output: &mut [F], point: &[F]) { +pub(crate) fn multilinear_lagrange_basis(output: &mut [F], point: &[F]) { assert!( output.len() <= (1 << point.len()), "Output length must be at most 2^point.len()" @@ -115,7 +63,7 @@ pub(crate) fn multilinear_lagrange_basis(output: &mut [F], point: &[F] /// polynomial_evaluation(point) = L^T × M × R /// /// Splits variables between rows and columns based on sigma and nu. -pub fn compute_left_right_vectors( +pub fn compute_left_right_vectors( point: &[F], nu: usize, sigma: usize, diff --git a/src/primitives/serialization.rs b/src/primitives/serialization.rs index f4a0bf1e..ea14bdc3 100644 --- a/src/primitives/serialization.rs +++ b/src/primitives/serialization.rs @@ -182,11 +182,43 @@ mod primitive_impls { impl_primitive_serialization!(u16, 2); impl_primitive_serialization!(u32, 4); impl_primitive_serialization!(u64, 8); - impl_primitive_serialization!(usize, std::mem::size_of::()); + impl_primitive_serialization!(u128, 16); impl_primitive_serialization!(i8, 1); impl_primitive_serialization!(i16, 2); impl_primitive_serialization!(i32, 4); impl_primitive_serialization!(i64, 8); + impl_primitive_serialization!(i128, 16); + + impl Valid for usize { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } + } + + impl HachiSerialize for usize { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (*self as u64).serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 8 + } + } + + impl HachiDeserialize for usize { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let val = u64::deserialize_with_mode(reader, compress, validate)?; + Ok(val as usize) + } + } impl Valid for bool { fn check(&self) -> Result<(), SerializationError> { diff --git a/src/primitives/transcript.rs b/src/primitives/transcript.rs index 89ebc29a..ce2e8f53 100644 --- a/src/primitives/transcript.rs +++ b/src/primitives/transcript.rs @@ -2,13 +2,13 @@ #![allow(missing_docs)] -use crate::primitives::arithmetic::{Field, Module}; +use crate::primitives::arithmetic::{CanonicalField, FieldCore, Module}; use crate::primitives::HachiSerialize; /// Transcript for Fiat-Shamir transformations pub trait Transcript { /// Field type for challenges - type Field: Field; + type Field: FieldCore + CanonicalField; /// Append raw bytes to the transcript fn append_bytes(&mut self, label: &[u8], bytes: &[u8]); diff --git a/src/protocol/challenges/mod.rs b/src/protocol/challenges/mod.rs new file mode 100644 index 00000000..524941ed --- /dev/null +++ b/src/protocol/challenges/mod.rs @@ -0,0 +1,6 @@ +//! Protocol-level Fiat–Shamir challenge samplers. +//! +//! These utilities derive structured challenges (e.g. sparse ring elements) from +//! the transcript while keeping the low-level representations in the algebra layer. + +pub mod sparse; diff --git a/src/protocol/challenges/sparse.rs b/src/protocol/challenges/sparse.rs new file mode 100644 index 00000000..aca54aaf --- /dev/null +++ b/src/protocol/challenges/sparse.rs @@ -0,0 +1,121 @@ +//! Sparse challenge sampling via Fiat–Shamir. + +use crate::algebra::ring::{CyclotomicRing, SparseChallenge, SparseChallengeConfig}; +use crate::error::HachiError; +use crate::protocol::transcript::labels::{ABSORB_SPARSE_CHALLENGE, CHALLENGE_SPARSE_CHALLENGE}; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; + +/// Sample a sparse ring challenge (exact weight ω) from a transcript. +/// +/// This is intentionally deterministic and label-aware: +/// - first we absorb the sampling context under `ABSORB_SPARSE_CHALLENGE`, +/// - then we derive as many `CHALLENGE_SPARSE_CHALLENGE` scalars as needed. +/// +/// Notes: +/// - Indices are sampled with a simple `mod D` reduction. For the intended +/// regimes (small `D`, cryptographic transcript), any bias is negligible. +/// - Duplicate indices are rejected to enforce exact Hamming weight. +/// +/// # Errors +/// +/// Returns an error if the provided config is invalid for degree `D`. +pub fn sparse_challenge_from_transcript( + transcript: &mut T, + label: &[u8], + instance_idx: u64, + cfg: &SparseChallengeConfig, +) -> Result +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + cfg.validate::() + .map_err(|e| HachiError::InvalidInput(format!("invalid sparse challenge config: {e}")))?; + + // Absorb domain-separating context so different call sites can't collide. + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, label); + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, &instance_idx.to_le_bytes()); + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, &(D as u64).to_le_bytes()); + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, &(cfg.weight as u64).to_le_bytes()); + // Include the coefficient alphabet (as little-endian i16 stream). + let mut coeff_bytes = Vec::with_capacity(cfg.nonzero_coeffs.len() * 2); + for &c in cfg.nonzero_coeffs.iter() { + coeff_bytes.extend_from_slice(&c.to_le_bytes()); + } + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, &coeff_bytes); + + let mut seen = vec![false; D]; + let mut positions = Vec::with_capacity(cfg.weight); + let mut coeffs = Vec::with_capacity(cfg.weight); + + while positions.len() < cfg.weight { + let r = transcript + .challenge_scalar(CHALLENGE_SPARSE_CHALLENGE) + .to_canonical_u128(); + let lo = r as u64; + let hi = (r >> 64) as u64; + + let pos = (lo % (D as u64)) as usize; + if seen[pos] { + continue; + } + seen[pos] = true; + positions.push(pos as u32); + + let coeff_idx = (hi % (cfg.nonzero_coeffs.len() as u64)) as usize; + let c = cfg.nonzero_coeffs[coeff_idx]; + debug_assert_ne!(c, 0); + coeffs.push(c); + } + + Ok(SparseChallenge { positions, coeffs }) +} + +/// Sample `n` sparse challenges from a transcript, returning the sparse +/// representation directly. +/// +/// # Errors +/// +/// Returns an error if challenge sampling fails. +pub fn sample_sparse_challenges( + transcript: &mut T, + label: &[u8], + n: usize, + cfg: &SparseChallengeConfig, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + (0..n) + .map(|i| sparse_challenge_from_transcript::(transcript, label, i as u64, cfg)) + .collect() +} + +/// Sample `n` sparse challenges from a transcript and convert them to dense +/// `CyclotomicRing` elements. +/// +/// # Errors +/// +/// Returns an error if challenge sampling or dense conversion fails. +pub fn sample_dense_challenges( + transcript: &mut T, + label: &[u8], + n: usize, + cfg: &SparseChallengeConfig, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + (0..n) + .map(|i| { + let sparse = + sparse_challenge_from_transcript::(transcript, label, i as u64, cfg)?; + sparse + .to_dense::() + .map_err(|e| HachiError::InvalidInput(e.to_string())) + }) + .collect() +} diff --git a/src/protocol/commitment/commit.rs b/src/protocol/commitment/commit.rs new file mode 100644 index 00000000..8bfaa21c --- /dev/null +++ b/src/protocol/commitment/commit.rs @@ -0,0 +1,1106 @@ +//! Ring-native §4.1 commitment core implementation. + +use super::config::{ + ensure_block_layout, ensure_matrix_shape_ge, ensure_supported_num_vars, + validate_and_derive_layout, HachiCommitmentLayout, +}; +use super::onehot::{inner_ajtai_onehot_wide, map_onehot_to_sparse_blocks}; +use super::scheme::{CommitWitness, RingCommitmentScheme}; +use super::types::RingCommitment; +#[cfg(feature = "disk-persistence")] +use super::utils::crt_ntt::build_ntt_slots; +use super::utils::crt_ntt::{build_ntt_slot, NttSlotCache}; +use super::utils::flat_matrix::FlatMatrix; +use super::utils::linear::{ + decompose_rows_i8, flatten_i8_blocks, mat_vec_mul_ntt_i8, mat_vec_mul_ntt_single_i8, +}; +use super::utils::matrix::{derive_public_matrix, sample_public_matrix_seed, PublicMatrixSeed}; +use super::CommitmentConfig; +use crate::algebra::fields::wide::HasWide; +use crate::algebra::CyclotomicRing; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::protocol::hachi_poly_ops::OneHotIndex; +use crate::protocol::ring_switch::w_commitment_layout; +use crate::{cfg_into_iter, cfg_iter, CanonicalField, FieldCore, FieldSampling}; +#[cfg(feature = "disk-persistence")] +use std::fs; +use std::io::{Read, Write}; +#[cfg(feature = "disk-persistence")] +use std::path::PathBuf; + +/// Seed-only stage for deterministic setup expansion. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiSetupSeed { + /// Maximum supported variable count. + pub max_num_vars: usize, + /// Runtime commitment layout. + pub layout: HachiCommitmentLayout, + /// Public seed used to derive commitment matrices. + pub public_matrix_seed: PublicMatrixSeed, +} + +/// Expanded setup stage containing coefficient-form matrices stored as +/// D-agnostic flat field-element arrays. +/// +/// The same `HachiExpandedSetup` can be viewed at different ring dimensions by +/// calling [`FlatMatrix::view`] or [`FlatMatrix::row`] with the desired +/// const-generic `D`. +#[allow(non_snake_case)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiExpandedSetup { + /// Setup seed and runtime layout metadata. + pub seed: HachiSetupSeed, + /// Inner matrix `A`. + pub A: FlatMatrix, + /// Outer matrix `B`. + pub B: FlatMatrix, + /// Prover matrix `D ∈ R_q^{n_D × δ·2^R}` (§4.2). + pub D_mat: FlatMatrix, +} + +/// Prover setup artifact (expanded setup + per-matrix NTT caches). +/// +/// The NTT caches are tied to a specific ring dimension D. +#[allow(non_snake_case)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiProverSetup { + /// Expanded matrix stage used by both prover and verifier. + pub expanded: HachiExpandedSetup, + /// NTT cache for the A matrix. + pub ntt_A: NttSlotCache, + /// NTT cache for the B matrix. + pub ntt_B: NttSlotCache, + /// NTT cache for the D matrix. + pub ntt_D: NttSlotCache, +} + +/// Verifier setup artifact derived from prover setup. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiVerifierSetup { + /// Expanded matrix stage used for verification. + pub expanded: HachiExpandedSetup, +} + +impl HachiExpandedSetup { + /// Runtime layout carried by this setup (the max-dimension layout). + pub fn layout(&self) -> HachiCommitmentLayout { + self.seed.layout + } +} + +impl HachiProverSetup { + /// Runtime layout carried by this setup (the max-dimension layout). + pub fn layout(&self) -> HachiCommitmentLayout { + self.expanded.layout() + } + + /// Panic if `layout`'s matrix dimensions exceed this setup's maximums. + /// + /// # Panics + /// + /// Panics if any of `layout`'s matrix widths (inner, outer, D) exceed + /// those of this setup. + pub fn assert_layout_fits(&self, layout: &HachiCommitmentLayout) { + let max = &self.expanded.seed.layout; + assert!( + layout.inner_width <= max.inner_width, + "A matrix too narrow: need {} but setup has {}", + layout.inner_width, + max.inner_width + ); + assert!( + layout.outer_width <= max.outer_width, + "B matrix too narrow: need {} but setup has {}", + layout.outer_width, + max.outer_width + ); + assert!( + layout.d_matrix_width <= max.d_matrix_width, + "D matrix too narrow: need {} but setup has {}", + layout.d_matrix_width, + max.d_matrix_width + ); + } +} + +impl Valid for HachiSetupSeed { + fn check(&self) -> Result<(), SerializationError> { + self.layout.check() + } +} + +impl HachiSerialize for HachiSetupSeed { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.max_num_vars + .serialize_with_mode(&mut writer, compress)?; + self.layout.serialize_with_mode(&mut writer, compress)?; + writer.write_all(&self.public_matrix_seed)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.max_num_vars.serialized_size(compress) + self.layout.serialized_size(compress) + 32 + } +} + +impl HachiDeserialize for HachiSetupSeed { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let max_num_vars = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let layout = HachiCommitmentLayout::deserialize_with_mode(&mut reader, compress, validate)?; + let mut public_matrix_seed = [0u8; 32]; + reader.read_exact(&mut public_matrix_seed)?; + let out = Self { + max_num_vars, + layout, + public_matrix_seed, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl Valid for HachiExpandedSetup { + fn check(&self) -> Result<(), SerializationError> { + self.seed.check()?; + self.A.check()?; + self.B.check()?; + self.D_mat.check()?; + Ok(()) + } +} + +impl HachiSerialize for HachiExpandedSetup { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.seed.serialize_with_mode(&mut writer, compress)?; + self.A.serialize_with_mode(&mut writer, compress)?; + self.B.serialize_with_mode(&mut writer, compress)?; + self.D_mat.serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.seed.serialized_size(compress) + + self.A.serialized_size(compress) + + self.B.serialized_size(compress) + + self.D_mat.serialized_size(compress) + } +} + +impl HachiDeserialize for HachiExpandedSetup { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let out = Self { + seed: HachiSetupSeed::deserialize_with_mode(&mut reader, compress, validate)?, + A: FlatMatrix::deserialize_with_mode(&mut reader, compress, validate)?, + B: FlatMatrix::deserialize_with_mode(&mut reader, compress, validate)?, + D_mat: FlatMatrix::deserialize_with_mode(&mut reader, compress, validate)?, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl Valid for HachiProverSetup { + fn check(&self) -> Result<(), SerializationError> { + self.expanded.check() + } +} + +impl HachiSerialize for HachiProverSetup { + fn serialize_with_mode( + &self, + _writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + Err(SerializationError::InvalidData( + "HachiProverSetup contains runtime NTT caches and is not serializable".into(), + )) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 0 + } +} + +impl Valid for HachiVerifierSetup { + fn check(&self) -> Result<(), SerializationError> { + self.expanded.check() + } +} + +impl HachiSerialize for HachiVerifierSetup { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.expanded.serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.expanded.serialized_size(compress) + } +} + +impl HachiDeserialize for HachiVerifierSetup { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + Ok(Self { + expanded: HachiExpandedSetup::deserialize_with_mode(reader, compress, validate)?, + }) + } +} + +#[cfg(feature = "disk-persistence")] +fn get_storage_path(max_num_vars: usize) -> Option { + let cache_directory = if let Ok(local_app_data) = std::env::var("LOCALAPPDATA") { + Some(PathBuf::from(local_app_data)) + } else if let Ok(home) = std::env::var("HOME") { + let mut path = PathBuf::from(&home); + let macos_cache = { + let mut test_path = PathBuf::from(&home); + test_path.push("Library"); + test_path.push("Caches"); + test_path.exists() + }; + if macos_cache { + path.push("Library"); + path.push("Caches"); + } else { + path.push(".cache"); + } + Some(path) + } else { + None + }; + + cache_directory.map(|mut path| { + path.push("hachi"); + path.push(format!("hachi_{max_num_vars}.setup")); + path + }) +} + +#[cfg(feature = "disk-persistence")] +fn save_expanded_setup(setup: &HachiExpandedSetup, max_num_vars: usize) { + let Some(storage_path) = get_storage_path(max_num_vars) else { + tracing::warn!("Could not determine storage directory; skipping setup save"); + return; + }; + + if let Some(parent) = storage_path.parent() { + fs::create_dir_all(parent) + .unwrap_or_else(|e| panic!("Failed to create storage directory: {e}")); + } + + tracing::info!("Saving setup to {}", storage_path.display()); + + let file = fs::File::create(&storage_path) + .unwrap_or_else(|e| panic!("Failed to create setup file: {e}")); + let mut writer = std::io::BufWriter::new(file); + + setup + .serialize_compressed(&mut writer) + .unwrap_or_else(|e| panic!("Failed to serialize setup: {e}")); + + tracing::info!("Successfully saved setup to disk"); +} + +#[cfg(feature = "disk-persistence")] +fn load_expanded_setup( + max_num_vars: usize, +) -> Result, HachiError> { + let storage_path = get_storage_path(max_num_vars).ok_or_else(|| { + HachiError::InvalidSetup("Failed to determine storage directory".to_string()) + })?; + + if !storage_path.exists() { + return Err(HachiError::InvalidSetup(format!( + "Setup file not found at {}", + storage_path.display() + ))); + } + + tracing::info!("Loading setup from {}", storage_path.display()); + + let file = fs::File::open(&storage_path) + .map_err(|e| HachiError::InvalidSetup(format!("Failed to open setup file: {e}")))?; + let mut reader = std::io::BufReader::new(file); + + let setup = HachiExpandedSetup::deserialize_compressed(&mut reader) + .map_err(|e| HachiError::InvalidSetup(format!("Failed to deserialize setup: {e}")))?; + + tracing::info!("Loaded setup for max_num_vars={max_num_vars}"); + Ok(setup) +} + +/// Build prover and verifier setup from a pre-existing expanded setup by +/// reconstructing the NTT caches. +#[cfg(feature = "disk-persistence")] +pub(crate) fn setup_from_expanded( + expanded: HachiExpandedSetup, +) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> { + let (ntt_a, ntt_b, ntt_d) = build_ntt_slots( + expanded.A.view::(), + expanded.B.view::(), + expanded.D_mat.view::(), + )?; + let prover_setup = HachiProverSetup { + expanded: expanded.clone(), + ntt_A: ntt_a, + ntt_B: ntt_b, + ntt_D: ntt_d, + }; + let verifier_setup = HachiVerifierSetup { expanded }; + Ok((prover_setup, verifier_setup)) +} + +/// Concrete §4.1 commitment core. +#[derive(Clone, Copy, Default)] +pub struct HachiCommitmentCore; + +impl RingCommitmentScheme for HachiCommitmentCore +where + F: FieldCore + CanonicalField + FieldSampling + HasWide + Valid, + Cfg: CommitmentConfig, +{ + type ProverSetup = HachiProverSetup; + type VerifierSetup = HachiVerifierSetup; + type Commitment = RingCommitment; + + #[tracing::instrument(skip_all, name = "RingCommitmentScheme::setup")] + fn setup(max_num_vars: usize) -> Result<(Self::ProverSetup, Self::VerifierSetup), HachiError> { + let layout = validate_and_derive_layout::(max_num_vars)?; + ensure_supported_num_vars(max_num_vars, layout.required_num_vars::()?)?; + + #[cfg(feature = "disk-persistence")] + { + match load_expanded_setup::(max_num_vars) { + Ok(expanded) => { + tracing::info!("Loaded setup from disk, rebuilding NTT caches"); + return setup_from_expanded(expanded); + } + Err(HachiError::InvalidSetup(msg)) if msg.contains("not found") => { + tracing::debug!("Setup file not found, will generate new one"); + } + Err(e) => { + panic!("Failed to load setup from disk: {e}"); + } + } + } + + let w_layout = w_commitment_layout::(layout)?; + let a_cols = layout.inner_width.max(w_layout.inner_width); + let b_cols = layout.outer_width.max(w_layout.outer_width); + let d_cols = layout.d_matrix_width.max(w_layout.d_matrix_width); + + let public_matrix_seed = sample_public_matrix_seed(); + let a_matrix = derive_public_matrix::(Cfg::N_A, a_cols, &public_matrix_seed, b"A"); + let b_matrix = derive_public_matrix::(Cfg::N_B, b_cols, &public_matrix_seed, b"B"); + let d_matrix = derive_public_matrix::(Cfg::N_D, d_cols, &public_matrix_seed, b"D"); + + let a_flat = FlatMatrix::from_ring_matrix(&a_matrix); + let b_flat = FlatMatrix::from_ring_matrix(&b_matrix); + let d_flat = FlatMatrix::from_ring_matrix(&d_matrix); + + let ntt_a = build_ntt_slot(a_flat.view::())?; + let ntt_b = build_ntt_slot(b_flat.view::())?; + let ntt_d = build_ntt_slot(d_flat.view::())?; + let expanded = HachiExpandedSetup { + seed: HachiSetupSeed { + max_num_vars, + layout, + public_matrix_seed, + }, + A: a_flat, + B: b_flat, + D_mat: d_flat, + }; + + #[cfg(feature = "disk-persistence")] + save_expanded_setup(&expanded, max_num_vars); + + let prover_setup = HachiProverSetup { + expanded: expanded.clone(), + ntt_A: ntt_a, + ntt_B: ntt_b, + ntt_D: ntt_d, + }; + let verifier_setup = HachiVerifierSetup { expanded }; + ensure_matrix_shape_ge::( + &prover_setup.expanded.A, + Cfg::N_A, + layout.inner_width, + "A", + )?; + ensure_matrix_shape_ge::( + &prover_setup.expanded.B, + Cfg::N_B, + layout.outer_width, + "B", + )?; + ensure_matrix_shape_ge::( + &prover_setup.expanded.D_mat, + Cfg::N_D, + layout.d_matrix_width, + "D", + )?; + Ok((prover_setup, verifier_setup)) + } + + fn layout(setup: &Self::ProverSetup) -> Result { + Ok(setup.layout()) + } + + #[tracing::instrument(skip_all, name = "RingCommitmentScheme::commit_ring_blocks")] + fn commit_ring_blocks( + f_blocks: &[Vec>], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let layout = setup.layout(); + ensure_supported_num_vars( + setup.expanded.seed.max_num_vars, + layout.required_num_vars::()?, + )?; + ensure_block_layout(f_blocks, layout)?; + ensure_matrix_shape_ge::(&setup.expanded.A, Cfg::N_A, layout.inner_width, "A")?; + ensure_matrix_shape_ge::(&setup.expanded.B, Cfg::N_B, layout.outer_width, "B")?; + + let depth_commit = layout.num_digits_commit; + let depth_open = layout.num_digits_open; + let log_basis = layout.log_basis; + let block_slices: Vec<&[CyclotomicRing]> = + f_blocks.iter().map(|b| b.as_slice()).collect(); + let t_all = mat_vec_mul_ntt_i8(&setup.ntt_A, &block_slices, depth_commit, log_basis); + let t_hat_all: Vec> = cfg_into_iter!(t_all) + .map(|t_i| decompose_rows_i8(&t_i, depth_open, log_basis)) + .collect(); + + let t_hat_flat = flatten_i8_blocks(&t_hat_all); + + let u: Vec> = mat_vec_mul_ntt_single_i8(&setup.ntt_B, &t_hat_flat); + Ok(CommitWitness::new(RingCommitment { u }, t_hat_all)) + } + + #[tracing::instrument(skip_all, name = "RingCommitmentScheme::commit_coeffs")] + fn commit_coeffs( + f_coeffs: &[CyclotomicRing], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let layout = setup.layout(); + let num_blocks = layout.num_blocks; + let block_len = layout.block_len; + let max_len = num_blocks + .checked_mul(block_len) + .ok_or_else(|| HachiError::InvalidSetup("coefficient length overflow".to_string()))?; + if f_coeffs.len() > max_len { + return Err(HachiError::InvalidSize { + expected: max_len, + actual: f_coeffs.len(), + }); + } + + let depth_commit = layout.num_digits_commit; + let depth_open = layout.num_digits_open; + let log_basis = layout.log_basis; + let coeff_len = f_coeffs.len(); + + let block_slices: Vec<&[CyclotomicRing]> = (0..num_blocks) + .map(|i| { + let start = i * block_len; + if start >= coeff_len { + &[] as &[CyclotomicRing] + } else { + &f_coeffs[start..(start + block_len).min(coeff_len)] + } + }) + .collect(); + + let t_all = mat_vec_mul_ntt_i8(&setup.ntt_A, &block_slices, depth_commit, log_basis); + let t_hat_all: Vec> = cfg_into_iter!(t_all) + .map(|t_i| decompose_rows_i8(&t_i, depth_open, log_basis)) + .collect(); + + let t_hat_flat = flatten_i8_blocks(&t_hat_all); + + let u: Vec> = mat_vec_mul_ntt_single_i8(&setup.ntt_B, &t_hat_flat); + Ok(CommitWitness::new(RingCommitment { u }, t_hat_all)) + } + + #[tracing::instrument(skip_all, name = "RingCommitmentScheme::commit_onehot")] + fn commit_onehot( + onehot_k: usize, + indices: &[Option], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let layout = setup.layout(); + ensure_supported_num_vars( + setup.expanded.seed.max_num_vars, + layout.required_num_vars::()?, + )?; + ensure_matrix_shape_ge::(&setup.expanded.A, Cfg::N_A, layout.inner_width, "A")?; + ensure_matrix_shape_ge::(&setup.expanded.B, Cfg::N_B, layout.outer_width, "B")?; + + let sparse_blocks = + map_onehot_to_sparse_blocks(onehot_k, indices, layout.r_vars, layout.m_vars, D)?; + + let depth_commit = layout.num_digits_commit; + let depth_open = layout.num_digits_open; + let log_basis = layout.log_basis; + let zero_block_len = Cfg::N_A.checked_mul(depth_open).unwrap(); + let a_view = setup.expanded.A.view::(); + let block_len = layout.block_len; + + let t_hat_all: Vec> = cfg_iter!(sparse_blocks) + .map(|block_entries| { + if block_entries.is_empty() { + vec![[0i8; D]; zero_block_len] + } else { + let t_i = + inner_ajtai_onehot_wide(&a_view, block_entries, block_len, depth_commit); + decompose_rows_i8(&t_i, depth_open, log_basis) + } + }) + .collect(); + + let t_hat_flat = flatten_i8_blocks(&t_hat_all); + + let u: Vec> = mat_vec_mul_ntt_single_i8(&setup.ntt_B, &t_hat_flat); + Ok(CommitWitness::new(RingCommitment { u }, t_hat_all)) + } +} + +impl HachiCommitmentCore { + fn layout_envelope( + max_num_vars: usize, + inner_width: usize, + outer_width: usize, + d_matrix_width: usize, + preferred_r_vars: usize, + num_digits_open: usize, + num_digits_fold: usize, + log_basis: u32, + ) -> Result { + let alpha = D.trailing_zeros() as usize; + let outer_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + let r_vars = preferred_r_vars.min(outer_vars); + let m_vars = outer_vars - r_vars; + let num_blocks = 1usize + .checked_shl(r_vars as u32) + .ok_or_else(|| HachiError::InvalidSetup("num_blocks overflow".to_string()))?; + let block_len = 1usize + .checked_shl(m_vars as u32) + .ok_or_else(|| HachiError::InvalidSetup("block_len overflow".to_string()))?; + + Ok(HachiCommitmentLayout { + m_vars, + r_vars, + num_blocks, + block_len, + inner_width, + outer_width, + d_matrix_width, + // Setup metadata only tracks width envelopes; runtime commits/proofs + // carry their own exact decomposition parameters. + num_digits_commit: 1, + num_digits_open, + num_digits_fold, + log_basis, + }) + } + + /// Create a setup with a caller-specified layout, bypassing + /// `CommitmentConfig::commitment_layout`. + /// + /// Use this when the desired `(m_vars, r_vars)` split differs from what + /// the config's heuristic would choose (e.g. mega-polynomial commitments + /// where each sub-polynomial occupies one block). + /// + /// # Errors + /// + /// Returns `HachiError` on invalid layout or matrix generation failures. + pub fn setup_with_layout( + layout: HachiCommitmentLayout, + ) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + let max_num_vars = layout.required_num_vars::()?; + let public_matrix_seed = sample_public_matrix_seed(); + Self::setup_with_layout_and_seed::(layout, max_num_vars, public_matrix_seed) + } + + /// Create a setup that supports any of the provided runtime layouts. + /// + /// This sizes the public matrices from the exact per-layout maxima + /// (including recursive `w` commitments) instead of inflating through a + /// synthetic max layout. + /// + /// # Errors + /// + /// Returns `HachiError` if `layouts` is empty, uses inconsistent + /// decomposition parameters, or matrix generation fails. + pub fn setup_with_layouts( + layouts: &[HachiCommitmentLayout], + ) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + let Some((&first_layout, _)) = layouts.split_first() else { + return Err(HachiError::InvalidSetup( + "setup_with_layouts requires at least one layout".to_string(), + )); + }; + + let mut max_num_vars = 0usize; + let mut max_inner_width = 0usize; + let mut max_outer_width = 0usize; + let mut max_d_matrix_width = 0usize; + let mut max_r_vars = 0usize; + let mut max_num_digits_open = 0usize; + let mut max_num_digits_fold = 0usize; + + for &layout in layouts { + if layout.log_basis != first_layout.log_basis { + return Err(HachiError::InvalidSetup(format!( + "setup_with_layouts requires a shared log_basis (expected {}, got {})", + first_layout.log_basis, layout.log_basis + ))); + } + + max_num_vars = max_num_vars.max(layout.required_num_vars::()?); + max_inner_width = max_inner_width.max(layout.inner_width); + max_outer_width = max_outer_width.max(layout.outer_width); + max_d_matrix_width = max_d_matrix_width.max(layout.d_matrix_width); + max_r_vars = max_r_vars.max(layout.r_vars); + max_num_digits_open = max_num_digits_open.max(layout.num_digits_open); + max_num_digits_fold = max_num_digits_fold.max(layout.num_digits_fold); + + let w_layout = w_commitment_layout::(layout)?; + if std::env::var_os("HACHI_SETUP_DIAGNOSTICS").is_some() { + eprintln!("[hachi setup] layout={layout:?}"); + eprintln!("[hachi setup] w_layout={w_layout:?}"); + } + max_inner_width = max_inner_width.max(w_layout.inner_width); + max_outer_width = max_outer_width.max(w_layout.outer_width); + max_d_matrix_width = max_d_matrix_width.max(w_layout.d_matrix_width); + max_r_vars = max_r_vars.max(w_layout.r_vars); + max_num_digits_open = max_num_digits_open.max(w_layout.num_digits_open); + max_num_digits_fold = max_num_digits_fold.max(w_layout.num_digits_fold); + } + + let envelope_layout = Self::layout_envelope::( + max_num_vars, + max_inner_width, + max_outer_width, + max_d_matrix_width, + max_r_vars, + max_num_digits_open, + max_num_digits_fold, + first_layout.log_basis, + )?; + if std::env::var_os("HACHI_SETUP_DIAGNOSTICS").is_some() { + eprintln!("[hachi setup] envelope_layout={envelope_layout:?}"); + eprintln!("[hachi setup] max_num_vars={max_num_vars}"); + } + let public_matrix_seed = sample_public_matrix_seed(); + Self::setup_with_matrix_widths_and_seed::( + envelope_layout, + max_num_vars, + public_matrix_seed, + max_inner_width, + max_outer_width, + max_d_matrix_width, + ) + } + + /// Like `setup_with_layout` but reuses an existing setup's random seed and + /// A matrix (which depends only on `m_vars`). Only regenerates B and D + /// matrices for the new `r_vars`. + /// + /// Use this when creating a mega-polynomial setup that shares `m_vars` with + /// an individual polynomial setup — avoids re-deriving and NTT-transforming + /// the A matrix. + /// + /// # Errors + /// + /// Returns `HachiError` if the new layout is incompatible with the existing + /// setup or matrix shapes are inconsistent. + pub fn setup_from_existing( + existing: &HachiExpandedSetup, + new_r_vars: usize, + ) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + let old_layout = existing.seed.layout; + let new_layout = HachiCommitmentLayout::new::( + old_layout.m_vars, + new_r_vars, + &Cfg::decomposition(), + )?; + + if new_layout.inner_width != old_layout.inner_width { + return Err(HachiError::InvalidSetup( + "setup_from_existing requires matching m_vars/inner_width".to_string(), + )); + } + + let w_layout = w_commitment_layout::(new_layout)?; + let a_width = existing.A.first_row_len::(); + if a_width < w_layout.inner_width { + return Err(HachiError::InvalidSetup(format!( + "existing A width {a_width} < w inner_width {}", + w_layout.inner_width + ))); + } + let b_cols = new_layout.outer_width.max(w_layout.outer_width); + + let max_num_vars = new_layout.required_num_vars::()?; + let seed = existing.seed.public_matrix_seed; + + let d_cols = new_layout.d_matrix_width.max(w_layout.d_matrix_width); + let b_matrix = derive_public_matrix::(Cfg::N_B, b_cols, &seed, b"B"); + let d_matrix = derive_public_matrix::(Cfg::N_D, d_cols, &seed, b"D"); + + let b_flat = FlatMatrix::from_ring_matrix(&b_matrix); + let d_flat = FlatMatrix::from_ring_matrix(&d_matrix); + + let ntt_a = build_ntt_slot(existing.A.view::())?; + let ntt_b = build_ntt_slot(b_flat.view::())?; + let ntt_d = build_ntt_slot(d_flat.view::())?; + let expanded = HachiExpandedSetup { + seed: HachiSetupSeed { + max_num_vars, + layout: new_layout, + public_matrix_seed: seed, + }, + A: existing.A.clone(), + B: b_flat, + D_mat: d_flat, + }; + let prover_setup = HachiProverSetup { + expanded: expanded.clone(), + ntt_A: ntt_a, + ntt_B: ntt_b, + ntt_D: ntt_d, + }; + let verifier_setup = HachiVerifierSetup { expanded }; + Ok((prover_setup, verifier_setup)) + } + + fn setup_with_layout_and_seed( + layout: HachiCommitmentLayout, + max_num_vars: usize, + public_matrix_seed: PublicMatrixSeed, + ) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + let w_layout = w_commitment_layout::(layout)?; + let a_cols = layout.inner_width.max(w_layout.inner_width); + let b_cols = layout.outer_width.max(w_layout.outer_width); + let d_cols = layout.d_matrix_width.max(w_layout.d_matrix_width); + + Self::setup_with_matrix_widths_and_seed::( + layout, + max_num_vars, + public_matrix_seed, + a_cols, + b_cols, + d_cols, + ) + } + + fn setup_with_matrix_widths_and_seed( + layout: HachiCommitmentLayout, + max_num_vars: usize, + public_matrix_seed: PublicMatrixSeed, + a_cols: usize, + b_cols: usize, + d_cols: usize, + ) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + if std::env::var_os("HACHI_SETUP_DIAGNOSTICS").is_some() { + let ring_bytes = std::mem::size_of::>(); + let a_raw_mb = (Cfg::N_A * a_cols * ring_bytes) as f64 / (1024.0_f64 * 1024.0_f64); + let b_raw_mb = (Cfg::N_B * b_cols * ring_bytes) as f64 / (1024.0_f64 * 1024.0_f64); + let d_raw_mb = (Cfg::N_D * d_cols * ring_bytes) as f64 / (1024.0_f64 * 1024.0_f64); + eprintln!( + "[hachi setup] a_cols={a_cols}, b_cols={b_cols}, d_cols={d_cols}, ring_bytes={ring_bytes}" + ); + eprintln!( + "[hachi setup] raw_matrix_mb: A={a_raw_mb:.1}, B={b_raw_mb:.1}, D={d_raw_mb:.1}, total={:.1}", + a_raw_mb + b_raw_mb + d_raw_mb + ); + } + let a_matrix = derive_public_matrix::(Cfg::N_A, a_cols, &public_matrix_seed, b"A"); + let b_matrix = derive_public_matrix::(Cfg::N_B, b_cols, &public_matrix_seed, b"B"); + let d_matrix = derive_public_matrix::(Cfg::N_D, d_cols, &public_matrix_seed, b"D"); + + let a_flat = FlatMatrix::from_ring_matrix(&a_matrix); + let b_flat = FlatMatrix::from_ring_matrix(&b_matrix); + let d_flat = FlatMatrix::from_ring_matrix(&d_matrix); + + let ntt_a = build_ntt_slot(a_flat.view::())?; + let ntt_b = build_ntt_slot(b_flat.view::())?; + let ntt_d = build_ntt_slot(d_flat.view::())?; + let expanded = HachiExpandedSetup { + seed: HachiSetupSeed { + max_num_vars, + layout, + public_matrix_seed, + }, + A: a_flat, + B: b_flat, + D_mat: d_flat, + }; + let prover_setup = HachiProverSetup { + expanded: expanded.clone(), + ntt_A: ntt_a, + ntt_B: ntt_b, + ntt_D: ntt_d, + }; + let verifier_setup = HachiVerifierSetup { expanded }; + ensure_matrix_shape_ge::( + &prover_setup.expanded.A, + Cfg::N_A, + layout.inner_width, + "A", + )?; + ensure_matrix_shape_ge::( + &prover_setup.expanded.B, + Cfg::N_B, + layout.outer_width, + "B", + )?; + ensure_matrix_shape_ge::( + &prover_setup.expanded.D_mat, + Cfg::N_D, + layout.d_matrix_width, + "D", + )?; + Ok((prover_setup, verifier_setup)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::primitives::{HachiDeserialize, HachiSerialize}; + use crate::test_utils::{TinyConfig, F as TestF}; + + #[test] + fn expanded_setup_roundtrips_and_derives_same_verifier() { + const TEST_D: usize = 64; + let (prover_setup, verifier_setup) = + >::setup(16) + .unwrap(); + + let mut bytes = Vec::new(); + prover_setup + .expanded + .serialize_compressed(&mut bytes) + .unwrap(); + let decoded = HachiExpandedSetup::::deserialize_compressed(&bytes[..]).unwrap(); + + assert_eq!(decoded, prover_setup.expanded); + + let derived_verifier = HachiVerifierSetup { + expanded: decoded.clone(), + }; + assert_eq!(derived_verifier, verifier_setup); + } + + #[test] + fn setup_with_layouts_uses_exact_width_envelope() { + const TEST_D: usize = 64; + + let layout_a = + HachiCommitmentLayout::new::(4, 2, &TinyConfig::decomposition()).unwrap(); + let layout_b = + HachiCommitmentLayout::new::(1, 6, &TinyConfig::decomposition()).unwrap(); + let w_layout_a = w_commitment_layout::(layout_a).unwrap(); + let w_layout_b = w_commitment_layout::(layout_b).unwrap(); + + let expected_inner = [ + layout_a.inner_width, + layout_b.inner_width, + w_layout_a.inner_width, + w_layout_b.inner_width, + ] + .into_iter() + .max() + .unwrap(); + let expected_outer = [ + layout_a.outer_width, + layout_b.outer_width, + w_layout_a.outer_width, + w_layout_b.outer_width, + ] + .into_iter() + .max() + .unwrap(); + let expected_d = [ + layout_a.d_matrix_width, + layout_b.d_matrix_width, + w_layout_a.d_matrix_width, + w_layout_b.d_matrix_width, + ] + .into_iter() + .max() + .unwrap(); + let expected_max_num_vars = [ + layout_a.required_num_vars::().unwrap(), + layout_b.required_num_vars::().unwrap(), + ] + .into_iter() + .max() + .unwrap(); + + let (setup, _) = HachiCommitmentCore::setup_with_layouts::(&[ + layout_a, layout_b, + ]) + .unwrap(); + let envelope = setup.layout(); + + assert_eq!(setup.expanded.seed.max_num_vars, expected_max_num_vars); + assert_eq!(envelope.inner_width, expected_inner); + assert_eq!(envelope.outer_width, expected_outer); + assert_eq!(envelope.d_matrix_width, expected_d); + assert_eq!(setup.expanded.A.first_row_len::(), expected_inner); + assert_eq!(setup.expanded.B.first_row_len::(), expected_outer); + assert_eq!(setup.expanded.D_mat.first_row_len::(), expected_d); + } + + #[cfg(feature = "disk-persistence")] + mod disk_persistence { + use super::*; + use std::fs; + + fn cleanup_setup_file(max_num_vars: usize) { + if let Some(path) = get_storage_path(max_num_vars) { + let _ = fs::remove_file(path); + } + } + + #[test] + fn save_and_load_roundtrips() { + const TEST_D: usize = 64; + const MAX_VARS: usize = 100; + + cleanup_setup_file(MAX_VARS); + + let (prover_setup, _) = + >::setup( + MAX_VARS, + ) + .unwrap(); + + let loaded = load_expanded_setup::(MAX_VARS).unwrap(); + assert_eq!(loaded, prover_setup.expanded); + + cleanup_setup_file(MAX_VARS); + } + + #[test] + fn setup_uses_cache_on_second_call() { + const TEST_D: usize = 64; + const MAX_VARS: usize = 101; + + cleanup_setup_file(MAX_VARS); + + let (first, _) = + >::setup( + MAX_VARS, + ) + .unwrap(); + + let (second, _) = + >::setup( + MAX_VARS, + ) + .unwrap(); + + assert_eq!(first.expanded, second.expanded); + + cleanup_setup_file(MAX_VARS); + } + + #[test] + fn ntt_caches_rebuilt_correctly_from_disk() { + use crate::algebra::CyclotomicRing; + + const TEST_D: usize = 64; + const MAX_VARS: usize = 102; + + cleanup_setup_file(MAX_VARS); + + let (fresh_setup, _) = + >::setup( + MAX_VARS, + ) + .unwrap(); + + let loaded_expanded = load_expanded_setup::(MAX_VARS).unwrap(); + let (disk_setup, _) = setup_from_expanded::(loaded_expanded).unwrap(); + + let layout = fresh_setup.layout(); + let num_coeffs = layout.num_blocks * layout.block_len; + let coeffs = vec![CyclotomicRing::::zero(); num_coeffs]; + + let fresh_commit = >::commit_coeffs(&coeffs, &fresh_setup) + .unwrap(); + let disk_commit = >::commit_coeffs(&coeffs, &disk_setup) + .unwrap(); + + assert_eq!(fresh_commit.commitment, disk_commit.commitment); + + cleanup_setup_file(MAX_VARS); + } + } +} diff --git a/src/protocol/commitment/config.rs b/src/protocol/commitment/config.rs new file mode 100644 index 00000000..7c666f49 --- /dev/null +++ b/src/protocol/commitment/config.rs @@ -0,0 +1,794 @@ +//! Configuration presets for ring-native commitment construction. + +use super::utils::flat_matrix::FlatMatrix; +use super::utils::math::checked_pow2; +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::FieldCore; +use std::io::{Read, Write}; + +/// Parameters controlling the gadget decomposition depth (called δ in the paper). +/// +/// The gadget base is `b = 2^log_basis`. Each ring coefficient with centered +/// magnitude fitting in `log_commit_bound` bits is decomposed into +/// `ceil(log_commit_bound / log_basis)` balanced digits in `[-b/2, b/2)`. +/// +/// Smaller `log_commit_bound` (when polynomial coefficients are known to be +/// small) yields fewer decomposition levels, which proportionally shrinks the +/// witness vector, the commitment matrices, and the proving cost. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct DecompositionParams { + /// Base-2 logarithm of the gadget base (e.g., 3 for base-8 digits in [-4, 3]). + pub log_basis: u32, + + /// Bit-width of the largest coefficient that the *commitment* decomposition + /// must represent. Controls the commitment-side decomposition depth (δ in + /// the paper): `num_digits = ceil(log_commit_bound / log_basis)`. + /// + /// The centered representation maps each coefficient `c ∈ [0, q)` to the + /// signed value in `(-q/2, q/2]`. A value of `k` means the signed magnitude + /// fits in `k` bits, i.e., lies in `[-2^(k-1), 2^(k-1) - 1]`. + /// + /// Examples: + /// - Binary (0/1) polynomials: 1 + /// - Already range-checked digits in `[-b/2, b/2)`: `log_basis` (one digit) + /// - Arbitrary Fp128 elements: 128 + pub log_commit_bound: u32, + + /// Bit-width of the largest coefficient that the *opening* decomposition + /// must represent (ŵ = G⁻¹(w_folded)). + /// + /// During opening, `fold_blocks` computes inner products with arbitrary + /// field-element weights, so the result always has full-field-size + /// coefficients regardless of the original `log_commit_bound`. When `None`, + /// defaults to `log_commit_bound` (correct when `log_commit_bound` already + /// covers the full field, e.g. 128). Set to the field modulus bit-width + /// when `log_commit_bound` is smaller (e.g. for recursive w commitments + /// where entries are small but fold products are not). + pub log_open_bound: Option, +} + +/// Compute the gadget decomposition depth (δ in the paper) from a +/// coefficient bit-width bound. +/// +/// Returns `ceil(log_bound / log_basis)`, with an extra level when the +/// balanced-digit range would not cover the full bound. +/// +/// # Panics +/// +/// Panics if `log_basis` is 0 or >= 128. +pub fn compute_num_digits(log_bound: u32, log_basis: u32) -> usize { + assert!(log_basis > 0 && log_basis < 128, "invalid log_basis"); + if log_bound == 0 { + return 1; + } + let mut levels = (log_bound as usize).div_ceil(log_basis as usize); + + // When levels * log_basis > log_bound (i.e., not exactly aligned), the + // balanced digit range (b/2-1) * (b^levels - 1)/(b-1) always exceeds + // 2^(log_bound-1) for b >= 4 (log_basis >= 2). Only check when aligned. + let total_bits = (levels as u32).saturating_mul(log_basis); + if total_bits <= log_bound { + let b: u128 = 1u128 << log_basis; + let half_b_minus_1 = b / 2 - 1; + let b_minus_1 = b - 1; + let mut b_pow = 1u128; + for _ in 0..levels { + b_pow = b_pow.saturating_mul(b); + } + let max_positive = half_b_minus_1.saturating_mul(b_pow.saturating_sub(1) / b_minus_1); + let required = if log_bound > 128 { + u128::MAX / 2 + } else if log_bound == 0 { + 0 + } else { + (1u128 << (log_bound - 1)).saturating_sub(1) + }; + if max_positive < required { + levels += 1; + } + } + levels.max(1) +} + +/// Compute the decomposition depth for the folded witness `z_pre` +/// (τ in the paper). +/// +/// The folded witness satisfies `||z_pre||_inf <= β` where +/// `β = 2^r_vars * challenge_weight * 2^(log_basis - 1)`. +/// Returns enough gadget levels to represent values up to `β`. +pub fn compute_num_digits_fold(r_vars: usize, challenge_weight: usize, log_basis: u32) -> usize { + let shift = r_vars + (log_basis as usize) - 1; + if shift >= 127 || challenge_weight == 0 { + return compute_num_digits(128, log_basis); + } + let beta = (challenge_weight as u128).saturating_mul(1u128 << shift); + if beta == 0 { + return 1; + } + let log_beta = 128 - beta.leading_zeros(); + compute_num_digits(log_beta, log_basis) +} + +/// Find the `(m_vars, r_vars)` split that minimizes the level-0 +/// witness-to-polynomial ratio for a given config. +/// +/// The witness ring element count is dominated by: +/// ```text +/// w ≈ 2^r · (δ_open + N_A · δ_commit) + 2^m · δ_commit · δ_fold(r) +/// ``` +/// Multiplying the ratio by `2^(m+r)` (constant for fixed `reduced_vars`) +/// gives an equivalent integer cost: +/// ```text +/// C1 · 2^r + δ_commit · δ_fold(r) · 2^m +/// ``` +/// where `C1 = δ_open + N_A · δ_commit`. This function searches all valid +/// `(m, r)` pairs for the minimum using pure integer arithmetic (no +/// floating-point), so it is safe to run inside a zkVM guest. +/// +/// For the full-field config (`δ_commit = 43`), z_pre dominates and the +/// result is near-balanced (`m ≈ r`). For narrow configs (`δ_commit = 1`), +/// the w_hat/t_hat term matters more and the result skews to `m ≈ r + 4`. +pub fn optimal_m_r_split(reduced_vars: usize) -> (usize, usize) { + // Guard: for S >= 53, shifts could overflow u64. Fall back to balanced + // split (this threshold is far beyond any practical polynomial size). + if reduced_vars <= 2 || reduced_vars >= 53 { + let r = reduced_vars / 2; + return (reduced_vars - r, r); + } + + let decomp = Cfg::decomposition(); + let open_bound = decomp.log_open_bound.unwrap_or(decomp.log_commit_bound); + let delta_open = compute_num_digits(open_bound, decomp.log_basis) as u64; + let delta_commit = compute_num_digits(decomp.log_commit_bound, decomp.log_basis) as u64; + let c1 = delta_open + Cfg::N_A as u64 * delta_commit; + + let mut best_r = reduced_vars / 2; + let mut best_cost = u64::MAX; + + for r in 1..reduced_vars { + let m = reduced_vars - r; + let delta_fold = compute_num_digits_fold( + r, + Cfg::challenge_weight_for_ring_dim(Cfg::D), + decomp.log_basis, + ) as u64; + let cost = c1 * (1u64 << r) + delta_commit * delta_fold * (1u64 << m); + if cost < best_cost { + best_cost = cost; + best_r = r; + } + } + + (reduced_vars - best_r, best_r) +} + +/// Runtime commitment layout authority for ring-native commitments. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct HachiCommitmentLayout { + /// Number of variables inside each committed block (`2^m_vars` entries). + pub m_vars: usize, + /// Number of block-select variables (`2^r_vars` blocks). + pub r_vars: usize, + /// Number of committed blocks (`2^r_vars`). + pub num_blocks: usize, + /// Number of ring elements per block (`2^m_vars`). + pub block_len: usize, + /// Width of inner matrix `A` (`block_len * num_digits_commit`). + pub inner_width: usize, + /// Width of outer matrix `B` (`n_a * num_digits_open * num_blocks`). + pub outer_width: usize, + /// Width of prover matrix `D` (`num_digits_open * num_blocks`). + pub d_matrix_width: usize, + /// Number of gadget decomposition levels for commitment-time coefficients + /// (δ_commit in the paper). Controls how the original polynomial + /// coefficients are decomposed into balanced base-b digits for the Ajtai + /// commitment. + pub num_digits_commit: usize, + /// Number of gadget decomposition levels for opening-time folded + /// evaluations (δ_open in the paper). Folding inner-products with + /// arbitrary field-element weights produces full-field-size coefficients, + /// so this equals `num_digits_commit` when `log_commit_bound` covers + /// the full field, and is larger otherwise (e.g. recursive w witnesses). + pub num_digits_open: usize, + /// Number of gadget decomposition levels for the folded witness `z_pre` + /// (τ in the paper). Derived from the L∞ bound on `z_pre`. + pub num_digits_fold: usize, + /// Base-2 logarithm of gadget decomposition base. + pub log_basis: u32, +} + +impl HachiCommitmentLayout { + /// Build a layout from `(m_vars, r_vars)`, config constants, and decomposition + /// parameters. + /// + /// `num_digits_fold` (τ) is auto-derived from the beta bound + /// (`r_vars`, `challenge_weight`, `log_basis`). + /// + /// # Errors + /// + /// Returns an error when powers or derived widths overflow. + pub fn new( + m_vars: usize, + r_vars: usize, + decomp: &DecompositionParams, + ) -> Result { + let depth_commit = compute_num_digits(decomp.log_commit_bound, decomp.log_basis); + let open_bound = decomp.log_open_bound.unwrap_or(decomp.log_commit_bound); + let depth_open = compute_num_digits(open_bound, decomp.log_basis); + let depth_fold = compute_num_digits_fold( + r_vars, + Cfg::challenge_weight_for_ring_dim(Cfg::D), + decomp.log_basis, + ); + Self::new_with_decomp( + m_vars, + r_vars, + Cfg::N_A, + depth_commit, + depth_open, + depth_fold, + decomp.log_basis, + ) + } + + /// Build a layout from explicit decomposition parameters (no config trait needed). + /// + /// # Errors + /// + /// Returns an error when parameters are invalid or derived widths overflow. + pub fn new_with_decomp( + m_vars: usize, + r_vars: usize, + n_a: usize, + num_digits_commit: usize, + num_digits_open: usize, + num_digits_fold: usize, + log_basis: u32, + ) -> Result { + if log_basis == 0 || log_basis >= 128 { + return Err(HachiError::InvalidSetup("invalid log_basis".to_string())); + } + let num_blocks = checked_pow2(r_vars)?; + let block_len = checked_pow2(m_vars)?; + let inner_width = block_len + .checked_mul(num_digits_commit) + .ok_or_else(|| HachiError::InvalidSetup("inner width overflow".to_string()))?; + let outer_width = n_a + .checked_mul(num_digits_open) + .and_then(|x| x.checked_mul(num_blocks)) + .ok_or_else(|| HachiError::InvalidSetup("outer width overflow".to_string()))?; + let d_matrix_width = num_digits_open + .checked_mul(num_blocks) + .ok_or_else(|| HachiError::InvalidSetup("D-matrix width overflow".to_string()))?; + Ok(Self { + m_vars, + r_vars, + num_blocks, + block_len, + inner_width, + outer_width, + d_matrix_width, + num_digits_commit, + num_digits_open, + num_digits_fold, + log_basis, + }) + } + + /// Total number of outer variables consumed by ring coefficients. + /// + /// # Errors + /// + /// Returns an error if the variable count overflows. + pub fn outer_vars(&self) -> Result { + self.m_vars + .checked_add(self.r_vars) + .ok_or_else(|| HachiError::InvalidSetup("variable count overflow".to_string())) + } + + /// Required polynomial variable count for this layout (`outer + alpha`). + /// + /// # Errors + /// + /// Returns an error if the variable count overflows. + pub fn required_num_vars(&self) -> Result { + let alpha = D.trailing_zeros() as usize; + self.outer_vars()? + .checked_add(alpha) + .ok_or_else(|| HachiError::InvalidSetup("variable count overflow".to_string())) + } +} + +impl Valid for HachiCommitmentLayout { + fn check(&self) -> Result<(), SerializationError> { + if self.num_blocks == 0 || self.block_len == 0 { + return Err(SerializationError::InvalidData( + "invalid zero block layout".to_string(), + )); + } + Ok(()) + } +} + +impl HachiSerialize for HachiCommitmentLayout { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.m_vars.serialize_with_mode(&mut writer, compress)?; + self.r_vars.serialize_with_mode(&mut writer, compress)?; + self.num_blocks.serialize_with_mode(&mut writer, compress)?; + self.block_len.serialize_with_mode(&mut writer, compress)?; + self.inner_width + .serialize_with_mode(&mut writer, compress)?; + self.outer_width + .serialize_with_mode(&mut writer, compress)?; + self.d_matrix_width + .serialize_with_mode(&mut writer, compress)?; + self.num_digits_commit + .serialize_with_mode(&mut writer, compress)?; + self.num_digits_open + .serialize_with_mode(&mut writer, compress)?; + self.num_digits_fold + .serialize_with_mode(&mut writer, compress)?; + (self.log_basis as usize).serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.m_vars.serialized_size(compress) + + self.r_vars.serialized_size(compress) + + self.num_blocks.serialized_size(compress) + + self.block_len.serialized_size(compress) + + self.inner_width.serialized_size(compress) + + self.outer_width.serialized_size(compress) + + self.d_matrix_width.serialized_size(compress) + + self.num_digits_commit.serialized_size(compress) + + self.num_digits_open.serialized_size(compress) + + self.num_digits_fold.serialized_size(compress) + + (self.log_basis as usize).serialized_size(compress) + } +} + +impl HachiDeserialize for HachiCommitmentLayout { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let out = Self { + m_vars: usize::deserialize_with_mode(&mut reader, compress, validate)?, + r_vars: usize::deserialize_with_mode(&mut reader, compress, validate)?, + num_blocks: usize::deserialize_with_mode(&mut reader, compress, validate)?, + block_len: usize::deserialize_with_mode(&mut reader, compress, validate)?, + inner_width: usize::deserialize_with_mode(&mut reader, compress, validate)?, + outer_width: usize::deserialize_with_mode(&mut reader, compress, validate)?, + d_matrix_width: usize::deserialize_with_mode(&mut reader, compress, validate)?, + num_digits_commit: usize::deserialize_with_mode(&mut reader, compress, validate)?, + num_digits_open: usize::deserialize_with_mode(&mut reader, compress, validate)?, + num_digits_fold: usize::deserialize_with_mode(&mut reader, compress, validate)?, + log_basis: usize::deserialize_with_mode(&mut reader, compress, validate)? as u32, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Parameter bundle for the ring-native commitment core (§4.1–§4.2). +/// +/// Security parameters (`N_A`, `N_B`, `N_D`, `CHALLENGE_WEIGHT`) are +/// compile-time constants fixed for a given security level. Decomposition +/// parameters (gadget depths, `log_basis`) are runtime values derived from +/// [`DecompositionParams`] and live in [`HachiCommitmentLayout`]. +pub trait CommitmentConfig: Clone + Send + Sync + 'static { + /// Ring degree used by `CyclotomicRing`. + const D: usize; + /// Inner Ajtai matrix row count. + const N_A: usize; + /// Outer commitment matrix row count. + const N_B: usize; + /// Prover commitment matrix `D` row count (§4.2). + const N_D: usize; + /// Hamming weight of sparse challenges (`ω` in the paper). + const CHALLENGE_WEIGHT: usize; + + /// Decomposition parameters (gadget base and coefficient bounds). + fn decomposition() -> DecompositionParams; + + /// Choose the runtime commitment layout for `max_num_vars`. + /// + /// # Errors + /// + /// Returns an error if `max_num_vars` does not admit a valid layout. + fn commitment_layout(max_num_vars: usize) -> Result; + + /// Runtime L∞ bound for `z` (`β`) used by stage-1 folding checks. + /// + /// # Errors + /// + /// Returns an error on invalid parameters or arithmetic overflow. + fn beta_bound(layout: HachiCommitmentLayout) -> Result { + beta_linf_fold_bound( + layout.r_vars, + Self::challenge_weight_for_ring_dim(Self::D), + layout.log_basis, + ) + } + + /// Ring dimension to use at a given fold level. + /// + /// `level` is 0-indexed (level 0 is the initial polynomial). + /// `_w_num_vars` is the number of variables in the witness at this level. + /// + /// The default implementation returns `Self::D` at all levels (constant D). + /// Override for decreasing-D schedules. + fn d_at_level(_level: usize, _w_num_vars: usize) -> usize { + Self::D + } + + /// Module rank (inner Ajtai row count) at a given fold level. + /// + /// Must satisfy `d_at_level(level) * n_a_at_level(level) >= security_dim` + /// for the target security level. The default returns `Self::N_A` at all levels. + fn n_a_at_level(_level: usize) -> usize { + Self::N_A + } + + /// Challenge weight (Hamming weight ω) appropriate for ring dimension `d`. + /// + /// The default returns `Self::CHALLENGE_WEIGHT` for any `d`, which is + /// correct for constant-D configs. Override for varying-D schedules where + /// the optimal weight depends on the ring dimension (e.g., to maintain + /// ≥128 bits of challenge entropy as D decreases). + fn challenge_weight_for_ring_dim(_d: usize) -> usize { + Self::CHALLENGE_WEIGHT + } +} + +/// Deterministic upper bound for the stage-1 folded-witness infinity norm. +/// +/// This encodes the bound used in `QuadraticEquation::compute_z_hat`: +/// `||z||_inf <= 2^R * ω * (b/2)` where `b = 2^LOG_BASIS`. +/// +/// # Errors +/// +/// Returns an error when parameters are out of range or intermediate products +/// overflow `u128`. +pub(crate) fn beta_linf_fold_bound( + r: usize, + challenge_weight: usize, + log_basis: u32, +) -> Result { + if !(1..128).contains(&log_basis) { + return Err(HachiError::InvalidSetup("invalid LOG_BASIS".to_string())); + } + if r >= 128 { + return Err(HachiError::InvalidSetup("r_vars must be < 128".to_string())); + } + + let blocks = 1u128 << r; + let b = 1u128 << log_basis; + let half_b = b / 2; + + let term = blocks + .checked_mul(challenge_weight as u128) + .ok_or_else(|| HachiError::InvalidSetup("beta bound overflow".to_string()))?; + term.checked_mul(half_b) + .ok_or_else(|| HachiError::InvalidSetup("beta bound overflow".to_string())) +} + +/// Validate static config invariants and derive runtime dimensions. +/// +/// # Errors +/// +/// Returns an error when config constants are inconsistent or overflow. +pub(super) fn validate_and_derive_layout( + max_num_vars: usize, +) -> Result { + if D != Cfg::D { + return Err(HachiError::InvalidSetup(format!( + "const D={D} mismatches config D={}", + Cfg::D + ))); + } + Cfg::commitment_layout(max_num_vars) +} + +/// Ensure `max_num_vars` is sufficient for config dimensions. +/// +/// # Errors +/// +/// Returns an error when `max_num_vars < required_vars`. +pub(super) fn ensure_supported_num_vars( + max_num_vars: usize, + required_vars: usize, +) -> Result<(), HachiError> { + if max_num_vars < required_vars { + return Err(HachiError::InvalidSetup(format!( + "max_num_vars {max_num_vars} is smaller than required {required_vars}" + ))); + } + Ok(()) +} + +/// Ensure input blocks match the expected config-derived layout. +/// +/// # Errors +/// +/// Returns an error when block count or per-block size mismatch. +pub(super) fn ensure_block_layout( + f_blocks: &[Vec>], + layout: HachiCommitmentLayout, +) -> Result<(), HachiError> { + if f_blocks.len() != layout.num_blocks { + return Err(HachiError::InvalidSize { + expected: layout.num_blocks, + actual: f_blocks.len(), + }); + } + for block in f_blocks { + if block.len() != layout.block_len { + return Err(HachiError::InvalidSize { + expected: layout.block_len, + actual: block.len(), + }); + } + } + Ok(()) +} + +/// Ensure matrix has at least the expected dimensions. +/// +/// Matrices may be wider than the main layout requires when widened to +/// accommodate the w-commitment's column counts. +/// +/// # Errors +/// +/// Returns an error if row count mismatches or any row is too narrow. +pub(super) fn ensure_matrix_shape_ge( + mat: &FlatMatrix, + expected_rows: usize, + min_cols: usize, + name: &str, +) -> Result<(), HachiError> { + if mat.num_rows() != expected_rows { + return Err(HachiError::InvalidSize { + expected: expected_rows, + actual: mat.num_rows(), + }); + } + let actual_cols = mat.num_cols_at::(); + if actual_cols < min_cols { + return Err(HachiError::InvalidSetup(format!( + "{name} has width {actual_cols}, expected >= {min_cols}", + ))); + } + Ok(()) +} + +/// Small correctness-first config for tests and local benchmarks. +/// +/// Fixed layout (m_vars=4, r_vars=2) for fast test iteration. For larger +/// polynomials, use [`DynamicSmallTestCommitmentConfig`] instead. +#[derive(Clone, Copy, Debug, Default)] +pub struct SmallTestCommitmentConfig; + +impl CommitmentConfig for SmallTestCommitmentConfig { + const D: usize = 16; + const N_A: usize = 8; + const N_B: usize = 4; + const N_D: usize = 4; + const CHALLENGE_WEIGHT: usize = 3; + + fn decomposition() -> DecompositionParams { + DecompositionParams { + log_basis: 3, + log_commit_bound: 32, + log_open_bound: None, + } + } + + fn commitment_layout(_max_num_vars: usize) -> Result { + HachiCommitmentLayout::new::(4, 2, &Self::decomposition()) + } +} + +/// D=16 config with dynamic layout that adapts to polynomial size. +/// +/// Uses the same D=16 ring dimension as [`SmallTestCommitmentConfig`] but +/// derives `m_vars`/`r_vars` from `max_num_vars`, so it can commit +/// polynomials with an arbitrary number of variables. +#[derive(Clone, Copy, Debug, Default)] +pub struct DynamicSmallTestCommitmentConfig; + +impl CommitmentConfig for DynamicSmallTestCommitmentConfig { + const D: usize = 16; + const N_A: usize = 8; + const N_B: usize = 4; + const N_D: usize = 4; + const CHALLENGE_WEIGHT: usize = 3; + + fn decomposition() -> DecompositionParams { + DecompositionParams { + log_basis: 3, + log_commit_bound: 32, + log_open_bound: None, + } + } + + fn commitment_layout(max_num_vars: usize) -> Result { + let alpha = Self::D.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + if reduced_vars == 0 { + return Err(HachiError::InvalidSetup( + "max_num_vars must leave at least one outer variable".to_string(), + )); + } + let (m_vars, r_vars) = optimal_m_r_split::(reduced_vars); + HachiCommitmentLayout::new::(m_vars, r_vars, &Self::decomposition()) + } +} + +/// Production-oriented profile for 128-bit base fields (`Fp128

`), +/// parameterized by the coefficient bound used at commit time. +/// +/// This profile targets the `D = 512`, `n_A = n_B = n_D = 1` regime with +/// base-8 balanced decomposition (`log_basis = 3`) over ~128-bit moduli. +/// +/// `LOG_COMMIT_BOUND` is the bit-width of the largest polynomial coefficient +/// the commitment decomposition must represent. Smaller bounds yield fewer +/// decomposition levels (`delta_commit = ceil(LOG_COMMIT_BOUND / log_basis)`) +/// and proportionally smaller witnesses. +/// +/// Opening always uses the full field modulus (128 bits) because folding with +/// arbitrary field-element weights produces full-field-size coefficients. +/// +/// # Aliases +/// +/// - [`Fp128FullCommitmentConfig`] = `<128>` — arbitrary field-element polys +/// - [`Fp128OneHotCommitmentConfig`] = `<1>` — binary / one-hot polys +/// - [`Fp128LogBasisCommitmentConfig`] = `<3>` — balanced-digit witnesses +/// - [`Fp128CommitmentConfig`] — backward-compatible alias for `<128>` +/// +/// # β derivation (stage-1 folded witness `z`) +/// +/// - In `compute_z_hat`, each coordinate is `z[j] = Σ_i s_i[j].mul_by_sparse(c_i)`. +/// - `balanced_decompose_pow2` yields per-coefficient digits in `[-b/2, b/2)` +/// where `b = 2^LOG_BASIS`, so each input coefficient has `|·| <= b/2`. +/// - Challenges use exactly `ω = CHALLENGE_WEIGHT` nonzeros in `{±1}`. +/// - Therefore each `mul_by_sparse` output coefficient is bounded by `ω * (b/2)`. +/// - Summing over `2^R` blocks (R = r_vars) gives: +/// `||z||_inf <= 2^R * ω * (b/2)`. +#[derive(Clone, Copy, Debug, Default)] +pub struct Fp128BoundedCommitmentConfig; + +impl CommitmentConfig + for Fp128BoundedCommitmentConfig +{ + const D: usize = 512; + const N_A: usize = 1; + const N_B: usize = 1; + const N_D: usize = 1; + const CHALLENGE_WEIGHT: usize = 19; + + fn decomposition() -> DecompositionParams { + DecompositionParams { + log_basis: 3, + log_commit_bound: LOG_COMMIT_BOUND, + log_open_bound: if LOG_COMMIT_BOUND < 128 { + Some(128) + } else { + None + }, + } + } + + fn commitment_layout(max_num_vars: usize) -> Result { + let alpha = Self::D.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + if reduced_vars == 0 { + return Err(HachiError::InvalidSetup( + "max_num_vars must leave at least one outer variable".to_string(), + )); + } + let (m_vars, r_vars) = optimal_m_r_split::(reduced_vars); + HachiCommitmentLayout::new::(m_vars, r_vars, &Self::decomposition()) + } +} + +/// Full-field (128-bit) coefficient bound for arbitrary field-element polynomials. +pub type Fp128FullCommitmentConfig = Fp128BoundedCommitmentConfig<128>; + +/// Binary (1-bit) coefficient bound for one-hot or binary polynomials. +/// +/// Reduces `delta_commit` from 43 to 1 compared to [`Fp128FullCommitmentConfig`], +/// shrinking the dominant `z_pre` witness component by ~43x. +pub type Fp128OneHotCommitmentConfig = Fp128BoundedCommitmentConfig<1>; + +/// Log-basis (3-bit) coefficient bound for balanced-digit witnesses. +/// +/// Functionally equivalent to `WCommitmentConfig<512, Fp128FullCommitmentConfig>` +/// for recursive w-openings. +pub type Fp128LogBasisCommitmentConfig = Fp128BoundedCommitmentConfig<3>; + +/// Backward-compatible alias for [`Fp128FullCommitmentConfig`]. +pub type Fp128CommitmentConfig = Fp128FullCommitmentConfig; + +/// Halving-D commitment config for Fp128 (D=512 → 256 → 128). +/// +/// Uses `d_at_level` and `n_a_at_level` to halve the ring dimension at each +/// fold level while doubling the module rank to maintain D×N_A ≥ 512 for +/// security. Stops halving at D=128, which is the minimum ring dimension +/// for which sparse ternary challenges provide sufficient security. +/// +/// Challenge weights are scaled per ring dimension to maintain ≥128 bits +/// of challenge entropy (log₂(C(D,ω) · 2^ω) ≥ 128): +/// D=512: ω=19 (~131 bits), D=256: ω=23 (~131 bits), D=128: ω=31 (~130 bits). +#[derive(Clone, Copy, Debug, Default)] +pub struct Fp128HalvingDCommitmentConfig; + +impl CommitmentConfig for Fp128HalvingDCommitmentConfig { + const D: usize = 512; + const N_A: usize = 1; + const N_B: usize = 1; + const N_D: usize = 1; + const CHALLENGE_WEIGHT: usize = 19; + + fn decomposition() -> DecompositionParams { + DecompositionParams { + log_basis: 3, + log_commit_bound: 128, + log_open_bound: None, + } + } + + fn commitment_layout(max_num_vars: usize) -> Result { + let alpha = Self::D.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + if reduced_vars == 0 { + return Err(HachiError::InvalidSetup( + "max_num_vars must leave at least one outer variable".to_string(), + )); + } + let (m_vars, r_vars) = optimal_m_r_split::(reduced_vars); + HachiCommitmentLayout::new::(m_vars, r_vars, &Self::decomposition()) + } + + fn d_at_level(level: usize, _w_num_vars: usize) -> usize { + match level { + 0 => 512, + 1 => 256, + _ => 128, + } + } + + fn n_a_at_level(level: usize) -> usize { + match level { + 0 => 1, + 1 => 2, + _ => 4, + } + } + + fn challenge_weight_for_ring_dim(d: usize) -> usize { + match d { + 512 => 19, + 256 => 23, + 128 => 31, + _ => panic!("Fp128HalvingDCommitmentConfig: unsupported ring dim {d}"), + } + } +} diff --git a/src/protocol/commitment/mod.rs b/src/protocol/commitment/mod.rs new file mode 100644 index 00000000..67953058 --- /dev/null +++ b/src/protocol/commitment/mod.rs @@ -0,0 +1,26 @@ +//! Protocol commitment abstraction layer. + +mod commit; +mod config; +pub mod onehot; +mod scheme; +mod transcript_append; +mod types; +pub mod utils; + +pub use commit::{ + HachiCommitmentCore, HachiExpandedSetup, HachiProverSetup, HachiSetupSeed, HachiVerifierSetup, +}; +pub use config::optimal_m_r_split; +pub use config::{ + compute_num_digits, compute_num_digits_fold, CommitmentConfig, DecompositionParams, + DynamicSmallTestCommitmentConfig, Fp128BoundedCommitmentConfig, Fp128CommitmentConfig, + Fp128FullCommitmentConfig, Fp128HalvingDCommitmentConfig, Fp128LogBasisCommitmentConfig, + Fp128OneHotCommitmentConfig, HachiCommitmentLayout, SmallTestCommitmentConfig, +}; +pub use onehot::{map_onehot_to_sparse_blocks, SparseBlockEntry}; +pub use scheme::{CommitWitness, CommitmentScheme, RingCommitmentScheme}; +pub use transcript_append::AppendToTranscript; +pub use types::{ + DummyProof, HachiCommitment, HachiOpeningClaim, HachiOpeningPoint, RingCommitment, +}; diff --git a/src/protocol/commitment/onehot.rs b/src/protocol/commitment/onehot.rs new file mode 100644 index 00000000..470d3bd4 --- /dev/null +++ b/src/protocol/commitment/onehot.rs @@ -0,0 +1,336 @@ +//! One-hot commitment path for regular one-hot ring elements. +//! +//! Exploits the sparsity of one-hot witnesses (coefficients in {0,1}) to +//! eliminate all inner ring multiplications. The inner Ajtai `t = A * s` +//! reduces to summing selected columns of `A` with negacyclic rotations. + +use std::collections::BTreeMap; + +use crate::algebra::fields::wide::{HasWide, ReduceTo}; +use crate::algebra::ring::{CyclotomicRing, WideCyclotomicRing}; +use crate::error::HachiError; +use crate::protocol::commitment::utils::flat_matrix::RingMatrixView; +use crate::protocol::hachi_poly_ops::OneHotIndex; +use crate::{AdditiveGroup, CanonicalField, FieldCore}; + +/// Describes a nonzero ring element within one block of the commitment layout. +#[derive(Debug, Clone, PartialEq)] +pub struct SparseBlockEntry { + /// Position within the block (0..2^M). + pub pos_in_block: usize, + /// Coefficient indices that are 1 within this ring element. + pub nonzero_coeffs: Vec, +} + +/// Map a regular one-hot witness to sparse ring block entries. +/// +/// - `onehot_k`: chunk size K. The witness has T chunks of K field elements, +/// each chunk containing exactly one 1. +/// - `indices`: length-T slice where `indices[c]` is the hot position in +/// chunk `c` (must be in `[0, K)`). +/// - `r`, `m`: commitment config parameters (2^R blocks of 2^M ring elements). +/// - `D`: ring degree (const generic on caller side, passed as runtime here). +/// +/// Returns one `Vec` per block (outer len = 2^R). +/// +/// # Errors +/// +/// Returns an error if K and D are not "nicely matched" (one must divide +/// the other), if any index is out of range, or if the dimensions don't +/// fill the commitment layout. +pub fn map_onehot_to_sparse_blocks( + onehot_k: usize, + indices: &[Option], + r: usize, + m: usize, + d: usize, +) -> Result>, HachiError> { + if onehot_k == 0 || d == 0 { + return Err(HachiError::InvalidInput( + "onehot_k and D must be nonzero".into(), + )); + } + if !(onehot_k % d == 0 || d % onehot_k == 0) { + return Err(HachiError::InvalidInput(format!( + "K={onehot_k} and D={d} must be nicely matched (one divides the other)" + ))); + } + + let num_chunks = indices.len(); + let total_field_elems = num_chunks + .checked_mul(onehot_k) + .ok_or_else(|| HachiError::InvalidInput("T*K overflow".into()))?; + if total_field_elems % d != 0 { + return Err(HachiError::InvalidInput(format!( + "T*K={total_field_elems} is not divisible by D={d}" + ))); + } + let total_ring_elems = total_field_elems / d; + let num_blocks = 1usize << r; + let block_len = 1usize << m; + if total_ring_elems != num_blocks * block_len { + return Err(HachiError::InvalidSize { + expected: num_blocks * block_len, + actual: total_ring_elems, + }); + } + + let mut ring_elem_map: BTreeMap> = BTreeMap::new(); + for (c, opt) in indices.iter().enumerate() { + let Some(&idx_raw) = opt.as_ref() else { + continue; + }; + let idx = idx_raw.as_usize(); + if idx >= onehot_k { + return Err(HachiError::InvalidInput(format!( + "index {idx} out of range for chunk size K={onehot_k} at position {c}" + ))); + } + let field_pos = c * onehot_k + idx; + let ring_elem_idx = field_pos / d; + let coeff_idx = field_pos % d; + ring_elem_map + .entry(ring_elem_idx) + .or_default() + .push(coeff_idx); + } + + // Sequential block layout matching commit_coeffs: block i = ring elements + // [i*block_len, (i+1)*block_len). + let mut blocks: Vec> = vec![Vec::new(); num_blocks]; + for (ring_elem_idx, nonzero_coeffs) in ring_elem_map { + let block_idx = ring_elem_idx / block_len; + let pos_in_block = ring_elem_idx % block_len; + blocks[block_idx].push(SparseBlockEntry { + pos_in_block, + nonzero_coeffs, + }); + } + + Ok(blocks) +} + +/// Sparse inner Ajtai: compute `t = A * s` for a one-hot block. +/// +/// Instead of materializing the full decomposed vector `s` and doing a dense +/// matvec, we accumulate only the nonzero contributions using fused +/// shift-accumulate (no intermediate temporaries): +/// +/// ```text +/// t[a] += A[a][entry.pos * num_digits] * (X^{k_1} + X^{k_2} + ...) +/// ``` +#[cfg(test)] +#[allow(non_snake_case)] +pub(crate) fn inner_ajtai_onehot_t_only( + A: &[Vec>], + sparse_entries: &[SparseBlockEntry], + _block_len: usize, + num_digits: usize, +) -> Vec> { + let n_a = A.len(); + + let mut t = vec![CyclotomicRing::::zero(); n_a]; + for entry in sparse_entries { + let col = entry.pos_in_block * num_digits; + for a in 0..n_a { + A[a][col].mul_by_monomial_sum_into(&mut t[a], &entry.nonzero_coeffs); + } + } + + t +} + +/// Wide-accumulator variant of [`inner_ajtai_onehot_t_only`]. +/// +/// Accumulates into `WideCyclotomicRing` (carry-free i32 additions), +/// then reduces once at the end. This avoids per-addition modular reduction. +#[allow(non_snake_case)] +pub(crate) fn inner_ajtai_onehot_wide( + A: &RingMatrixView<'_, F, D>, + sparse_entries: &[SparseBlockEntry], + _block_len: usize, + num_digits: usize, +) -> Vec> +where + F: FieldCore + CanonicalField + HasWide, + F::Wide: AdditiveGroup + From + ReduceTo, +{ + let n_a = A.num_rows(); + let mut t_wide = vec![WideCyclotomicRing::::zero(); n_a]; + + for entry in sparse_entries { + let col = entry.pos_in_block * num_digits; + for (a_idx, t_w) in t_wide.iter_mut().enumerate() { + let a_wide = WideCyclotomicRing::from_ring(&A.row(a_idx)[col]); + a_wide.mul_by_monomial_sum_into(t_w, &entry.nonzero_coeffs); + } + } + + t_wide.into_iter().map(|w| w.reduce()).collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::{Fp64, Prime128M8M4M1M0}; + use crate::protocol::commitment::utils::flat_matrix::FlatMatrix; + use rand::rngs::StdRng; + use rand::SeedableRng; + + #[test] + fn map_onehot_k_gt_d() { + // K=16, D=4, T=2 chunks => 32 field elements => 8 ring elements + // R=1 (2 blocks), M=2 (4 per block) => 8 ring elements total + let k = 16; + let d = 4; + let indices: Vec> = vec![Some(3), Some(10)]; + let blocks = map_onehot_to_sparse_blocks(k, &indices, 1, 2, d).unwrap(); + + assert_eq!(blocks.len(), 2); + let total_entries: usize = blocks.iter().map(|b| b.len()).sum(); + assert_eq!(total_entries, 2, "T=2 nonzero ring elements"); + + for block in &blocks { + for entry in block { + assert_eq!(entry.nonzero_coeffs.len(), 1, "K>D => single monomial"); + } + } + } + + #[test] + fn map_onehot_k_eq_d() { + // K=4, D=4, T=4 chunks => 16 field elements => 4 ring elements + // R=1 (2 blocks), M=1 (2 per block) + let k = 4; + let d = 4; + let indices: Vec> = vec![Some(0), Some(2), Some(3), Some(1)]; + let blocks = map_onehot_to_sparse_blocks(k, &indices, 1, 1, d).unwrap(); + + assert_eq!(blocks.len(), 2); + let total_entries: usize = blocks.iter().map(|b| b.len()).sum(); + assert_eq!(total_entries, 4, "K=D => every ring element is nonzero"); + + for block in &blocks { + for entry in block { + assert_eq!(entry.nonzero_coeffs.len(), 1, "K=D => single monomial"); + } + } + } + + #[test] + fn map_onehot_k_lt_d() { + // K=4, D=8, T=8 chunks => 32 field elements => 4 ring elements + // R=1 (2 blocks), M=1 (2 per block) + let k = 4; + let d = 8; + let indices: Vec> = vec![ + Some(0), + Some(2), + Some(3), + Some(1), + Some(0), + Some(0), + Some(3), + Some(3), + ]; + let blocks = map_onehot_to_sparse_blocks(k, &indices, 1, 1, d).unwrap(); + + assert_eq!(blocks.len(), 2); + let total_entries: usize = blocks.iter().map(|b| b.len()).sum(); + assert_eq!(total_entries, 4, "D>K => all ring elements nonzero"); + + for block in &blocks { + for entry in block { + assert_eq!( + entry.nonzero_coeffs.len(), + 2, + "D=2K => 2 nonzero coeffs per ring element" + ); + } + } + } + + #[test] + fn map_onehot_rejects_non_divisible() { + let result = map_onehot_to_sparse_blocks(3, &[Some(0usize), Some(1)], 0, 1, 4); + assert!(result.is_err()); + } + + #[test] + fn wide_matches_reference() { + type F = Fp64<4294967197>; + const D: usize = 64; + + let mut rng = StdRng::seed_from_u64(0xdead_beef); + let n_a = 3; + let block_len = 4; + let num_digits = 5; + let a_matrix: Vec>> = (0..n_a) + .map(|_| { + (0..block_len * num_digits) + .map(|_| CyclotomicRing::random(&mut rng)) + .collect() + }) + .collect(); + + let entries = vec![ + SparseBlockEntry { + pos_in_block: 0, + nonzero_coeffs: vec![1, 7, 15], + }, + SparseBlockEntry { + pos_in_block: 2, + nonzero_coeffs: vec![0, 63], + }, + ]; + + let a_flat = FlatMatrix::from_ring_matrix(&a_matrix); + let a_view = a_flat.view::(); + let ref_result = inner_ajtai_onehot_t_only(&a_matrix, &entries, block_len, num_digits); + let wide_result = inner_ajtai_onehot_wide(&a_view, &entries, block_len, num_digits); + + assert_eq!(ref_result.len(), wide_result.len()); + for (r, w) in ref_result.iter().zip(wide_result.iter()) { + assert_eq!(r, w, "wide result must match reference"); + } + } + + #[test] + fn wide_matches_reference_fp128() { + type F = Prime128M8M4M1M0; + const D: usize = 64; + + let mut rng = StdRng::seed_from_u64(0xcafe_1234); + let n_a = 2; + let block_len = 2; + let num_digits = 3; + let a_matrix: Vec>> = (0..n_a) + .map(|_| { + (0..block_len * num_digits) + .map(|_| CyclotomicRing::random(&mut rng)) + .collect() + }) + .collect(); + + let entries = vec![ + SparseBlockEntry { + pos_in_block: 0, + nonzero_coeffs: vec![0, 5, 32, 63], + }, + SparseBlockEntry { + pos_in_block: 1, + nonzero_coeffs: vec![10], + }, + ]; + + let a_flat = FlatMatrix::from_ring_matrix(&a_matrix); + let a_view = a_flat.view::(); + let ref_result = inner_ajtai_onehot_t_only(&a_matrix, &entries, block_len, num_digits); + let wide_result = inner_ajtai_onehot_wide(&a_view, &entries, block_len, num_digits); + + assert_eq!(ref_result.len(), wide_result.len()); + for (r, w) in ref_result.iter().zip(wide_result.iter()) { + assert_eq!(r, w, "wide result must match reference (Fp128)"); + } + } +} diff --git a/src/protocol/commitment/scheme.rs b/src/protocol/commitment/scheme.rs new file mode 100644 index 00000000..eeebb313 --- /dev/null +++ b/src/protocol/commitment/scheme.rs @@ -0,0 +1,239 @@ +//! Commitment-scheme trait surface for Hachi protocol code. + +use super::config::{CommitmentConfig, HachiCommitmentLayout}; +use super::transcript_append::AppendToTranscript; +use crate::algebra::CyclotomicRing; +use crate::error::HachiError; +use crate::protocol::hachi_poly_ops::{HachiPolyOps, OneHotIndex}; +use crate::protocol::opening_point::BasisMode; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; + +/// Witness data produced alongside a ring-native commitment. +/// +/// Contains the commitment itself plus `t_hat` (basis-decomposed inner Ajtai +/// output) from the two-layer Ajtai construction (§4.1). The decomposed input +/// vectors `s` are NOT stored; they are recomputed from the polynomial during +/// proving via `HachiPolyOps`. +pub struct CommitWitness { + /// The ring commitment (outer Ajtai output `u = B · t̂`). + pub commitment: C, + /// Per-block basis-decomposed inner Ajtai output vectors as i8 digit planes. + pub t_hat: Vec>, + _marker: std::marker::PhantomData, +} + +impl CommitWitness { + /// Construct a new commit witness. + pub fn new(commitment: C, t_hat: Vec>) -> Self { + Self { + commitment, + t_hat, + _marker: std::marker::PhantomData, + } + } +} + +/// Commitment-scheme interface used by Hachi protocol code. +/// +/// Generic over field `F` and cyclotomic ring degree `D`. +/// Polynomials are provided as `impl HachiPolyOps`. +pub trait CommitmentScheme: Clone + Send + Sync + 'static +where + F: FieldCore + CanonicalField, +{ + /// Prover setup parameters. + type ProverSetup: Clone + Send + Sync; + /// Verifier setup parameters. + type VerifierSetup: Clone + Send + Sync; + /// Commitment object. + type Commitment: Clone + PartialEq + Send + Sync + AppendToTranscript; + /// Evaluation/opening proof object. + type Proof: Clone + Send + Sync; + /// Prover-side hint produced at commitment time. + type CommitHint: Clone + Send + Sync; + + /// Build prover setup for maximum polynomial dimension. + /// + /// # Panics + /// + /// Panics if internal setup fails (programming error, not adversarial input). + fn setup_prover(max_num_vars: usize) -> Self::ProverSetup; + + /// Derive verifier setup from prover setup. + fn setup_verifier(setup: &Self::ProverSetup) -> Self::VerifierSetup; + + /// Commit to one polynomial with a caller-specified layout. + /// + /// The layout's matrix dimensions must not exceed the setup's max dimensions. + /// Callers control `num_digits_commit` via the layout to reduce decomposition + /// depth for polynomials with bounded coefficients (e.g. delta=1 for {0,1}). + /// + /// # Errors + /// + /// Returns an error when setup/parameter constraints are not satisfied. + fn commit>( + poly: &P, + setup: &Self::ProverSetup, + layout: &HachiCommitmentLayout, + ) -> Result<(Self::Commitment, Self::CommitHint), HachiError>; + + /// Produce an opening proof at `opening_point` with a caller-specified layout. + /// + /// The layout must match the one used during commitment. Recursive w-opening + /// levels derive their own layouts internally via `WCommitmentConfig`. + /// + /// `basis` selects the polynomial representation (see [`BasisMode`]). + /// + /// # Errors + /// + /// Returns an error if the opening point is invalid or proof generation fails. + #[allow(clippy::too_many_arguments)] + fn prove, P: HachiPolyOps>( + setup: &Self::ProverSetup, + poly: &P, + opening_point: &[F], + hint: Self::CommitHint, + transcript: &mut T, + commitment: &Self::Commitment, + basis: BasisMode, + layout: &HachiCommitmentLayout, + ) -> Result; + + /// Verify an opening proof with a caller-specified layout. + /// + /// The layout must be reconstructed deterministically by the verifier — + /// never deserialized from the proof. It must match the layout used by the + /// prover for commitment and proving. + /// + /// `basis` must match the mode used by the prover (see [`BasisMode`]). + /// + /// # Errors + /// + /// Returns an error when verification fails. + #[allow(clippy::too_many_arguments)] + fn verify>( + proof: &Self::Proof, + setup: &Self::VerifierSetup, + transcript: &mut T, + opening_point: &[F], + opening: &F, + commitment: &Self::Commitment, + basis: BasisMode, + layout: &HachiCommitmentLayout, + ) -> Result<(), HachiError>; + + /// Protocol identifier. + fn protocol_name() -> &'static [u8]; +} + +/// Ring-native commitment interface for §4.1 implementation work. +pub trait RingCommitmentScheme: Clone + Send + Sync + 'static +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, +{ + /// Prover setup parameters. + type ProverSetup: Clone + Send + Sync; + /// Verifier setup parameters. + type VerifierSetup: Clone + Send + Sync; + /// Ring-native commitment type. + type Commitment: Clone + PartialEq + Send + Sync; + + /// Construct commitment setup for at most `max_num_vars` variables. + /// + /// # Errors + /// + /// Returns an error if dimensions are inconsistent with `Cfg`. + fn setup(max_num_vars: usize) -> Result<(Self::ProverSetup, Self::VerifierSetup), HachiError>; + + /// Read the runtime layout carried by `setup`. + /// + /// # Errors + /// + /// Returns an error when setup metadata is inconsistent. + fn layout(setup: &Self::ProverSetup) -> Result; + + /// Commit to ring blocks arranged as `2^R` vectors of length `2^M`. + /// + /// # Errors + /// + /// Returns an error if block layout mismatches config or commitment fails. + fn commit_ring_blocks( + f_blocks: &[Vec>], + setup: &Self::ProverSetup, + ) -> Result, HachiError>; + + /// Commit to a flat coefficient table `(f_i)_{i∈{0,1}^ℓ}` in ring form. + /// + /// # Errors + /// + /// Returns an error if `f_coeffs.len()` does not match the configured block + /// layout or if the underlying commitment routine fails. + fn commit_coeffs( + f_coeffs: &[CyclotomicRing], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let layout = Self::layout(setup)?; + let num_blocks = layout.num_blocks; + let block_len = layout.block_len; + let expected_len = num_blocks + .checked_mul(block_len) + .ok_or_else(|| HachiError::InvalidSetup("coefficient length overflow".to_string()))?; + if f_coeffs.len() != expected_len { + return Err(HachiError::InvalidSize { + expected: expected_len, + actual: f_coeffs.len(), + }); + } + + let blocks: Vec>> = f_coeffs + .chunks_exact(block_len) + .map(|chunk| chunk.to_vec()) + .collect(); + + Self::commit_ring_blocks(&blocks, setup) + } + + /// Commit to a regular one-hot witness. + /// + /// # Errors + /// + /// Returns an error if dimensions are inconsistent or any index is out + /// of range. + fn commit_onehot( + onehot_k: usize, + indices: &[Option], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let num_chunks = indices.len(); + let total_field_elems = num_chunks + .checked_mul(onehot_k) + .ok_or_else(|| HachiError::InvalidInput("T*K overflow".into()))?; + if total_field_elems % D != 0 { + return Err(HachiError::InvalidInput(format!( + "T*K={total_field_elems} is not divisible by D={D}" + ))); + } + + let total_ring_elems = total_field_elems / D; + let mut ring_coeffs = vec![CyclotomicRing::::zero(); total_ring_elems]; + for (c, opt) in indices.iter().enumerate() { + let Some(&idx_raw) = opt.as_ref() else { + continue; + }; + let idx = idx_raw.as_usize(); + if idx >= onehot_k { + return Err(HachiError::InvalidInput(format!( + "index {idx} out of range for chunk size K={onehot_k} at position {c}" + ))); + } + let field_pos = c * onehot_k + idx; + let ring_idx = field_pos / D; + let coeff_idx = field_pos % D; + ring_coeffs[ring_idx].coeffs[coeff_idx] = F::one(); + } + + Self::commit_coeffs(&ring_coeffs, setup) + } +} diff --git a/src/protocol/commitment/transcript_append.rs b/src/protocol/commitment/transcript_append.rs new file mode 100644 index 00000000..6510ccb5 --- /dev/null +++ b/src/protocol/commitment/transcript_append.rs @@ -0,0 +1,13 @@ +//! Traits for appending commitment objects to protocol transcripts. + +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; + +/// Protocol object that can be absorbed into a transcript. +pub trait AppendToTranscript +where + F: FieldCore + CanonicalField, +{ + /// Append this object to a transcript using the provided event label. + fn append_to_transcript>(&self, label: &[u8], transcript: &mut T); +} diff --git a/src/protocol/commitment/types.rs b/src/protocol/commitment/types.rs new file mode 100644 index 00000000..339a1a03 --- /dev/null +++ b/src/protocol/commitment/types.rs @@ -0,0 +1,157 @@ +//! Protocol commitment/opening wrapper types. + +use super::transcript_append::AppendToTranscript; +use crate::algebra::ring::CyclotomicRing; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; +use std::io::{Read, Write}; + +/// A Hachi opening point represented as field coordinates. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiOpeningPoint { + /// Point coordinates used for multilinear opening. + pub r: Vec, +} + +/// A Hachi opening claim `(point, value)`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiOpeningClaim { + /// Opening point. + pub point: HachiOpeningPoint, + /// Claimed value at `point`. + pub value: F, +} + +/// Minimal commitment wrapper used by protocol traits/tests. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct HachiCommitment(pub u128); + +/// Minimal proof wrapper used by protocol trait stubs and tests. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct DummyProof(pub u128); + +impl Valid for HachiCommitment { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl HachiSerialize for HachiCommitment { + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(&mut writer, Compress::No) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 16 + } +} + +impl HachiDeserialize for HachiCommitment { + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let value = u128::deserialize_with_mode(&mut reader, Compress::No, validate)?; + Ok(Self(value)) + } +} + +impl Valid for DummyProof { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl HachiSerialize for DummyProof { + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(&mut writer, Compress::No) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 16 + } +} + +impl HachiDeserialize for DummyProof { + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let value = u128::deserialize_with_mode(&mut reader, Compress::No, validate)?; + Ok(Self(value)) + } +} + +impl AppendToTranscript for HachiCommitment +where + F: FieldCore + CanonicalField, +{ + fn append_to_transcript>(&self, label: &[u8], transcript: &mut T) { + transcript.append_serde(label, self); + } +} + +/// Ring-native commitment object `u in R_q^{n_B}` used by §4.1. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct RingCommitment { + /// Outer commitment vector. + pub u: Vec>, +} + +impl Valid for RingCommitment { + fn check(&self) -> Result<(), SerializationError> { + self.u.check() + } +} + +impl HachiSerialize for RingCommitment { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.u.serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.u.serialized_size(compress) + } +} + +impl HachiDeserialize for RingCommitment { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let u = + Vec::>::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { u }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl AppendToTranscript for RingCommitment +where + F: FieldCore + CanonicalField, +{ + fn append_to_transcript>(&self, label: &[u8], transcript: &mut T) { + transcript.append_serde(label, self); + } +} diff --git a/src/protocol/commitment/utils/crt_ntt.rs b/src/protocol/commitment/utils/crt_ntt.rs new file mode 100644 index 00000000..5b056e91 --- /dev/null +++ b/src/protocol/commitment/utils/crt_ntt.rs @@ -0,0 +1,186 @@ +//! Protocol-facing CRT+NTT parameter dispatch and matrix caching. + +use crate::algebra::ntt::prime::PrimeWidth; +use crate::algebra::ntt::tables::{ + q128_primes, q64_primes, MAX_CRT_RING_DEGREE, Q128_MODULUS, Q128_NUM_PRIMES, Q32_MODULUS, + Q32_NUM_PRIMES, Q32_PRIMES, Q64_MODULUS, Q64_NUM_PRIMES, RING_DEGREE, +}; +use crate::algebra::ring::{CrtNttParamSet, CyclotomicCrtNtt}; +use crate::cfg_into_iter; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::{CanonicalField, FieldCore}; + +use super::flat_matrix::RingMatrixView; +use super::norm::detect_field_modulus; + +/// Supported protocol CRT+NTT parameter families. +#[derive(Clone)] +pub(crate) enum ProtocolCrtNttParams { + Q32(CrtNttParamSet), + Q64(CrtNttParamSet), + Q128(CrtNttParamSet), +} + +/// Select a CRT+NTT parameter set from field modulus and ring degree. +/// +/// Dispatch policy: +/// - `q <= 2^32-99` and `D <= 64`: Q32 (`i16`) +/// - `q <= 2^64-59` and `D <= 1024`: Q64 (`i32`, conservative K=5) +/// - `q == 2^128-275` and `D <= 1024`: Q128 (`i32`, K=5) +/// - otherwise: explicit setup error +pub(crate) fn select_crt_ntt_params( +) -> Result, HachiError> { + if !D.is_power_of_two() { + return Err(HachiError::InvalidSetup(format!( + "CRT+NTT requires power-of-two ring degree, got D={D}" + ))); + } + if D > MAX_CRT_RING_DEGREE { + return Err(HachiError::InvalidSetup(format!( + "CRT+NTT supports D <= {MAX_CRT_RING_DEGREE}, got D={D}" + ))); + } + + let modulus = detect_field_modulus::(); + + if modulus == Q128_MODULUS { + return Ok(ProtocolCrtNttParams::Q128(CrtNttParamSet::new( + q128_primes(), + ))); + } + + if modulus <= Q32_MODULUS as u128 { + if D <= RING_DEGREE { + return Ok(ProtocolCrtNttParams::Q32(CrtNttParamSet::new(Q32_PRIMES))); + } + return Ok(ProtocolCrtNttParams::Q64(CrtNttParamSet::new(q64_primes()))); + } + + if modulus <= Q64_MODULUS as u128 { + return Ok(ProtocolCrtNttParams::Q64(CrtNttParamSet::new(q64_primes()))); + } + + Err(HachiError::InvalidSetup(format!( + "no CRT+NTT parameter set for modulus {modulus} and D={D}; supported ranges: <= {Q64_MODULUS} (with Q32/Q64 dispatch) or exactly {Q128_MODULUS}" + ))) +} + +/// Pre-converted CRT+NTT cache for a single matrix, keyed by parameter family. +/// +/// Stores both negacyclic (for mat-vec) and cyclic (for quotient) representations +/// to avoid repeated coefficient-to-NTT conversion. +#[derive(Debug, Clone, PartialEq, Eq)] +#[allow(missing_docs)] +pub enum NttSlotCache { + /// 32-bit CRT primes. + Q32 { + neg: Vec>>, + cyc: Vec>>, + params: CrtNttParamSet, + }, + /// 64-bit CRT primes. + Q64 { + neg: Vec>>, + cyc: Vec>>, + params: CrtNttParamSet, + }, + /// 128-bit CRT primes. + Q128 { + neg: Vec>>, + cyc: Vec>>, + params: CrtNttParamSet, + }, +} + +fn convert_mat( + mat: RingMatrixView<'_, F, D>, + params: &CrtNttParamSet, +) -> Vec>> +where + F: FieldCore + CanonicalField, + W: PrimeWidth, +{ + cfg_into_iter!(0..mat.num_rows()) + .map(|i| { + mat.row(i) + .iter() + .map(|a| CyclotomicCrtNtt::from_ring_with_params(a, params)) + .collect() + }) + .collect() +} + +fn convert_mat_cyclic( + mat: RingMatrixView<'_, F, D>, + params: &CrtNttParamSet, +) -> Vec>> +where + F: FieldCore + CanonicalField, + W: PrimeWidth, +{ + cfg_into_iter!(0..mat.num_rows()) + .map(|i| { + mat.row(i) + .iter() + .map(|a| CyclotomicCrtNtt::from_ring_cyclic(a, params)) + .collect() + }) + .collect() +} + +/// Build an NTT slot cache for a single matrix. +/// +/// # Errors +/// +/// Returns an error if no CRT+NTT parameter set matches the field modulus and ring degree. +#[tracing::instrument(skip_all, name = "build_ntt_slot")] +pub fn build_ntt_slot( + mat: RingMatrixView<'_, F, D>, +) -> Result, HachiError> { + let params = select_crt_ntt_params::()?; + Ok(build_ntt_slot_from_params(mat, params)) +} + +fn build_ntt_slot_from_params( + mat: RingMatrixView<'_, F, D>, + params: ProtocolCrtNttParams, +) -> NttSlotCache { + match params { + ProtocolCrtNttParams::Q32(p) => NttSlotCache::Q32 { + neg: convert_mat(mat, &p), + cyc: convert_mat_cyclic(mat, &p), + params: p, + }, + ProtocolCrtNttParams::Q64(p) => NttSlotCache::Q64 { + neg: convert_mat(mat, &p), + cyc: convert_mat_cyclic(mat, &p), + params: p, + }, + ProtocolCrtNttParams::Q128(p) => NttSlotCache::Q128 { + neg: convert_mat(mat, &p), + cyc: convert_mat_cyclic(mat, &p), + params: p, + }, + } +} + +/// Build NTT slot caches for three matrices, computing CRT+NTT parameters once. +/// +/// # Errors +/// +/// Returns an error if no CRT+NTT parameter set matches the field modulus and ring degree. +#[tracing::instrument(skip_all, name = "build_ntt_slots")] +#[allow(non_snake_case)] +pub fn build_ntt_slots( + A: RingMatrixView<'_, F, D>, + B: RingMatrixView<'_, F, D>, + D_mat: RingMatrixView<'_, F, D>, +) -> Result<(NttSlotCache, NttSlotCache, NttSlotCache), HachiError> { + let params = select_crt_ntt_params::()?; + let slot_a = build_ntt_slot_from_params(A, params.clone()); + let slot_b = build_ntt_slot_from_params(B, params.clone()); + let slot_d = build_ntt_slot_from_params(D_mat, params); + Ok((slot_a, slot_b, slot_d)) +} diff --git a/src/protocol/commitment/utils/flat_matrix.rs b/src/protocol/commitment/utils/flat_matrix.rs new file mode 100644 index 00000000..47946a10 --- /dev/null +++ b/src/protocol/commitment/utils/flat_matrix.rs @@ -0,0 +1,426 @@ +//! D-agnostic flat matrix storage with typed ring-element views. +//! +//! [`FlatMatrix`] stores matrix entries as raw field elements, independent of +//! any ring dimension. A [`RingMatrixView`] borrows the flat data and +//! interprets it as `CyclotomicRing` slices, enabling the same +//! underlying matrix to be viewed at different ring dimensions. + +use crate::algebra::CyclotomicRing; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::FieldCore; +use std::io::{Read, Write}; + +/// Row-major matrix of field elements, independent of ring dimension. +/// +/// Each row contains `cols_ring * gen_ring_dim` contiguous field elements, +/// where `cols_ring` is the number of ring elements per row at the dimension +/// (`gen_ring_dim`) used when the matrix was generated. +/// +/// To view with a smaller ring dimension D' (where D' divides `gen_ring_dim`), +/// each row is re-chunked into `cols_ring * gen_ring_dim / D'` ring elements. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FlatMatrix { + data: Vec, + num_rows: usize, + /// Number of ring elements per row at the generation dimension. + cols_ring: usize, + /// Ring dimension used when generating (D_max). + gen_ring_dim: usize, +} + +impl FlatMatrix { + /// Number of rows. + #[inline] + pub fn num_rows(&self) -> usize { + self.num_rows + } + + /// Number of ring-element columns at the generation dimension. + #[inline] + pub fn cols_ring(&self) -> usize { + self.cols_ring + } + + /// Ring dimension used during generation. + #[inline] + pub fn gen_ring_dim(&self) -> usize { + self.gen_ring_dim + } + + /// Number of field elements per row. + #[inline] + pub fn row_field_len(&self) -> usize { + self.cols_ring * self.gen_ring_dim + } + + /// Build from a `Vec>>`, flattening ring elements + /// into contiguous field-element storage. + pub fn from_ring_matrix(mat: &[Vec>]) -> Self { + let num_rows = mat.len(); + let cols_ring = if num_rows > 0 { mat[0].len() } else { 0 }; + let row_len = cols_ring * D; + let mut data = Vec::with_capacity(num_rows * row_len); + for row in mat { + debug_assert_eq!(row.len(), cols_ring); + for ring_elem in row { + data.extend_from_slice(&ring_elem.coeffs); + } + } + Self { + data, + num_rows, + cols_ring, + gen_ring_dim: D, + } + } + + /// Create a typed view at ring dimension D. + /// + /// D must divide `gen_ring_dim`. The view re-chunks each row so that + /// `cols_at_d = cols_ring * gen_ring_dim / D`. + /// + /// # Panics + /// + /// Panics if `D == 0`, D does not divide `gen_ring_dim`, or the matrix is + /// empty with inconsistent metadata. + pub fn view(&self) -> RingMatrixView<'_, F, D> { + assert!(D > 0, "ring dimension must be positive"); + assert!( + self.gen_ring_dim % D == 0, + "D={D} does not divide gen_ring_dim={}", + self.gen_ring_dim + ); + let scale = self.gen_ring_dim / D; + let cols_at_d = self.cols_ring * scale; + RingMatrixView { + data: &self.data, + num_rows: self.num_rows, + num_cols: cols_at_d, + } + } + + /// Borrow the raw field-element data. + #[inline] + pub fn raw_data(&self) -> &[F] { + &self.data + } + + /// Number of ring-element columns when viewed at dimension D. + #[inline] + pub fn num_cols_at(&self) -> usize { + debug_assert!(D > 0 && self.gen_ring_dim % D == 0); + self.cols_ring * (self.gen_ring_dim / D) + } + + /// Borrow a single row as a slice of ring elements at dimension D (zero-copy). + /// + /// # Panics + /// + /// Panics if `row >= num_rows` or D does not divide `gen_ring_dim`. + #[inline] + pub fn row(&self, row: usize) -> &[CyclotomicRing] { + assert!(D > 0 && self.gen_ring_dim % D == 0); + assert!(row < self.num_rows, "row {row} out of bounds"); + let row_field_len = self.cols_ring * self.gen_ring_dim; + let start = row * row_field_len; + let field_slice = &self.data[start..start + row_field_len]; + let num_cols = row_field_len / D; + // SAFETY: CyclotomicRing is #[repr(transparent)] over [F; D]. + unsafe { + std::slice::from_raw_parts( + field_slice.as_ptr() as *const CyclotomicRing, + num_cols, + ) + } + } + + /// Whether the matrix has zero rows. + #[inline] + pub fn is_empty(&self) -> bool { + self.num_rows == 0 + } + + /// Convenience: number of ring-element columns in the first row at dimension D, + /// or 0 if empty. + #[inline] + pub fn first_row_len(&self) -> usize { + if self.is_empty() { + 0 + } else { + self.num_cols_at::() + } + } +} + +impl Valid for FlatMatrix { + fn check(&self) -> Result<(), SerializationError> { + for f in &self.data { + f.check()?; + } + Ok(()) + } +} + +impl HachiSerialize for FlatMatrix { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.num_rows.serialize_with_mode(&mut writer, compress)?; + self.cols_ring.serialize_with_mode(&mut writer, compress)?; + self.gen_ring_dim + .serialize_with_mode(&mut writer, compress)?; + for f in &self.data { + f.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + 3 * std::mem::size_of::() + + self + .data + .iter() + .map(|f| f.serialized_size(compress)) + .sum::() + } +} + +impl HachiDeserialize for FlatMatrix { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let num_rows = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let cols_ring = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let gen_ring_dim = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let total = num_rows * cols_ring * gen_ring_dim; + let mut data = Vec::with_capacity(total); + for _ in 0..total { + data.push(F::deserialize_with_mode(&mut reader, compress, validate)?); + } + let out = Self { + data, + num_rows, + cols_ring, + gen_ring_dim, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Typed read-only view of a [`FlatMatrix`] at a specific ring dimension D. +/// +/// Provides zero-copy access to rows as `&[CyclotomicRing]` by +/// transmuting the underlying `&[F]` slice (safe because `CyclotomicRing` +/// is `#[repr(transparent)]` over `[F; D]`). +#[derive(Debug, Clone, Copy)] +pub struct RingMatrixView<'a, F: FieldCore, const D: usize> { + data: &'a [F], + num_rows: usize, + num_cols: usize, +} + +impl<'a, F: FieldCore, const D: usize> RingMatrixView<'a, F, D> { + /// Number of rows in the view. + #[inline] + pub fn num_rows(&self) -> usize { + self.num_rows + } + + /// Number of ring-element columns per row. + #[inline] + pub fn num_cols(&self) -> usize { + self.num_cols + } + + /// Borrow a single row as a slice of ring elements (zero-copy). + /// + /// # Panics + /// + /// Panics if `row >= num_rows`. + #[inline] + pub fn row(&self, row: usize) -> &'a [CyclotomicRing] { + assert!(row < self.num_rows, "row {row} out of bounds"); + let row_field_len = self.num_cols * D; + let start = row * row_field_len; + let field_slice = &self.data[start..start + row_field_len]; + // SAFETY: CyclotomicRing is #[repr(transparent)] over [F; D], + // so a contiguous &[F] of length num_cols*D has the same layout as + // &[CyclotomicRing] of length num_cols. + unsafe { + std::slice::from_raw_parts( + field_slice.as_ptr() as *const CyclotomicRing, + self.num_cols, + ) + } + } + + /// Iterate over all rows. + pub fn rows(&self) -> impl Iterator]> + '_ { + (0..self.num_rows).map(move |i| self.row(i)) + } + + /// Take a sub-view: first `n_rows` rows, first `n_cols` ring-element columns. + /// + /// This cannot produce a contiguous sub-view because rows are not + /// contiguous after column truncation. Instead it returns a + /// [`SubMatrixView`] that copies on access. + /// + /// # Panics + /// + /// Panics if `n_rows > self.num_rows` or `n_cols > self.num_cols`. + pub fn submatrix(&self, n_rows: usize, n_cols: usize) -> SubMatrixView<'a, F, D> { + assert!(n_rows <= self.num_rows); + assert!(n_cols <= self.num_cols); + SubMatrixView { + parent: *self, + n_rows, + n_cols, + } + } + + /// Collect into the legacy `Vec>>` representation. + pub fn to_vec_vec(&self) -> Vec>> { + (0..self.num_rows).map(|i| self.row(i).to_vec()).collect() + } +} + +/// A non-contiguous sub-view that yields column-truncated rows. +#[derive(Debug, Clone, Copy)] +pub struct SubMatrixView<'a, F: FieldCore, const D: usize> { + parent: RingMatrixView<'a, F, D>, + n_rows: usize, + n_cols: usize, +} + +impl<'a, F: FieldCore, const D: usize> SubMatrixView<'a, F, D> { + /// Number of rows. + #[inline] + pub fn num_rows(&self) -> usize { + self.n_rows + } + + /// Number of ring-element columns. + #[inline] + pub fn num_cols(&self) -> usize { + self.n_cols + } + + /// Borrow a row, truncated to `n_cols` ring elements. + /// + /// # Panics + /// + /// Panics if `row >= n_rows`. + #[inline] + pub fn row(&self, row: usize) -> &'a [CyclotomicRing] { + assert!(row < self.n_rows, "row {row} out of bounds"); + &self.parent.row(row)[..self.n_cols] + } + + /// Iterate over rows. + pub fn rows(&self) -> impl Iterator]> + '_ { + (0..self.n_rows).map(move |i| self.row(i)) + } + + /// Collect into the legacy `Vec>>` representation. + pub fn to_vec_vec(&self) -> Vec>> { + self.rows().map(|r| r.to_vec()).collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Prime128M8M4M1M0; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Prime128M8M4M1M0; + + #[test] + fn roundtrip_from_ring_matrix_and_view() { + let mut rng = StdRng::seed_from_u64(42); + let rows = 3usize; + let cols = 5usize; + let mat: Vec>> = (0..rows) + .map(|_| { + (0..cols) + .map(|_| CyclotomicRing::random(&mut rng)) + .collect() + }) + .collect(); + + let flat = FlatMatrix::from_ring_matrix(&mat); + assert_eq!(flat.num_rows(), rows); + assert_eq!(flat.cols_ring(), cols); + assert_eq!(flat.gen_ring_dim(), 64); + + let view = flat.view::<64>(); + assert_eq!(view.num_rows(), rows); + assert_eq!(view.num_cols(), cols); + + for (i, orig_row) in mat.iter().enumerate() { + let view_row = view.row(i); + assert_eq!(view_row, orig_row.as_slice()); + } + } + + #[test] + fn view_at_smaller_d_rechunks_correctly() { + let mut rng = StdRng::seed_from_u64(99); + let rows = 2usize; + let cols = 4usize; + let mat: Vec>> = (0..rows) + .map(|_| { + (0..cols) + .map(|_| CyclotomicRing::random(&mut rng)) + .collect() + }) + .collect(); + + let flat = FlatMatrix::from_ring_matrix(&mat); + + // View at D=32: each D=64 element becomes 2 D=32 elements + let view32 = flat.view::<32>(); + assert_eq!(view32.num_rows(), rows); + assert_eq!(view32.num_cols(), cols * 2); + + // Verify field elements are the same + for r in 0..rows { + let ring32_row = view32.row(r); + let orig_row = flat.view::<64>().row(r); + for (j, orig_ring) in orig_row.iter().enumerate() { + let lo = &ring32_row[j * 2]; + let hi = &ring32_row[j * 2 + 1]; + assert_eq!(&orig_ring.coeffs[..32], lo.coefficients()); + assert_eq!(&orig_ring.coeffs[32..], hi.coefficients()); + } + } + } + + #[test] + fn submatrix_truncates_correctly() { + let mut rng = StdRng::seed_from_u64(7); + let mat: Vec>> = (0..4) + .map(|_| (0..8).map(|_| CyclotomicRing::random(&mut rng)).collect()) + .collect(); + + let flat = FlatMatrix::from_ring_matrix(&mat); + let view = flat.view::<64>(); + let sub = view.submatrix(2, 5); + + assert_eq!(sub.num_rows(), 2); + assert_eq!(sub.num_cols(), 5); + for (r, row) in mat.iter().enumerate().take(2) { + assert_eq!(sub.row(r), &row[..5]); + } + } +} diff --git a/src/protocol/commitment/utils/linear.rs b/src/protocol/commitment/utils/linear.rs new file mode 100644 index 00000000..21263c61 --- /dev/null +++ b/src/protocol/commitment/utils/linear.rs @@ -0,0 +1,1008 @@ +//! Linear algebra helpers for ring commitment. + +#[cfg(target_arch = "aarch64")] +use crate::algebra::ntt::neon; +use crate::algebra::ntt::{MontCoeff, PrimeWidth}; +use crate::algebra::{CrtNttParamSet, CyclotomicCrtNtt, CyclotomicRing, DigitMontLut}; +#[cfg(test)] +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::{cfg_fold_reduce, cfg_into_iter, cfg_iter}; +use crate::{CanonicalField, FieldCore}; +use std::array::from_fn; +use std::mem::size_of; + +use super::crt_ntt::NttSlotCache; +#[cfg(test)] +use super::crt_ntt::{select_crt_ntt_params, ProtocolCrtNttParams}; + +#[cfg(test)] +pub(crate) fn mat_vec_mul_unchecked( + mat: &[Vec>], + vec: &[CyclotomicRing], +) -> Vec> { + let mut out = Vec::with_capacity(mat.len()); + for row in mat { + debug_assert_eq!(row.len(), vec.len()); + let mut acc = CyclotomicRing::::zero(); + for (a, x) in row.iter().zip(vec.iter()) { + acc += *a * *x; + } + out.push(acc); + } + out +} + +#[inline] +fn accumulate_pointwise_product_into( + acc: &mut CyclotomicCrtNtt, + lhs: &CyclotomicCrtNtt, + rhs: &CyclotomicCrtNtt, + params: &CrtNttParamSet, +) { + #[cfg(target_arch = "aarch64")] + if neon::use_neon_ntt() { + for k in 0..K { + let prime = params.primes[k]; + unsafe { + if size_of::() == size_of::() { + neon::pointwise_mul_acc_i32( + acc.limbs[k].as_mut_ptr() as *mut i32, + lhs.limbs[k].as_ptr() as *const i32, + rhs.limbs[k].as_ptr() as *const i32, + D, + prime.p.to_i64() as i32, + prime.pinv.to_i64() as i32, + ); + } else { + neon::pointwise_mul_acc_i16( + acc.limbs[k].as_mut_ptr() as *mut i16, + lhs.limbs[k].as_ptr() as *const i16, + rhs.limbs[k].as_ptr() as *const i16, + D, + prime.p.to_i64() as i16, + prime.pinv.to_i64() as i16, + ); + } + } + } + return; + } + + for k in 0..K { + let prime = params.primes[k]; + let acc_limb = &mut acc.limbs[k]; + let lhs_limb = &lhs.limbs[k]; + let rhs_limb = &rhs.limbs[k]; + for ((acc_coeff, lhs_coeff), rhs_coeff) in acc_limb + .iter_mut() + .zip(lhs_limb.iter()) + .zip(rhs_limb.iter()) + { + let prod = prime.mul(*lhs_coeff, *rhs_coeff); + let sum = MontCoeff::from_raw(acc_coeff.raw().wrapping_add(prod.raw())); + *acc_coeff = prime.reduce_range(sum); + } + } +} + +#[cfg(test)] +fn precompute_dense_mat_ntt_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + mat: &[Vec>], + params: &CrtNttParamSet, +) -> Vec>> { + mat.iter() + .map(|row| { + row.iter() + .map(|a| CyclotomicCrtNtt::from_ring_with_params(a, params)) + .collect() + }) + .collect() +} + +#[cfg(test)] +fn mat_vec_mul_dense_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + mat: &[Vec>], + vec: &[CyclotomicRing], + params: &CrtNttParamSet, +) -> Vec> { + let ntt_vec: Vec> = vec + .iter() + .map(|v| CyclotomicCrtNtt::from_ring_with_params(v, params)) + .collect(); + + mat.iter() + .map(|row| { + debug_assert_eq!(row.len(), ntt_vec.len()); + let mut acc = CyclotomicCrtNtt::::zero(); + for (a, x_ntt) in row.iter().zip(ntt_vec.iter()) { + let a_ntt = CyclotomicCrtNtt::from_ring_with_params(a, params); + accumulate_pointwise_product_into(&mut acc, &a_ntt, x_ntt, params); + } + acc.to_ring_with_params(params) + }) + .collect() +} + +#[cfg(test)] +fn mat_vec_mul_dense_many_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + mat: &[Vec>], + vecs: &[Vec>], + params: &CrtNttParamSet, +) -> Vec>> { + let ntt_mat = precompute_dense_mat_ntt_with_params(mat, params); + vecs.iter() + .map(|vec| { + let ntt_vec: Vec> = vec + .iter() + .map(|v| CyclotomicCrtNtt::from_ring_with_params(v, params)) + .collect(); + + ntt_mat + .iter() + .map(|row_ntt| { + debug_assert_eq!(row_ntt.len(), ntt_vec.len()); + let mut acc = CyclotomicCrtNtt::::zero(); + for (a_ntt, x_ntt) in row_ntt.iter().zip(ntt_vec.iter()) { + accumulate_pointwise_product_into(&mut acc, a_ntt, x_ntt, params); + } + acc.to_ring_with_params(params) + }) + .collect() + }) + .collect() +} + +#[cfg(test)] +pub(crate) fn mat_vec_mul_crt_ntt( + mat: &[Vec>], + vec: &[CyclotomicRing], +) -> Result>, HachiError> { + let params = select_crt_ntt_params::()?; + let out = match ¶ms { + ProtocolCrtNttParams::Q32(p) => mat_vec_mul_dense_with_params(mat, vec, p), + ProtocolCrtNttParams::Q64(p) => mat_vec_mul_dense_with_params(mat, vec, p), + ProtocolCrtNttParams::Q128(p) => mat_vec_mul_dense_with_params(mat, vec, p), + }; + Ok(out) +} + +#[cfg(test)] +pub(crate) fn mat_vec_mul_crt_ntt_many( + mat: &[Vec>], + vecs: &[Vec>], +) -> Result>>, HachiError> { + let params = select_crt_ntt_params::()?; + let out = match ¶ms { + ProtocolCrtNttParams::Q32(p) => mat_vec_mul_dense_many_with_params(mat, vecs, p), + ProtocolCrtNttParams::Q64(p) => mat_vec_mul_dense_many_with_params(mat, vecs, p), + ProtocolCrtNttParams::Q128(p) => mat_vec_mul_dense_many_with_params(mat, vecs, p), + }; + Ok(out) +} + +fn unreduced_quotient_ntt( + ntt_row: &[CyclotomicCrtNtt], + cyc_row: &[CyclotomicCrtNtt], + vec_neg: &[CyclotomicCrtNtt], + vec_cyc: &[CyclotomicCrtNtt], + params: &CrtNttParamSet, +) -> CyclotomicRing +where + F: FieldCore + CanonicalField, + W: PrimeWidth, +{ + let n = ntt_row.len().min(vec_neg.len()); + + let mut acc_neg = CyclotomicCrtNtt::::zero(); + let mut acc_cyc = CyclotomicCrtNtt::::zero(); + + for j in 0..n { + accumulate_pointwise_product_into(&mut acc_neg, &ntt_row[j], &vec_neg[j], params); + accumulate_pointwise_product_into(&mut acc_cyc, &cyc_row[j], &vec_cyc[j], params); + } + + let neg_ring: CyclotomicRing = acc_neg.to_ring_with_params(params); + let cyc_ring: CyclotomicRing = acc_cyc.to_ring_cyclic(params); + + let neg_coeffs = neg_ring.coefficients(); + let cyc_coeffs = cyc_ring.coefficients(); + let quotient: [F; D] = from_fn(|k| (cyc_coeffs[k] - neg_coeffs[k]) * F::TWO_INV); + CyclotomicRing::from_coefficients(quotient) +} + +macro_rules! dispatch_slot_quotient { + ($slot:expr, $vec:expr, $convert_neg:ident, $convert_cyc:ident, $quotient_fn:ident) => {{ + match $slot { + NttSlotCache::Q32 { + neg, + cyc, + params: p, + } => { + let v = $vec; + let n = neg.first().map_or(0, |r| r.len().min(v.len())); + let v_neg: Vec<_> = cfg_iter!(v[..n]) + .map(|x| CyclotomicCrtNtt::$convert_neg(x, p)) + .collect(); + let v_cyc: Vec<_> = cfg_iter!(v[..n]) + .map(|x| CyclotomicCrtNtt::$convert_cyc(x, p)) + .collect(); + cfg_into_iter!(0..neg.len()) + .map(|i| $quotient_fn(&neg[i], &cyc[i], &v_neg, &v_cyc, p)) + .collect() + } + NttSlotCache::Q64 { + neg, + cyc, + params: p, + } => { + let v = $vec; + let n = neg.first().map_or(0, |r| r.len().min(v.len())); + let v_neg: Vec<_> = cfg_iter!(v[..n]) + .map(|x| CyclotomicCrtNtt::$convert_neg(x, p)) + .collect(); + let v_cyc: Vec<_> = cfg_iter!(v[..n]) + .map(|x| CyclotomicCrtNtt::$convert_cyc(x, p)) + .collect(); + cfg_into_iter!(0..neg.len()) + .map(|i| $quotient_fn(&neg[i], &cyc[i], &v_neg, &v_cyc, p)) + .collect() + } + NttSlotCache::Q128 { + neg, + cyc, + params: p, + } => { + let v = $vec; + let n = neg.first().map_or(0, |r| r.len().min(v.len())); + let v_neg: Vec<_> = cfg_iter!(v[..n]) + .map(|x| CyclotomicCrtNtt::$convert_neg(x, p)) + .collect(); + let v_cyc: Vec<_> = cfg_iter!(v[..n]) + .map(|x| CyclotomicCrtNtt::$convert_cyc(x, p)) + .collect(); + cfg_into_iter!(0..neg.len()) + .map(|i| $quotient_fn(&neg[i], &cyc[i], &v_neg, &v_cyc, p)) + .collect() + } + } + }}; +} + +/// Compute unreduced quotients for matrix rows against a witness vector. +/// +/// For each row: `r_i = high_part(sum_j row_ij * vec_j) = (cyc - neg) / 2`. +/// Vec NTT conversions and matrix cyclic NTT are precomputed once (not per-row). +pub fn unreduced_quotient_rows_ntt_cached( + slot: &NttSlotCache, + vec: &[CyclotomicRing], +) -> Vec> { + dispatch_slot_quotient!( + slot, + vec, + from_ring_with_params, + from_ring_cyclic, + unreduced_quotient_ntt + ) +} + +macro_rules! dispatch_slot { + ($slot:expr, $func:ident $(, $arg:expr)*) => {{ + match $slot { + NttSlotCache::Q32 { neg, params: p, .. } => $func(neg, $($arg,)* p), + NttSlotCache::Q64 { neg, params: p, .. } => $func(neg, $($arg,)* p), + NttSlotCache::Q128 { neg, params: p, .. } => $func(neg, $($arg,)* p), + } + }}; +} + +/// Flatten a nested `Vec>` into a contiguous `Vec<[i8; D]>` using +/// bulk memcpy per block, avoiding element-by-element iteration. +pub fn flatten_i8_blocks(blocks: &[Vec<[i8; D]>]) -> Vec<[i8; D]> { + let total: usize = blocks.iter().map(|b| b.len()).sum(); + let mut flat = Vec::with_capacity(total); + for block in blocks { + flat.extend_from_slice(block); + } + flat +} + +/// Basis-decompose a block of ring elements into `block.len() * num_digits` gadget components. +pub fn decompose_block( + block: &[CyclotomicRing], + num_digits: usize, + log_basis: u32, +) -> Vec> { + let mut out = vec![CyclotomicRing::::zero(); block.len() * num_digits]; + for (i, coeff_vec) in block.iter().enumerate() { + coeff_vec.balanced_decompose_pow2_into( + &mut out[i * num_digits..(i + 1) * num_digits], + log_basis, + ); + } + out +} + +/// Decompose each ring element in `rows` into `num_digits` gadget components. +pub fn decompose_rows( + rows: &[CyclotomicRing], + num_digits: usize, + log_basis: u32, +) -> Vec> { + let mut out = vec![CyclotomicRing::::zero(); rows.len() * num_digits]; + for (i, row) in rows.iter().enumerate() { + row.balanced_decompose_pow2_into(&mut out[i * num_digits..(i + 1) * num_digits], log_basis); + } + out +} + +/// Decompose each ring element where the last digit carries the remainder. +pub fn decompose_rows_with_carry( + rows: &[CyclotomicRing], + delta: usize, + log_basis: u32, +) -> Vec> { + let mut out = Vec::with_capacity(rows.len() * delta); + for row in rows { + out.extend(row.balanced_decompose_pow2_with_carry(delta, log_basis)); + } + out +} + +/// Like [`decompose_block`] but outputs `[i8; D]` digit planes instead of ring elements. +pub fn decompose_block_i8( + block: &[CyclotomicRing], + num_digits: usize, + log_basis: u32, +) -> Vec<[i8; D]> { + let mut out = Vec::with_capacity(block.len() * num_digits); + for coeff_vec in block { + out.extend(coeff_vec.balanced_decompose_pow2_i8(num_digits, log_basis)); + } + out +} + +/// Like [`decompose_rows`] but outputs `[i8; D]` digit planes instead of ring elements. +pub fn decompose_rows_i8( + rows: &[CyclotomicRing], + num_digits: usize, + log_basis: u32, +) -> Vec<[i8; D]> { + let mut out = Vec::with_capacity(rows.len() * num_digits); + for row in rows { + out.extend(row.balanced_decompose_pow2_i8(num_digits, log_basis)); + } + out +} + +#[inline] +fn is_zero_plane(plane: &[i8; D]) -> bool { + plane.iter().all(|&d| d == 0) +} + +#[cfg(target_arch = "aarch64")] +const TARGET_L2_CACHE_BYTES: usize = 4 * 1024 * 1024; +#[cfg(target_arch = "x86_64")] +const TARGET_L2_CACHE_BYTES: usize = 1024 * 1024; +#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))] +const TARGET_L2_CACHE_BYTES: usize = 1024 * 1024; + +#[inline] +#[allow(dead_code)] +fn add_ntt_into( + acc: &mut CyclotomicCrtNtt, + other: &CyclotomicCrtNtt, + params: &CrtNttParamSet, +) { + #[cfg(target_arch = "aarch64")] + if neon::use_neon_ntt() { + for k in 0..K { + let prime = params.primes[k]; + unsafe { + if size_of::() == size_of::() { + neon::add_reduce_i32( + acc.limbs[k].as_mut_ptr() as *mut i32, + other.limbs[k].as_ptr() as *const i32, + D, + prime.p.to_i64() as i32, + ); + } else { + neon::add_reduce_i16( + acc.limbs[k].as_mut_ptr() as *mut i16, + other.limbs[k].as_ptr() as *const i16, + D, + prime.p.to_i64() as i16, + ); + } + } + } + return; + } + + for k in 0..K { + let prime = params.primes[k]; + for d in 0..D { + let sum = + MontCoeff::from_raw(acc.limbs[k][d].raw().wrapping_add(other.limbs[k][d].raw())); + acc.limbs[k][d] = prime.reduce_range(sum); + } + } +} + +/// Column-tiled A*x across multiple blocks simultaneously. +/// +/// Each rayon thread owns one column tile of `ntt_mat` (sized to fit in L2 +/// cache) and iterates over all blocks, accumulating partial NTT results. +/// The matrix is loaded from DRAM exactly once. A final reduction sums +/// partial accumulators across tiles for each block. +/// +/// Accepts raw ring-coefficient slices per block. Decomposes to i8 digits +/// on-the-fly per tile to avoid materializing all digits at once. +/// Tile width is auto-computed from ring parameters and target L2 cache size. +#[tracing::instrument(skip_all, name = "mat_vec_mul_ntt_i8")] +pub fn mat_vec_mul_ntt_i8( + slot: &NttSlotCache, + blocks: &[&[CyclotomicRing]], + num_digits: usize, + log_basis: u32, +) -> Vec>> { + dispatch_slot!( + slot, + mat_vec_mul_i8_with_params, + blocks, + num_digits, + log_basis + ) +} + +/// Column-tiled A*x across multiple blocks of pre-decomposed i8 digit planes. +/// +/// This is the `num_digits_commit = 1` specialization of +/// [`mat_vec_mul_ntt_i8`]. It skips the `CyclotomicRing -> i8 digit plane` +/// decomposition entirely because the caller already holds each coefficient as a +/// balanced digit plane. +#[tracing::instrument(skip_all, name = "mat_vec_mul_ntt_digits_i8")] +pub fn mat_vec_mul_ntt_digits_i8( + slot: &NttSlotCache, + blocks: &[&[[i8; D]]], +) -> Vec>> { + dispatch_slot!(slot, mat_vec_mul_digits_i8_with_params, blocks) +} + +fn mat_vec_mul_digits_i8_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + ntt_mat: &[Vec>], + blocks: &[&[[i8; D]]], + params: &CrtNttParamSet, +) -> Vec>> { + let num_blocks = blocks.len(); + if num_blocks == 0 { + return vec![]; + } + let n_a = ntt_mat.len(); + let inner_width = ntt_mat.first().map_or(0, |row| row.len()); + if inner_width == 0 || n_a == 0 { + return vec![vec![CyclotomicRing::::zero(); n_a]; num_blocks]; + } + + let lut = DigitMontLut::new(params); + let tw = (TARGET_L2_CACHE_BYTES / (K * D * size_of::())).max(1); + let num_tiles = inner_width.div_ceil(tw); + + let final_accs: Vec>> = cfg_fold_reduce!( + 0..num_tiles, + || vec![vec![CyclotomicCrtNtt::::zero(); n_a]; num_blocks], + |mut accs: Vec>>, tile_idx| { + let tile_start = tile_idx * tw; + let tile_end = (tile_start + tw).min(inner_width); + + for block_idx in 0..num_blocks { + let block = blocks[block_idx]; + if tile_start >= block.len() { + continue; + } + let block_tile_end = tile_end.min(block.len()); + for (j, digit) in block[tile_start..block_tile_end].iter().enumerate() { + if is_zero_plane(digit) { + continue; + } + let ntt_d = CyclotomicCrtNtt::from_i8_with_lut(digit, params, &lut); + for (acc, mat_row) in accs[block_idx].iter_mut().zip(ntt_mat.iter()) { + accumulate_pointwise_product_into( + acc, + &mat_row[tile_start + j], + &ntt_d, + params, + ); + } + } + } + accs + }, + |mut a: Vec>>, b| { + for block_idx in 0..num_blocks { + for row in 0..n_a { + add_ntt_into(&mut a[block_idx][row], &b[block_idx][row], params); + } + } + a + } + ); + + cfg_into_iter!(final_accs) + .map(|row_accs| { + row_accs + .into_iter() + .map(|acc| acc.to_ring_with_params(params)) + .collect() + }) + .collect() +} + +fn mat_vec_mul_i8_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + ntt_mat: &[Vec>], + blocks: &[&[CyclotomicRing]], + num_digits: usize, + log_basis: u32, + params: &CrtNttParamSet, +) -> Vec>> { + let num_blocks = blocks.len(); + if num_blocks == 0 { + return vec![]; + } + let n_a = ntt_mat.len(); + let inner_width = ntt_mat.first().map_or(0, |row| row.len()); + if inner_width == 0 || n_a == 0 { + return vec![vec![CyclotomicRing::::zero(); n_a]; num_blocks]; + } + + let lut = DigitMontLut::new(params); + let tw = (TARGET_L2_CACHE_BYTES / (K * D * size_of::())).max(1); + let num_tiles = inner_width.div_ceil(tw); + + let final_accs: Vec>> = cfg_fold_reduce!( + 0..num_tiles, + || vec![vec![CyclotomicCrtNtt::::zero(); n_a]; num_blocks], + |mut accs: Vec>>, tile_idx| { + let tile_start = tile_idx * tw; + let tile_end = (tile_start + tw).min(inner_width); + let ring_start = tile_start / num_digits; + let ring_end = ((tile_end - 1) / num_digits) + 1; + let digit_offset = tile_start - ring_start * num_digits; + let tile_len = tile_end - tile_start; + + for block_idx in 0..num_blocks { + let block = blocks[block_idx]; + if ring_start >= block.len() { + continue; + } + let block_ring_end = ring_end.min(block.len()); + let partial_coeffs = &block[ring_start..block_ring_end]; + let all_digits = decompose_block_i8(partial_coeffs, num_digits, log_basis); + let available = all_digits.len().saturating_sub(digit_offset); + let n = tile_len.min(available); + + for (j, digit) in all_digits[digit_offset..digit_offset + n] + .iter() + .enumerate() + { + if is_zero_plane(digit) { + continue; + } + let ntt_d = CyclotomicCrtNtt::from_i8_with_lut(digit, params, &lut); + for (acc, mat_row) in accs[block_idx].iter_mut().zip(ntt_mat.iter()) { + accumulate_pointwise_product_into( + acc, + &mat_row[tile_start + j], + &ntt_d, + params, + ); + } + } + } + accs + }, + |mut a: Vec>>, b| { + for block_idx in 0..num_blocks { + for row in 0..n_a { + add_ntt_into(&mut a[block_idx][row], &b[block_idx][row], params); + } + } + a + } + ); + + cfg_into_iter!(final_accs) + .map(|row_accs| { + row_accs + .into_iter() + .map(|acc| acc.to_ring_with_params(params)) + .collect() + }) + .collect() +} + +/// Column-tiled mat-vec for a single pre-decomposed i8 digit vector. +/// +/// Same tiling strategy as [`mat_vec_mul_ntt_i8`] but for a single +/// input vector of i8 digit planes (already decomposed). Tiles the matrix +/// columns to keep each tile in L2, eliminating the full `ntt_vec` +/// materialization of the non-tiled path. +/// Tile width is auto-computed from ring parameters and target L2 cache size. +#[tracing::instrument(skip_all, name = "mat_vec_mul_ntt_single_i8")] +pub fn mat_vec_mul_ntt_single_i8( + slot: &NttSlotCache, + vec: &[[i8; D]], +) -> Vec> { + dispatch_slot!(slot, mat_vec_mul_single_i8_with_params, vec) +} + +fn mat_vec_mul_single_i8_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + ntt_mat: &[Vec>], + vec: &[[i8; D]], + params: &CrtNttParamSet, +) -> Vec> { + let n_a = ntt_mat.len(); + let inner_width = ntt_mat.first().map_or(0, |row| row.len()); + if inner_width == 0 || n_a == 0 { + return vec![CyclotomicRing::::zero(); n_a]; + } + + let lut = DigitMontLut::new(params); + let vec_len = vec.len().min(inner_width); + let tw = (TARGET_L2_CACHE_BYTES / (K * D * size_of::())).max(1); + let num_tiles = vec_len.div_ceil(tw); + + let final_accs: Vec> = cfg_fold_reduce!( + 0..num_tiles, + || vec![CyclotomicCrtNtt::::zero(); n_a], + |mut accs: Vec>, tile_idx| { + let tile_start = tile_idx * tw; + let tile_end = (tile_start + tw).min(vec_len); + for (j, digit) in vec[tile_start..tile_end].iter().enumerate() { + if is_zero_plane(digit) { + continue; + } + let ntt_d = CyclotomicCrtNtt::from_i8_with_lut(digit, params, &lut); + for (acc, mat_row) in accs.iter_mut().zip(ntt_mat.iter()) { + accumulate_pointwise_product_into( + acc, + &mat_row[tile_start + j], + &ntt_d, + params, + ); + } + } + accs + }, + |mut a: Vec>, b| { + for row in 0..n_a { + add_ntt_into(&mut a[row], &b[row], params); + } + a + } + ); + + final_accs + .into_iter() + .map(|acc| acc.to_ring_with_params(params)) + .collect() +} + +/// Like [`unreduced_quotient_rows_ntt_cached`] but accepts i8 digit planes +/// instead of ring elements, using direct i8 -> CRT+NTT conversion. +/// Column-tiled with zero-skip for all-zero digit planes. +#[tracing::instrument(skip_all, name = "unreduced_quotient_rows_ntt_cached_i8")] +pub fn unreduced_quotient_rows_ntt_cached_i8( + slot: &NttSlotCache, + vec: &[[i8; D]], +) -> Vec> { + match slot { + NttSlotCache::Q32 { + neg, + cyc, + params: p, + } => quotient_single_i8_with_params(neg, cyc, vec, p), + NttSlotCache::Q64 { + neg, + cyc, + params: p, + } => quotient_single_i8_with_params(neg, cyc, vec, p), + NttSlotCache::Q128 { + neg, + cyc, + params: p, + } => quotient_single_i8_with_params(neg, cyc, vec, p), + } +} + +fn quotient_single_i8_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + ntt_neg: &[Vec>], + ntt_cyc: &[Vec>], + vec: &[[i8; D]], + params: &CrtNttParamSet, +) -> Vec> { + let n_a = ntt_neg.len(); + let inner_width = ntt_neg.first().map_or(0, |row| row.len()); + if inner_width == 0 || n_a == 0 { + return vec![CyclotomicRing::::zero(); n_a]; + } + + let lut = DigitMontLut::new(params); + let vec_len = vec.len().min(inner_width); + let tw = (TARGET_L2_CACHE_BYTES / (K * D * size_of::())).max(1); + let num_tiles = vec_len.div_ceil(tw); + + let zero = CyclotomicCrtNtt::::zero(); + + let (final_neg, final_cyc): ( + Vec>, + Vec>, + ) = cfg_fold_reduce!( + 0..num_tiles, + || (vec![zero.clone(); n_a], vec![zero.clone(); n_a]), + |mut accs: ( + Vec>, + Vec> + ), + tile_idx| { + let tile_start = tile_idx * tw; + let tile_end = (tile_start + tw).min(vec_len); + for (j, digit) in vec[tile_start..tile_end].iter().enumerate() { + if is_zero_plane(digit) { + continue; + } + let ntt_d_neg = CyclotomicCrtNtt::from_i8_with_lut(digit, params, &lut); + let ntt_d_cyc = CyclotomicCrtNtt::from_i8_cyclic_with_lut(digit, params, &lut); + let col = tile_start + j; + for (row, (acc_neg, acc_cyc)) in + accs.0.iter_mut().zip(accs.1.iter_mut()).enumerate() + { + accumulate_pointwise_product_into( + acc_neg, + &ntt_neg[row][col], + &ntt_d_neg, + params, + ); + accumulate_pointwise_product_into( + acc_cyc, + &ntt_cyc[row][col], + &ntt_d_cyc, + params, + ); + } + } + accs + }, + |mut a: ( + Vec>, + Vec> + ), + b| { + for row in 0..n_a { + add_ntt_into(&mut a.0[row], &b.0[row], params); + add_ntt_into(&mut a.1[row], &b.1[row], params); + } + a + } + ); + + final_neg + .into_iter() + .zip(final_cyc) + .map(|(neg_acc, cyc_acc)| { + let neg_ring: CyclotomicRing = neg_acc.to_ring_with_params(params); + let cyc_ring: CyclotomicRing = cyc_acc.to_ring_cyclic(params); + let neg_c = neg_ring.coefficients(); + let cyc_c = cyc_ring.coefficients(); + let q: [F; D] = from_fn(|k| (cyc_c[k] - neg_c[k]) * F::TWO_INV); + CyclotomicRing::from_coefficients(q) + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::{ + mat_vec_mul_crt_ntt, mat_vec_mul_crt_ntt_many, mat_vec_mul_digits_i8_with_params, + mat_vec_mul_i8_with_params, mat_vec_mul_unchecked, precompute_dense_mat_ntt_with_params, + }; + use crate::algebra::{CyclotomicRing, Fp64}; + use crate::protocol::commitment::utils::crt_ntt::{ + select_crt_ntt_params, ProtocolCrtNttParams, + }; + use crate::FromSmallInt; + + #[test] + fn dense_mat_vec_matches_schoolbook_q32_d64() { + type F = Fp64<4294967197>; + const D: usize = 64; + let mat: Vec>> = (0..3) + .map(|i| { + (0..4) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + F::from_u64((i as u64 * 10_000 + j as u64 * 100 + k as u64 + 1) % 97) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + let vec: Vec> = (0..4) + .map(|j| { + let coeffs = + std::array::from_fn(|k| F::from_u64((j as u64 * 50 + k as u64 + 3) % 89)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + + let schoolbook = mat_vec_mul_unchecked(&mat, &vec); + let crt_ntt = mat_vec_mul_crt_ntt(&mat, &vec).expect("Q32 dispatch should succeed"); + assert_eq!(schoolbook, crt_ntt); + } + + #[test] + fn dense_mat_vec_matches_schoolbook_q64_dispatch_for_large_d() { + type F = Fp64<4294967197>; + const D: usize = 128; + let mat: Vec>> = (0..2) + .map(|i| { + (0..2) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + F::from_u64((i as u64 * 20_000 + j as u64 * 300 + k as u64 + 7) % 113) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + let vec: Vec> = (0..2) + .map(|j| { + let coeffs = + std::array::from_fn(|k| F::from_u64((j as u64 * 70 + k as u64 + 11) % 101)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + + let schoolbook = mat_vec_mul_unchecked(&mat, &vec); + let crt_ntt = mat_vec_mul_crt_ntt(&mat, &vec).expect("Q64 dispatch should succeed"); + assert_eq!(schoolbook, crt_ntt); + } + + #[test] + fn dense_mat_vec_many_matches_individual_crt_ntt_q32_d64() { + type F = Fp64<4294967197>; + const D: usize = 64; + let mat: Vec>> = (0..3) + .map(|i| { + (0..4) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + F::from_u64((i as u64 * 10_000 + j as u64 * 100 + k as u64 + 1) % 97) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + + let vecs: Vec>> = (0..3) + .map(|seed| { + (0..4) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + F::from_u64((seed as u64 * 700 + j as u64 * 50 + k as u64 + 3) % 89) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + + let expected: Vec>> = vecs + .iter() + .map(|v| mat_vec_mul_crt_ntt(&mat, v).expect("single CRT+NTT mat-vec should succeed")) + .collect(); + + let got = + mat_vec_mul_crt_ntt_many(&mat, &vecs).expect("batched CRT+NTT mat-vec should succeed"); + assert_eq!(expected, got); + } + + #[test] + fn mat_vec_mul_digits_i8_matches_num_digits_one_roundtrip() { + type F = Fp64<4294967197>; + const D: usize = 64; + let log_basis = 3; + + let mat: Vec>> = (0..3) + .map(|i| { + (0..6) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + let raw = (i as i64 * 19 + j as i64 * 7 + k as i64) % 7; + F::from_i64(raw - 3) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + + let digit_blocks: Vec> = vec![ + (0..6) + .map(|j| std::array::from_fn(|k| ((j + 2 * k) % 7) as i8 - 3)) + .collect(), + (0..4) + .map(|j| std::array::from_fn(|k| ((2 * j + k) % 7) as i8 - 3)) + .collect(), + vec![], + ]; + + let ring_blocks: Vec>> = digit_blocks + .iter() + .map(|block| { + block + .iter() + .map(|digit| { + let coeffs = std::array::from_fn(|k| F::from_i64(digit[k] as i64)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + + let ring_block_slices: Vec<&[CyclotomicRing]> = + ring_blocks.iter().map(Vec::as_slice).collect(); + let digit_block_slices: Vec<&[[i8; D]]> = digit_blocks.iter().map(Vec::as_slice).collect(); + + match select_crt_ntt_params::().expect("CRT+NTT params should exist") { + ProtocolCrtNttParams::Q32(params) => { + let ntt_mat = precompute_dense_mat_ntt_with_params(&mat, ¶ms); + let via_roundtrip = + mat_vec_mul_i8_with_params(&ntt_mat, &ring_block_slices, 1, log_basis, ¶ms); + let direct = + mat_vec_mul_digits_i8_with_params(&ntt_mat, &digit_block_slices, ¶ms); + assert_eq!(via_roundtrip, direct); + } + _ => panic!("unexpected parameter family"), + } + } +} diff --git a/src/protocol/commitment/utils/math.rs b/src/protocol/commitment/utils/math.rs new file mode 100644 index 00000000..16cf8618 --- /dev/null +++ b/src/protocol/commitment/utils/math.rs @@ -0,0 +1,14 @@ +//! Small math helpers for commitment internals. + +use crate::error::HachiError; + +/// Compute `2^exp` with overflow checks. +/// +/// # Errors +/// +/// Returns `InvalidSetup` if `2^exp` does not fit in `usize`. +pub(in crate::protocol::commitment) fn checked_pow2(exp: usize) -> Result { + 1usize + .checked_shl(exp as u32) + .ok_or_else(|| HachiError::InvalidSetup(format!("2^{exp} does not fit usize"))) +} diff --git a/src/protocol/commitment/utils/matrix.rs b/src/protocol/commitment/utils/matrix.rs new file mode 100644 index 00000000..9be13b11 --- /dev/null +++ b/src/protocol/commitment/utils/matrix.rs @@ -0,0 +1,134 @@ +//! Matrix sampling helpers for setup. + +use crate::algebra::ring::CyclotomicRing; +use crate::{FieldCore, FieldSampling}; +use rand_core::{CryptoRng, RngCore}; +use sha3::digest::{ExtendableOutput, Update, XofReader}; +use sha3::Shake256; + +/// Public seed used to derive commitment matrices. +pub(crate) type PublicMatrixSeed = [u8; 32]; + +const PUBLIC_MATRIX_DOMAIN: &[u8] = b"hachi/commitment/public-matrix"; + +/// Fixed public seed for deterministic, reproducible setup. +pub(crate) fn sample_public_matrix_seed() -> PublicMatrixSeed { + let mut seed = [0u8; 32]; + seed[..8].copy_from_slice(&0xDEAD_BEEF_CAFE_BABEu64.to_le_bytes()); + seed +} + +/// Derive a public matrix from a seed using domain-separated SHAKE expansion. +/// +/// This follows the same high-level pattern used in NIST lattice specs: +/// derive deterministic public structure from a seed + indices, then sample +/// coefficients via rejection-sampling at the field layer. +/// +/// NOTE: Potential future hardening: +/// move toward stricter ML-KEM/ML-DSA-style byte layout and parsing rules +/// (fixed-format seed/index encoding and scheme-specific expansion details) +/// if we decide to maximize standards-shape interoperability. +pub(crate) fn derive_public_matrix( + rows: usize, + cols: usize, + seed: &PublicMatrixSeed, + matrix_label: &[u8], +) -> Vec>> { + (0..rows) + .map(|r| { + (0..cols) + .map(|c| { + let mut entry_rng = ShakeXofRng::new(seed, matrix_label, rows, cols, r, c); + CyclotomicRing::random(&mut entry_rng) + }) + .collect() + }) + .collect() +} + +struct ShakeXofRng { + reader: Box, +} + +impl ShakeXofRng { + // Dimensions (`rows`, `cols`) are intentionally excluded from the domain + // separator so that a matrix derived at one size is a prefix of the same + // matrix derived at a larger size. Each entry is uniquely identified by + // `(seed, matrix_label, row, col)`, which is sufficient for collision + // resistance while enabling setup reuse across poly/mega-poly layouts. + fn new( + seed: &PublicMatrixSeed, + matrix_label: &[u8], + _rows: usize, + _cols: usize, + row: usize, + col: usize, + ) -> Self { + let mut xof = Shake256::default(); + absorb_len_prefixed(&mut xof, b"domain", PUBLIC_MATRIX_DOMAIN); + absorb_len_prefixed(&mut xof, b"seed", seed); + absorb_len_prefixed(&mut xof, b"matrix", matrix_label); + absorb_len_prefixed(&mut xof, b"row", &(row as u64).to_le_bytes()); + absorb_len_prefixed(&mut xof, b"col", &(col as u64).to_le_bytes()); + Self { + reader: Box::new(xof.finalize_xof()), + } + } +} + +impl RngCore for ShakeXofRng { + fn next_u32(&mut self) -> u32 { + let mut buf = [0u8; 4]; + self.fill_bytes(&mut buf); + u32::from_le_bytes(buf) + } + + fn next_u64(&mut self) -> u64 { + let mut buf = [0u8; 8]; + self.fill_bytes(&mut buf); + u64::from_le_bytes(buf) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.reader.read(dest); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +impl CryptoRng for ShakeXofRng {} + +fn absorb_len_prefixed(xof: &mut Shake256, label: &[u8], data: &[u8]) { + xof.update(&(label.len() as u64).to_le_bytes()); + xof.update(label); + xof.update(&(data.len() as u64).to_le_bytes()); + xof.update(data); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Fp64; + + type F = Fp64<4294967197>; + const D: usize = 64; + + #[test] + fn matrix_derivation_is_deterministic_for_same_seed() { + let seed = [42u8; 32]; + let m1 = derive_public_matrix::(3, 5, &seed, b"A"); + let m2 = derive_public_matrix::(3, 5, &seed, b"A"); + assert_eq!(m1, m2); + } + + #[test] + fn matrix_derivation_domain_separates_labels() { + let seed = [7u8; 32]; + let a = derive_public_matrix::(2, 3, &seed, b"A"); + let b = derive_public_matrix::(2, 3, &seed, b"B"); + assert_ne!(a, b); + } +} diff --git a/src/protocol/commitment/utils/mod.rs b/src/protocol/commitment/utils/mod.rs new file mode 100644 index 00000000..6e768fb2 --- /dev/null +++ b/src/protocol/commitment/utils/mod.rs @@ -0,0 +1,9 @@ +//! Utility helpers for commitment internals. + +pub mod crt_ntt; +pub mod flat_matrix; +pub mod linear; +pub(crate) mod math; +pub(crate) mod matrix; +pub(crate) mod norm; +pub mod ntt_cache; diff --git a/src/protocol/commitment/utils/norm.rs b/src/protocol/commitment/utils/norm.rs new file mode 100644 index 00000000..cc01f9a7 --- /dev/null +++ b/src/protocol/commitment/utils/norm.rs @@ -0,0 +1,48 @@ +//! Infinity norm utilities for ring elements over Z_q. + +use crate::algebra::ring::CyclotomicRing; +use crate::CanonicalField; + +/// Detect the field modulus from the canonical representation. +/// +/// Uses the identity: the canonical form of `−1` in `Z_q` is `q − 1`. +pub(crate) fn detect_field_modulus() -> u128 { + (-F::one()).to_canonical_u128() + 1 +} + +/// Centered absolute value of a field element. +/// +/// Maps canonical representation `v ∈ [0, q)` to `min(v, q − v)`. +#[inline] +pub(crate) fn centered_abs(x: F, modulus: u128) -> u128 { + let v = x.to_canonical_u128(); + let half = modulus / 2; + if v <= half { + v + } else { + modulus - v + } +} + +/// L∞ norm of a single ring element (maximum centered coefficient magnitude). +pub(crate) fn ring_inf_norm( + r: &CyclotomicRing, + modulus: u128, +) -> u128 { + r.coefficients() + .iter() + .map(|c| centered_abs(*c, modulus)) + .max() + .unwrap_or(0) +} + +/// L∞ norm of a vector of ring elements. +pub(crate) fn vec_inf_norm( + v: &[CyclotomicRing], + modulus: u128, +) -> u128 { + v.iter() + .map(|r| ring_inf_norm(r, modulus)) + .max() + .unwrap_or(0) +} diff --git a/src/protocol/commitment/utils/ntt_cache.rs b/src/protocol/commitment/utils/ntt_cache.rs new file mode 100644 index 00000000..6e29bd26 --- /dev/null +++ b/src/protocol/commitment/utils/ntt_cache.rs @@ -0,0 +1,105 @@ +//! Multi-D NTT cache management. +//! +//! Wraps per-D [`NttSlotCache`] bundles with lazy computation and memoization. +//! A single [`MultiDNttCaches`] can hold NTT caches for any subset of supported +//! ring dimensions, built on demand from a shared [`FlatMatrix`]. + +use super::crt_ntt::{build_ntt_slot, NttSlotCache}; +use super::flat_matrix::FlatMatrix; +use crate::error::HachiError; +use crate::{CanonicalField, FieldCore}; + +/// Per-matrix NTT caches for multiple ring dimensions. +/// +/// Each field is lazily populated by the `get_or_build_*` methods. +/// Fields use `Box>` to keep the struct's inline size +/// small: `NttSlotCache<1024>` alone is ~80 KB due to inline twiddle +/// arrays, so storing them unboxed would make this struct ~155 KB. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MultiDNttCaches { + /// Cache for D=64. + pub d64: Option>>, + /// Cache for D=128. + pub d128: Option>>, + /// Cache for D=256. + pub d256: Option>>, + /// Cache for D=512. + pub d512: Option>>, + /// Cache for D=1024. + pub d1024: Option>>, +} + +macro_rules! impl_get_or_build { + ($fn_name:ident, $field:ident, $d_val:expr) => { + /// Get (or build and memoize) the NTT cache for this ring dimension. + /// + /// # Errors + /// + /// Returns an error if no CRT+NTT parameter set matches the field and D. + pub fn $fn_name( + &mut self, + mat: &FlatMatrix, + ) -> Result<&NttSlotCache<$d_val>, HachiError> { + if self.$field.is_none() { + self.$field = Some(Box::new(build_ntt_slot(mat.view::<$d_val>())?)); + } + Ok(self.$field.as_deref().unwrap()) + } + }; +} + +impl MultiDNttCaches { + /// Empty cache set. + pub fn new() -> Self { + Self { + d64: None, + d128: None, + d256: None, + d512: None, + d1024: None, + } + } + + impl_get_or_build!(get_or_build_64, d64, 64); + impl_get_or_build!(get_or_build_128, d128, 128); + impl_get_or_build!(get_or_build_256, d256, 256); + impl_get_or_build!(get_or_build_512, d512, 512); + impl_get_or_build!(get_or_build_1024, d1024, 1024); + + /// Check if a cache for dimension `d` is already populated. + pub fn has(&self, d: usize) -> bool { + match d { + 64 => self.d64.is_some(), + 128 => self.d128.is_some(), + 256 => self.d256.is_some(), + 512 => self.d512.is_some(), + 1024 => self.d1024.is_some(), + _ => false, + } + } +} + +impl Default for MultiDNttCaches { + fn default() -> Self { + Self::new() + } +} + +/// Bundle of three multi-D NTT caches for the A, B, and D matrices. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[allow(non_snake_case)] +pub struct MultiDNttBundle { + /// NTT caches for the A matrix at various ring dimensions. + pub A: MultiDNttCaches, + /// NTT caches for the B matrix at various ring dimensions. + pub B: MultiDNttCaches, + /// NTT caches for the D matrix at various ring dimensions. + pub D_mat: MultiDNttCaches, +} + +impl MultiDNttBundle { + /// Empty bundle. + pub fn new() -> Self { + Self::default() + } +} diff --git a/src/protocol/commitment_scheme.rs b/src/protocol/commitment_scheme.rs new file mode 100644 index 00000000..52ac9bd4 --- /dev/null +++ b/src/protocol/commitment_scheme.rs @@ -0,0 +1,1481 @@ +//! Commitment scheme trait implementation. + +use crate::algebra::fields::wide::HasWide; +use crate::algebra::fields::HasUnreducedOps; +use crate::algebra::CyclotomicRing; +#[cfg(debug_assertions)] +use crate::algebra::SparseChallenge; +use crate::error::HachiError; +use crate::primitives::poly::multilinear_lagrange_basis; +use crate::primitives::serialization::Valid; +use crate::protocol::commitment::utils::crt_ntt::NttSlotCache; +use crate::protocol::commitment::utils::linear::{flatten_i8_blocks, mat_vec_mul_ntt_single_i8}; +use crate::protocol::commitment::utils::ntt_cache::MultiDNttBundle; +use crate::protocol::commitment::{ + AppendToTranscript, CommitmentConfig, CommitmentScheme, HachiCommitmentCore, + HachiCommitmentLayout, HachiExpandedSetup, HachiProverSetup, HachiVerifierSetup, + RingCommitment, RingCommitmentScheme, +}; +use crate::protocol::hachi_poly_ops::{BalancedDigitPoly, HachiPolyOps}; +use crate::protocol::opening_point::{BasisMode, RingOpeningPoint}; +use crate::protocol::proof::{ + FlatCommitmentHint, FlatRingVec, HachiCommitmentHint, HachiLevelProof, HachiProof, PackedDigits, +}; +#[cfg(any(test, debug_assertions))] +use crate::protocol::quadratic_equation::compute_m_a_reference; +use crate::protocol::quadratic_equation::QuadraticEquation; +use crate::protocol::ring_switch::eval_ring_at; +#[cfg(debug_assertions)] +use crate::protocol::ring_switch::m_row_count; +#[cfg(test)] +use crate::protocol::ring_switch::{build_alpha_evals_y, compute_m_evals_x}; +use crate::protocol::ring_switch::{ + build_w_evals, commit_w, ring_switch_build_w, ring_switch_finalize, ring_switch_verifier, + w_ring_element_count, RingSwitchOutput, WCommitmentConfig, +}; +use crate::protocol::sumcheck::eq_poly::EqPolynomial; +use crate::protocol::sumcheck::hachi_sumcheck::{HachiSumcheckProver, HachiSumcheckVerifier}; +#[cfg(debug_assertions)] +use crate::protocol::sumcheck::{multilinear_eval, range_check_eval}; +use crate::protocol::sumcheck::{multilinear_eval_small, prove_sumcheck, verify_sumcheck}; +use crate::protocol::transcript::labels::{ + ABSORB_COMMITMENT, ABSORB_EVALUATION_CLAIMS, CHALLENGE_SUMCHECK_BATCH, CHALLENGE_SUMCHECK_ROUND, +}; +use crate::protocol::transcript::Transcript; +use crate::{dispatch_ring_dim, dispatch_with_ntt}; +use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt}; +#[cfg(debug_assertions)] +use std::iter; +use std::marker::PhantomData; +use std::time::Instant; + +#[cfg(test)] +use crate::protocol::ring_switch::expand_m_a; +#[cfg(test)] +use crate::protocol::transcript::labels::{ + ABSORB_SUMCHECK_W, CHALLENGE_RING_SWITCH, DOMAIN_HACHI_PROTOCOL, +}; +#[cfg(test)] +use crate::protocol::transcript::Blake2bTranscript; +#[cfg(test)] +use crate::protocol::SmallTestCommitmentConfig; + +/// Minimum w vector length (in field elements) below which further folding +/// is not beneficial. When `w.len() <= MIN_W_LEN_FOR_FOLDING`, the prover +/// sends `w` directly instead of recursing. +const MIN_W_LEN_FOR_FOLDING: usize = 4096; + +/// Minimum shrink ratio (next_w / prev_w) below which further folding +/// stops being worthwhile. If the w vector doesn't shrink by at least +/// this factor, the overhead of another fold level outweighs the saving. +const MIN_SHRINK_RATIO: f64 = 0.5; + +#[inline] +fn relation_claim_from_rows( + tau1: &[F], + alpha: F, + v: &[CyclotomicRing], + u: &[CyclotomicRing], + y_ring: &CyclotomicRing, +) -> F { + let eq_tau1 = EqPolynomial::evals(tau1); + let mut acc = F::zero(); + let mut row_idx = 0usize; + + for r in v { + if row_idx >= eq_tau1.len() { + return acc; + } + acc += eq_tau1[row_idx] * eval_ring_at(r, &alpha); + row_idx += 1; + } + for r in u { + if row_idx >= eq_tau1.len() { + return acc; + } + acc += eq_tau1[row_idx] * eval_ring_at(r, &alpha); + row_idx += 1; + } + if row_idx < eq_tau1.len() { + acc += eq_tau1[row_idx] * eval_ring_at(y_ring, &alpha); + } + acc +} + +/// End-to-end PCS wrapper, generic over ring degree `D` and config `Cfg`. +#[derive(Clone, Copy, Debug, Default)] +pub struct HachiCommitmentScheme { + _cfg: PhantomData, +} + +/// Output from a single prove level, needed to chain into the next level. +/// +/// D-agnostic: ring elements are erased into [`HachiLevelProof`] and +/// the commitment hint is stored as [`FlatCommitmentHint`]. +struct ProveLevelOutput { + level_proof: HachiLevelProof, + w: Vec, + w_hint: FlatCommitmentHint, + sumcheck_challenges: Vec, + num_u: usize, + num_l: usize, +} + +/// Prove one fold level: quad_eq -> ring_switch -> sumcheck. +/// +/// Generic over the commitment config so it works for both the original +/// polynomial (using `Cfg`) and recursive w-openings (using `WCommitmentConfig`). +type CommitFn<'a, F> = + Box Result<(FlatRingVec, FlatCommitmentHint), HachiError> + 'a>; + +#[cfg(debug_assertions)] +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn prove_level_diagnostic( + expanded: &HachiExpandedSetup, + opening_point: &RingOpeningPoint, + challenges: &[SparseChallenge], + rs: &RingSwitchOutput, + v: &[CyclotomicRing], + u: &[CyclotomicRing], + y_ring: &CyclotomicRing, + layout: HachiCommitmentLayout, + level: usize, +) where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, +{ + let m_a = + compute_m_a_reference::(expanded, opening_point, challenges, &rs.alpha, layout) + .expect("compute_m_a diagnostic failed"); + + let x_len = 1usize << rs.num_u; + let d = D; + + let mut w_at_alpha = vec![F::zero(); x_len]; + for (x, w_at_alpha_x) in w_at_alpha.iter_mut().enumerate() { + let mut val = F::zero(); + for y in 0..d { + let idx = x + y * x_len; + if idx < rs.w_evals.len() { + val += rs.alpha_evals_y[y] * F::from_i64(rs.w_evals[idx] as i64); + } + } + *w_at_alpha_x = val; + } + + let num_rows = m_row_count::(); + let y_full: Vec = v + .iter() + .chain(u.iter()) + .chain(iter::once(y_ring)) + .map(|r| eval_ring_at(r, &rs.alpha)) + .collect(); + + eprintln!( + " [hachi prove L{level}] per-row M*w=y diagnostic (num_rows={num_rows}, x_len={x_len}, m_a_cols={}):", + m_a.first().map_or(0, |r| r.len()), + ); + for i in 0..num_rows { + let mw_i: F = m_a[i] + .iter() + .enumerate() + .fold(F::zero(), |acc, (x, &m_ix)| { + acc + m_ix * w_at_alpha.get(x).copied().unwrap_or(F::zero()) + }); + let y_i = if i < y_full.len() { + y_full[i] + } else { + F::zero() + }; + let residual = mw_i - y_i; + let row_name = match i { + _ if i < Cfg::N_D => "D", + _ if i < Cfg::N_D + Cfg::N_B => "B", + _ if i == Cfg::N_D + Cfg::N_B => "bTw", + _ if i == Cfg::N_D + Cfg::N_B + 1 => "challenge_fold", + _ => "A", + }; + eprintln!( + " row {i} ({row_name}): match={}, residual_is_zero={}, mw_is_zero={}, y_is_zero={}", + residual.is_zero(), + residual.is_zero(), + mw_i.is_zero(), + y_i.is_zero(), + ); + } + + let verifier_claim = relation_claim_from_rows::(&rs.tau1, rs.alpha, v, u, y_ring); + let x_mask = x_len - 1; + let mut prover_claim = F::zero(); + for (idx, &w) in rs.w_evals.iter().enumerate() { + prover_claim += + F::from_i64(w as i64) * rs.alpha_evals_y[idx >> rs.num_u] * rs.m_evals_x[idx & x_mask]; + } + eprintln!( + " [hachi prove L{level}] relation_claim cross-check: match={}, prover_is_zero={}, verifier_is_zero={}", + verifier_claim == prover_claim, + prover_claim.is_zero(), + verifier_claim.is_zero(), + ); +} + +#[cfg(debug_assertions)] +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn prove_level_selfcheck( + tau0: &[F], + sumcheck_challenges: &[F], + w_eval: F, + b: usize, + batching_coeff: F, + alpha_evals_y: &[F], + m_evals_x: &[F], + num_u: usize, + final_claim: F, + level: usize, +) { + let eq_val = EqPolynomial::mle(tau0, sumcheck_challenges); + let norm_oracle = eq_val * range_check_eval(w_eval, b); + let (x_ch, y_ch) = sumcheck_challenges.split_at(num_u); + let alpha_val = multilinear_eval(alpha_evals_y, y_ch).unwrap(); + let m_val = multilinear_eval(m_evals_x, x_ch).unwrap(); + let relation_oracle = w_eval * alpha_val * m_val; + let prover_expected = batching_coeff * norm_oracle + relation_oracle; + if prover_expected != final_claim { + eprintln!(" [hachi prove L{level}] PROVER self-check FAILED: expected != final_claim"); + } else { + eprintln!(" [hachi prove L{level}] PROVER self-check OK"); + } +} + +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn prove_one_level( + expanded: &HachiExpandedSetup, + ntt_a: &NttSlotCache, + ntt_b: &NttSlotCache, + ntt_d: &NttSlotCache, + commit_w_fn: CommitFn<'_, F>, + poly: &P, + opening_point: &[F], + hint: HachiCommitmentHint, + transcript: &mut T, + commitment: &RingCommitment, + basis: BasisMode, + level: usize, + layout: HachiCommitmentLayout, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + HasUnreducedOps + HasWide, + T: Transcript, + Cfg: CommitmentConfig, + P: HachiPolyOps, +{ + { + let x: u8 = 0; + eprintln!( + " [prove_one_level L{level}] stack ~= {:#x}", + &x as *const u8 as usize + ); + } + let alpha = Cfg::D.trailing_zeros() as usize; + if opening_point.len() < alpha { + return Err(HachiError::InvalidPointDimension { + expected: alpha, + actual: opening_point.len(), + }); + } + let target_num_vars = layout.m_vars + layout.r_vars + alpha; + let mut padded_point = opening_point.to_vec(); + padded_point.resize(target_num_vars, F::zero()); + let outer_point = &padded_point[alpha..]; + + let ring_opening_point = { + let _span = tracing::info_span!("ring_opening_point", level).entered(); + ring_opening_point_from_field::(outer_point, layout.r_vars, layout.m_vars, basis)? + }; + + let t0 = Instant::now(); + let fold_scalars = &ring_opening_point.a; + let eval_outer_scalars = &ring_opening_point.b; + let (y_ring, w_folded) = { + let _span = tracing::info_span!("evaluate_and_fold", level).entered(); + poly.evaluate_and_fold(eval_outer_scalars, fold_scalars, layout.block_len) + }; + eprintln!( + " [hachi prove L{level}] evaluate_and_fold: {:.2}s (num_ring_elems={})", + t0.elapsed().as_secs_f64(), + poly.num_ring_elems() + ); + + commitment.append_to_transcript(ABSORB_COMMITMENT, transcript); + for pt in &padded_point { + transcript.append_field(ABSORB_EVALUATION_CLAIMS, pt); + } + transcript.append_serde(ABSORB_EVALUATION_CLAIMS, &y_ring); + + let t1 = Instant::now(); + let mut quad_eq = Box::new(QuadraticEquation::::new_prover( + ntt_d, + ring_opening_point, + poly, + w_folded, + hint, + transcript, + commitment, + &y_ring, + layout, + )?); + eprintln!( + " [hachi prove L{level}] quad_eq new_prover: {:.2}s", + t1.elapsed().as_secs_f64() + ); + + let t2 = Instant::now(); + let w = + ring_switch_build_w::(&mut quad_eq, expanded, ntt_a, ntt_b, ntt_d, layout)?; + eprintln!( + " [hachi prove L{level}] ring_switch_build_w: {:.2}s (w.len()={})", + t2.elapsed().as_secs_f64(), + w.len() + ); + + let t_cw = Instant::now(); + let (w_commitment_flat, w_hint_flat) = commit_w_fn(&w)?; + eprintln!( + " [hachi prove L{level}] commit_w: {:.2}s (ring_dim={})", + t_cw.elapsed().as_secs_f64(), + w_commitment_flat.ring_dim() + ); + + let rs = ring_switch_finalize::( + &quad_eq, + expanded, + transcript, + w, + w_commitment_flat, + w_hint_flat, + layout, + )?; + eprintln!( + " [hachi prove L{level}] ring_switch_finalize: {:.2}s (num_u={}, num_l={})", + t2.elapsed().as_secs_f64(), + rs.num_u, + rs.num_l + ); + + let batching_coeff: F = transcript.challenge_scalar(CHALLENGE_SUMCHECK_BATCH); + + #[cfg(debug_assertions)] + prove_level_diagnostic::( + expanded, + quad_eq.opening_point(), + &quad_eq.challenges, + &rs, + &quad_eq.v, + &commitment.u, + &y_ring, + layout, + level, + ); + + let t3 = Instant::now(); + let relation_claim = + relation_claim_from_rows::(&rs.tau1, rs.alpha, &quad_eq.v, &commitment.u, &y_ring); + let RingSwitchOutput { + w, + w_commitment, + w_hint, + w_evals, + w_evals_field: _, + m_evals_x, + alpha_evals_y, + num_u, + num_l, + tau0, + tau1: _, + b, + alpha: _, + } = rs; + let w_evals_small = w_evals.clone(); + #[cfg(debug_assertions)] + let alpha_evals_y_debug = alpha_evals_y.clone(); + #[cfg(debug_assertions)] + let m_evals_x_debug = m_evals_x.clone(); + let mut fused_prover = HachiSumcheckProver::new( + batching_coeff, + w_evals, + &tau0, + b, + alpha_evals_y, + m_evals_x, + num_u, + num_l, + relation_claim, + ); + + let (sumcheck_proof, sumcheck_challenges, _final_claim) = + prove_sumcheck::(&mut fused_prover, transcript, |tr| { + tr.challenge_scalar(CHALLENGE_SUMCHECK_ROUND) + })?; + eprintln!( + " [hachi prove L{level}] fused sumcheck: {:.2}s", + t3.elapsed().as_secs_f64() + ); + + let w_eval = { + let _span = tracing::info_span!("multilinear_eval", level).entered(); + multilinear_eval_small(&w_evals_small, &sumcheck_challenges)? + }; + + #[cfg(debug_assertions)] + prove_level_selfcheck( + &tau0, + &sumcheck_challenges, + w_eval, + b, + batching_coeff, + &alpha_evals_y_debug, + &m_evals_x_debug, + num_u, + _final_claim, + level, + ); + + Ok(ProveLevelOutput { + level_proof: HachiLevelProof::new::( + y_ring, + quad_eq.v, + sumcheck_proof, + w_commitment, + w_eval, + ), + w, + w_hint, + sumcheck_challenges, + num_u, + num_l, + }) +} + +/// Whether the prover should stop folding and send `w` directly. +/// +/// `prev_w_len` is the polynomial length at the previous level (or the +/// original polynomial's field-element count for level 0). +fn should_stop_folding(w_len: usize, prev_w_len: usize) -> bool { + if w_len <= MIN_W_LEN_FOR_FOLDING { + return true; + } + let ratio = w_len as f64 / prev_w_len as f64; + ratio > MIN_SHRINK_RATIO +} + +/// Derive the opening point for the next fold level from the sumcheck +/// challenges of the current level. +/// +/// Sumcheck challenges are ordered `[x_0..x_{num_u-1}, y_0..y_{num_l-1}]` +/// where x selects ring elements and y selects coefficients. +/// The PCS opening point is `[inner, outer]` = `[y, x]`. +fn next_level_opening_point( + sumcheck_challenges: &[F], + num_u: usize, + num_l: usize, +) -> Vec { + let (x, y) = sumcheck_challenges.split_at(num_u); + debug_assert_eq!(y.len(), num_l); + let mut point = Vec::with_capacity(num_u + num_l); + point.extend_from_slice(y); + point.extend_from_slice(x); + point +} + +/// Dispatch a commit-w operation to the correct ring dimension. +/// +/// Each match arm builds NTT caches for the target D and calls `commit_w`. +/// `#[inline(never)]` isolates the match arms in their own stack frame, +/// preventing debug-mode stack bloat from monomorphized arms. +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn dispatch_commit( + commit_d: usize, + commit_ntt_bundle: &mut MultiDNttBundle, + expanded: &HachiExpandedSetup, + w: &[i8], +) -> Result<(FlatRingVec, FlatCommitmentHint), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, +{ + dispatch_with_ntt!( + commit_d, + commit_ntt_bundle, + expanded, + |D_COMMIT, ca, cb, _cd| { + let (wc, wh) = + commit_w::>(w, ca, cb)?; + Ok(( + FlatRingVec::from_commitment(&wc), + FlatCommitmentHint::from_typed(wh), + )) + } + ) +} + +/// Dispatch a prove-level operation to the correct ring dimension. +/// +/// Handles the fast-path (`level_d == D`) and the dynamic dispatch path. +/// `#[inline(never)]` isolates the monomorphized match arms in their own +/// stack frame. +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn dispatch_prove_level( + level_d: usize, + ntt_bundle: &mut MultiDNttBundle, + expanded: &HachiExpandedSetup, + setup_ntt_a: &NttSlotCache, + setup_ntt_b: &NttSlotCache, + setup_ntt_d: &NttSlotCache, + commit_ntt_bundle: &mut MultiDNttBundle, + commit_d: usize, + current_w: &[i8], + current_hint: &FlatCommitmentHint, + current_challenges: &[F], + current_num_u: usize, + current_num_l: usize, + last_w_commitment: &FlatRingVec, + last_w_eval: F, + transcript: &mut T, + level: usize, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + HasUnreducedOps + HasWide, + T: Transcript, + Cfg: CommitmentConfig, +{ + if level_d == D { + prove_subsequent_level::( + expanded, + setup_ntt_a, + setup_ntt_b, + setup_ntt_d, + commit_ntt_bundle, + commit_d, + current_w, + current_hint, + current_challenges, + current_num_u, + current_num_l, + last_w_commitment, + last_w_eval, + transcript, + level, + ) + } else { + dispatch_with_ntt!( + level_d, + ntt_bundle, + expanded, + |D_LEVEL, ntt_a, ntt_b, ntt_d| { + prove_subsequent_level::( + expanded, + ntt_a, + ntt_b, + ntt_d, + commit_ntt_bundle, + commit_d, + current_w, + current_hint, + current_challenges, + current_num_u, + current_num_l, + last_w_commitment, + last_w_eval, + transcript, + level, + ) + } + ) + } +} + +/// Dispatch a verify-level operation to the correct ring dimension. +/// +/// Each match arm converts the D-erased commitment to a typed one, +/// derives the w-commitment layout, and calls `verify_one_level`. +/// `#[inline(never)]` isolates the monomorphized match arms. +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn dispatch_verify_level( + level_d: usize, + level_proof: &HachiLevelProof, + setup: &HachiVerifierSetup, + transcript: &mut T, + opening_point: &[F], + opening: &F, + current_commitment: &FlatRingVec, + basis: BasisMode, + is_last: bool, + final_w: Option<&[F]>, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, + Cfg: CommitmentConfig, +{ + dispatch_ring_dim!(level_d, |D_LEVEL| { + let typed_commitment: RingCommitment = + current_commitment.to_ring_commitment(); + let w_layout = + >::commitment_layout(opening_point.len())?; + verify_one_level::>( + level_proof, + setup, + transcript, + opening_point, + opening, + &typed_commitment, + basis, + is_last, + final_w, + w_layout, + ) + }) +} + +/// Single subsequent (recursive) prove level, extracted so that the +/// dispatch match arms contain only a function call. +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn prove_subsequent_level( + expanded: &HachiExpandedSetup, + ntt_a: &NttSlotCache, + ntt_b: &NttSlotCache, + ntt_d: &NttSlotCache, + commit_ntt_bundle: &mut MultiDNttBundle, + commit_d: usize, + current_w: &[i8], + current_hint: &FlatCommitmentHint, + current_challenges: &[F], + current_num_u: usize, + current_num_l: usize, + last_w_commitment: &FlatRingVec, + #[cfg_attr(not(debug_assertions), allow(unused_variables))] last_w_eval: F, + transcript: &mut T, + level: usize, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + HasUnreducedOps + HasWide, + T: Transcript, + Cfg: CommitmentConfig, +{ + let w_poly = BalancedDigitPoly::::from_i8_digits(current_w)?; + let opening_point = next_level_opening_point(current_challenges, current_num_u, current_num_l); + + #[cfg(debug_assertions)] + { + let mut field_evals: Vec = current_w.iter().map(|&d| F::from_i8(d)).collect(); + field_evals.resize(w_poly.num_ring_elems() * D_LEVEL, F::zero()); + let direct_eval = multilinear_eval(&field_evals, &opening_point).unwrap(); + if last_w_eval != direct_eval { + eprintln!(" [hachi prove L{level}] BUG: w_eval mismatch! prev_level w_eval != w_poly eval at opening_point"); + eprintln!( + " w_poly ring_elems={}, field_len={}, opening_point.len()={}", + w_poly.num_ring_elems(), + field_evals.len(), + opening_point.len() + ); + } else { + eprintln!(" [hachi prove L{level}] w_eval consistency OK"); + } + } + + let w_commitment: RingCommitment = last_w_commitment.to_ring_commitment(); + let typed_hint: HachiCommitmentHint = current_hint.to_typed(); + + let commit_fn: CommitFn<'_, F> = Box::new( + |w: &[i8]| -> Result<(FlatRingVec, FlatCommitmentHint), HachiError> { + if commit_d == D_LEVEL { + let (wc, wh) = commit_w::>( + w, ntt_a, ntt_b, + )?; + Ok(( + FlatRingVec::from_commitment(&wc), + FlatCommitmentHint::from_typed(wh), + )) + } else { + dispatch_commit::(commit_d, commit_ntt_bundle, expanded, w) + } + }, + ); + + let w_layout = >::commitment_layout(opening_point.len())?; + prove_one_level::, _>( + expanded, + ntt_a, + ntt_b, + ntt_d, + commit_fn, + &w_poly, + &opening_point, + typed_hint, + transcript, + &w_commitment, + BasisMode::Lagrange, + level, + w_layout, + ) +} + +impl CommitmentScheme for HachiCommitmentScheme +where + F: FieldCore + CanonicalField + FieldSampling + HasWide + HasUnreducedOps + Valid, + Cfg: CommitmentConfig, +{ + type ProverSetup = HachiProverSetup; + type VerifierSetup = HachiVerifierSetup; + type Commitment = RingCommitment; + type Proof = HachiProof; + type CommitHint = HachiCommitmentHint; + + #[tracing::instrument(skip_all, name = "HachiCommitmentScheme::setup_prover")] + fn setup_prover(max_num_vars: usize) -> Self::ProverSetup { + let (setup, _) = + >::setup(max_num_vars) + .expect("commitment setup failed"); + setup + } + + fn setup_verifier(setup: &Self::ProverSetup) -> Self::VerifierSetup { + HachiVerifierSetup { + expanded: setup.expanded.clone(), + } + } + + #[tracing::instrument(skip_all, name = "HachiCommitmentScheme::commit")] + fn commit>( + poly: &P, + setup: &Self::ProverSetup, + layout: &HachiCommitmentLayout, + ) -> Result<(Self::Commitment, Self::CommitHint), HachiError> { + setup.assert_layout_fits(layout); + let t_hat_all = poly.commit_inner( + &setup.expanded.A, + &setup.ntt_A, + layout.block_len, + layout.num_digits_commit, + layout.num_digits_open, + layout.log_basis, + )?; + let t_hat_flat = flatten_i8_blocks(&t_hat_all); + let u: Vec> = mat_vec_mul_ntt_single_i8(&setup.ntt_B, &t_hat_flat); + let hint = HachiCommitmentHint::new(t_hat_all); + Ok((RingCommitment { u }, hint)) + } + + #[tracing::instrument(skip_all, name = "HachiCommitmentScheme::prove")] + fn prove, P: HachiPolyOps>( + setup: &Self::ProverSetup, + poly: &P, + opening_point: &[F], + hint: Self::CommitHint, + transcript: &mut T, + commitment: &Self::Commitment, + basis: BasisMode, + layout: &HachiCommitmentLayout, + ) -> Result { + let t_prove_total = Instant::now(); + let mut levels = Vec::new(); + + let mut ntt_bundle = MultiDNttBundle::new(); + let mut commit_ntt_bundle = MultiDNttBundle::new(); + + // Level 0: original polynomial with caller-provided layout. + // The w-commitment is produced at the NEXT level's D. + let commit_d_0 = Cfg::d_at_level(1, 0); + let commit_fn_0: CommitFn<'_, F> = if commit_d_0 == D { + Box::new( + |w: &[i8]| -> Result<(FlatRingVec, FlatCommitmentHint), HachiError> { + let (wc, wh) = commit_w::(w, &setup.ntt_A, &setup.ntt_B)?; + Ok(( + FlatRingVec::from_commitment(&wc), + FlatCommitmentHint::from_typed(wh), + )) + }, + ) + } else { + Box::new( + |w: &[i8]| -> Result<(FlatRingVec, FlatCommitmentHint), HachiError> { + dispatch_commit::( + commit_d_0, + &mut commit_ntt_bundle, + &setup.expanded, + w, + ) + }, + ) + }; + let out = prove_one_level::( + &setup.expanded, + &setup.ntt_A, + &setup.ntt_B, + &setup.ntt_D, + commit_fn_0, + poly, + opening_point, + hint, + transcript, + commitment, + basis, + 0, + *layout, + )?; + levels.push(out.level_proof); + + let mut prev_poly_len = poly.num_ring_elems() * D; + let mut current_w = out.w; + let mut current_hint = out.w_hint; + let mut current_challenges = out.sumcheck_challenges; + let mut current_num_u = out.num_u; + let mut current_num_l = out.num_l; + let mut level = 1usize; + + // Subsequent levels: recursive w-opening with WCommitmentConfig. + // Each level dispatches to the ring dimension from Cfg::d_at_level. + // The w-commitment is produced at the NEXT level's D. + while !should_stop_folding(current_w.len(), prev_poly_len) { + let level_d = Cfg::d_at_level(level, current_w.len()); + let commit_d = Cfg::d_at_level(level + 1, 0); + + let last_w_eval = levels.last().unwrap().w_eval; + let last_w_commitment = &levels.last().unwrap().w_commitment; + let out = dispatch_prove_level::( + level_d, + &mut ntt_bundle, + &setup.expanded, + &setup.ntt_A, + &setup.ntt_B, + &setup.ntt_D, + &mut commit_ntt_bundle, + commit_d, + ¤t_w, + ¤t_hint, + ¤t_challenges, + current_num_u, + current_num_l, + last_w_commitment, + last_w_eval, + transcript, + level, + )?; + + levels.push(out.level_proof); + + prev_poly_len = current_w.len(); + current_w = out.w; + current_hint = out.w_hint; + current_challenges = out.sumcheck_challenges; + current_num_u = out.num_u; + current_num_l = out.num_l; + level += 1; + } + + eprintln!( + " [hachi prove] total ({level} levels): {:.2}s", + t_prove_total.elapsed().as_secs_f64() + ); + + let log_basis = Cfg::decomposition().log_basis; + let final_w = PackedDigits::from_i8_digits(¤t_w, log_basis); + + Ok(HachiProof { levels, final_w }) + } + + #[tracing::instrument(skip_all, name = "HachiCommitmentScheme::verify")] + fn verify>( + proof: &Self::Proof, + setup: &Self::VerifierSetup, + transcript: &mut T, + opening_point: &[F], + opening: &F, + commitment: &Self::Commitment, + basis: BasisMode, + layout: &HachiCommitmentLayout, + ) -> Result<(), HachiError> { + if proof.levels.is_empty() { + return Err(HachiError::InvalidProof); + } + + let num_levels = proof.levels.len(); + let final_w_elems: Vec = proof.final_w.to_field_elems(); + + // State carried between levels. + // Commitment is D-erased so the loop can handle varying D per level. + let mut current_point = opening_point.to_vec(); + let mut current_opening = *opening; + let mut current_commitment = FlatRingVec::from_commitment(commitment); + let mut current_basis = basis; + + for (i, level_proof) in proof.levels.iter().enumerate() { + let is_last = i == num_levels - 1; + let level_d = Cfg::d_at_level(i, current_point.len()); + eprintln!( + " [verify] level {i}, is_last={is_last}, point_len={}, D={level_d}", + current_point.len() + ); + + let challenges = if i == 0 { + let typed_commitment: RingCommitment = + current_commitment.to_ring_commitment(); + verify_one_level::( + level_proof, + setup, + transcript, + ¤t_point, + ¤t_opening, + &typed_commitment, + current_basis, + is_last, + if is_last { Some(&final_w_elems) } else { None }, + *layout, + )? + } else { + dispatch_verify_level::( + level_d, + level_proof, + setup, + transcript, + ¤t_point, + ¤t_opening, + ¤t_commitment, + current_basis, + is_last, + if is_last { Some(&final_w_elems) } else { None }, + )? + }; + + if !is_last { + let alpha_bits = level_d.trailing_zeros() as usize; + let num_l = alpha_bits; + let num_u = challenges.len() - num_l; + + current_point = next_level_opening_point(&challenges, num_u, num_l); + current_opening = level_proof.w_eval; + current_commitment = level_proof.w_commitment.clone(); + current_basis = BasisMode::Lagrange; + } + } + + Ok(()) + } + + fn protocol_name() -> &'static [u8] { + unimplemented!() + } +} + +/// Verify one fold level. +/// +/// At the final level, `final_w` is provided and the verifier checks w_val +/// from it directly. At intermediate levels, `level_proof.w_eval` is used. +/// +/// Returns the sumcheck challenges for chaining into the next level. +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn verify_one_level( + level_proof: &HachiLevelProof, + setup: &HachiVerifierSetup, + transcript: &mut T, + opening_point: &[F], + opening: &F, + commitment: &RingCommitment, + basis: BasisMode, + is_last: bool, + final_w: Option<&[F]>, + layout: HachiCommitmentLayout, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, + Cfg: CommitmentConfig, +{ + let y_ring: CyclotomicRing = level_proof.y_ring_typed(); + let v_typed: Vec> = level_proof.v_typed(); + + let alpha_bits = Cfg::D.trailing_zeros() as usize; + if opening_point.len() < alpha_bits { + return Err(HachiError::InvalidSetup( + "opening point length underflow".to_string(), + )); + } + let target_num_vars = layout.m_vars + layout.r_vars + alpha_bits; + let mut padded_point = opening_point.to_vec(); + padded_point.resize(target_num_vars, F::zero()); + let inner_point = &padded_point[..alpha_bits]; + let reduced_opening_point = &padded_point[alpha_bits..]; + + commitment.append_to_transcript(ABSORB_COMMITMENT, transcript); + for pt in &padded_point { + transcript.append_field(ABSORB_EVALUATION_CLAIMS, pt); + } + transcript.append_serde(ABSORB_EVALUATION_CLAIMS, &y_ring); + + let v = reduce_inner_openings_to_ring_elements::(inner_point, basis)?; + let d = F::from_u64(Cfg::D as u64); + let trace_lhs = trace::(&(y_ring * v.sigma_m1())); + let trace_rhs = d * *opening; + if trace_lhs != trace_rhs { + return Err(HachiError::InvalidProof); + } + + let ring_opening_point = ring_opening_point_from_field::( + reduced_opening_point, + layout.r_vars, + layout.m_vars, + basis, + )?; + let quad_eq = Box::new(QuadraticEquation::::new_verifier( + ring_opening_point, + v_typed.clone(), + transcript, + commitment, + &y_ring, + layout, + )?); + + let w_len = if is_last { + final_w.map_or(0, |fw| fw.len()) + } else { + w_ring_element_count::(layout) * D + }; + eprintln!(" [verify] w_len={w_len}, is_last={is_last}"); + + let rs = ring_switch_verifier::( + &quad_eq, + &setup.expanded, + w_len, + &level_proof.w_commitment, + transcript, + layout, + )?; + + let batching_coeff: F = transcript.challenge_scalar(CHALLENGE_SUMCHECK_BATCH); + + let fused_verifier = if is_last { + let fw = final_w.ok_or(HachiError::InvalidProof)?; + let (w_evals_full, _, _) = build_w_evals(fw, Cfg::D)?; + HachiSumcheckVerifier::new( + batching_coeff, + w_evals_full, + rs.tau0, + rs.b, + rs.alpha_evals_y, + rs.m_evals_x, + rs.tau1, + v_typed, + commitment.u.clone(), + y_ring, + rs.alpha, + rs.num_u, + rs.num_l, + ) + } else { + HachiSumcheckVerifier::new( + batching_coeff, + Vec::new(), + rs.tau0, + rs.b, + rs.alpha_evals_y, + rs.m_evals_x, + rs.tau1, + v_typed, + commitment.u.clone(), + y_ring, + rs.alpha, + rs.num_u, + rs.num_l, + ) + .with_w_val_override(level_proof.w_eval) + }; + + let challenges = verify_sumcheck::( + &level_proof.sumcheck_proof, + &fused_verifier, + transcript, + |tr| tr.challenge_scalar(CHALLENGE_SUMCHECK_ROUND), + )?; + + Ok(challenges) +} + +/// Re-derive the ring-switch challenge `alpha` and the expanded `M_a` vector +/// by replaying the transcript from the proof data and setup, exactly as the +/// verifier does. +#[cfg(test)] +pub(crate) fn rederive_alpha_and_m_a( + proof: &HachiProof, + setup: &HachiVerifierSetup, + opening_point: &[F], + commitment: &RingCommitment, +) -> Result<(F, Vec), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + 'static, + Cfg: CommitmentConfig, +{ + let level0 = proof.levels.first().ok_or(HachiError::InvalidProof)?; + let y_ring: CyclotomicRing = level0.y_ring_typed(); + let v_typed: Vec> = level0.v_typed(); + + let alpha_bits = Cfg::D.trailing_zeros() as usize; + if opening_point.len() < alpha_bits { + return Err(HachiError::InvalidSetup( + "opening point length underflow".to_string(), + )); + } + let layout = Cfg::commitment_layout(opening_point.len())?; + let ring_opening_point = ring_opening_point_from_field::( + &opening_point[alpha_bits..], + layout.r_vars, + layout.m_vars, + BasisMode::Lagrange, + )?; + let mut transcript = Blake2bTranscript::::new(DOMAIN_HACHI_PROTOCOL); + + commitment.append_to_transcript(ABSORB_COMMITMENT, &mut transcript); + for pt in opening_point { + transcript.append_field(ABSORB_EVALUATION_CLAIMS, pt); + } + transcript.append_serde(ABSORB_EVALUATION_CLAIMS, &y_ring); + + let quad_eq = QuadraticEquation::::new_verifier( + ring_opening_point, + v_typed, + &mut transcript, + commitment, + &y_ring, + layout, + )?; + transcript.append_serde(ABSORB_SUMCHECK_W, &level0.w_commitment); + let alpha: F = transcript.challenge_scalar(CHALLENGE_RING_SWITCH); + let m_a = compute_m_a_reference::( + &setup.expanded, + quad_eq.opening_point(), + &quad_eq.challenges, + &alpha, + layout, + )?; + let m_a_vec = expand_m_a::(&m_a, alpha, layout.log_basis)?; + Ok((alpha, m_a_vec)) +} + +/// Re-derive the ring-switch challenge `alpha` and the fused `m_evals_x` +/// table by replaying the verifier transcript from the proof data and setup. +#[cfg(test)] +pub(crate) fn rederive_alpha_and_m_evals_x( + proof: &HachiProof, + setup: &HachiVerifierSetup, + opening_point: &[F], + commitment: &RingCommitment, + tau1: &[F], +) -> Result<(F, Vec), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + 'static, + Cfg: CommitmentConfig, +{ + let level0 = proof.levels.first().ok_or(HachiError::InvalidProof)?; + let y_ring: CyclotomicRing = level0.y_ring_typed(); + let v_typed: Vec> = level0.v_typed(); + + let alpha_bits = Cfg::D.trailing_zeros() as usize; + if opening_point.len() < alpha_bits { + return Err(HachiError::InvalidSetup( + "opening point length underflow".to_string(), + )); + } + let layout = Cfg::commitment_layout(opening_point.len())?; + let ring_opening_point = ring_opening_point_from_field::( + &opening_point[alpha_bits..], + layout.r_vars, + layout.m_vars, + BasisMode::Lagrange, + )?; + let mut transcript = Blake2bTranscript::::new(DOMAIN_HACHI_PROTOCOL); + + commitment.append_to_transcript(ABSORB_COMMITMENT, &mut transcript); + for pt in opening_point { + transcript.append_field(ABSORB_EVALUATION_CLAIMS, pt); + } + transcript.append_serde(ABSORB_EVALUATION_CLAIMS, &y_ring); + + let quad_eq = QuadraticEquation::::new_verifier( + ring_opening_point, + v_typed, + &mut transcript, + commitment, + &y_ring, + layout, + )?; + transcript.append_serde(ABSORB_SUMCHECK_W, &level0.w_commitment); + let alpha: F = transcript.challenge_scalar(CHALLENGE_RING_SWITCH); + let alpha_evals_y = build_alpha_evals_y(alpha, D); + let m_evals_x = compute_m_evals_x::( + &setup.expanded, + quad_eq.opening_point(), + &quad_eq.challenges, + alpha, + &alpha_evals_y, + layout, + tau1, + )?; + Ok((alpha, m_evals_x)) +} + +fn lagrange_weights(point: &[F]) -> Vec { + let len = 1usize << point.len(); + let mut weights = vec![F::zero(); len]; + multilinear_lagrange_basis(&mut weights, point); + weights +} + +/// Multilinear monomial weights: `⊗ᵢ (1, xᵢ)`. +/// +/// The j-th entry is `∏_{i ∈ bits(j)} point[i]`. +fn monomial_weights(point: &[F]) -> Vec { + let len = 1usize << point.len(); + let mut weights = vec![F::zero(); len]; + weights[0] = F::one(); + for (level, &p) in point.iter().enumerate() { + let k = 1usize << level; + for i in (0..k).rev() { + weights[i + k] = weights[i] * p; + } + } + weights +} + +fn basis_weights(point: &[F], mode: BasisMode) -> Vec { + match mode { + BasisMode::Lagrange => lagrange_weights(point), + BasisMode::Monomial => monomial_weights(point), + } +} + +fn ring_opening_point_from_field( + opening_point: &[F], + r_vars: usize, + m_vars: usize, + basis: BasisMode, +) -> Result, HachiError> { + let expected_len = r_vars + .checked_add(m_vars) + .ok_or_else(|| HachiError::InvalidSetup("opening point length overflow".to_string()))?; + if opening_point.len() != expected_len { + return Err(HachiError::InvalidPointDimension { + expected: expected_len, + actual: opening_point.len(), + }); + } + + // Sequential ordering: M variables (position in block) come first, + // R variables (block selection) come second. + let a = basis_weights(&opening_point[..m_vars], basis); + let b = basis_weights(&opening_point[m_vars..], basis); + Ok(RingOpeningPoint { a, b }) +} + +fn reduce_inner_openings_to_ring_elements( + inner_point: &[F], + basis: BasisMode, +) -> Result, HachiError> { + let weights = basis_weights(inner_point, basis); + if weights.len() != D { + return Err(HachiError::InvalidInput(format!( + "inner basis length {} does not match D={D}", + weights.len() + ))); + } + Ok(CyclotomicRing::from_slice(&weights)) +} + +fn trace(u: &CyclotomicRing) -> F { + let d = F::from_u64(D as u64); + u.coefficients()[0] * d +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::commitment::CommitmentConfig; + use crate::protocol::hachi_poly_ops::DensePoly; + use crate::protocol::transcript::Blake2bTranscript; + use crate::test_utils::F; + use crate::{CommitmentScheme, FromSmallInt}; + + type Cfg = SmallTestCommitmentConfig; + const D: usize = Cfg::D; + type Scheme = HachiCommitmentScheme; + + fn make_dense_poly(num_vars: usize) -> (DensePoly, Vec) { + let len = 1usize << num_vars; + let evals: Vec = (0..len).map(|i| F::from_u64(i as u64)).collect(); + let poly = DensePoly::::from_field_evals(num_vars, &evals).unwrap(); + (poly, evals) + } + + #[test] + fn verify_passes_for_consistent_opening() { + let alpha = D.trailing_zeros() as usize; + let layout = Cfg::commitment_layout(16).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha; + + let (poly, evals) = make_dense_poly(num_vars); + + let setup = >::setup_prover(num_vars); + let verifier_setup = >::setup_verifier(&setup); + + let (commitment, hint) = + >::commit(&poly, &setup, &layout).unwrap(); + + let opening_point: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + let lw = lagrange_weights(&opening_point); + let opening: F = evals + .iter() + .zip(lw.iter()) + .fold(F::zero(), |a, (&c, &w)| a + c * w); + + let mut prover_transcript = Blake2bTranscript::::new(b"test/prove"); + let proof = >::prove( + &setup, + &poly, + &opening_point, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + let mut verifier_transcript = Blake2bTranscript::::new(b"test/prove"); + let result = >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ); + + assert!(result.is_ok()); + } + + #[test] + fn verify_rejects_wrong_opening() { + let alpha = D.trailing_zeros() as usize; + let layout = Cfg::commitment_layout(16).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha; + + let (poly, evals) = make_dense_poly(num_vars); + + let setup = >::setup_prover(num_vars); + let verifier_setup = >::setup_verifier(&setup); + + let (commitment, hint) = + >::commit(&poly, &setup, &layout).unwrap(); + + let opening_point: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + let lw = lagrange_weights(&opening_point); + let opening: F = evals + .iter() + .zip(lw.iter()) + .fold(F::zero(), |a, (&c, &w)| a + c * w); + + let mut prover_transcript = Blake2bTranscript::::new(b"test/prove"); + let proof = >::prove( + &setup, + &poly, + &opening_point, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + let wrong_opening = opening + F::one(); + let mut verifier_transcript = Blake2bTranscript::::new(b"test/prove"); + let result = >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &wrong_opening, + &commitment, + BasisMode::Lagrange, + &layout, + ); + + assert!( + result.is_err(), + "verify must reject an incorrect opening value" + ); + } + + #[test] + fn monomial_basis_prove_verify_round_trip() { + let alpha = D.trailing_zeros() as usize; + let layout = Cfg::commitment_layout(16).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha; + let len = 1usize << num_vars; + + let coeffs: Vec = (0..len).map(|i| F::from_u64(i as u64)).collect(); + let poly = DensePoly::::from_field_evals(num_vars, &coeffs).unwrap(); + + let setup = >::setup_prover(num_vars); + let verifier_setup = >::setup_verifier(&setup); + + let (commitment, hint) = + >::commit(&poly, &setup, &layout).unwrap(); + + let opening_point: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + + let mw = monomial_weights(&opening_point); + let opening: F = coeffs + .iter() + .zip(mw.iter()) + .fold(F::zero(), |acc, (&c, &w)| acc + c * w); + + let mut prover_transcript = Blake2bTranscript::::new(b"test/monomial"); + let proof = >::prove( + &setup, + &poly, + &opening_point, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Monomial, + &layout, + ) + .unwrap(); + + let mut verifier_transcript = Blake2bTranscript::::new(b"test/monomial"); + let result = >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Monomial, + &layout, + ); + + assert!( + result.is_ok(), + "monomial-basis proof should verify: {result:?}" + ); + } +} diff --git a/src/protocol/dispatch.rs b/src/protocol/dispatch.rs new file mode 100644 index 00000000..62db9f51 --- /dev/null +++ b/src/protocol/dispatch.rs @@ -0,0 +1,137 @@ +//! Runtime-to-const-generic dispatch for ring dimension D. +//! +//! The supported D values (all powers of 2 that admit a CRT+NTT decomposition) +//! are: 64, 128, 256, 512, 1024. + +/// Bridge a runtime `d: usize` to a const-generic `D` context. +/// +/// Calls `$body` with the matched const `D`. Inside `$body`, `D` is available +/// as a const generic parameter (via the generated function). +/// +/// # Supported dimensions +/// +/// 64, 128, 256, 512, 1024. +/// +/// # Panics +/// +/// Panics at runtime if `d` is not one of the supported values. +/// +/// # Examples +/// +/// ``` +/// use hachi_pcs::dispatch_ring_dim; +/// let ring_dim: usize = 256; +/// let result = dispatch_ring_dim!(ring_dim, |D| D * 2); +/// assert_eq!(result, 512); +/// ``` +#[macro_export] +macro_rules! dispatch_ring_dim { + ($d:expr, |$D:ident| $body:expr) => {{ + let __d = $d; + match __d { + 64 => { + const $D: usize = 64; + $body + } + 128 => { + const $D: usize = 128; + $body + } + 256 => { + const $D: usize = 256; + $body + } + 512 => { + const $D: usize = 512; + $body + } + 1024 => { + const $D: usize = 1024; + $body + } + _ => panic!("unsupported ring dimension: {__d}"), + } + }}; +} + +/// Like [`dispatch_ring_dim!`] but also lazily builds NTT caches for the +/// matched ring dimension from a [`crate::protocol::commitment::utils::ntt_cache::MultiDNttBundle`] and +/// [`crate::protocol::commitment::HachiExpandedSetup`]. +/// +/// Inside the body, `$D` is a const ring dimension and `$ntt_a`, `$ntt_b`, +/// `$ntt_d` are `&NttSlotCache` references. +/// +/// # Panics +/// +/// Panics at runtime if `d` is not one of the supported values. +#[macro_export] +macro_rules! dispatch_with_ntt { + ($d:expr, $ntt:expr, $expanded:expr, + |$D:ident, $ntt_a:ident, $ntt_b:ident, $ntt_d:ident| $body:expr) => {{ + let __d = $d; + match __d { + 64 => { + const $D: usize = 64; + let $ntt_a = ($ntt).A.get_or_build_64(&($expanded).A)?; + let $ntt_b = ($ntt).B.get_or_build_64(&($expanded).B)?; + let $ntt_d = ($ntt).D_mat.get_or_build_64(&($expanded).D_mat)?; + $body + } + 128 => { + const $D: usize = 128; + let $ntt_a = ($ntt).A.get_or_build_128(&($expanded).A)?; + let $ntt_b = ($ntt).B.get_or_build_128(&($expanded).B)?; + let $ntt_d = ($ntt).D_mat.get_or_build_128(&($expanded).D_mat)?; + $body + } + 256 => { + const $D: usize = 256; + let $ntt_a = ($ntt).A.get_or_build_256(&($expanded).A)?; + let $ntt_b = ($ntt).B.get_or_build_256(&($expanded).B)?; + let $ntt_d = ($ntt).D_mat.get_or_build_256(&($expanded).D_mat)?; + $body + } + 512 => { + const $D: usize = 512; + let $ntt_a = ($ntt).A.get_or_build_512(&($expanded).A)?; + let $ntt_b = ($ntt).B.get_or_build_512(&($expanded).B)?; + let $ntt_d = ($ntt).D_mat.get_or_build_512(&($expanded).D_mat)?; + $body + } + 1024 => { + const $D: usize = 1024; + let $ntt_a = ($ntt).A.get_or_build_1024(&($expanded).A)?; + let $ntt_b = ($ntt).B.get_or_build_1024(&($expanded).B)?; + let $ntt_d = ($ntt).D_mat.get_or_build_1024(&($expanded).D_mat)?; + $body + } + _ => panic!("unsupported ring dimension: {__d}"), + } + }}; +} + +/// The set of supported ring dimensions for [`dispatch_ring_dim!`]. +pub const SUPPORTED_RING_DIMS: &[usize] = &[64, 128, 256, 512, 1024]; + +/// Returns true if `d` is one of the [`SUPPORTED_RING_DIMS`]. +#[inline] +pub fn is_supported_ring_dim(d: usize) -> bool { + SUPPORTED_RING_DIMS.contains(&d) +} + +#[cfg(test)] +mod tests { + #[test] + fn dispatch_ring_dim_basic() { + for &d in super::SUPPORTED_RING_DIMS { + let result = dispatch_ring_dim!(d, |D| D); + assert_eq!(result, d); + } + } + + #[test] + #[should_panic(expected = "unsupported ring dimension")] + fn dispatch_ring_dim_unsupported_panics() { + let _ = dispatch_ring_dim!(42, |D| D); + } +} diff --git a/src/protocol/greyhound/eval.rs b/src/protocol/greyhound/eval.rs new file mode 100644 index 00000000..a09bb7ba --- /dev/null +++ b/src/protocol/greyhound/eval.rs @@ -0,0 +1,611 @@ +//! Greyhound prover-side evaluation reduction. +//! +//! Produces a 4-row witness matching the C reference structure (adapted for +//! multilinear evaluation) and 5 constraints via `greyhound_reduce`. + +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +use crate::primitives::poly::multilinear_lagrange_basis; +use crate::protocol::commitment::utils::linear::decompose_rows_with_carry; +use crate::protocol::greyhound::reduce::greyhound_reduce; +use crate::protocol::greyhound::types::GreyhoundEvalProof; +use crate::protocol::labrador::comkey::{derive_extendable_comkey_matrix, LabradorComKeySeed}; +use crate::protocol::labrador::select_config; +use crate::protocol::labrador::transcript::{ + absorb_greyhound_eval_claim, absorb_greyhound_eval_context, absorb_greyhound_u2, + sample_greyhound_fold_challenge, GreyhoundEvalTranscriptContext, +}; +use crate::protocol::labrador::types::{LabradorStatement, LabradorWitness}; +use crate::protocol::labrador::utils::mat_vec_mul; +use crate::protocol::prg::MatrixPrgBackendChoice; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, FieldSampling}; + +/// Output of `greyhound_eval`: proof, witness, and reduced statement. +pub type GreyhoundEvalResult = ( + GreyhoundEvalProof, + LabradorWitness, + LabradorStatement, +); + +/// Build Greyhound evaluation proof and reduced Labrador witness/statement. +/// +/// The witness has 4 rows matching the C reference: +/// row0: z_low (m*f elements) — low part of decomposed amortized z +/// row1: z_high (m*f elements) — high part (z = z_low + 2^bu * z_high) +/// row2: t_hat (kappa*fu*n elements) — decomposed inner commitments +/// row3: v_hat (fu*n elements) — decomposed partial evaluations +/// +/// # Errors +/// +/// Returns an error if reshaping, config selection, or commitment fails. +pub fn greyhound_eval( + witness_coeffs: &[F], + eval_point: &[F], + eval_value: F, + w_commitment_u1: &[CyclotomicRing], + comkey_seed: &LabradorComKeySeed, + backend: MatrixPrgBackendChoice, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, +{ + let ring_witness = pack_coefficients_to_ring::(witness_coeffs); + if ring_witness.is_empty() { + return Err(HachiError::InvalidInput( + "greyhound_eval requires non-empty witness".to_string(), + )); + } + let (m_rows, n_cols, inner_vars) = choose_dimensions(ring_witness.len()); + if eval_point.len() < inner_vars { + return Err(HachiError::InvalidPointDimension { + expected: inner_vars, + actual: eval_point.len(), + }); + } + + let inner_point = &eval_point[eval_point.len() - inner_vars..]; + let mut inner_basis = vec![F::zero(); 1usize << inner_vars]; + multilinear_lagrange_basis(&mut inner_basis, inner_point); + + let matrix = reshape_columns(&ring_witness, m_rows, n_cols); + let partial_evals = partial_evaluate_columns(&matrix, &inner_basis); + + // Select Labrador config from columns (pre-amortization dimensions). + let column_witness = columns_to_witness(&matrix); + let cfg = select_config(&column_witness)?; + + // Decompose partial evaluations v → v_hat (group 3). + let v_hat = decompose_rows_with_carry(&partial_evals, cfg.fu, cfg.bu as u32); + + // Commit v_hat → u2 (outer commitment to evaluation witness). + let u2 = if cfg.kappa1 > 0 { + let b_eval = derive_extendable_comkey_matrix::( + cfg.kappa1, + v_hat.len(), + comkey_seed, + b"greyhound/comkey/B_eval", + backend, + ); + mat_vec_mul(&b_eval, &v_hat) + } else { + v_hat.clone() + }; + + // Transcript: absorb context, claim, u2. + absorb_greyhound_eval_context( + transcript, + &GreyhoundEvalTranscriptContext { + m_rows, + n_cols, + inner_vars, + eval_point_len: eval_point.len(), + prg_backend_id: backend as u8, + }, + )?; + absorb_greyhound_eval_claim(transcript, eval_point, &eval_value); + absorb_greyhound_u2(transcript, &u2); + + // Sample n_cols fold challenges from transcript. + let fold_challenges: Vec = (0..n_cols) + .map(|_| sample_greyhound_fold_challenge(transcript)) + .collect(); + + // Amortize columns: z[j] = sum_col c_col * column[col][j]. + let mut z = vec![CyclotomicRing::::zero(); m_rows]; + for (col_idx, column) in matrix.iter().enumerate() { + let c = fold_challenges[col_idx]; + for (j, elem) in column.iter().enumerate() { + z[j] += elem.scale(&c); + } + } + + // Decompose z → groups 0 (z_low) and 1 (z_high). + // First: decompose with (f, b), then split each part into low/high with bu. + let z_first = decompose_rows_with_carry(&z, cfg.f, cfg.b as u32); + let z_uniform = decompose_rows_with_carry(&z_first, 2, cfg.bu as u32); + let mut z_low = Vec::with_capacity(z_first.len()); + let mut z_high = Vec::with_capacity(z_first.len()); + for i in 0..z_first.len() { + z_low.push(z_uniform[2 * i]); + z_high.push(z_uniform[2 * i + 1]); + } + + // Compute inner commitments t_j = A * column_j, decompose → t_hat (group 2). + let mut t_hat_flat = Vec::new(); + for column in &matrix { + let a = derive_extendable_comkey_matrix::( + cfg.kappa, + column.len(), + comkey_seed, + b"labrador/comkey/A", + backend, + ); + let t_j = mat_vec_mul(&a, column); + t_hat_flat.extend(decompose_rows_with_carry(&t_j, cfg.fu, cfg.bu as u32)); + } + + let greyhound_witness = LabradorWitness::new_unchecked(vec![z_low, z_high, t_hat_flat, v_hat]); + + let proof = GreyhoundEvalProof { + u2: u2.clone(), + m_rows, + n_cols, + inner_vars, + config: cfg, + }; + + let mut statement = greyhound_reduce( + &proof, + w_commitment_u1, + eval_point, + eval_value, + &fold_challenges, + comkey_seed, + backend, + )?; + statement.beta_sq = greyhound_witness.norm(); + + Ok((proof, greyhound_witness, statement)) +} + +fn pack_coefficients_to_ring( + coeffs: &[F], +) -> Vec> { + if coeffs.is_empty() { + return Vec::new(); + } + let mut out = Vec::with_capacity(coeffs.len().div_ceil(D)); + for chunk in coeffs.chunks(D) { + let ring = CyclotomicRing::from_coefficients(std::array::from_fn(|i| { + chunk.get(i).copied().unwrap_or_else(F::zero) + })); + out.push(ring); + } + out +} + +fn choose_dimensions(num_ring_elements: usize) -> (usize, usize, usize) { + let n = num_ring_elements.max(1).next_power_of_two(); + let k_total = n.trailing_zeros() as usize; + let inner_vars = k_total / 2; + let outer_vars = k_total - inner_vars; + (1usize << inner_vars, 1usize << outer_vars, inner_vars) +} + +fn reshape_columns( + ring_witness: &[CyclotomicRing], + m_rows: usize, + n_cols: usize, +) -> Vec>> { + (0..n_cols) + .map(|col| { + (0..m_rows) + .map(|row| { + let idx = col * m_rows + row; + ring_witness + .get(idx) + .copied() + .unwrap_or_else(CyclotomicRing::::zero) + }) + .collect::>() + }) + .collect() +} + +fn partial_evaluate_columns( + columns: &[Vec>], + inner_basis: &[F], +) -> Vec> { + columns + .iter() + .map(|col| { + let mut acc = CyclotomicRing::::zero(); + for (elem, &basis) in col.iter().zip(inner_basis.iter()) { + acc += elem.scale(&basis); + } + acc + }) + .collect() +} + +/// Build a temporary witness from columns for config selection. +fn columns_to_witness( + matrix: &[Vec>], +) -> LabradorWitness { + LabradorWitness::new_unchecked(matrix.to_vec()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::protocol::greyhound::greyhound_verify_stage1; + use crate::protocol::labrador::{prove_level, prove_with_config, verify, LabradorProof}; + use crate::protocol::transcript::labels::DOMAIN_GREYHOUND_EVAL; + use crate::protocol::transcript::labels::DOMAIN_LABRADOR_PROTOCOL; + use crate::protocol::transcript::Blake2bTranscript; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + #[test] + fn eval_outputs_four_row_witness_and_five_constraints() { + let coeffs: Vec = (0..256).map(|i| F::from_i64((i as i64 % 13) - 6)).collect(); + let eval_point: Vec = (0..8).map(|i| F::from_i64(i as i64 + 1)).collect(); + let eval_value = F::from_i64(9); + let u1 = vec![CyclotomicRing::::one(), CyclotomicRing::::one()]; + let mut transcript = Blake2bTranscript::::new(DOMAIN_GREYHOUND_EVAL); + let (proof, witness, statement) = greyhound_eval( + &coeffs, + &eval_point, + eval_value, + &u1, + &[8u8; 32], + MatrixPrgBackendChoice::Shake256, + &mut transcript, + ) + .unwrap(); + assert_eq!(proof.u2, statement.u2); + assert_eq!(witness.rows().len(), 4); + assert_eq!(statement.constraints.len(), 5); + } + + #[test] + fn stage1_constraints_verify_with_full_witness() { + let backend = MatrixPrgBackendChoice::Shake256; + let comkey_seed = [42u8; 32]; + + let ring_elems = 16; + let coeffs = vec![F::zero(); ring_elems * D]; + + let ring_witness = pack_coefficients_to_ring::(&coeffs); + let (m_rows, n_cols, inner_vars) = choose_dimensions(ring_witness.len()); + let outer_vars = n_cols.trailing_zeros() as usize; + let eval_point: Vec = (0..(inner_vars + outer_vars)) + .map(|i| F::from_i64(i as i64 + 2)) + .collect(); + + let inner_point = &eval_point[eval_point.len() - inner_vars..]; + let mut inner_basis = vec![F::zero(); 1usize << inner_vars]; + multilinear_lagrange_basis(&mut inner_basis, inner_point); + let matrix = reshape_columns(&ring_witness, m_rows, n_cols); + let partial_evals = partial_evaluate_columns(&matrix, &inner_basis); + + let mut outer_basis = vec![F::zero(); 1usize << outer_vars]; + multilinear_lagrange_basis(&mut outer_basis, &eval_point[..outer_vars]); + let mut eval_ring = CyclotomicRing::::zero(); + for (v, basis) in partial_evals.iter().zip(outer_basis.iter()) { + eval_ring += v.scale(basis); + } + assert!(eval_ring.coefficients()[1..].iter().all(|c| c.is_zero())); + let eval_value = eval_ring.coefficients()[0]; + + let mut transcript = Blake2bTranscript::::new(DOMAIN_GREYHOUND_EVAL); + let (proof, witness, _statement) = greyhound_eval( + &coeffs, + &eval_point, + eval_value, + &[], + &comkey_seed, + backend, + &mut transcript, + ) + .unwrap(); + + let cfg = &proof.config; + assert!(cfg.kappa1 > 0); + let t_hat = &witness.rows()[2]; + let b_mat = derive_extendable_comkey_matrix::( + cfg.kappa1, + t_hat.len(), + &comkey_seed, + b"labrador/comkey/B", + backend, + ); + let u1 = mat_vec_mul(&b_mat, t_hat); + + let z_norm_sq = witness.rows()[0] + .iter() + .chain(witness.rows()[1].iter()) + .map(|ring| ring.coeff_norm_sq()) + .fold(0u128, |acc, v| acc.saturating_add(v)); + let mut verifier_transcript = Blake2bTranscript::::new(DOMAIN_GREYHOUND_EVAL); + greyhound_verify_stage1( + &proof, + &u1, + &eval_point, + eval_value, + &witness, + z_norm_sq, + &comkey_seed, + backend, + &mut verifier_transcript, + ) + .unwrap(); + } + + #[test] + fn stage2_single_labrador_fold_verifies() { + let backend = MatrixPrgBackendChoice::Shake256; + let comkey_seed = [42u8; 32]; + let jl_seed = [7u8; 16]; + + let ring_elems = 16; + let coeffs = vec![F::zero(); ring_elems * D]; + + let ring_witness = pack_coefficients_to_ring::(&coeffs); + let (m_rows, n_cols, inner_vars) = choose_dimensions(ring_witness.len()); + let outer_vars = n_cols.trailing_zeros() as usize; + let eval_point: Vec = (0..(inner_vars + outer_vars)) + .map(|i| F::from_i64(i as i64 + 2)) + .collect(); + + let inner_point = &eval_point[eval_point.len() - inner_vars..]; + let mut inner_basis = vec![F::zero(); 1usize << inner_vars]; + multilinear_lagrange_basis(&mut inner_basis, inner_point); + let matrix = reshape_columns(&ring_witness, m_rows, n_cols); + let partial_evals = partial_evaluate_columns(&matrix, &inner_basis); + + let mut outer_basis = vec![F::zero(); 1usize << outer_vars]; + multilinear_lagrange_basis(&mut outer_basis, &eval_point[..outer_vars]); + let mut eval_ring = CyclotomicRing::::zero(); + for (v, basis) in partial_evals.iter().zip(outer_basis.iter()) { + eval_ring += v.scale(basis); + } + let eval_value = eval_ring.coefficients()[0]; + + let mut gh_transcript = Blake2bTranscript::::new(DOMAIN_GREYHOUND_EVAL); + let (proof, witness, _statement) = greyhound_eval( + &coeffs, + &eval_point, + eval_value, + &[], + &comkey_seed, + backend, + &mut gh_transcript, + ) + .unwrap(); + + let z_norm_sq = witness.rows()[0] + .iter() + .chain(witness.rows()[1].iter()) + .map(|ring| ring.coeff_norm_sq()) + .fold(0u128, |acc, v| acc.saturating_add(v)); + let t_hat = &witness.rows()[2]; + let b_mat = derive_extendable_comkey_matrix::( + proof.config.kappa1, + t_hat.len(), + &comkey_seed, + b"labrador/comkey/B", + backend, + ); + let u1 = mat_vec_mul(&b_mat, t_hat); + let mut gh_verify_transcript = Blake2bTranscript::::new(DOMAIN_GREYHOUND_EVAL); + greyhound_verify_stage1( + &proof, + &u1, + &eval_point, + eval_value, + &witness, + z_norm_sq, + &comkey_seed, + backend, + &mut gh_verify_transcript, + ) + .unwrap(); + + let mut transcript_replay = Blake2bTranscript::::new(DOMAIN_GREYHOUND_EVAL); + absorb_greyhound_eval_context( + &mut transcript_replay, + &GreyhoundEvalTranscriptContext { + m_rows: proof.m_rows, + n_cols: proof.n_cols, + inner_vars: proof.inner_vars, + eval_point_len: eval_point.len(), + prg_backend_id: backend as u8, + }, + ) + .unwrap(); + absorb_greyhound_eval_claim(&mut transcript_replay, &eval_point, &eval_value); + absorb_greyhound_u2(&mut transcript_replay, &proof.u2); + let fold_challenges: Vec = (0..proof.n_cols) + .map(|_| sample_greyhound_fold_challenge(&mut transcript_replay)) + .collect(); + let mut statement = greyhound_reduce( + &proof, + &u1, + &eval_point, + eval_value, + &fold_challenges, + &comkey_seed, + backend, + ) + .unwrap(); + statement.beta_sq = witness.norm(); + + let mut prover_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + let fold = prove_level( + &witness, + &statement, + &proof.config, + &comkey_seed, + &jl_seed, + backend, + 0, + &mut prover_transcript, + ) + .unwrap(); + + let labrador_proof = LabradorProof { + levels: vec![fold.level_proof.clone()], + final_opening_witness: fold.next_witness.clone(), + }; + let mut verify_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + let verify_result = verify( + &statement, + &labrador_proof, + &comkey_seed, + &jl_seed, + backend, + &mut verify_transcript, + ) + .unwrap(); + assert_eq!(verify_result.final_opening_witness, fold.next_witness); + assert_eq!(verify_result.terminal_statement, fold.statement); + } + + #[test] + fn stage3_full_labrador_recursion_verifies() { + let backend = MatrixPrgBackendChoice::Shake256; + let comkey_seed = [42u8; 32]; + let jl_seed = [7u8; 16]; + + let ring_elems = 16; + let coeffs = vec![F::zero(); ring_elems * D]; + + let ring_witness = pack_coefficients_to_ring::(&coeffs); + let (m_rows, n_cols, inner_vars) = choose_dimensions(ring_witness.len()); + let outer_vars = n_cols.trailing_zeros() as usize; + let eval_point: Vec = (0..(inner_vars + outer_vars)) + .map(|i| F::from_i64(i as i64 + 3)) + .collect(); + + let inner_point = &eval_point[eval_point.len() - inner_vars..]; + let mut inner_basis = vec![F::zero(); 1usize << inner_vars]; + multilinear_lagrange_basis(&mut inner_basis, inner_point); + let matrix = reshape_columns(&ring_witness, m_rows, n_cols); + let partial_evals = partial_evaluate_columns(&matrix, &inner_basis); + + let mut outer_basis = vec![F::zero(); 1usize << outer_vars]; + multilinear_lagrange_basis(&mut outer_basis, &eval_point[..outer_vars]); + let mut eval_ring = CyclotomicRing::::zero(); + for (v, basis) in partial_evals.iter().zip(outer_basis.iter()) { + eval_ring += v.scale(basis); + } + let eval_value = eval_ring.coefficients()[0]; + + let mut gh_transcript = Blake2bTranscript::::new(DOMAIN_GREYHOUND_EVAL); + let (proof, witness, _statement) = greyhound_eval( + &coeffs, + &eval_point, + eval_value, + &[], + &comkey_seed, + backend, + &mut gh_transcript, + ) + .unwrap(); + + let z_norm_sq = witness.rows()[0] + .iter() + .chain(witness.rows()[1].iter()) + .map(|ring| ring.coeff_norm_sq()) + .fold(0u128, |acc, v| acc.saturating_add(v)); + let t_hat = &witness.rows()[2]; + let b_mat = derive_extendable_comkey_matrix::( + proof.config.kappa1, + t_hat.len(), + &comkey_seed, + b"labrador/comkey/B", + backend, + ); + let u1 = mat_vec_mul(&b_mat, t_hat); + let mut gh_verify_transcript = Blake2bTranscript::::new(DOMAIN_GREYHOUND_EVAL); + greyhound_verify_stage1( + &proof, + &u1, + &eval_point, + eval_value, + &witness, + z_norm_sq, + &comkey_seed, + backend, + &mut gh_verify_transcript, + ) + .unwrap(); + + let mut transcript_replay = Blake2bTranscript::::new(DOMAIN_GREYHOUND_EVAL); + absorb_greyhound_eval_context( + &mut transcript_replay, + &GreyhoundEvalTranscriptContext { + m_rows: proof.m_rows, + n_cols: proof.n_cols, + inner_vars: proof.inner_vars, + eval_point_len: eval_point.len(), + prg_backend_id: backend as u8, + }, + ) + .unwrap(); + absorb_greyhound_eval_claim(&mut transcript_replay, &eval_point, &eval_value); + absorb_greyhound_u2(&mut transcript_replay, &proof.u2); + let fold_challenges: Vec = (0..proof.n_cols) + .map(|_| sample_greyhound_fold_challenge(&mut transcript_replay)) + .collect(); + let mut statement = greyhound_reduce( + &proof, + &u1, + &eval_point, + eval_value, + &fold_challenges, + &comkey_seed, + backend, + ) + .unwrap(); + statement.beta_sq = witness.norm(); + + let mut prover_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + let labrador_proof = prove_with_config( + witness, + &statement, + &proof.config, + &comkey_seed, + &jl_seed, + backend, + &mut prover_transcript, + ) + .unwrap(); + assert!( + !labrador_proof.levels.is_empty(), + "expected Labrador recursion to run at least one level" + ); + + let mut verify_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + let verify_result = verify( + &statement, + &labrador_proof, + &comkey_seed, + &jl_seed, + backend, + &mut verify_transcript, + ) + .unwrap(); + assert_eq!( + verify_result.final_opening_witness, + labrador_proof.final_opening_witness + ); + } +} diff --git a/src/protocol/greyhound/mod.rs b/src/protocol/greyhound/mod.rs new file mode 100644 index 00000000..625f15ba --- /dev/null +++ b/src/protocol/greyhound/mod.rs @@ -0,0 +1,11 @@ +//! Greyhound evaluation reduction layer. + +pub mod eval; +pub mod reduce; +pub mod types; +pub mod verify; + +pub use eval::greyhound_eval; +pub use reduce::greyhound_reduce; +pub use types::{GreyhoundDimensions, GreyhoundEvalProof}; +pub use verify::greyhound_verify_stage1; diff --git a/src/protocol/greyhound/reduce.rs b/src/protocol/greyhound/reduce.rs new file mode 100644 index 00000000..6412426d --- /dev/null +++ b/src/protocol/greyhound/reduce.rs @@ -0,0 +1,332 @@ +//! Greyhound verifier-side reduction to Labrador statement. +//! +//! Builds 5 constraints matching the C reference, adapted for multilinear +//! evaluation. The fold challenges are passed in (sampled by the caller from +//! the transcript) so this function is transcript-free. + +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +use crate::primitives::poly::multilinear_lagrange_basis; +use crate::protocol::greyhound::types::GreyhoundEvalProof; +use crate::protocol::labrador::comkey::{derive_extendable_comkey_matrix, LabradorComKeySeed}; +use crate::protocol::labrador::types::{LabradorConstraint, LabradorStatement}; +use crate::protocol::prg::MatrixPrgBackendChoice; +use crate::{CanonicalField, FieldCore, FieldSampling}; + +/// Rebuild a Labrador statement from Greyhound proof data and fold challenges. +/// +/// The 5 constraints encode (multilinear adaptation of C's `polcom_reduce`): +/// 0. Outer commitment: B · group2 = u1 +/// 1. Eval-witness commitment: B_eval · group3 = u2 +/// 2. Amortization consistency: = +/// 3. Inner commitment relation: A · z = (mult=kappa) +/// 4. Evaluation check: = eval_value +/// +/// # Errors +/// +/// Returns an error if dimensions are invalid. +pub fn greyhound_reduce( + eval_proof: &GreyhoundEvalProof, + w_commitment_u1: &[CyclotomicRing], + eval_point: &[F], + eval_value: F, + fold_challenges: &[F], + comkey_seed: &LabradorComKeySeed, + backend: MatrixPrgBackendChoice, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, +{ + let m_rows = eval_proof.m_rows; + let n_cols = eval_proof.n_cols; + if n_cols == 0 || m_rows == 0 { + return Err(HachiError::InvalidInput( + "greyhound proof has zero dimensions".to_string(), + )); + } + if eval_point.len() < eval_proof.inner_vars { + return Err(HachiError::InvalidPointDimension { + expected: eval_proof.inner_vars, + actual: eval_point.len(), + }); + } + if fold_challenges.len() != n_cols { + return Err(HachiError::InvalidInput(format!( + "expected {} fold challenges, got {}", + n_cols, + fold_challenges.len() + ))); + } + + let outer_vars = eval_point.len() - eval_proof.inner_vars; + let mut outer_basis = vec![F::zero(); 1usize << outer_vars]; + multilinear_lagrange_basis(&mut outer_basis, &eval_point[..outer_vars]); + let mut inner_basis = vec![F::zero(); 1usize << eval_proof.inner_vars]; + if eval_proof.inner_vars > 0 { + multilinear_lagrange_basis( + &mut inner_basis, + &eval_point[eval_point.len() - eval_proof.inner_vars..], + ); + } else if !inner_basis.is_empty() { + inner_basis[0] = F::one(); + } + + let constraints = build_constraints( + eval_proof, + w_commitment_u1, + &outer_basis, + &inner_basis, + eval_value, + fold_challenges, + comkey_seed, + backend, + ); + + Ok(LabradorStatement { + u1: w_commitment_u1.to_vec(), + u2: eval_proof.u2.clone(), + challenges: Vec::new(), + constraints, + beta_sq: 0, + hash: [0u8; 16], + }) +} + +/// Build the 5 constraints for the 4-row Greyhound witness. +/// +/// Witness layout (element-major decomposition ordering): +/// row0: z_low — m*f elements, z_low[j*f+k] = low part of k-th decomp of z[j] +/// row1: z_high — m*f elements, z_high[j*f+k] = high part +/// row2: t_hat — kappa*fu*n elements, per-column decomposed inner commitments +/// row3: v_hat — fu*n elements, decomposed partial evaluations +/// +/// Reconstruction: +/// z[j] = sum_k 2^{kb} * (z_low[j*f+k] + 2^bu * z_high[j*f+k]) +/// t_col_i[c] = sum_l 2^{l*bu} * t_hat[i*kappa*fu + c*fu + l] +/// v[i] = sum_l 2^{l*bu} * v_hat[i*fu + l] +#[allow(clippy::too_many_arguments)] +fn build_constraints( + proof: &GreyhoundEvalProof, + u1: &[CyclotomicRing], + outer_basis: &[F], + inner_basis: &[F], + eval_value: F, + fold_challenges: &[F], + comkey_seed: &LabradorComKeySeed, + backend: MatrixPrgBackendChoice, +) -> Vec> { + let m = proof.m_rows; + let n = proof.n_cols; + let cfg = &proof.config; + let f = cfg.f; + let b = cfg.b; + let fu = cfg.fu; + let bu = cfg.bu; + let kappa = cfg.kappa; + let kappa1 = cfg.kappa1; + + let scalar_ring = + |s: F| -> CyclotomicRing { + CyclotomicRing::from_coefficients(std::array::from_fn(|k| { + if k == 0 { + s + } else { + F::zero() + } + })) + }; + + let pow2 = |exp: usize| -> F { + let mut v = F::one(); + for _ in 0..exp { + v = v + v; + } + v + }; + + // cnst0: B · row2 = u1 (outer commitment of decomposed inner commitments) + let num_rows = 4; // z_low, z_high, t_hat, v_hat + let t_hat_len = kappa * fu * n; + let v_hat_len = fu * n; + let z_group_len = m * f; + + // cnst0: B · row2 = u1 + let c0 = if kappa1 > 0 { + let b_mat = derive_extendable_comkey_matrix::( + kappa1, + t_hat_len, + comkey_seed, + b"labrador/comkey/B", + backend, + ); + let coeffs: Vec> = b_mat.into_iter().flatten().collect(); + let mut coefficients = vec![vec![]; num_rows]; + coefficients[2] = coeffs; + LabradorConstraint { + coefficients, + target: u1.to_vec(), + } + } else { + let one = CyclotomicRing::::one(); + let mut coefficients = vec![vec![]; num_rows]; + coefficients[2] = vec![one; u1.len()]; + LabradorConstraint { + coefficients, + target: u1.to_vec(), + } + }; + + // cnst1: B_eval · row3 = u2 + let c1 = if kappa1 > 0 { + let b_eval = derive_extendable_comkey_matrix::( + kappa1, + v_hat_len, + comkey_seed, + b"greyhound/comkey/B_eval", + backend, + ); + let coeffs: Vec> = b_eval.into_iter().flatten().collect(); + let mut coefficients = vec![vec![]; num_rows]; + coefficients[3] = coeffs; + LabradorConstraint { + coefficients, + target: proof.u2.clone(), + } + } else { + let one = CyclotomicRing::::one(); + let mut coefficients = vec![vec![]; num_rows]; + coefficients[3] = vec![one; proof.u2.len()]; + LabradorConstraint { + coefficients, + target: proof.u2.clone(), + } + }; + + // cnst2: amortization consistency + let mut phi0 = vec![CyclotomicRing::::zero(); z_group_len]; + let mut phi1 = vec![CyclotomicRing::::zero(); z_group_len]; + let bu_scale = pow2(bu); + for j in 0..m { + for k in 0..f { + let w = scalar_ring(inner_basis[j] * pow2(k * b)); + phi0[j * f + k] = w; + phi1[j * f + k] = w.scale(&bu_scale); + } + } + let mut phi_v = vec![CyclotomicRing::::zero(); v_hat_len]; + for i in 0..n { + for l in 0..fu { + phi_v[i * fu + l] = scalar_ring(-(fold_challenges[i] * pow2(l * bu))); + } + } + let c2 = LabradorConstraint { + coefficients: vec![phi0, phi1, vec![], phi_v], + target: vec![CyclotomicRing::::zero()], + }; + + // cnst3: inner commitment relation A·z - c·t = 0 + let a_mat = derive_extendable_comkey_matrix::( + kappa, + m, + comkey_seed, + b"labrador/comkey/A", + backend, + ); + let mut phi_z0 = vec![CyclotomicRing::::zero(); kappa * z_group_len]; + let mut phi_z1 = vec![CyclotomicRing::::zero(); kappa * z_group_len]; + for r in 0..kappa { + for j in 0..m { + for k in 0..f { + let w = a_mat[r][j].scale(&pow2(k * b)); + phi_z0[r * z_group_len + j * f + k] = w; + phi_z1[r * z_group_len + j * f + k] = w.scale(&bu_scale); + } + } + } + let t_hat_per_col = kappa * fu; + let mut phi_t = vec![CyclotomicRing::::zero(); kappa * t_hat_len]; + for i in 0..n { + for l in 0..fu { + let neg_ci_scale = scalar_ring(-(fold_challenges[i] * pow2(l * bu))); + for r in 0..kappa { + phi_t[r * t_hat_len + i * t_hat_per_col + r * fu + l] = neg_ci_scale; + } + } + } + let c3 = LabradorConstraint { + coefficients: vec![phi_z0, phi_z1, phi_t, vec![]], + target: vec![CyclotomicRing::::zero(); kappa], + }; + + // cnst4: evaluation check + let mut phi_eval = vec![CyclotomicRing::::zero(); v_hat_len]; + for i in 0..n { + let ob = outer_basis.get(i).copied().unwrap_or_else(F::zero); + for l in 0..fu { + phi_eval[i * fu + l] = scalar_ring(ob * pow2(l * bu)); + } + } + let mut coefficients = vec![vec![]; num_rows]; + coefficients[3] = phi_eval; + let c4 = LabradorConstraint { + coefficients, + target: vec![scalar_ring(eval_value)], + }; + + vec![c0, c1, c2, c3, c4] +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::protocol::labrador::types::LabradorReductionConfig; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + #[test] + fn reduce_builds_five_constraints() { + let cfg = LabradorReductionConfig { + f: 1, + b: 8, + fu: 2, + bu: 10, + kappa: 3, + kappa1: 2, + tail: false, + }; + let proof = GreyhoundEvalProof { + u2: vec![CyclotomicRing::::one(), CyclotomicRing::::one()], + m_rows: 4, + n_cols: 4, + inner_vars: 2, + config: cfg, + }; + let u1 = vec![CyclotomicRing::::one(), CyclotomicRing::::one()]; + let eval_point = vec![ + F::from_i64(1), + F::from_i64(2), + F::from_i64(3), + F::from_i64(4), + ]; + let fold_challenges = vec![ + F::from_i64(1), + F::from_i64(2), + F::from_i64(3), + F::from_i64(4), + ]; + let st = greyhound_reduce( + &proof, + &u1, + &eval_point, + F::from_i64(7), + &fold_challenges, + &[8u8; 32], + MatrixPrgBackendChoice::Shake256, + ) + .unwrap(); + assert_eq!(st.constraints.len(), 5); + } +} diff --git a/src/protocol/greyhound/scheme.rs b/src/protocol/greyhound/scheme.rs new file mode 100644 index 00000000..e69de29b diff --git a/src/protocol/greyhound/types.rs b/src/protocol/greyhound/types.rs new file mode 100644 index 00000000..c808bff7 --- /dev/null +++ b/src/protocol/greyhound/types.rs @@ -0,0 +1,52 @@ +//! Greyhound evaluation proof types. + +use crate::algebra::ring::CyclotomicRing; +use crate::protocol::labrador::types::LabradorReductionConfig; +use crate::FieldCore; + +/// Shape metadata for reshaped witness matrices. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct GreyhoundDimensions { + /// Number of matrix rows (`2^{k_inner}`). + pub m_rows: usize, + /// Number of matrix columns (`2^{k_outer}`). + pub n_cols: usize, + /// Number of inner variables (`k_inner`). + pub inner_vars: usize, +} + +/// Greyhound evaluation proof payload. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GreyhoundEvalProof { + /// Outer commitment to decomposed partial evaluations. + pub u2: Vec>, + /// Matrix row count. + pub m_rows: usize, + /// Matrix column count. + pub n_cols: usize, + /// Split point for `r = (r_outer, r_inner)`. + pub inner_vars: usize, + /// Labrador config agreed between prover and verifier. + pub config: LabradorReductionConfig, +} + +impl GreyhoundEvalProof { + /// Construct an empty proof (used when Greyhound is disabled). + pub fn empty() -> Self { + Self { + u2: Vec::new(), + m_rows: 0, + n_cols: 0, + inner_vars: 0, + config: LabradorReductionConfig { + f: 1, + b: 1, + fu: 1, + bu: 1, + kappa: 1, + kappa1: 0, + tail: false, + }, + } + } +} diff --git a/src/protocol/greyhound/verify.rs b/src/protocol/greyhound/verify.rs new file mode 100644 index 00000000..7f1a7373 --- /dev/null +++ b/src/protocol/greyhound/verify.rs @@ -0,0 +1,299 @@ +//! Greyhound verifier-side checks (stage 1, no Labrador recursion). + +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +use crate::primitives::poly::multilinear_lagrange_basis; +use crate::protocol::greyhound::types::GreyhoundEvalProof; +use crate::protocol::labrador::comkey::{derive_extendable_comkey_matrix, LabradorComKeySeed}; +use crate::protocol::labrador::transcript::{ + absorb_greyhound_eval_claim, absorb_greyhound_eval_context, absorb_greyhound_u2, + sample_greyhound_fold_challenge, GreyhoundEvalTranscriptContext, +}; +use crate::protocol::labrador::types::LabradorWitness; +use crate::protocol::labrador::utils::mat_vec_mul; +use crate::protocol::prg::MatrixPrgBackendChoice; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, FieldSampling}; + +/// Verify Greyhound evaluation proof using the full auxiliary witness. +/// +/// This stage performs direct checks of the linear system `Pz = h` and a +/// smallness bound on the transmitted `z` witness. Labrador recursion is +/// intentionally skipped. +/// +/// # Errors +/// +/// Returns [`HachiError::InvalidInput`] on dimension mismatches, norm bound +/// violations, commitment mismatches, or constraint failures. +/// Propagates transcript replay failures from Fiat-Shamir operations. +#[allow(clippy::too_many_arguments)] +pub fn greyhound_verify_stage1( + eval_proof: &GreyhoundEvalProof, + w_commitment_u1: &[CyclotomicRing], + eval_point: &[F], + eval_value: F, + witness: &LabradorWitness, + z_beta_sq: u128, + comkey_seed: &LabradorComKeySeed, + backend: MatrixPrgBackendChoice, + transcript: &mut T, +) -> Result<(), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, +{ + let m = eval_proof.m_rows; + let n = eval_proof.n_cols; + if m == 0 || n == 0 { + return Err(HachiError::InvalidInput( + "greyhound: zero-dimension proof".to_string(), + )); + } + if eval_point.len() < eval_proof.inner_vars { + return Err(HachiError::InvalidPointDimension { + expected: eval_proof.inner_vars, + actual: eval_point.len(), + }); + } + + let cfg = &eval_proof.config; + let z_group_len = m * cfg.f; + let t_hat_len = cfg.kappa * cfg.fu * n; + let v_hat_len = cfg.fu * n; + + let rows = witness.rows(); + if rows.len() != 4 { + return Err(HachiError::InvalidInput( + "greyhound: expected 4 witness rows".to_string(), + )); + } + let z_low = &rows[0]; + let z_high = &rows[1]; + let t_hat = &rows[2]; + let v_hat = &rows[3]; + if z_low.len() != z_group_len + || z_high.len() != z_group_len + || t_hat.len() != t_hat_len + || v_hat.len() != v_hat_len + { + return Err(HachiError::InvalidInput( + "greyhound: witness row lengths mismatch".to_string(), + )); + } + + // Check smallness of the transmitted z witness (rows 0 and 1). + let z_norm_sq = z_low + .iter() + .chain(z_high.iter()) + .map(|ring| ring.coeff_norm_sq()) + .fold(0u128, |acc, v| acc.saturating_add(v)); + if z_norm_sq > z_beta_sq { + return Err(HachiError::InvalidInput( + "greyhound: z norm exceeds bound".to_string(), + )); + } + + // Commitment checks: u1 (inner commitments) and u2 (evaluation witness). + let u1_expected = if cfg.kappa1 > 0 { + let b_mat = derive_extendable_comkey_matrix::( + cfg.kappa1, + t_hat_len, + comkey_seed, + b"labrador/comkey/B", + backend, + ); + mat_vec_mul(&b_mat, t_hat) + } else { + t_hat.to_vec() + }; + if u1_expected != w_commitment_u1 { + return Err(HachiError::InvalidInput( + "greyhound: u1 commitment mismatch".to_string(), + )); + } + + let u2_expected = if cfg.kappa1 > 0 { + let b_eval = derive_extendable_comkey_matrix::( + cfg.kappa1, + v_hat_len, + comkey_seed, + b"greyhound/comkey/B_eval", + backend, + ); + mat_vec_mul(&b_eval, v_hat) + } else { + v_hat.to_vec() + }; + if u2_expected != eval_proof.u2 { + return Err(HachiError::InvalidInput( + "greyhound: u2 commitment mismatch".to_string(), + )); + } + + // Transcript replay to obtain fold challenges. + absorb_greyhound_eval_context( + transcript, + &GreyhoundEvalTranscriptContext { + m_rows: m, + n_cols: n, + inner_vars: eval_proof.inner_vars, + eval_point_len: eval_point.len(), + prg_backend_id: backend as u8, + }, + )?; + absorb_greyhound_eval_claim(transcript, eval_point, &eval_value); + absorb_greyhound_u2(transcript, &eval_proof.u2); + let fold_challenges: Vec = (0..n) + .map(|_| sample_greyhound_fold_challenge(transcript)) + .collect(); + + // Basis vectors. + let outer_vars = eval_point.len() - eval_proof.inner_vars; + let mut outer_basis = vec![F::zero(); 1usize << outer_vars]; + multilinear_lagrange_basis(&mut outer_basis, &eval_point[..outer_vars]); + let mut inner_basis = vec![F::zero(); 1usize << eval_proof.inner_vars]; + if eval_proof.inner_vars > 0 { + multilinear_lagrange_basis( + &mut inner_basis, + &eval_point[eval_point.len() - eval_proof.inner_vars..], + ); + } else if !inner_basis.is_empty() { + inner_basis[0] = F::one(); + } + + let z = reconstruct_z(z_low, z_high, m, cfg.f, cfg.b, cfg.bu); + let v = reconstruct_v(v_hat, n, cfg.fu, cfg.bu); + let t_cols = reconstruct_t_cols(t_hat, n, cfg.kappa, cfg.fu, cfg.bu); + + // Constraint 2: = sum_i c_i * v_i. + let mut lhs = CyclotomicRing::::zero(); + for (j, basis) in inner_basis.iter().enumerate() { + let z_j = z.get(j).copied().unwrap_or_else(CyclotomicRing::zero); + lhs += z_j.scale(basis); + } + let mut rhs = CyclotomicRing::::zero(); + for (i, c_i) in fold_challenges.iter().enumerate() { + let v_i = v.get(i).copied().unwrap_or_else(CyclotomicRing::zero); + rhs += v_i.scale(c_i); + } + if lhs != rhs { + return Err(HachiError::InvalidInput( + "greyhound: amortization constraint failed".to_string(), + )); + } + + // Constraint 3: A * z = sum_i c_i * t_i. + let a_mat = derive_extendable_comkey_matrix::( + cfg.kappa, + m, + comkey_seed, + b"labrador/comkey/A", + backend, + ); + let lhs_vec = mat_vec_mul(&a_mat, &z); + let mut rhs_vec = vec![CyclotomicRing::::zero(); cfg.kappa]; + for (i, c_i) in fold_challenges.iter().enumerate() { + if let Some(t_i) = t_cols.get(i) { + for (r, t_ir) in t_i.iter().enumerate() { + rhs_vec[r] += t_ir.scale(c_i); + } + } + } + if lhs_vec != rhs_vec { + return Err(HachiError::InvalidInput( + "greyhound: inner commitment constraint failed".to_string(), + )); + } + + // Constraint 4: = eval_value. + let mut eval_check = CyclotomicRing::::zero(); + for (i, basis) in outer_basis.iter().enumerate() { + let v_i = v.get(i).copied().unwrap_or_else(CyclotomicRing::zero); + eval_check += v_i.scale(basis); + } + if eval_check != scalar_ring(eval_value) { + return Err(HachiError::InvalidInput( + "greyhound: evaluation constraint failed".to_string(), + )); + } + + Ok(()) +} + +fn pow2(exp: usize) -> F { + let mut v = F::one(); + for _ in 0..exp { + v = v + v; + } + v +} + +fn scalar_ring(s: F) -> CyclotomicRing { + CyclotomicRing::from_coefficients(std::array::from_fn(|i| if i == 0 { s } else { F::zero() })) +} + +fn reconstruct_z( + z_low: &[CyclotomicRing], + z_high: &[CyclotomicRing], + m: usize, + f: usize, + b: usize, + bu: usize, +) -> Vec> { + let mut out = vec![CyclotomicRing::::zero(); m]; + let bu_scale = pow2::(bu); + for (j, out_elem) in out.iter_mut().enumerate() { + let mut acc = CyclotomicRing::::zero(); + for k in 0..f { + let idx = j * f + k; + let mut digit = z_low[idx]; + digit += z_high[idx].scale(&bu_scale); + let scale = pow2::(k * b); + acc += digit.scale(&scale); + } + *out_elem = acc; + } + out +} + +fn reconstruct_v( + v_hat: &[CyclotomicRing], + n: usize, + fu: usize, + bu: usize, +) -> Vec> { + let mut out = vec![CyclotomicRing::::zero(); n]; + for (i, out_elem) in out.iter_mut().enumerate() { + let mut acc = CyclotomicRing::::zero(); + for l in 0..fu { + let idx = i * fu + l; + let scale = pow2::(l * bu); + acc += v_hat[idx].scale(&scale); + } + *out_elem = acc; + } + out +} + +fn reconstruct_t_cols( + t_hat: &[CyclotomicRing], + n: usize, + kappa: usize, + fu: usize, + bu: usize, +) -> Vec>> { + let mut out = vec![vec![CyclotomicRing::::zero(); kappa]; n]; + let per_col = kappa * fu; + for (i, out_row) in out.iter_mut().enumerate() { + for (r, out_elem) in out_row.iter_mut().enumerate() { + let mut acc = CyclotomicRing::::zero(); + for l in 0..fu { + let idx = i * per_col + r * fu + l; + let scale = pow2::(l * bu); + acc += t_hat[idx].scale(&scale); + } + *out_elem = acc; + } + } + out +} diff --git a/src/protocol/hachi_poly_ops/decompose_fold_neon.rs b/src/protocol/hachi_poly_ops/decompose_fold_neon.rs new file mode 100644 index 00000000..0ee84394 --- /dev/null +++ b/src/protocol/hachi_poly_ops/decompose_fold_neon.rs @@ -0,0 +1,143 @@ +//! AArch64 NEON kernel for sparse-multiply-accumulate in `decompose_fold`. +//! +//! Rotates an i8 digit plane by each challenge position and accumulates +//! into an i32 accumulator using widening add/sub (`SADDW` / `SSUBW`). + +use std::arch::aarch64::*; + +/// NEON sparse-multiply-accumulate. +/// +/// For each challenge term `(pos, coeff)`, rotates the `digit_plane` by `pos` +/// positions in the negacyclic ring (X^D + 1) and adds or subtracts the +/// widened i8 values into the i32 `acc`. +/// +/// # Safety +/// +/// - `digit_plane` must point to at least `d` valid i8 values. +/// - `acc` must point to at least `d` valid i32 values. +/// - `d` must be a multiple of 16. +#[target_feature(enable = "neon")] +pub(super) unsafe fn sparse_mul_acc_neon( + digit_plane: *const i8, + acc: *mut i32, + d: usize, + positions: &[u32], + coeffs: &[i16], +) { + debug_assert!(d % 16 == 0); + + for (&pos, &coeff) in positions.iter().zip(coeffs.iter()) { + let p = pos as usize; + let split = d - p; + + if coeff > 0 { + acc_rotated_add(digit_plane, acc, d, p, split); + } else { + acc_rotated_sub(digit_plane, acc, d, p, split); + } + } +} + +/// Add rotated digit plane: acc[i+p] += digits[i] for i in [0, split), +/// acc[i-split] -= digits[i] for i in [split, D) (negacyclic wrap). +#[inline(always)] +unsafe fn acc_rotated_add(digits: *const i8, acc: *mut i32, d: usize, p: usize, split: usize) { + // First segment: digits[0..split] -> acc[p..D], ADD + acc_segment_add(digits, acc.add(p), split); + // Second segment: digits[split..D] -> acc[0..p], SUB (negacyclic) + if p > 0 { + acc_segment_sub(digits.add(split), acc, p); + } + let _ = d; +} + +/// Sub rotated digit plane: acc[i+p] -= digits[i] for i in [0, split), +/// acc[i-split] += digits[i] for i in [split, D) (negacyclic wrap). +#[inline(always)] +unsafe fn acc_rotated_sub(digits: *const i8, acc: *mut i32, d: usize, p: usize, split: usize) { + // First segment: digits[0..split] -> acc[p..D], SUB + acc_segment_sub(digits, acc.add(p), split); + // Second segment: digits[split..D] -> acc[0..p], ADD (negacyclic) + if p > 0 { + acc_segment_add(digits.add(split), acc, p); + } + let _ = d; +} + +/// Widen i8 source values to i32 and ADD into accumulator. +/// Handles arbitrary length (processes 16 at a time, then remainder). +#[inline(always)] +unsafe fn acc_segment_add(src: *const i8, dst: *mut i32, len: usize) { + let chunks = len / 16; + let rem = len % 16; + + for i in 0..chunks { + let offset = i * 16; + let v = vld1q_s8(src.add(offset)); + + let lo8 = vget_low_s8(v); + let hi8 = vget_high_s8(v); + let lo16 = vmovl_s8(lo8); + let hi16 = vmovl_s8(hi8); + + let s0 = vmovl_s16(vget_low_s16(lo16)); + let s1 = vmovl_s16(vget_high_s16(lo16)); + let s2 = vmovl_s16(vget_low_s16(hi16)); + let s3 = vmovl_s16(vget_high_s16(hi16)); + + let d0 = vld1q_s32(dst.add(offset)); + let d1 = vld1q_s32(dst.add(offset + 4)); + let d2 = vld1q_s32(dst.add(offset + 8)); + let d3 = vld1q_s32(dst.add(offset + 12)); + + vst1q_s32(dst.add(offset), vaddq_s32(d0, s0)); + vst1q_s32(dst.add(offset + 4), vaddq_s32(d1, s1)); + vst1q_s32(dst.add(offset + 8), vaddq_s32(d2, s2)); + vst1q_s32(dst.add(offset + 12), vaddq_s32(d3, s3)); + } + + let base = chunks * 16; + for i in 0..rem { + let val = *src.add(base + i) as i32; + *dst.add(base + i) += val; + } +} + +/// Widen i8 source values to i32 and SUB from accumulator. +/// Handles arbitrary length (processes 16 at a time, then remainder). +#[inline(always)] +unsafe fn acc_segment_sub(src: *const i8, dst: *mut i32, len: usize) { + let chunks = len / 16; + let rem = len % 16; + + for i in 0..chunks { + let offset = i * 16; + let v = vld1q_s8(src.add(offset)); + + let lo8 = vget_low_s8(v); + let hi8 = vget_high_s8(v); + let lo16 = vmovl_s8(lo8); + let hi16 = vmovl_s8(hi8); + + let s0 = vmovl_s16(vget_low_s16(lo16)); + let s1 = vmovl_s16(vget_high_s16(lo16)); + let s2 = vmovl_s16(vget_low_s16(hi16)); + let s3 = vmovl_s16(vget_high_s16(hi16)); + + let d0 = vld1q_s32(dst.add(offset)); + let d1 = vld1q_s32(dst.add(offset + 4)); + let d2 = vld1q_s32(dst.add(offset + 8)); + let d3 = vld1q_s32(dst.add(offset + 12)); + + vst1q_s32(dst.add(offset), vsubq_s32(d0, s0)); + vst1q_s32(dst.add(offset + 4), vsubq_s32(d1, s1)); + vst1q_s32(dst.add(offset + 8), vsubq_s32(d2, s2)); + vst1q_s32(dst.add(offset + 12), vsubq_s32(d3, s3)); + } + + let base = chunks * 16; + for i in 0..rem { + let val = *src.add(base + i) as i32; + *dst.add(base + i) -= val; + } +} diff --git a/src/protocol/hachi_poly_ops/mod.rs b/src/protocol/hachi_poly_ops/mod.rs new file mode 100644 index 00000000..d7dfa74f --- /dev/null +++ b/src/protocol/hachi_poly_ops/mod.rs @@ -0,0 +1,1050 @@ +//! Operation-centric polynomial trait for the Hachi commitment scheme. +//! +//! [`HachiPolyOps`] exposes the four operations the Hachi commit/prove paths +//! need from a polynomial, rather than raw coefficient access. Each +//! implementation handles every operation in its own optimal way: +//! +//! - [`DensePoly`] — standard dense algorithms (decompose + NTT matvec). +//! - [`OneHotPoly`] — sparse monomial tricks, avoids all inner ring +//! multiplications. +//! +//! # Extensibility +//! +//! This trait is coupled to power-of-2 cyclotomic rings +//! ([`CyclotomicRing`]). When non-power-of-2 rings are added, the trait +//! signature will change. Additional operation methods may be added as the +//! protocol evolves. + +use crate::algebra::fields::wide::HasWide; +use crate::algebra::ring::sparse_challenge::SparseChallenge; +use crate::algebra::CyclotomicRing; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::commitment::onehot::{ + inner_ajtai_onehot_wide, map_onehot_to_sparse_blocks, SparseBlockEntry, +}; +use crate::protocol::commitment::utils::crt_ntt::NttSlotCache; +use crate::protocol::commitment::utils::flat_matrix::FlatMatrix; +use crate::protocol::commitment::utils::linear::{ + decompose_rows_i8, mat_vec_mul_ntt_digits_i8, mat_vec_mul_ntt_i8, +}; +use crate::{cfg_fold_reduce, cfg_into_iter, cfg_iter, CanonicalField, FieldCore}; +use std::array::from_fn; +use std::marker::PhantomData; + +#[cfg(target_arch = "aarch64")] +use crate::algebra::ntt::neon; + +#[cfg(target_arch = "aarch64")] +mod decompose_fold_neon; + +/// Precomputed constants for balanced base-b decomposition. +struct DecomposeParams { + half_q: u128, + q: u128, + mask: i128, + half_b: i128, + b_val: i128, + log_basis: u32, +} + +/// Decompose all D coefficients of a ring element into balanced base-b digits, +/// storing results in digit-major order for subsequent SIMD scatter. +/// +/// Uses K=3 interleaved carry chains to saturate ALU throughput (3x ILP gain +/// over processing one coefficient at a time on out-of-order cores). +/// +/// `digit_buf` is `[num_digits][D]` in i8, OVERWRITTEN (not accumulated). +#[inline(never)] +fn decompose_ring_interleaved( + coeffs: &[u128; D], + digit_buf: &mut [Vec], + num_digits: usize, + p: &DecomposeParams, +) { + let bulk_end = D - (D % 3); + + for base in (0..bulk_end).step_by(3) { + let mut c0 = to_signed(coeffs[base], p); + let mut c1 = to_signed(coeffs[base + 1], p); + let mut c2 = to_signed(coeffs[base + 2], p); + + for plane in digit_buf.iter_mut().take(num_digits) { + let d0 = extract_balanced_digit(&mut c0, p); + let d1 = extract_balanced_digit(&mut c1, p); + let d2 = extract_balanced_digit(&mut c2, p); + plane[base] = d0 as i8; + plane[base + 1] = d1 as i8; + plane[base + 2] = d2 as i8; + } + } + + for idx in bulk_end..D { + let mut c = to_signed(coeffs[idx], p); + for plane in digit_buf.iter_mut().take(num_digits) { + plane[idx] = extract_balanced_digit(&mut c, p) as i8; + } + } +} + +#[inline(always)] +fn to_signed(canonical: u128, p: &DecomposeParams) -> i128 { + if canonical > p.half_q { + -((p.q - canonical) as i128) + } else { + canonical as i128 + } +} + +#[inline(always)] +fn extract_balanced_digit(c: &mut i128, p: &DecomposeParams) -> i32 { + let d = *c & p.mask; + let balanced = if d >= p.half_b { d - p.b_val } else { d }; + *c = (*c - balanced) >> p.log_basis; + balanced as i32 +} + +/// Scalar sparse-multiply-accumulate: accumulate `challenge * digit_plane` +/// into `acc` using the rotate-and-add formulation. +/// +/// `digit_plane` is `[i8; D]`, `acc` is `[i32; D]`. +/// Each challenge term rotates the digit plane and adds/subtracts contiguously. +fn sparse_mul_acc_scalar( + digit_plane: &[i8], + challenge: &SparseChallenge, + acc: &mut [i32; D], +) { + for (&pos, &coeff) in challenge.positions.iter().zip(challenge.coeffs.iter()) { + let p = pos as usize; + let split = D - p; + if coeff > 0 { + for i in 0..split { + acc[i + p] += digit_plane[i] as i32; + } + for i in split..D { + acc[i - split] -= digit_plane[i] as i32; + } + } else { + for i in 0..split { + acc[i + p] -= digit_plane[i] as i32; + } + for i in split..D { + acc[i - split] += digit_plane[i] as i32; + } + } + } +} + +/// Dispatch to NEON or scalar sparse-multiply-accumulate. +#[inline(always)] +fn sparse_mul_acc( + digit_plane: &[i8], + challenge: &SparseChallenge, + acc: &mut [i32; D], +) { + #[cfg(target_arch = "aarch64")] + { + if neon::use_neon_ntt() { + unsafe { + decompose_fold_neon::sparse_mul_acc_neon( + digit_plane.as_ptr(), + acc.as_mut_ptr(), + D, + &challenge.positions, + &challenge.coeffs, + ); + } + return; + } + } + sparse_mul_acc_scalar::(digit_plane, challenge, acc); +} + +/// Operations the Hachi commitment scheme needs from a polynomial. +/// +/// The four methods correspond to the four places in commit/prove that consume +/// polynomial data. Implementations decide *how* to carry out each operation +/// (dense decompose + NTT, sparse monomial tricks, streaming, etc.). +pub trait HachiPolyOps: Clone + Send + Sync { + /// Per-polynomial cache type for the A-matrix commit path. + /// + /// `DensePoly` uses `NttSlotCache` (CRT+NTT of A for dense mat-vec). + /// `OneHotPoly` uses `()` (one-hot commit bypasses NTT entirely). + type CommitCache: Send + Sync; + + /// Total number of ring elements in the polynomial. + fn num_ring_elems(&self) -> usize; + + /// **Op 1 — prove: ring-space evaluation.** + /// + /// Computes the global weighted sum `y = Σᵢ scalars[i] · self[i]`. + /// + /// `scalars` has length >= `num_ring_elems`; excess entries are ignored. + fn evaluate_ring(&self, scalars: &[F]) -> CyclotomicRing; + + /// **Op 2 — prove: per-block fold.** + /// + /// For each contiguous block of `block_len` ring elements, computes + /// `Σⱼ scalars[j] · self[i·block_len + j]`. + /// + /// Returns one ring element per block (total `ceil(num_ring_elems / block_len)`). + /// `scalars` has length `block_len`. + fn fold_blocks(&self, scalars: &[F], block_len: usize) -> Vec>; + + /// Fused fold + evaluation in a single pass over the polynomial. + /// + /// `eval_outer_scalars` is the per-block weight vector `b` (size `num_blocks`). + /// `fold_scalars` is the per-element-in-block weight vector `a` (size `block_len`). + /// + /// The full evaluation scalars factor as `outer_weights[i*block_len + j] = b[i] * a[j]`, + /// so `eval = Σ_i b[i] * fold(a)[i]` — derived from the fold result without + /// materializing the full `2^(m_vars + r_vars)` weight vector. + fn evaluate_and_fold( + &self, + eval_outer_scalars: &[F], + fold_scalars: &[F], + block_len: usize, + ) -> (CyclotomicRing, Vec>) { + let folded = self.fold_blocks(fold_scalars, block_len); + let eval = folded + .iter() + .zip(eval_outer_scalars.iter()) + .fold(CyclotomicRing::::zero(), |acc, (f_i, s_i)| { + acc + f_i.scale(s_i) + }); + (eval, folded) + } + + /// **Op 3 — prove: decompose + challenge-fold.** + /// + /// For each block of `block_len` ring elements: + /// 1. Decompose: `sᵢ = G⁻¹(blockᵢ)` via `balanced_decompose_pow2(num_digits, log_basis)`. + /// 2. Accumulate: `z += cᵢ ⊗ sᵢ` (sparse challenge multiplication). + /// + /// Returns `z` of length `block_len · num_digits`. + fn decompose_fold( + &self, + challenges: &[SparseChallenge], + block_len: usize, + num_digits: usize, + log_basis: u32, + ) -> Vec>; + + /// **Op 4 — commit: per-block inner Ajtai.** + /// + /// For each block of `block_len` ring elements: + /// 1. `sᵢ = G⁻¹(blockᵢ)` with `num_digits_commit` levels. + /// 2. `tᵢ = A · sᵢ` (matrix-vector multiply via NTT cache or sparse path). + /// 3. `t̂ᵢ = G⁻¹(tᵢ)` with `num_digits_open` levels (t has full-field + /// coefficients regardless of s's digit count). + /// + /// Returns one `t̂ᵢ` vector per block as `[i8; D]` digit planes. + /// + /// # Errors + /// + /// Returns an error if the cached matrix-vector multiply fails. + fn commit_inner( + &self, + a_matrix: &FlatMatrix, + ntt_a: &NttSlotCache, + block_len: usize, + num_digits_commit: usize, + num_digits_open: usize, + log_basis: u32, + ) -> Result>, HachiError>; +} + +/// Dense polynomial: all ring coefficients materialized in memory. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DensePoly { + /// Ring coefficients in sequential block order. + pub coeffs: Vec>, +} + +impl DensePoly { + /// Pack field-element evaluations into ring elements. + /// + /// The first `α = log₂(D)` variables become coefficient slots within each + /// ring element; the remaining variables index ring elements. + /// + /// # Errors + /// + /// Returns an error if `D` is not a power of two, `num_vars < log₂(D)`, or + /// `evals.len() != 2^num_vars`. + pub fn from_field_evals(num_vars: usize, evals: &[F]) -> Result { + if D == 0 || !D.is_power_of_two() { + return Err(HachiError::InvalidInput(format!( + "ring degree D={D} is not a power of two" + ))); + } + let alpha = D.trailing_zeros() as usize; + if num_vars < alpha { + return Err(HachiError::InvalidInput(format!( + "num_vars {num_vars} is smaller than alpha {alpha}" + ))); + } + let expected_len = 1usize + .checked_shl(num_vars as u32) + .ok_or_else(|| HachiError::InvalidInput(format!("2^{num_vars} does not fit usize")))?; + if evals.len() != expected_len { + return Err(HachiError::InvalidSize { + expected: expected_len, + actual: evals.len(), + }); + } + + let outer_len = expected_len / D; + let coeffs: Vec> = (0..outer_len) + .map(|i| CyclotomicRing::from_slice(&evals[i * D..(i + 1) * D])) + .collect(); + Ok(Self { coeffs }) + } + + /// Wrap an existing vector of ring elements. + pub fn from_ring_coeffs(coeffs: Vec>) -> Self { + Self { coeffs } + } +} + +impl HachiPolyOps for DensePoly +where + F: FieldCore + CanonicalField, +{ + type CommitCache = NttSlotCache; + + fn num_ring_elems(&self) -> usize { + self.coeffs.len() + } + + fn evaluate_ring(&self, scalars: &[F]) -> CyclotomicRing { + #[cfg(feature = "parallel")] + { + self.coeffs + .par_iter() + .zip(scalars.par_iter()) + .fold( + || CyclotomicRing::::zero(), + |acc, (f_i, w_i)| acc + f_i.scale(w_i), + ) + .reduce(|| CyclotomicRing::::zero(), |a, b| a + b) + } + #[cfg(not(feature = "parallel"))] + { + self.coeffs + .iter() + .zip(scalars.iter()) + .fold(CyclotomicRing::::zero(), |acc, (f_i, w_i)| { + acc + f_i.scale(w_i) + }) + } + } + + fn fold_blocks(&self, scalars: &[F], block_len: usize) -> Vec> { + let n = self.coeffs.len(); + let num_blocks = n.div_ceil(block_len); + cfg_into_iter!(0..num_blocks) + .map(|i| { + let start = i * block_len; + let end = (start + block_len).min(n); + let block = &self.coeffs[start..end]; + let mut acc = CyclotomicRing::::zero(); + for (b_j, &a_j) in block.iter().zip(scalars.iter()) { + acc += b_j.scale(&a_j); + } + acc + }) + .collect() + } + + fn decompose_fold( + &self, + challenges: &[SparseChallenge], + block_len: usize, + num_digits: usize, + log_basis: u32, + ) -> Vec> { + let n = self.coeffs.len(); + let coeffs = &self.coeffs; + + let q = (-F::one()).to_canonical_u128() + 1; + let params = DecomposeParams { + half_q: q / 2, + q, + mask: (1i128 << log_basis) - 1, + half_b: 1i128 << (log_basis - 1), + b_val: 1i128 << log_basis, + log_basis, + }; + + // Two-phase approach: decompose ring element coefficients into i8 digit + // planes, then scatter via sparse polynomial multiply. + // + // Phase 1 (decompose): K=3 interleaved carry chains for ILP (~3x over + // single-chain). Writes into a digit_buf [num_digits][D] in i8 (~16 KB, + // L1-resident). + // + // Phase 2 (scatter): rotate-and-add formulation — contiguous NEON + // SADDW/SSUBW on aarch64, scalar fallback elsewhere. Accumulates into + // z_local [num_digits][D] in i32 (~66 KB, L2-resident). + let z_chunks: Vec>> = cfg_into_iter!(0..block_len) + .map(|elem_idx| { + let mut z_local: Vec<[i32; D]> = vec![[0i32; D]; num_digits]; + let mut digit_buf: Vec> = vec![vec![0i8; D]; num_digits]; + + for (block_idx, c_i) in challenges.iter().enumerate() { + let global_idx = block_idx * block_len + elem_idx; + if global_idx >= n { + continue; + } + let ring = &coeffs[global_idx]; + + let canonical: [u128; D] = from_fn(|k| ring.coeffs[k].to_canonical_u128()); + decompose_ring_interleaved::( + &canonical, + &mut digit_buf, + num_digits, + ¶ms, + ); + + for digit in 0..num_digits { + sparse_mul_acc::(&digit_buf[digit], c_i, &mut z_local[digit]); + } + } + + let q = params.q; + z_local + .into_iter() + .map(|arr| { + let field_coeffs: [F; D] = from_fn(|k| { + let v = arr[k]; + if v >= 0 { + F::from_canonical_u128_reduced(v as u128) + } else { + F::from_canonical_u128_reduced(q - ((-v) as u128)) + } + }); + CyclotomicRing::from_coefficients(field_coeffs) + }) + .collect() + }) + .collect(); + + let mut z = Vec::with_capacity(block_len * num_digits); + for chunk in z_chunks { + z.extend(chunk); + } + z + } + + #[tracing::instrument(skip_all, name = "DensePoly::commit_inner")] + fn commit_inner( + &self, + _a_matrix: &FlatMatrix, + ntt_a: &NttSlotCache, + block_len: usize, + num_digits_commit: usize, + num_digits_open: usize, + log_basis: u32, + ) -> Result>, HachiError> { + let n = self.coeffs.len(); + let num_blocks = n.div_ceil(block_len); + + let block_slices: Vec<&[CyclotomicRing]> = (0..num_blocks) + .map(|i| { + let start = i * block_len; + if start >= n { + &[] as &[CyclotomicRing] + } else { + &self.coeffs[start..(start + block_len).min(n)] + } + }) + .collect(); + + let t_all = mat_vec_mul_ntt_i8(ntt_a, &block_slices, num_digits_commit, log_basis); + + let results: Vec> = cfg_into_iter!(t_all) + .map(|t_i| decompose_rows_i8(&t_i, num_digits_open, log_basis)) + .collect(); + + Ok(results) + } +} + +/// Ring polynomial whose coefficients are already balanced base-`2^log_basis` +/// digits. +/// +/// This is the recursive `w` witness used by Hachi's later prove levels. Unlike +/// [`DensePoly`], it can skip the `i8 -> field -> dense ring` round-trip and +/// operate on the digit planes directly. +#[derive(Debug, Clone)] +pub(crate) struct BalancedDigitPoly<'a, F: FieldCore, const D: usize> { + coeffs: &'a [[i8; D]], + padded_ring_elems: usize, + _marker: PhantomData, +} + +impl<'a, F: FieldCore, const D: usize> BalancedDigitPoly<'a, F, D> { + /// Wrap a flat digit vector laid out as consecutive ring coefficients. + pub(crate) fn from_i8_digits(digits: &'a [i8]) -> Result { + let (coeffs, remainder) = digits.as_chunks::(); + if !remainder.is_empty() { + return Err(HachiError::InvalidSize { + expected: D, + actual: digits.len(), + }); + } + + Ok(Self { + coeffs, + padded_ring_elems: coeffs.len().next_power_of_two().max(1), + _marker: PhantomData, + }) + } + + #[inline] + fn block_slice(&self, block_idx: usize, block_len: usize) -> &'a [[i8; D]] { + let start = block_idx * block_len; + if start >= self.coeffs.len() { + &[] + } else { + &self.coeffs[start..(start + block_len).min(self.coeffs.len())] + } + } +} + +impl<'a, F, const D: usize> HachiPolyOps for BalancedDigitPoly<'a, F, D> +where + F: FieldCore + CanonicalField, +{ + type CommitCache = NttSlotCache; + + fn num_ring_elems(&self) -> usize { + self.padded_ring_elems + } + + fn evaluate_ring(&self, scalars: &[F]) -> CyclotomicRing { + let total = cfg_fold_reduce!( + 0..self.coeffs.len().min(scalars.len()), + || [F::zero(); D], + |mut acc: [F; D], idx| { + let scalar = scalars[idx]; + let digit = &self.coeffs[idx]; + for (coeff, &d) in acc.iter_mut().zip(digit.iter()) { + if d != 0 { + *coeff += scalar * F::from_i8(d); + } + } + acc + }, + |mut a: [F; D], b: [F; D]| { + for (a_coeff, b_coeff) in a.iter_mut().zip(b.iter()) { + *a_coeff += *b_coeff; + } + a + } + ); + CyclotomicRing::from_coefficients(total) + } + + fn fold_blocks(&self, scalars: &[F], block_len: usize) -> Vec> { + let num_blocks = self.num_ring_elems().div_ceil(block_len); + cfg_into_iter!(0..num_blocks) + .map(|block_idx| { + let mut acc = [F::zero(); D]; + for (ring, &scalar) in self + .block_slice(block_idx, block_len) + .iter() + .zip(scalars.iter()) + { + for (coeff, &d) in acc.iter_mut().zip(ring.iter()) { + if d != 0 { + *coeff += scalar * F::from_i8(d); + } + } + } + CyclotomicRing::from_coefficients(acc) + }) + .collect() + } + + fn decompose_fold( + &self, + challenges: &[SparseChallenge], + block_len: usize, + num_digits: usize, + _log_basis: u32, + ) -> Vec> { + let inner_width = block_len * num_digits; + let num_blocks = self.num_ring_elems().div_ceil(block_len); + + let q = (-F::one()).to_canonical_u128() + 1; + cfg_fold_reduce!( + 0..challenges.len().min(num_blocks), + || vec![[0i32; D]; inner_width], + |mut z_local: Vec<[i32; D]>, block_idx| { + let challenge = &challenges[block_idx]; + for (elem_idx, digit_plane) in + self.block_slice(block_idx, block_len).iter().enumerate() + { + sparse_mul_acc::( + digit_plane, + challenge, + &mut z_local[elem_idx * num_digits], + ); + } + z_local + }, + |mut a: Vec<[i32; D]>, b: Vec<[i32; D]>| { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + for (a_coeff, b_coeff) in ai.iter_mut().zip(bi.iter()) { + *a_coeff += *b_coeff; + } + } + a + } + ) + .into_iter() + .map(|arr| { + let coeffs = from_fn(|k| { + let v = arr[k]; + if v >= 0 { + F::from_canonical_u128_reduced(v as u128) + } else { + F::from_canonical_u128_reduced(q - ((-v) as u128)) + } + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + } + + fn commit_inner( + &self, + _a_matrix: &FlatMatrix, + ntt_a: &NttSlotCache, + block_len: usize, + num_digits_commit: usize, + num_digits_open: usize, + log_basis: u32, + ) -> Result>, HachiError> { + let num_blocks = self.num_ring_elems().div_ceil(block_len); + let coeff_len = self.coeffs.len(); + + let t_all = if num_digits_commit == 1 { + let block_slices: Vec<&[[i8; D]]> = (0..num_blocks) + .map(|block_idx| self.block_slice(block_idx, block_len)) + .collect(); + mat_vec_mul_ntt_digits_i8(ntt_a, &block_slices) + } else { + let ring_elems: Vec> = self + .coeffs + .iter() + .map(|digit| { + let coeffs = from_fn(|k| F::from_i8(digit[k])); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let block_slices: Vec<&[CyclotomicRing]> = (0..num_blocks) + .map(|block_idx| { + let start = block_idx * block_len; + if start >= coeff_len { + &[] as &[CyclotomicRing] + } else { + &ring_elems[start..(start + block_len).min(coeff_len)] + } + }) + .collect(); + mat_vec_mul_ntt_i8(ntt_a, &block_slices, num_digits_commit, log_basis) + }; + + let results = cfg_into_iter!(t_all) + .map(|t_i| decompose_rows_i8(&t_i, num_digits_open, log_basis)) + .collect(); + Ok(results) + } +} + +/// Types usable as one-hot position indices. +/// +/// Implemented for `u8`, `u16`, `u32`, and `usize`. +pub trait OneHotIndex: Copy + Send + Sync + std::fmt::Debug + 'static { + /// Convert to `usize` for indexing. + fn as_usize(self) -> usize; +} + +impl OneHotIndex for u8 { + #[inline] + fn as_usize(self) -> usize { + self as usize + } +} + +impl OneHotIndex for u16 { + #[inline] + fn as_usize(self) -> usize { + self as usize + } +} + +impl OneHotIndex for u32 { + #[inline] + fn as_usize(self) -> usize { + self as usize + } +} + +impl OneHotIndex for usize { + #[inline] + fn as_usize(self) -> usize { + self + } +} + +/// One-hot polynomial: sparse witness with at most one nonzero field element +/// per chunk of size `onehot_k`. +/// +/// Exploits sparsity in all four operations, avoiding inner ring +/// multiplications during commit and decomposing only nonzero monomials. +/// +/// Generic over `I`: the index type stored per chunk. Use `u8` when +/// `onehot_k <= 256` to cut per-entry memory from 16 bytes to 2 bytes. +#[derive(Debug, Clone)] +pub struct OneHotPoly { + onehot_k: usize, + indices: Vec>, + m_vars: usize, + sparse_blocks: Vec>, + _marker: PhantomData, +} + +impl OneHotPoly { + /// Build a one-hot polynomial from chunk size and hot-position indices. + /// + /// `indices[c]` is the hot position in chunk `c` (`None` for all-zero chunks). + /// + /// # Errors + /// + /// Returns an error if dimensions are inconsistent or any index is out of range. + pub fn new( + onehot_k: usize, + indices: Vec>, + r_vars: usize, + m_vars: usize, + ) -> Result { + let sparse_blocks = map_onehot_to_sparse_blocks(onehot_k, &indices, r_vars, m_vars, D)?; + Ok(Self { + onehot_k, + indices, + m_vars, + sparse_blocks, + _marker: PhantomData, + }) + } + + fn total_ring_elems(&self) -> usize { + let total_field = self.indices.len() * self.onehot_k; + total_field / D + } +} + +impl HachiPolyOps for OneHotPoly +where + F: FieldCore + CanonicalField + HasWide, +{ + type CommitCache = NttSlotCache; + + fn num_ring_elems(&self) -> usize { + self.total_ring_elems() + } + + fn evaluate_ring(&self, scalars: &[F]) -> CyclotomicRing { + let block_len = 1usize << self.m_vars; + cfg_fold_reduce!( + 0..self.sparse_blocks.len(), + || CyclotomicRing::::zero(), + |mut acc: CyclotomicRing, block_idx: usize| { + let block_offset = block_idx * block_len; + for entry in &self.sparse_blocks[block_idx] { + let ring_idx = block_offset + entry.pos_in_block; + if ring_idx < scalars.len() { + let s = scalars[ring_idx]; + for &ci in &entry.nonzero_coeffs { + acc.coeffs[ci] += s; + } + } + } + acc + }, + |a, b| a + b + ) + } + + fn fold_blocks(&self, scalars: &[F], block_len: usize) -> Vec> { + cfg_iter!(self.sparse_blocks) + .map(|entries| { + let mut coeffs_acc = [F::zero(); D]; + for entry in entries { + if entry.pos_in_block < scalars.len() && entry.pos_in_block < block_len { + let s = scalars[entry.pos_in_block]; + for &ci in &entry.nonzero_coeffs { + coeffs_acc[ci] += s; + } + } + } + CyclotomicRing::from_coefficients(coeffs_acc) + }) + .collect() + } + + fn decompose_fold( + &self, + challenges: &[SparseChallenge], + block_len: usize, + num_digits: usize, + _log_basis: u32, + ) -> Vec> { + let inner_width = block_len * num_digits; + let num_blocks = self.sparse_blocks.len(); + + // One-hot coefficients are {0,1}: balanced_decompose_pow2 produces + // nonzero output only in digit plane 0 (the value itself). + // + // Direct sparse-sparse multiply: for each nonzero coefficient at + // position `ci` and each challenge term `(pos, coeff)`, we know the + // exact output position in the cyclotomic ring (X^D + 1). + // O(omega * |nonzero_coeffs|) per entry instead of O(omega * D). + cfg_fold_reduce!( + 0..challenges.len().min(num_blocks), + || vec![CyclotomicRing::::zero(); inner_width], + |mut z: Vec>, i: usize| { + let c_i = &challenges[i]; + for entry in &self.sparse_blocks[i] { + let j = entry.pos_in_block * num_digits; + let z_coeffs = &mut z[j].coeffs; + for (&pos, &coeff) in c_i.positions.iter().zip(c_i.coeffs.iter()) { + let c_val = F::from_i64(coeff as i64); + for &ci in &entry.nonzero_coeffs { + let target = ci + pos as usize; + if target < D { + z_coeffs[target] += c_val; + } else { + z_coeffs[target - D] -= c_val; + } + } + } + } + z + }, + |mut a: Vec>, b: Vec>| { + for (ai, bi) in a.iter_mut().zip(b.into_iter()) { + *ai += bi; + } + a + } + ) + } + + #[tracing::instrument(skip_all, name = "OneHotPoly::commit_inner")] + fn commit_inner( + &self, + a_matrix: &FlatMatrix, + _ntt_a: &NttSlotCache, + block_len: usize, + num_digits_commit: usize, + num_digits_open: usize, + log_basis: u32, + ) -> Result>, HachiError> { + let a_view = a_matrix.view::(); + let n_a = a_view.num_rows(); + let zero_block_len = n_a.checked_mul(num_digits_open).unwrap(); + + let t_hat_all: Vec> = cfg_iter!(self.sparse_blocks) + .map(|block_entries| { + if block_entries.is_empty() { + vec![[0i8; D]; zero_block_len] + } else { + let t_i = inner_ajtai_onehot_wide( + &a_view, + block_entries, + block_len, + num_digits_commit, + ); + decompose_rows_i8(&t_i, num_digits_open, log_basis) + } + }) + .collect(); + + Ok(t_hat_all) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::commitment::{ + CommitmentConfig, HachiCommitmentCore, RingCommitmentScheme, + }; + use crate::protocol::ring_switch::w_commitment_layout; + use crate::test_utils::{TinyConfig, D as TestD, F as TestF}; + use crate::FromSmallInt; + + #[test] + fn dense_poly_from_field_evals_roundtrip() { + let num_vars = 10; + let len = 1usize << num_vars; + let evals: Vec = (0..len).map(|i| TestF::from_u64(i as u64)).collect(); + let poly = DensePoly::::from_field_evals(num_vars, &evals).unwrap(); + assert_eq!(poly.num_ring_elems(), len / TestD); + } + + #[test] + fn dense_commit_inner_matches_ring_commit() { + let (setup, _) = + >::setup(16) + .unwrap(); + let layout = setup.layout(); + let num_ring = layout.num_blocks * layout.block_len; + let evals: Vec = (0..num_ring * TestD) + .map(|i| TestF::from_u64(i as u64)) + .collect(); + + let alpha = TestD.trailing_zeros() as usize; + let num_vars = alpha + layout.m_vars + layout.r_vars; + let poly = DensePoly::::from_field_evals(num_vars, &evals).unwrap(); + + let t_hat_poly = poly + .commit_inner( + &setup.expanded.A, + &setup.ntt_A, + layout.block_len, + layout.num_digits_commit, + layout.num_digits_open, + layout.log_basis, + ) + .unwrap(); + + let w = + >::commit_coeffs( + &poly.coeffs, + &setup, + ) + .unwrap(); + + assert_eq!(t_hat_poly, w.t_hat); + } + + #[test] + fn onehot_commit_inner_matches_ring_commit_onehot() { + let (setup, _) = + >::setup(16) + .unwrap(); + let layout = setup.layout(); + let total_ring = layout.num_blocks * layout.block_len; + let onehot_k = TestD; + let num_chunks = total_ring; + let indices: Vec> = (0..num_chunks).map(|i| Some(i % onehot_k)).collect(); + + let poly = OneHotPoly::::new( + onehot_k, + indices.clone(), + layout.r_vars, + layout.m_vars, + ) + .unwrap(); + + let t_hat_poly = poly + .commit_inner( + &setup.expanded.A, + &setup.ntt_A, + layout.block_len, + layout.num_digits_commit, + layout.num_digits_open, + layout.log_basis, + ) + .unwrap(); + + let w = + >::commit_onehot( + onehot_k, &indices, &setup, + ) + .unwrap(); + + assert_eq!(t_hat_poly, w.t_hat); + } + + #[test] + fn balanced_digit_poly_matches_dense_recursive_w_ops() { + let log_basis = TinyConfig::decomposition().log_basis; + let digits: Vec = (0..(3 * TestD)).map(|i| (i % 7) as i8 - 3).collect(); + let field_evals: Vec = digits.iter().map(|&d| TestF::from_i64(d as i64)).collect(); + let total_coeffs = digits.len().next_power_of_two().max(TestD); + let mut padded = field_evals.clone(); + padded.resize(total_coeffs, TestF::zero()); + + let dense = DensePoly::::from_field_evals( + total_coeffs.trailing_zeros() as usize, + &padded, + ) + .unwrap(); + let digit_poly = BalancedDigitPoly::::from_i8_digits(&digits).unwrap(); + + assert_eq!(digit_poly.num_ring_elems(), dense.num_ring_elems()); + + let eval_scalars: Vec = (0..digit_poly.num_ring_elems()) + .map(|i| TestF::from_u64((i + 2) as u64)) + .collect(); + assert_eq!( + digit_poly.evaluate_ring(&eval_scalars), + dense.evaluate_ring(&eval_scalars) + ); + + let block_len = 2; + let fold_scalars: Vec = (0..block_len) + .map(|i| TestF::from_u64((i + 5) as u64)) + .collect(); + assert_eq!( + digit_poly.fold_blocks(&fold_scalars, block_len), + dense.fold_blocks(&fold_scalars, block_len) + ); + + let num_blocks = digit_poly.num_ring_elems().div_ceil(block_len); + let challenges: Vec = (0..num_blocks) + .map(|i| SparseChallenge { + positions: vec![0u32, ((i + 3) % TestD) as u32], + coeffs: vec![1, -1], + }) + .collect(); + assert_eq!( + digit_poly.decompose_fold(&challenges, block_len, 1, log_basis), + dense.decompose_fold(&challenges, block_len, 1, log_basis) + ); + + let (setup, _) = + >::setup(16) + .unwrap(); + let w_layout = w_commitment_layout::(setup.layout()).unwrap(); + let digit_commit = digit_poly + .commit_inner( + &setup.expanded.A, + &setup.ntt_A, + w_layout.block_len, + w_layout.num_digits_commit, + w_layout.num_digits_open, + w_layout.log_basis, + ) + .unwrap(); + let dense_commit = dense + .commit_inner( + &setup.expanded.A, + &setup.ntt_A, + w_layout.block_len, + w_layout.num_digits_commit, + w_layout.num_digits_open, + w_layout.log_basis, + ) + .unwrap(); + + assert_eq!(digit_commit, dense_commit); + } +} diff --git a/src/protocol/labrador/challenge.rs b/src/protocol/labrador/challenge.rs new file mode 100644 index 00000000..efdedecf --- /dev/null +++ b/src/protocol/labrador/challenge.rs @@ -0,0 +1,231 @@ +//! Labrador challenge sampler (C-parity oriented). +//! +//! This ports the `polyvec_challenge` rejection sampler from the C reference. + +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +use crate::protocol::labrador::guardrails::{ + checked_add, checked_mul, ensure_power_of_two, ensure_temp_allocation_limit, + LABRADOR_MAX_CHALLENGE_POLYS, +}; +use crate::{CanonicalField, FieldCore, FromSmallInt}; +use sha3::digest::{ExtendableOutput, Update, XofReader}; +use sha3::Shake128; + +/// Number of `±1` coefficients in a challenge polynomial. +pub const LABRADOR_TAU1: usize = 32; +/// Number of `±2` coefficients in a challenge polynomial. +pub const LABRADOR_TAU2: usize = 8; +/// Operator norm bound used by C's challenge rejection sampler. +pub const LABRADOR_CHALLENGE_OPNORM_BOUND: f64 = 14.0; + +const SHAKE128_RATE: usize = 168; + +/// Sample Labrador challenge polynomials as signed coefficient arrays. +/// +/// The output follows C `polyvec_challenge`: each polynomial has exactly +/// `LABRADOR_TAU1` coefficients in `{±1}`, `LABRADOR_TAU2` coefficients in +/// `{±2}`, all other coefficients 0, and must satisfy operator-norm bound. +/// +/// # Errors +/// +/// Returns an error if ring parameters are incompatible with the C algorithm. +pub fn sample_labrador_challenge_coeffs( + len: usize, + seed: &[u8; 16], + nonce: u64, +) -> Result, HachiError> { + validate_challenge_params::()?; + if len > LABRADOR_MAX_CHALLENGE_POLYS { + return Err(HachiError::InvalidInput(format!( + "requested too many challenge polynomials: {len} (max {LABRADOR_MAX_CHALLENGE_POLYS})" + ))); + } + + let mut xof = Shake128::default(); + xof.update(seed); + xof.update(&nonce.to_le_bytes()); + let mut reader = xof.finalize_xof(); + + let mut out = Vec::with_capacity(len); + let mut remaining = len; + + while remaining >= 10 { + let bytes = checked_mul(17, SHAKE128_RATE, "challenge block bytes")?; + ensure_temp_allocation_limit(bytes, "challenge sampler")?; + let mut buf = vec![0u8; bytes]; + reader.read(&mut buf); + let produced = consume_challenge_buffer::(&mut out, 10, &buf); + remaining -= produced; + } + + while remaining > 0 { + let scaled = checked_mul(remaining, 17, "scaled tail blocks numerator")?; + let scaled = checked_add(scaled, 9, "tail blocks numerator rounding")?; + let blocks = scaled / 10; + let bytes = checked_mul(blocks, SHAKE128_RATE, "tail block bytes")?; + ensure_temp_allocation_limit(bytes, "challenge sampler tail")?; + let mut buf = vec![0u8; bytes]; + reader.read(&mut buf); + let produced = consume_challenge_buffer::(&mut out, remaining, &buf); + remaining -= produced; + } + + Ok(out) +} + +/// Sample Labrador challenge polynomials as dense ring elements. +/// +/// # Errors +/// +/// Returns an error if parameter checks fail. +pub fn sample_labrador_challenges( + len: usize, + seed: &[u8; 16], + nonce: u64, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, +{ + let coeffs = sample_labrador_challenge_coeffs::(len, seed, nonce)?; + Ok(coeffs + .into_iter() + .map(|poly| { + CyclotomicRing::from_coefficients(std::array::from_fn(|i| F::from_i64(poly[i] as i64))) + }) + .collect()) +} + +fn validate_challenge_params() -> Result<(), HachiError> { + ensure_power_of_two(D, "challenge sampler degree D")?; + if D > 256 { + return Err(HachiError::InvalidInput(format!( + "challenge sampler expects D <= 256, got {D}" + ))); + } + if LABRADOR_TAU1 + LABRADOR_TAU2 > D { + return Err(HachiError::InvalidInput(format!( + "tau1 + tau2 exceeds ring degree: {LABRADOR_TAU1} + {LABRADOR_TAU2} > {D}" + ))); + } + Ok(()) +} + +fn consume_challenge_buffer( + out: &mut Vec<[i16; D]>, + target_len: usize, + buf: &[u8], +) -> usize { + let sign_bytes = (LABRADOR_TAU1 + LABRADOR_TAU2).div_ceil(8); + let min_bytes = LABRADOR_TAU1 + LABRADOR_TAU2 + sign_bytes; + let mut produced = 0usize; + let mut cursor = 0usize; + + while produced < target_len && cursor <= buf.len().saturating_sub(min_bytes) { + let mut signs = 0u64; + for k in 0..sign_bytes { + signs |= (buf[cursor] as u64) << (8 * k); + cursor += 1; + } + + let mut poly = [0i16; D]; + let mut k = D - LABRADOR_TAU1 - LABRADOR_TAU2; + while k < D && cursor < buf.len() { + let b = (buf[cursor] as usize) & (D - 1); + cursor += 1; + if b <= k { + poly[k] = poly[b]; + let mut value = if k < D - LABRADOR_TAU2 { 1 } else { 2 }; + if (signs & 1) == 1 { + value = -value; + } + poly[b] = value; + signs >>= 1; + k += 1; + } + } + + if k == D && challenge_operator_norm::(&poly) <= LABRADOR_CHALLENGE_OPNORM_BOUND { + out.push(poly); + produced += 1; + } + } + + produced +} + +fn challenge_operator_norm(coeffs: &[i16; D]) -> f64 { + let mut max_norm = 0.0f64; + let d_f = D as f64; + for i in 0..D { + let theta = ((2 * i + 1) as f64) * std::f64::consts::PI / d_f; + let mut re = 0.0f64; + let mut im = 0.0f64; + for (j, &coeff) in coeffs.iter().enumerate() { + let angle = theta * (j as f64); + let c = coeff as f64; + re += c * angle.cos(); + im += c * angle.sin(); + } + let norm = (re * re + im * im).sqrt(); + if norm > max_norm { + max_norm = norm; + } + } + max_norm +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp32; + + type F = Fp32<4294967197>; + const D: usize = 64; + + // Fixed test seeds and nonces for deterministic replay. + const TEST_SEED_A: [u8; 16] = [7u8; 16]; + const TEST_SEED_B: [u8; 16] = [11u8; 16]; + const TEST_SEED_C: [u8; 16] = [5u8; 16]; + const TEST_NONCE_A: u64 = 9; + const TEST_NONCE_B: u64 = 17; + const TEST_NONCE_C: u64 = 4; + const TEST_NONCE_REF: u64 = 7; + + #[test] + fn challenge_sampler_is_deterministic() { + let c1 = sample_labrador_challenge_coeffs::(3, &TEST_SEED_A, TEST_NONCE_A).unwrap(); + let c2 = sample_labrador_challenge_coeffs::(3, &TEST_SEED_A, TEST_NONCE_A).unwrap(); + assert_eq!(c1, c2); + } + + #[test] + fn challenge_sampler_obeys_operator_norm_bound() { + let samples = sample_labrador_challenge_coeffs::(8, &TEST_SEED_B, TEST_NONCE_B).unwrap(); + assert_eq!(samples.len(), 8); + for poly in &samples { + assert!(challenge_operator_norm(poly) <= LABRADOR_CHALLENGE_OPNORM_BOUND); + } + } + + #[test] + fn challenge_sampler_supports_dense_ring_conversion() { + let dense = sample_labrador_challenges::(2, &TEST_SEED_C, TEST_NONCE_C).unwrap(); + assert_eq!(dense.len(), 2); + } + + #[test] + fn challenge_sampler_matches_transliterated_reference_vector() { + // Captured from the C-reference algorithm semantics (`polyvec_challenge`) + // for seed = [0,1,2,...,15], nonce = 7, len = 1. + let seed: [u8; 16] = std::array::from_fn(|i| i as u8); + let coeffs = sample_labrador_challenge_coeffs::(1, &seed, TEST_NONCE_REF).unwrap(); + let got = coeffs[0]; + let expected: [i16; D] = [ + 1, 1, 0, 1, 0, 0, 2, -1, 0, 0, 2, 1, 1, -1, -1, 1, -2, 0, 1, 0, -1, -1, 1, 0, 1, -1, 1, + 1, 0, -1, 0, -1, 2, 1, 1, -1, -2, 0, 0, 1, 0, 0, 1, 1, -2, 1, 0, 0, 0, 0, 0, 0, 1, 0, + -1, -1, 2, -1, 0, 1, -2, 1, 0, 0, + ]; + assert_eq!(got, expected); + } +} diff --git a/src/protocol/labrador/comkey.rs b/src/protocol/labrador/comkey.rs new file mode 100644 index 00000000..e7249e68 --- /dev/null +++ b/src/protocol/labrador/comkey.rs @@ -0,0 +1,78 @@ +//! Prefix-stable extendable commitment-key derivation for Labrador. +//! +//! Unlike setup matrices that bind full `(rows, cols)` shape, this derivation +//! binds only `(matrix_label, row, col)` so extending dimensions preserves the +//! previously derived prefix exactly. + +use crate::algebra::ring::CyclotomicRing; +use crate::protocol::prg::{MatrixPrgBackendChoice, MatrixPrgContext}; +use crate::{FieldCore, FieldSampling}; + +/// Public seed used to derive extendable Labrador commitment keys. +pub type LabradorComKeySeed = [u8; 32]; + +/// Derive a prefix-stable matrix for Labrador commitment keys. +/// +/// Prefix-stable means: if `M_small = derive(rows, cols)` and +/// `M_large = derive(rows2, cols2)` with `rows2 >= rows`, `cols2 >= cols`, +/// then `M_large[r][c] == M_small[r][c]` for all `r < rows`, `c < cols`. +pub fn derive_extendable_comkey_matrix( + rows: usize, + cols: usize, + seed: &LabradorComKeySeed, + matrix_label: &[u8], + backend: MatrixPrgBackendChoice, +) -> Vec>> { + (0..rows) + .map(|r| { + (0..cols) + .map(|c| { + // Dedicated key path: keep shape fields constant, bind only + // entry indices and matrix label. + let context = MatrixPrgContext { + seed, + matrix_label, + rows: 0, + cols: 0, + row: r, + col: c, + }; + let mut rng = backend.entry_rng(&context); + CyclotomicRing::random(&mut rng) + }) + .collect() + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Fp64; + + type F = Fp64<4294967197>; + const D: usize = 64; + + #[test] + fn extendable_derivation_has_prefix_stability() { + let seed = [19u8; 32]; + let backend = MatrixPrgBackendChoice::Shake256; + let small = derive_extendable_comkey_matrix::(3, 4, &seed, b"comkey/A", backend); + let large = derive_extendable_comkey_matrix::(5, 7, &seed, b"comkey/A", backend); + + for r in 0..3 { + for c in 0..4 { + assert_eq!(small[r][c], large[r][c]); + } + } + } + + #[test] + fn extendable_derivation_domain_separates_labels() { + let seed = [7u8; 32]; + let backend = MatrixPrgBackendChoice::Aes128Ctr; + let a = derive_extendable_comkey_matrix::(2, 3, &seed, b"comkey/A", backend); + let b = derive_extendable_comkey_matrix::(2, 3, &seed, b"comkey/B", backend); + assert_ne!(a, b); + } +} diff --git a/src/protocol/labrador/commit.rs b/src/protocol/labrador/commit.rs new file mode 100644 index 00000000..e3b05469 --- /dev/null +++ b/src/protocol/labrador/commit.rs @@ -0,0 +1,183 @@ +//! Two-tier Ajtai commitment helpers for Labrador (linear-only mode). + +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +use crate::protocol::commitment::utils::linear::decompose_rows_with_carry; +use crate::protocol::labrador::comkey::{derive_extendable_comkey_matrix, LabradorComKeySeed}; +use crate::protocol::labrador::types::{LabradorReductionConfig, LabradorWitness}; +use crate::protocol::labrador::utils::mat_vec_mul; +use crate::protocol::prg::MatrixPrgBackendChoice; +use crate::{CanonicalField, FieldCore, FieldSampling}; + +/// Commitment artifacts needed by downstream Labrador/Greyhound flows. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorCommitmentArtifacts { + /// Per-row inner commitments. + pub u_inner: Vec>>, + /// First outer commitment (`u1`). + pub u1: Vec>, + /// Second outer commitment (`u2`) from linear garbage terms. + pub u2: Vec>, + /// Decomposed witness rows. + pub decomposed_witness: Vec>>, + /// Decomposed inner commitments. + pub decomposed_inner: Vec>>, + /// Linear garbage terms `h_{ij}` (always present in linear-only mode). + pub linear_garbage: Vec>, +} + +/// Commit witness rows in linear-only Labrador mode. +/// +/// # Errors +/// +/// Returns an error if dimensions/config are invalid. +pub fn commit_linear_only( + witness: &LabradorWitness, + config: &LabradorReductionConfig, + comkey_seed: &LabradorComKeySeed, + backend: MatrixPrgBackendChoice, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, +{ + if witness.rows().is_empty() { + return Err(HachiError::InvalidInput( + "cannot commit empty Labrador witness".to_string(), + )); + } + if config.fu == 0 || config.bu == 0 || config.kappa == 0 { + return Err(HachiError::InvalidInput( + "invalid Labrador commitment config".to_string(), + )); + } + + let mut decomposed_witness = Vec::with_capacity(witness.rows().len()); + let mut u_inner = Vec::with_capacity(witness.rows().len()); + let mut decomposed_inner = Vec::with_capacity(witness.rows().len()); + + for (row_idx, row) in witness.rows().iter().enumerate() { + let a = derive_extendable_comkey_matrix::( + config.kappa, + row.len(), + comkey_seed, + b"labrador/comkey/A", + backend, + ); + let t = mat_vec_mul(&a, row); + if t.is_empty() { + return Err(HachiError::InvalidInput(format!( + "inner commitment row {row_idx} produced empty vector" + ))); + } + let t_hat = decompose_rows_with_carry(&t, config.fu, config.bu as u32); + let s_hat = decompose_rows_with_carry(row, config.f, config.b as u32); + decomposed_witness.push(s_hat); + decomposed_inner.push(t_hat); + u_inner.push(t); + } + + let mut t_hat_flat = Vec::new(); + for t_hat in &decomposed_inner { + t_hat_flat.extend(t_hat.iter().copied()); + } + + let u1 = if config.tail || config.kappa1 == 0 { + u_inner.iter().flat_map(|v| v.iter().copied()).collect() + } else { + let b = derive_extendable_comkey_matrix::( + config.kappa1, + t_hat_flat.len(), + comkey_seed, + b"labrador/comkey/B", + backend, + ); + mat_vec_mul(&b, &t_hat_flat) + }; + + let linear_garbage = build_linear_garbage(witness); + let u2 = if config.tail || config.kappa1 == 0 { + linear_garbage.clone() + } else { + let b2 = derive_extendable_comkey_matrix::( + config.kappa1, + linear_garbage.len(), + comkey_seed, + b"labrador/comkey/U2", + backend, + ); + mat_vec_mul(&b2, &linear_garbage) + }; + + Ok(LabradorCommitmentArtifacts { + u_inner, + u1, + u2, + decomposed_witness, + decomposed_inner, + linear_garbage, + }) +} + +fn build_linear_garbage( + witness: &LabradorWitness, +) -> Vec> { + let mut out = Vec::with_capacity( + (witness.rows().len() * witness.rows().len() + witness.rows().len()) / 2, + ); + for i in 0..witness.rows().len() { + for j in i..witness.rows().len() { + let len = witness.rows()[i].len().min(witness.rows()[j].len()); + let mut acc = CyclotomicRing::::zero(); + for k in 0..len { + acc += witness.rows()[i][k] * witness.rows()[j][k]; + } + out.push(acc); + } + } + out +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::protocol::labrador::types::LabradorReductionConfig; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + fn sample_witness() -> LabradorWitness { + let row = |len: usize| -> Vec> { + (0..len) + .map(|i| { + CyclotomicRing::from_coefficients(std::array::from_fn(|j| { + F::from_i64(((i + j) as i64 % 9) - 4) + })) + }) + .collect() + }; + LabradorWitness::new(vec![row(4), row(4), row(4)]) + } + + #[test] + fn commit_linear_only_is_deterministic() { + let witness = sample_witness(); + let cfg = LabradorReductionConfig { + f: 1, + b: 8, + fu: 2, + bu: 10, + kappa: 3, + kappa1: 2, + tail: false, + }; + let seed = [3u8; 32]; + let a = + commit_linear_only(&witness, &cfg, &seed, MatrixPrgBackendChoice::Shake256).unwrap(); + let b = + commit_linear_only(&witness, &cfg, &seed, MatrixPrgBackendChoice::Shake256).unwrap(); + assert_eq!(a, b); + assert!(!a.u2.is_empty(), "linear garbage commitment u2 must exist"); + } +} diff --git a/src/protocol/labrador/config.rs b/src/protocol/labrador/config.rs new file mode 100644 index 00000000..1dacf4c2 --- /dev/null +++ b/src/protocol/labrador/config.rs @@ -0,0 +1,164 @@ +//! Labrador parameter-selection and security checks. + +use crate::error::HachiError; +use crate::protocol::labrador::types::{LabradorReductionConfig, LabradorWitness}; +use crate::{CanonicalField, FieldCore}; + +pub(crate) const LABRADOR_LOGQ_BITS: usize = 32; +pub(crate) const JL_LIFTS: usize = 128_usize.div_ceil(LABRADOR_LOGQ_BITS); + +const LABRADOR_LOGQ: f64 = 32.0; +const LABRADOR_N: f64 = 64.0; +const LABRADOR_LOGDELTA: f64 = 0.00639138757765197; // log2(1.00444) +const LABRADOR_T: f64 = 14.0; +const LABRADOR_SLACK: f64 = 2.0; +const LABRADOR_TAU1: f64 = 32.0; +const LABRADOR_TAU2: f64 = 8.0; + +/// Module-SIS security check used by the C reference. +/// +/// Returns `true` when `log2(norm) < min(LOGQ, 2*sqrt(LOGQ*LOGDELTA*N)*sqrt(rank))`. +pub fn sis_secure(rank: usize, norm: f64) -> bool { + if rank == 0 || !norm.is_finite() || norm <= 0.0 { + return false; + } + let mut maxlog = + 2.0 * (LABRADOR_LOGQ * LABRADOR_LOGDELTA * LABRADOR_N).sqrt() * (rank as f64).sqrt(); + maxlog = maxlog.min(LABRADOR_LOGQ); + norm.log2() < maxlog +} + +/// Select a linear-only Labrador reduction config. +/// +/// This is a simplified, Rust-native port of the C `init_proof` parameter +/// selection path with `quadratic=0` and non-tail mode. +/// +/// # Errors +/// +/// Returns an error if witness metadata is empty/invalid or if no secure +/// commitment ranks are found within supported bounds. +pub fn select_config( + witness: &LabradorWitness, +) -> Result { + if witness.rows().is_empty() { + return Err(HachiError::InvalidInput( + "cannot select config for empty Labrador witness".to_string(), + )); + } + + let row_count = witness.rows().len() as f64; + let total_len: usize = witness.rows().iter().map(|r| r.len()).sum(); + if total_len == 0 { + return Err(HachiError::InvalidInput( + "cannot select config for zero-length Labrador witness".to_string(), + )); + } + let nn = (total_len as f64) / row_count; + let norm_sum: f64 = witness.norm() as f64; + let mut varz = norm_sum / ((total_len as f64) * (D as f64)); + varz *= LABRADOR_TAU1 + 4.0 * LABRADOR_TAU2; + varz = varz.max(1.0); + + let decompose = !sis_secure( + 13, + 6.0 * LABRADOR_T + * LABRADOR_SLACK + * (2.0 * (LABRADOR_TAU1 + 4.0 * LABRADOR_TAU2) * varz * nn * (D as f64)).sqrt(), + ) || 64.0 * varz > (1u64 << 28) as f64; + + let f = if decompose { 2usize } else { 1usize }; + let mut b = if decompose { + ((12.0f64.log2() + varz.log2()) / 4.0).round() as isize + } else { + ((12.0f64.log2() + varz.log2()) / 2.0).round() as isize + }; + b = b.clamp(1, LABRADOR_LOGQ as isize); + + let fu = (((LABRADOR_LOGQ as usize) + 2 * (b as usize) / 3) / (b as usize)).max(1); + let bu = (((LABRADOR_LOGQ as usize) + fu / 2) / fu).max(1); + + let rr = witness.rows().len() as f64; + let mut selected: Option<(usize, usize)> = None; + + for kappa in 1..=32usize { + let mut normsq = ((2f64.powi(2 * b as i32) / 12.0) * ((f - 1) as f64) + + varz / 2f64.powi((2 * (f - 1) as isize * b) as i32)) + * nn; + normsq += ((2f64.powi(2 * bu as i32) * ((fu - 1) as f64) + + 2f64.powi((2 * ((LABRADOR_LOGQ as usize) - (fu - 1) * bu)) as i32)) + / 12.0) + * (rr * (kappa as f64) + (rr * rr + rr) / 2.0); + normsq *= D as f64; + + let inner_ok = sis_secure( + kappa, + 6.0 * LABRADOR_T + * LABRADOR_SLACK + * 2f64.powi(((f - 1) * (b as usize)) as i32) + * normsq.sqrt(), + ); + if !inner_ok { + continue; + } + + let kappa1 = (1..=32usize).find(|&k1| sis_secure(k1, 2.0 * LABRADOR_SLACK * normsq.sqrt())); + if let Some(k1) = kappa1 { + selected = Some((kappa, k1)); + break; + } + } + + let (kappa, kappa1) = selected.ok_or_else(|| { + HachiError::InvalidInput("failed to find secure Labrador commitment ranks".to_string()) + })?; + + Ok(LabradorReductionConfig { + f, + b: b as usize, + fu, + bu, + kappa, + kappa1, + tail: false, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::algebra::ring::CyclotomicRing; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + fn row(len: usize) -> Vec> { + (0..len) + .map(|i| { + CyclotomicRing::from_coefficients(std::array::from_fn(|j| { + F::from_i64(((i + j) as i64 % 5) - 2) + })) + }) + .collect() + } + + #[test] + fn sis_secure_rejects_non_positive_norm() { + assert!(!sis_secure(4, 0.0)); + assert!(!sis_secure(4, -1.0)); + } + + #[test] + fn select_config_returns_valid_ranges() { + let witness = LabradorWitness::new(vec![row(32), row(32), row(32)]); + let cfg = select_config::(&witness).unwrap(); + assert!(cfg.f >= 1 && cfg.f <= 2); + assert!(cfg.b > 0); + assert!(cfg.fu > 0); + assert!(cfg.bu > 0); + assert!((1..=32).contains(&cfg.kappa)); + assert!((1..=32).contains(&cfg.kappa1)); + assert!(!cfg.tail); + } +} diff --git a/src/protocol/labrador/fold.rs b/src/protocol/labrador/fold.rs new file mode 100644 index 00000000..a446c4e5 --- /dev/null +++ b/src/protocol/labrador/fold.rs @@ -0,0 +1,1087 @@ +//! Labrador amortization transitions (standard and tail levels). + +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +use crate::protocol::commitment::utils::linear::decompose_rows_with_carry; +use crate::protocol::labrador::comkey::{derive_extendable_comkey_matrix, LabradorComKeySeed}; +use crate::protocol::labrador::johnson_lindenstrauss::{ + collapse, project, zero_constant_term_for_proof, LabradorJlMatrix, +}; +use crate::protocol::labrador::transcript::{ + absorb_labrador_jl_nonce, absorb_labrador_jl_projection, absorb_labrador_level_context, + LabradorLevelTranscriptContext, +}; +use crate::protocol::labrador::types::{ + LabradorConstraint, LabradorLevelProof, LabradorReductionConfig, LabradorStatement, + LabradorWitness, +}; +use crate::protocol::labrador::utils::mat_vec_mul; +use crate::protocol::prg::MatrixPrgBackendChoice; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::{challenge_ring_element_rejection_sampled, Transcript}; +use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt}; + +/// Output of one Labrador fold transition. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorFoldResult { + /// Next witness after amortization. + pub next_witness: LabradorWitness, + /// Replay-complete level proof record. + pub level_proof: LabradorLevelProof, + /// Reduced statement consumed by the next verifier step. + pub statement: LabradorStatement, +} + +use crate::protocol::labrador::config::JL_LIFTS; + +/// Perform one Labrador fold level (standard or tail, determined by `config.tail`). +/// +/// Follows the C Labrador protocol phases: +/// 1. Commit: inner + outer Ajtai commitment → u1 +/// 2. Project: JL projection → p\[256\], nonce +/// 3. LIFTS × (collapse + lift): build linear constraints from JL +/// 4. Amortize: absorb into transcript, sample ring-element challenges, +/// fold z = sum_i c_i * s_i, decompose z → output witness +/// +/// # Errors +/// +/// Returns `HachiError::InvalidInput` if the witness is empty or `config.f` is zero. +/// Propagates errors from commitment, projection, or hashing. +#[allow(clippy::too_many_arguments)] +pub fn prove_level( + witness: &LabradorWitness, + statement: &LabradorStatement, + config: &LabradorReductionConfig, + comkey_seed: &LabradorComKeySeed, + jl_seed: &[u8; 16], + backend: MatrixPrgBackendChoice, + level_index: usize, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + T: Transcript, +{ + if witness.rows().is_empty() { + return Err(HachiError::InvalidInput( + "cannot fold empty Labrador witness".to_string(), + )); + } + if config.f == 0 { + return Err(HachiError::InvalidInput( + "Labrador fold requires f > 0".to_string(), + )); + } + let r = witness.rows().len(); + let row_lengths: Vec = witness.rows().iter().map(|row| row.len()).collect(); + let max_len = row_lengths.iter().copied().max().unwrap_or(0); + + // Phase 1: Inner commitments (t_i) and outer commitment u1. + let a = derive_extendable_comkey_matrix::( + config.kappa, + max_len, + comkey_seed, + b"labrador/comkey/A", + backend, + ); + let mut t_hat = Vec::new(); + for row in witness.rows() { + let mut padded = Vec::with_capacity(max_len); + padded.extend_from_slice(row); + padded.resize(max_len, CyclotomicRing::::zero()); + let t = mat_vec_mul(&a, &padded); + t_hat.extend(decompose_rows_with_carry(&t, config.fu, config.bu as u32)); + } + + let u1 = if config.kappa1 > 0 && !config.tail { + let b = derive_extendable_comkey_matrix::( + config.kappa1, + t_hat.len(), + comkey_seed, + b"labrador/comkey/B", + backend, + ); + mat_vec_mul(&b, &t_hat) + } else { + t_hat.clone() + }; + + // Phase 2: JL Projection + let (jl_projection, jl_nonce) = project(witness, jl_seed, backend)?; + + // Transcript: absorb level context, commitments, JL. + absorb_labrador_level_context( + transcript, + &LabradorLevelTranscriptContext { + level_index, + tail: config.tail, + input_row_lengths: row_lengths.clone(), + input_row_chunks: vec![1usize; r], + f: config.f, + b: config.b, + fu: config.fu, + bu: config.bu, + kappa: config.kappa, + kappa1: config.kappa1, + prg_backend_id: backend as u8, + }, + )?; + transcript.append_serde(labels::ABSORB_LABRADOR_U1, &u1); + absorb_labrador_jl_projection(transcript, &jl_projection); + absorb_labrador_jl_nonce(transcript, jl_nonce); + + // Phase 3: JL lift constraints and aggregation. + let (phi_jl, b_jl, bb) = aggregate_jl_constraints_prover( + witness, + &jl_projection, + jl_seed, + jl_nonce, + backend, + transcript, + )?; + + // Aggregate statement constraints (after JL lifts). + let (phi_stmt, b_stmt) = + aggregate_statement_constraints(&statement.constraints, &row_lengths, transcript)?; + + let mut phi_total = phi_stmt; + add_phi_in_place(&mut phi_total, &phi_jl)?; + let b_total = b_stmt + b_jl; + + // Linear garbage h_ij from aggregated phi and witness. + let h = compute_linear_garbage(&phi_total, witness)?; + let h_hat = decompose_rows_with_carry(&h, config.fu, config.bu as u32); + + let u2 = if config.kappa1 > 0 && !config.tail { + let b2 = derive_extendable_comkey_matrix::( + config.kappa1, + h_hat.len(), + comkey_seed, + b"labrador/comkey/U2", + backend, + ); + mat_vec_mul(&b2, &h_hat) + } else { + h_hat.clone() + }; + + // Absorb u2 before amortization challenges. + transcript.append_serde(labels::ABSORB_LABRADOR_U2, &u2); + + // Phase 4: Amortize — sample r challenge ring-elements from transcript, fold. + let mut challenges = Vec::with_capacity(r); + for _ in 0..r { + challenges.push(challenge_ring_element_rejection_sampled( + transcript, + labels::CHALLENGE_LABRADOR_AMORTIZE, + )?); + } + + let z = amortize_witness(witness, &challenges, max_len); + let decomposed_z = decompose_rows_with_carry(&z, config.f, config.b as u32); + let z_rows = split_decomposed_rows(&decomposed_z, config.f, z.len())?; + + let mut output_rows: Vec>> = z_rows; + + if !config.tail { + let mut aux = Vec::with_capacity(t_hat.len() + h_hat.len()); + aux.extend_from_slice(&t_hat); + aux.extend_from_slice(&h_hat); + output_rows.push(aux); + } + + let next_witness = LabradorWitness::new_unchecked(output_rows); + let out_norm_sq: u128 = next_witness.norm(); + + let next_constraints = if config.tail { + Vec::new() + } else { + build_next_constraints( + &phi_total, + &b_total, + &challenges, + &row_lengths, + max_len, + config, + &u1, + &u2, + comkey_seed, + backend, + )? + }; + + let level_proof = LabradorLevelProof { + tail: config.tail, + input_row_lengths: row_lengths, + input_row_chunks: vec![1usize; r], + config: *config, + u1: u1.clone(), + u2: u2.clone(), + jl_projection, + jl_nonce, + bb, + norm_sq: out_norm_sq, + }; + + // NOTE: Recursive statement update is not implemented yet. + let statement = LabradorStatement { + u1, + u2, + challenges: challenges.clone(), + constraints: next_constraints, + beta_sq: out_norm_sq, + hash: [0u8; 16], + }; + + Ok(LabradorFoldResult { + next_witness, + level_proof, + statement, + }) +} + +fn split_decomposed_rows( + flat: &[CyclotomicRing], + parts: usize, + len: usize, +) -> Result>>, HachiError> { + if parts == 0 { + return Err(HachiError::InvalidInput( + "cannot split decomposition with zero parts".to_string(), + )); + } + if flat.len() != len * parts { + return Err(HachiError::InvalidInput(format!( + "decomposition length mismatch: got {}, expected {}", + flat.len(), + len * parts + ))); + } + let mut rows = vec![Vec::with_capacity(len); parts]; + for idx in 0..len { + for part in 0..parts { + rows[part].push(flat[idx * parts + part]); + } + } + Ok(rows) +} + +fn add_phi_in_place( + acc: &mut [Vec>], + other: &[Vec>], +) -> Result<(), HachiError> { + if acc.len() != other.len() { + return Err(HachiError::InvalidInput( + "phi row count mismatch".to_string(), + )); + } + for (row_acc, row_other) in acc.iter_mut().zip(other.iter()) { + if row_acc.len() != row_other.len() { + return Err(HachiError::InvalidInput( + "phi row length mismatch".to_string(), + )); + } + for (a, b) in row_acc.iter_mut().zip(row_other.iter()) { + *a += *b; + } + } + Ok(()) +} + +fn dot_product( + lhs: &[CyclotomicRing], + rhs: &[CyclotomicRing], +) -> CyclotomicRing { + let mut acc = CyclotomicRing::::zero(); + let len = lhs.len().min(rhs.len()); + for i in 0..len { + acc += lhs[i] * rhs[i]; + } + acc +} + +#[allow(clippy::type_complexity)] +fn aggregate_statement_constraints( + constraints: &[LabradorConstraint], + row_lengths: &[usize], + transcript: &mut T, +) -> Result<(Vec>>, CyclotomicRing), HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + let mut phi_total: Vec>> = row_lengths + .iter() + .map(|&len| vec![CyclotomicRing::zero(); len]) + .collect(); + let mut b_total = CyclotomicRing::::zero(); + + if constraints.is_empty() { + return Ok((phi_total, b_total)); + } + + for cnst in constraints { + let outputs = cnst.target.len().max(1); + for out_idx in 0..outputs { + let alpha = challenge_ring_element_rejection_sampled( + transcript, + labels::CHALLENGE_LABRADOR_AGGREGATION, + )?; + let target = cnst + .target + .get(out_idx) + .copied() + .unwrap_or_else(CyclotomicRing::::zero); + b_total += alpha * target; + + for (row_idx, coeffs) in cnst.coefficients.iter().enumerate() { + if coeffs.is_empty() { + continue; + } + if row_idx >= phi_total.len() { + return Err(HachiError::InvalidInput( + "constraint row index out of bounds".to_string(), + )); + } + let row_len = coeffs.len() / outputs; + let coeff_start = out_idx * row_len; + let coeff_slice = &coeffs[coeff_start..coeff_start + row_len]; + + for (j, coeff) in coeff_slice.iter().enumerate() { + phi_total[row_idx][j] += alpha * *coeff; + } + } + } + } + + Ok((phi_total, b_total)) +} + +fn flatten_witness( + witness: &LabradorWitness, +) -> (Vec>, Vec<(usize, usize)>) { + let mut flat = Vec::new(); + let mut ranges = Vec::with_capacity(witness.rows().len()); + let mut cursor = 0usize; + for row in witness.rows() { + let start = cursor; + flat.extend(row.iter().copied()); + cursor += row.len(); + ranges.push((start, cursor)); + } + (flat, ranges) +} + +fn sample_jl_collapse_challenge(transcript: &mut T) -> [i64; 256] +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + std::array::from_fn(|_| { + let s = transcript.challenge_scalar(labels::CHALLENGE_LABRADOR_JL_COLLAPSE); + let c = s.to_canonical_u128(); + if c > half_q { + -((q - c) as i64) + } else { + c as i64 + } + }) +} + +fn jl_collapse_phi_from_weights( + matrix: &LabradorJlMatrix, + omega: &[i64; 256], +) -> Result>, HachiError> { + if matrix.cols % D != 0 { + return Err(HachiError::InvalidInput( + "JL matrix cols not divisible by ring degree".to_string(), + )); + } + let mut weights = vec![0i64; matrix.cols]; + for (row_idx, row) in matrix.signs.iter().enumerate() { + let alpha = omega[row_idx]; + for (col_idx, &sign) in row.iter().enumerate() { + weights[col_idx] += alpha * (sign as i64); + } + } + + let ring_elems = matrix.cols / D; + let mut phi = Vec::with_capacity(ring_elems); + for idx in 0..ring_elems { + let coeffs = std::array::from_fn(|k| { + let w = weights[idx * D + k]; + F::from_i64(w) + }); + phi.push(CyclotomicRing::from_coefficients(coeffs).sigma_m1()); + } + Ok(phi) +} + +#[allow(clippy::type_complexity)] +fn aggregate_jl_constraints_prover( + witness: &LabradorWitness, + jl_projection: &[i32; 256], + jl_seed: &[u8; 16], + jl_nonce: u64, + backend: MatrixPrgBackendChoice, + transcript: &mut T, +) -> Result< + ( + Vec>>, + CyclotomicRing, + Vec>, + ), + HachiError, +> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + let (flat, ranges) = flatten_witness(witness); + let cols = flat + .len() + .checked_mul(D) + .ok_or_else(|| HachiError::InvalidInput("JL column count overflow".into()))?; + if cols == 0 { + return Err(HachiError::InvalidInput( + "JL collapse requires non-empty witness".to_string(), + )); + } + + let matrix = LabradorJlMatrix::generate(jl_seed, jl_nonce, cols, backend)?; + let mut phi_total: Vec>> = witness + .rows() + .iter() + .map(|row| vec![CyclotomicRing::zero(); row.len()]) + .collect(); + let mut b_total = CyclotomicRing::::zero(); + let mut bb = Vec::with_capacity(JL_LIFTS); + + for _ in 0..JL_LIFTS { + let omega = sample_jl_collapse_challenge::(transcript); + let phi_flat = jl_collapse_phi_from_weights::(&matrix, &omega)?; + let b_full = dot_product(&phi_flat, &flat); + let target = collapse(jl_projection, &omega); + let expected_c0 = F::from_i64(target); + if b_full.coefficients()[0] != expected_c0 { + return Err(HachiError::InvalidProof); + } + let (b_tx, _c0) = zero_constant_term_for_proof(b_full); + bb.push(b_tx); + transcript.append_serde(labels::ABSORB_LABRADOR_BB, &b_tx); + + let beta = challenge_ring_element_rejection_sampled( + transcript, + labels::CHALLENGE_LABRADOR_AGGREGATION, + )?; + b_total += beta * b_full; + + for (row_idx, (start, end)) in ranges.iter().enumerate() { + let row = &phi_flat[*start..*end]; + for (j, elem) in row.iter().enumerate() { + phi_total[row_idx][j] += beta * *elem; + } + } + } + + Ok((phi_total, b_total, bb)) +} + +fn compute_linear_garbage( + phi: &[Vec>], + witness: &LabradorWitness, +) -> Result>, HachiError> { + if phi.len() != witness.rows().len() { + return Err(HachiError::InvalidInput( + "phi row count mismatch".to_string(), + )); + } + let mut out = Vec::with_capacity((witness.rows().len() * (witness.rows().len() + 1)) / 2); + for i in 0..witness.rows().len() { + if phi[i].len() != witness.rows()[i].len() { + return Err(HachiError::InvalidInput( + "phi row length mismatch".to_string(), + )); + } + for j in i..witness.rows().len() { + if phi[j].len() != witness.rows()[j].len() { + return Err(HachiError::InvalidInput( + "phi row length mismatch".to_string(), + )); + } + let entry = if i == j { + dot_product(&phi[i], &witness.rows()[i]) + } else { + let lhs = dot_product(&phi[i], &witness.rows()[j]); + let rhs = dot_product(&phi[j], &witness.rows()[i]); + lhs + rhs + }; + out.push(entry); + } + } + Ok(out) +} + +#[allow(clippy::too_many_arguments)] +fn build_next_constraints< + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + const D: usize, +>( + phi_total: &[Vec>], + b_total: &CyclotomicRing, + challenges: &[CyclotomicRing], + row_lengths: &[usize], + max_len: usize, + config: &LabradorReductionConfig, + u1: &[CyclotomicRing], + u2: &[CyclotomicRing], + comkey_seed: &LabradorComKeySeed, + backend: MatrixPrgBackendChoice, +) -> Result>, HachiError> { + let r = row_lengths.len(); + if r == 0 || challenges.len() != r { + return Err(HachiError::InvalidInput( + "challenge row count mismatch".to_string(), + )); + } + if config.f == 0 { + return Err(HachiError::InvalidInput( + "cannot build next constraints with f=0".to_string(), + )); + } + + let pow_b: Vec = (0..config.f) + .map(|idx| pow2_field::(config.b * idx)) + .collect(); + let pow_bu: Vec = (0..config.fu) + .map(|idx| pow2_field::(config.bu * idx)) + .collect(); + + let mut combined_phi = vec![CyclotomicRing::::zero(); max_len]; + for (row_idx, row_phi) in phi_total.iter().enumerate() { + let c = challenges[row_idx]; + for (j, elem) in row_phi.iter().enumerate() { + combined_phi[j] += c * *elem; + } + } + + let mut constraints = Vec::new(); + let t_hat_len = r * config.kappa * config.fu; + let h_len = r * (r + 1) / 2; + let h_hat_len = h_len * config.fu; + let aux_row = config.f; + let aux_row_len = t_hat_len + h_hat_len; + let num_rows = config.f + 1; + + if config.kappa1 > 0 { + if u1.len() != config.kappa1 || u2.len() != config.kappa1 { + return Err(HachiError::InvalidInput( + "u1/u2 length mismatch for next statement".to_string(), + )); + } + + // B · t_hat = u1 + let b = derive_extendable_comkey_matrix::( + config.kappa1, + t_hat_len, + comkey_seed, + b"labrador/comkey/B", + backend, + ); + let mut aux_coeffs = vec![CyclotomicRing::::zero(); config.kappa1 * aux_row_len]; + for (out_idx, b_row) in b.iter().enumerate() { + let start = out_idx * aux_row_len; + for (j, val) in b_row.iter().enumerate() { + aux_coeffs[start + j] = *val; + } + } + let mut coefficients = vec![vec![]; num_rows]; + coefficients[aux_row] = aux_coeffs; + constraints.push(LabradorConstraint { + coefficients, + target: u1.to_vec(), + }); + + // B2 · h_hat = u2 + let b2 = derive_extendable_comkey_matrix::( + config.kappa1, + h_hat_len, + comkey_seed, + b"labrador/comkey/U2", + backend, + ); + let mut aux_coeffs = vec![CyclotomicRing::::zero(); config.kappa1 * aux_row_len]; + for (out_idx, b2_row) in b2.iter().enumerate() { + let start = out_idx * aux_row_len + t_hat_len; + for (j, val) in b2_row.iter().enumerate() { + aux_coeffs[start + j] = *val; + } + } + let mut coefficients = vec![vec![]; num_rows]; + coefficients[aux_row] = aux_coeffs; + constraints.push(LabradorConstraint { + coefficients, + target: u2.to_vec(), + }); + } + + // A·z - c·t = 0 (inner commitment relation) + let a = derive_extendable_comkey_matrix::( + config.kappa, + max_len, + comkey_seed, + b"labrador/comkey/A", + backend, + ); + let mut az_coefficients = vec![vec![]; num_rows]; + for part_idx in 0..config.f { + let scale = pow_b[part_idx]; + let mut coeffs = Vec::with_capacity(config.kappa * max_len); + for a_row in &a { + for elem in a_row.iter() { + coeffs.push(elem.scale(&scale)); + } + } + az_coefficients[part_idx] = coeffs; + } + + let mut t_coeffs = vec![CyclotomicRing::::zero(); config.kappa * t_hat_len]; + for (row_idx, challenge) in challenges.iter().enumerate() { + for (part_idx, &scale) in pow_bu.iter().enumerate() { + let scaled = challenge.scale(&scale); + for k in 0..config.kappa { + let idx = row_idx * config.kappa * config.fu + k * config.fu + part_idx; + let slot = k * t_hat_len + idx; + t_coeffs[slot] = -scaled; + } + } + } + let mut aux_az = vec![CyclotomicRing::::zero(); config.kappa * aux_row_len]; + for k in 0..config.kappa { + let src_start = k * t_hat_len; + let dst_start = k * aux_row_len; + aux_az[dst_start..dst_start + t_hat_len] + .copy_from_slice(&t_coeffs[src_start..src_start + t_hat_len]); + } + az_coefficients[aux_row] = aux_az; + constraints.push(LabradorConstraint { + coefficients: az_coefficients, + target: vec![CyclotomicRing::::zero(); config.kappa], + }); + + // linear garbage constraint + let mut lg_coefficients = vec![vec![]; num_rows]; + for part_idx in 0..config.f { + let scale = pow_b[part_idx]; + let coeffs: Vec> = + combined_phi.iter().map(|elem| elem.scale(&scale)).collect(); + lg_coefficients[part_idx] = coeffs; + } + let mut h_coeffs = vec![CyclotomicRing::::zero(); h_hat_len]; + for i in 0..r { + for j in i..r { + let coeff = challenges[i] * challenges[j]; + let pair = pair_index(i, j, r); + for (part_idx, &scale) in pow_bu.iter().enumerate() { + let idx = pair * config.fu + part_idx; + h_coeffs[idx] = -(coeff.scale(&scale)); + } + } + } + let mut aux_lg = vec![CyclotomicRing::::zero(); aux_row_len]; + aux_lg[t_hat_len..t_hat_len + h_hat_len].copy_from_slice(&h_coeffs); + lg_coefficients[aux_row] = aux_lg; + constraints.push(LabradorConstraint { + coefficients: lg_coefficients, + target: vec![CyclotomicRing::::zero()], + }); + + // diagonal (norm) constraint + let mut diag_coeffs = vec![CyclotomicRing::::zero(); aux_row_len]; + for i in 0..r { + let pair = pair_index(i, i, r); + for (part_idx, &scale) in pow_bu.iter().enumerate() { + let idx = pair * config.fu + part_idx; + diag_coeffs[t_hat_len + idx] = constant_poly(scale); + } + } + let mut diag_coefficients = vec![vec![]; num_rows]; + diag_coefficients[aux_row] = diag_coeffs; + constraints.push(LabradorConstraint { + coefficients: diag_coefficients, + target: vec![*b_total], + }); + + Ok(constraints) +} + +fn pow2_field(exp: usize) -> F { + let two = F::from_u64(2); + let mut acc = F::one(); + for _ in 0..exp { + acc = acc * two; + } + acc +} + +fn constant_poly(value: F) -> CyclotomicRing { + CyclotomicRing::from_coefficients(std::array::from_fn( + |i| { + if i == 0 { + value + } else { + F::zero() + } + }, + )) +} + +fn pair_index(i: usize, j: usize, r: usize) -> usize { + debug_assert!(i <= j && j < r); + i * (2 * r - i + 1) / 2 + (j - i) +} + +/// Compute z = sum_i c_i * s_i (all-row linear combination). +fn amortize_witness( + witness: &LabradorWitness, + challenges: &[CyclotomicRing], + max_len: usize, +) -> Vec> { + let mut z = vec![CyclotomicRing::::zero(); max_len]; + for (row, challenge) in witness.rows().iter().zip(challenges.iter()) { + for (j, elem) in row.iter().enumerate() { + z[j] += *challenge * *elem; + } + } + z +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::protocol::labrador::types::LabradorReductionConfig; + use crate::protocol::labrador::{verify, LabradorProof}; + use crate::protocol::transcript::labels::DOMAIN_LABRADOR_PROTOCOL; + use crate::protocol::transcript::Blake2bTranscript; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + fn sample_witness() -> LabradorWitness { + let row = |len: usize| -> Vec> { + (0..len) + .map(|i| { + CyclotomicRing::from_coefficients(std::array::from_fn(|j| { + F::from_i64(((i + j) as i64 % 5) - 2) + })) + }) + .collect() + }; + LabradorWitness::new(vec![row(4), row(4), row(4)]) + } + + #[test] + fn standard_fold_produces_decomposed_output() { + let witness = sample_witness(); + let statement = LabradorStatement { + u1: Vec::new(), + u2: Vec::new(), + challenges: Vec::new(), + constraints: Vec::new(), + beta_sq: 1 << 20, + hash: [0u8; 16], + }; + let cfg = LabradorReductionConfig { + f: 1, + b: 8, + fu: 2, + bu: 10, + kappa: 3, + kappa1: 2, + tail: false, + }; + let seed = [1u8; 32]; + let jl_seed = [2u8; 16]; + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + let out = prove_level( + &witness, + &statement, + &cfg, + &seed, + &jl_seed, + MatrixPrgBackendChoice::Shake256, + 0, + &mut transcript, + ) + .unwrap(); + assert!( + !out.next_witness.rows().is_empty(), + "fold must produce output witness" + ); + assert_eq!(out.next_witness.rows().len(), cfg.f + 1); + assert!(!out.level_proof.u2.is_empty()); + } + + #[test] + fn tail_fold_produces_decomposed_output() { + let witness = sample_witness(); + let statement = LabradorStatement { + u1: Vec::new(), + u2: Vec::new(), + challenges: Vec::new(), + constraints: Vec::new(), + beta_sq: 1 << 20, + hash: [0u8; 16], + }; + let cfg = LabradorReductionConfig { + f: 1, + b: 8, + fu: 1, + bu: 32, + kappa: 2, + kappa1: 0, + tail: true, + }; + let seed = [3u8; 32]; + let jl_seed = [4u8; 16]; + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + let out = prove_level( + &witness, + &statement, + &cfg, + &seed, + &jl_seed, + MatrixPrgBackendChoice::Shake256, + 1, + &mut transcript, + ) + .unwrap(); + assert!( + !out.next_witness.rows().is_empty(), + "tail fold must produce output" + ); + assert_eq!(out.next_witness.rows().len(), cfg.f); + assert!(out.level_proof.tail); + } + + #[test] + fn amortize_is_linear_combination() { + let witness = sample_witness(); + let one = CyclotomicRing::::one(); + let challenges = vec![one; witness.rows().len()]; + let max_len = witness.rows().iter().map(|r| r.len()).max().unwrap(); + let z = amortize_witness(&witness, &challenges, max_len); + + for (j, z_elem) in z.iter().enumerate().take(max_len) { + let expected = witness + .rows() + .iter() + .map(|row| { + row.get(j) + .copied() + .unwrap_or_else(CyclotomicRing::::zero) + }) + .fold(CyclotomicRing::::zero(), |a, b| a + b); + assert_eq!(*z_elem, expected); + } + } + + #[test] + fn standard_fold_roundtrip_verifies() { + let mk_ring = |c: i64| { + CyclotomicRing::::from_coefficients(std::array::from_fn(|i| { + if i == 0 { + F::from_i64(c) + } else { + F::zero() + } + })) + }; + let witness = LabradorWitness::new(vec![ + vec![mk_ring(1), mk_ring(2)], + vec![mk_ring(3), mk_ring(-1)], + ]); + let target = witness.rows()[0][0] + witness.rows()[1][1]; + let statement = LabradorStatement { + u1: Vec::new(), + u2: Vec::new(), + challenges: Vec::new(), + constraints: vec![LabradorConstraint { + coefficients: vec![vec![mk_ring(1), mk_ring(0)], vec![mk_ring(0), mk_ring(1)]], + target: vec![target], + }], + beta_sq: 1 << 40, + hash: [0u8; 16], + }; + let cfg = LabradorReductionConfig { + f: 4, + b: 8, + fu: 4, + bu: 8, + kappa: 2, + kappa1: 2, + tail: false, + }; + let comkey_seed = [9u8; 32]; + let jl_seed = [7u8; 16]; + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + let fold = prove_level( + &witness, + &statement, + &cfg, + &comkey_seed, + &jl_seed, + MatrixPrgBackendChoice::Shake256, + 0, + &mut transcript, + ) + .unwrap(); + + let proof = LabradorProof { + levels: vec![fold.level_proof], + final_opening_witness: fold.next_witness, + }; + let mut verify_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + verify( + &statement, + &proof, + &comkey_seed, + &jl_seed, + MatrixPrgBackendChoice::Shake256, + &mut verify_transcript, + ) + .unwrap(); + + let base_proof = LabradorProof { + levels: Vec::new(), + final_opening_witness: proof.final_opening_witness.clone(), + }; + let mut base_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + verify( + &fold.statement, + &base_proof, + &comkey_seed, + &jl_seed, + MatrixPrgBackendChoice::Shake256, + &mut base_transcript, + ) + .unwrap(); + } + + #[test] + fn two_level_fold_roundtrip_verifies() { + let mk_ring = |c: i64| { + CyclotomicRing::::from_coefficients(std::array::from_fn(|i| { + if i == 0 { + F::from_i64(c) + } else { + F::zero() + } + })) + }; + let witness = LabradorWitness::new(vec![ + vec![mk_ring(1), mk_ring(2)], + vec![mk_ring(3), mk_ring(-1)], + ]); + let target = witness.rows()[0][0] + witness.rows()[1][1]; + let statement = LabradorStatement { + u1: Vec::new(), + u2: Vec::new(), + challenges: Vec::new(), + constraints: vec![LabradorConstraint { + coefficients: vec![vec![mk_ring(1), mk_ring(0)], vec![mk_ring(0), mk_ring(1)]], + target: vec![target], + }], + beta_sq: 1 << 40, + hash: [0u8; 16], + }; + let cfg = LabradorReductionConfig { + f: 4, + b: 8, + fu: 4, + bu: 8, + kappa: 2, + kappa1: 2, + tail: false, + }; + let comkey_seed = [9u8; 32]; + let jl_seed = [7u8; 16]; + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + let fold1 = prove_level( + &witness, + &statement, + &cfg, + &comkey_seed, + &jl_seed, + MatrixPrgBackendChoice::Shake256, + 0, + &mut transcript, + ) + .unwrap(); + let fold2 = prove_level( + &fold1.next_witness, + &fold1.statement, + &cfg, + &comkey_seed, + &jl_seed, + MatrixPrgBackendChoice::Shake256, + 1, + &mut transcript, + ) + .unwrap(); + + let r = fold2.level_proof.input_row_lengths.len(); + let challenges = &fold2.statement.challenges; + let aux_row = &fold2.next_witness.rows()[cfg.f]; + let t_hat_len = r * cfg.kappa * cfg.fu; + let t_hat = &aux_row[..t_hat_len]; + let mut t_flat = Vec::with_capacity(r * cfg.kappa); + for chunk in t_hat.chunks(cfg.fu) { + t_flat.push(CyclotomicRing::gadget_recompose_pow2(chunk, cfg.bu as u32)); + } + let z_parts: Vec>> = fold2.next_witness.rows()[..cfg.f].to_vec(); + let mut z = Vec::with_capacity(z_parts[0].len()); + for idx in 0..z_parts[0].len() { + let mut slice = Vec::with_capacity(cfg.f); + for part in &z_parts { + slice.push(part[idx]); + } + z.push(CyclotomicRing::gadget_recompose_pow2(&slice, cfg.b as u32)); + } + let a = derive_extendable_comkey_matrix::( + cfg.kappa, + z.len(), + &comkey_seed, + b"labrador/comkey/A", + MatrixPrgBackendChoice::Shake256, + ); + let az = mat_vec_mul(&a, &z); + let mut rhs = vec![CyclotomicRing::::zero(); cfg.kappa]; + for (row_idx, t_row) in t_flat.chunks(cfg.kappa).enumerate() { + let c = challenges[row_idx]; + for k in 0..cfg.kappa { + rhs[k] += c * t_row[k]; + } + } + assert_eq!(az, rhs); + + let proof = LabradorProof { + levels: vec![fold1.level_proof, fold2.level_proof], + final_opening_witness: fold2.next_witness, + }; + let mut verify_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + verify( + &statement, + &proof, + &comkey_seed, + &jl_seed, + MatrixPrgBackendChoice::Shake256, + &mut verify_transcript, + ) + .unwrap(); + } +} diff --git a/src/protocol/labrador/guardrails.rs b/src/protocol/labrador/guardrails.rs new file mode 100644 index 00000000..7d776dac --- /dev/null +++ b/src/protocol/labrador/guardrails.rs @@ -0,0 +1,89 @@ +//! Guardrails for Labrador/Greyhound protocol plumbing. + +use crate::error::HachiError; + +/// Maximum recursion levels accepted by the protocol. +/// +/// Mirrors the fixed upper bound used by the C reference (`proof *pi[16]`). +pub const LABRADOR_MAX_LEVELS: usize = 4; +/// Upper bound for JL nonce search attempts. +pub const LABRADOR_MAX_JL_NONCE_RETRIES: u64 = 1 << 20; +/// Upper bound on challenge polynomials sampled per call. +pub const LABRADOR_MAX_CHALLENGE_POLYS: usize = 1 << 12; +/// Upper bound for temporary byte allocations in Labrador helpers. +pub const LABRADOR_MAX_TEMP_BYTES: usize = 1 << 27; // 128 MiB + +/// Checked conversion from `usize` to `u64`. +/// +/// # Errors +/// +/// Returns an error when `value` does not fit into `u64`. +pub fn checked_usize_to_u64(value: usize, what: &'static str) -> Result { + u64::try_from(value) + .map_err(|_| HachiError::InvalidInput(format!("{what} does not fit in u64: {value}"))) +} + +/// Ensure a value is a power of two. +/// +/// # Errors +/// +/// Returns an error if `value` is not a power of two. +pub fn ensure_power_of_two(value: usize, what: &'static str) -> Result<(), HachiError> { + if !value.is_power_of_two() { + return Err(HachiError::InvalidInput(format!( + "{what} must be a power of two, got {value}" + ))); + } + Ok(()) +} + +/// Checked `a * b` for allocation sizing. +/// +/// # Errors +/// +/// Returns an error if multiplication overflows `usize`. +pub fn checked_mul(a: usize, b: usize, what: &'static str) -> Result { + a.checked_mul(b) + .ok_or_else(|| HachiError::InvalidInput(format!("overflow while computing {what}"))) +} + +/// Checked `a + b` for allocation sizing. +/// +/// # Errors +/// +/// Returns an error if addition overflows `usize`. +pub fn checked_add(a: usize, b: usize, what: &'static str) -> Result { + a.checked_add(b) + .ok_or_else(|| HachiError::InvalidInput(format!("overflow while computing {what}"))) +} + +/// Validate temporary allocation size against guardrail cap. +/// +/// # Errors +/// +/// Returns an error if `bytes > LABRADOR_MAX_TEMP_BYTES`. +pub fn ensure_temp_allocation_limit(bytes: usize, what: &'static str) -> Result<(), HachiError> { + if bytes > LABRADOR_MAX_TEMP_BYTES { + return Err(HachiError::InvalidInput(format!( + "{what} temporary allocation too large: {bytes} bytes (max {LABRADOR_MAX_TEMP_BYTES})" + ))); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn checked_mul_detects_overflow() { + let err = checked_mul(usize::MAX, 2, "overflow-test").unwrap_err(); + assert!(matches!(err, HachiError::InvalidInput(_))); + } + + #[test] + fn temp_limit_enforced() { + let err = ensure_temp_allocation_limit(LABRADOR_MAX_TEMP_BYTES + 1, "tmp").unwrap_err(); + assert!(matches!(err, HachiError::InvalidInput(_))); + } +} diff --git a/src/protocol/labrador/johnson_lindenstrauss.rs b/src/protocol/labrador/johnson_lindenstrauss.rs new file mode 100644 index 00000000..3bed9b4d --- /dev/null +++ b/src/protocol/labrador/johnson_lindenstrauss.rs @@ -0,0 +1,243 @@ +//! Johnson-Lindenstrauss helpers for Labrador reduction. + +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +use crate::protocol::labrador::guardrails::LABRADOR_MAX_JL_NONCE_RETRIES; +use crate::protocol::labrador::types::LabradorWitness; +use crate::protocol::prg::{MatrixPrgBackendChoice, MatrixPrgContext}; +use crate::{CanonicalField, FieldCore}; +use rand_core::RngCore; +use sha3::digest::{ExtendableOutput, Update, XofReader}; +use sha3::Shake256; + +const JL_ROWS: usize = 256; + +/// Binary JL matrix with entries in `{-1, +1}`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorJlMatrix { + /// Number of rows (fixed at 256 in Labrador). + pub rows: usize, + /// Number of columns. + pub cols: usize, + /// Matrix entries as `-1/+1`. + pub signs: Vec>, +} + +impl LabradorJlMatrix { + /// Deterministically generate a JL matrix from seed/nonce. + /// + /// # Errors + /// + /// Returns an error if `cols` is zero. + pub fn generate( + seed: &[u8; 16], + nonce: u64, + cols: usize, + backend: MatrixPrgBackendChoice, + ) -> Result { + if cols == 0 { + return Err(HachiError::InvalidInput( + "JL matrix requires non-zero column count".to_string(), + )); + } + let prg_seed = derive_jl_prg_seed(seed, nonce); + let byte_len = cols.div_ceil(8); + let mut signs = Vec::with_capacity(JL_ROWS); + for row in 0..JL_ROWS { + let context = MatrixPrgContext { + seed: &prg_seed, + matrix_label: b"labrador/jl", + rows: JL_ROWS, + cols, + row, + col: 0, + }; + let mut rng = backend.entry_rng(&context); + let mut bytes = vec![0u8; byte_len]; + rng.fill_bytes(&mut bytes); + let row_signs = (0..cols) + .map(|c| { + let bit = (bytes[c / 8] >> (c % 8)) & 1; + if bit == 0 { + -1 + } else { + 1 + } + }) + .collect(); + signs.push(row_signs); + } + Ok(Self { + rows: JL_ROWS, + cols, + signs, + }) + } +} + +/// Project a witness into 256 JL coordinates and return the nonce used. +/// +/// Nonce search starts from `1` and stops at the first projection that fits +/// signed 32-bit coordinates, up to `LABRADOR_MAX_JL_NONCE_RETRIES`. +/// +/// # Errors +/// +/// Returns an error if the witness is empty or if no valid projection is found +/// within the nonce search limit. +pub fn project( + witness: &LabradorWitness, + seed: &[u8; 16], + backend: MatrixPrgBackendChoice, +) -> Result<([i32; 256], u64), HachiError> { + let vector = flatten_witness_coeffs(witness); + if vector.is_empty() { + return Err(HachiError::InvalidInput( + "cannot JL-project empty witness".to_string(), + )); + } + for nonce in 1..=LABRADOR_MAX_JL_NONCE_RETRIES { + let matrix = LabradorJlMatrix::generate(seed, nonce, vector.len(), backend)?; + if let Some(proj) = project_with_matrix(&matrix, &vector) { + return Ok((proj, nonce)); + } + } + Err(HachiError::InvalidInput(format!( + "failed JL projection nonce search after {LABRADOR_MAX_JL_NONCE_RETRIES} attempts" + ))) +} + +/// Collapse a JL projection with challenge coefficients. +/// +/// Returns the linear target value `sum_i alpha[i] * projection[i]`. +pub fn collapse(projection: &[i32; 256], alpha: &[i64; 256]) -> i64 { + projection + .iter() + .zip(alpha.iter()) + .fold(0i128, |acc, (&p, &a)| acc + (p as i128) * (a as i128)) + .clamp(i64::MIN as i128, i64::MAX as i128) as i64 +} + +/// Zero out a polynomial constant term for proof transmission. +/// +/// Returns the modified polynomial and the removed constant term. +pub fn zero_constant_term_for_proof( + mut poly: CyclotomicRing, +) -> (CyclotomicRing, F) { + let coeffs = poly.coefficients_mut(); + let c0 = coeffs[0]; + coeffs[0] = F::zero(); + (poly, c0) +} + +/// Restore a polynomial constant term during verifier-side reduction. +pub fn restore_constant_term( + mut transmitted: CyclotomicRing, + constant: F, +) -> CyclotomicRing { + transmitted.coefficients_mut()[0] = constant; + transmitted +} + +fn derive_jl_prg_seed(seed: &[u8; 16], nonce: u64) -> [u8; 32] { + let mut xof = Shake256::default(); + xof.update(b"hachi/labrador/jl"); + xof.update(seed); + xof.update(&nonce.to_le_bytes()); + let mut out = [0u8; 32]; + xof.finalize_xof().read(&mut out); + out +} + +fn flatten_witness_coeffs( + witness: &LabradorWitness, +) -> Vec { + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + witness + .rows() + .iter() + .flat_map(|row| row.iter()) + .flat_map(|ring| ring.coefficients().iter()) + .map(|coeff| { + let c = coeff.to_canonical_u128(); + if c > half_q { + -((q - c) as i64) + } else { + c as i64 + } + }) + .collect() +} + +fn project_with_matrix(matrix: &LabradorJlMatrix, vector: &[i64]) -> Option<[i32; 256]> { + if matrix.cols != vector.len() || matrix.rows != JL_ROWS { + return None; + } + let mut out = [0i32; 256]; + for (row_idx, row) in matrix.signs.iter().enumerate() { + let mut acc = 0i128; + for (&sign, &value) in row.iter().zip(vector.iter()) { + acc += (sign as i128) * (value as i128); + } + if acc < i32::MIN as i128 || acc > i32::MAX as i128 { + return None; + } + out[row_idx] = acc as i32; + } + Some(out) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + fn sample_witness() -> LabradorWitness { + let row = |len: usize| -> Vec> { + (0..len) + .map(|i| { + CyclotomicRing::from_coefficients(std::array::from_fn(|j| { + F::from_i64(((i + j) as i64 % 7) - 3) + })) + }) + .collect() + }; + LabradorWitness::new(vec![row(4), row(4)]) + } + + #[test] + fn project_is_deterministic_and_replayable() { + let witness = sample_witness(); + let seed = [9u8; 16]; + let (p1, n1) = project(&witness, &seed, MatrixPrgBackendChoice::Shake256).unwrap(); + let (p2, n2) = project(&witness, &seed, MatrixPrgBackendChoice::Shake256).unwrap(); + assert_eq!(p1, p2); + assert_eq!(n1, n2); + } + + #[test] + fn collapse_matches_dot_product() { + let projection = std::array::from_fn(|i| i as i32 - 10); + let alpha = std::array::from_fn(|i| (2 * i as i64) - 7); + let got = collapse(&projection, &alpha); + let expected = projection + .iter() + .zip(alpha.iter()) + .fold(0i64, |acc, (&p, &a)| acc + (p as i64) * a); + assert_eq!(got, expected); + } + + #[test] + fn lift_zero_and_restore_constant_term() { + let poly: CyclotomicRing = + CyclotomicRing::from_coefficients(std::array::from_fn(|i| F::from_i64(i as i64 - 5))); + let (tx, c0) = zero_constant_term_for_proof(poly); + assert!(tx.coefficients()[0].is_zero()); + let restored = restore_constant_term(tx, c0); + assert_eq!(restored, poly); + } +} diff --git a/src/protocol/labrador/mod.rs b/src/protocol/labrador/mod.rs new file mode 100644 index 00000000..edd14f0c --- /dev/null +++ b/src/protocol/labrador/mod.rs @@ -0,0 +1,30 @@ +//! Labrador recursive proof sub-protocol. +//! +//! This module will host the Greyhound/Labrador integration used by Hachi's +//! recursive handoff path. + +pub mod challenge; +pub mod comkey; +pub mod commit; +pub mod config; +pub mod fold; +pub mod guardrails; +pub mod johnson_lindenstrauss; +pub mod prover; +pub mod transcript; +pub mod types; +pub mod utils; +pub mod verifier; + +pub use commit::{commit_linear_only, LabradorCommitmentArtifacts}; +pub use config::{select_config, sis_secure}; +pub use fold::{prove_level, LabradorFoldResult}; +pub use johnson_lindenstrauss::{ + collapse, project, restore_constant_term, zero_constant_term_for_proof, LabradorJlMatrix, +}; +pub use prover::{prove, prove_with_config}; +pub use types::{ + LabradorConstraint, LabradorLevelProof, LabradorProof, LabradorReductionConfig, + LabradorStatement, LabradorWitness, +}; +pub use verifier::{verify, LabradorVerifyResult}; diff --git a/src/protocol/labrador/prover.rs b/src/protocol/labrador/prover.rs new file mode 100644 index 00000000..b9f4e621 --- /dev/null +++ b/src/protocol/labrador/prover.rs @@ -0,0 +1,332 @@ +//! Labrador prover loop. + +use crate::error::HachiError; +use crate::protocol::labrador::comkey::LabradorComKeySeed; +use crate::protocol::labrador::fold::prove_level; +use crate::protocol::labrador::guardrails::LABRADOR_MAX_LEVELS; +use crate::protocol::labrador::select_config; +use crate::protocol::labrador::types::{LabradorProof, LabradorStatement, LabradorWitness}; +use crate::protocol::labrador::LabradorReductionConfig; +use crate::protocol::prg::MatrixPrgBackendChoice; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt}; + +const ESTIMATED_LOGQ_BITS: usize = 32; + +/// Build a recursive Labrador proof with optional tail acceptance. +/// +/// Standard levels are applied while witness size decreases. Tail mode is then +/// attempted once and accepted only if total `(proof + witness)` size improves. +/// +/// # Errors +/// +/// Returns an error if folding fails or if recursion limits are exceeded. +pub fn prove( + initial_witness: LabradorWitness, + initial_statement: &LabradorStatement, + comkey_seed: &LabradorComKeySeed, + jl_seed: &[u8; 16], + backend: MatrixPrgBackendChoice, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + T: Transcript, +{ + if initial_witness.rows().is_empty() { + return Err(HachiError::InvalidInput( + "cannot prove with empty Labrador witness".to_string(), + )); + } + + let mut levels = Vec::new(); + let mut witness = initial_witness; + let mut _statement = initial_statement.clone(); + let mut level_idx = 0usize; + + while level_idx + 1 < LABRADOR_MAX_LEVELS { + let before_size = witness_size_bits::(&witness); + if before_size == 0 || witness.rows().len() <= 1 { + break; + } + + let cfg = select_config(&witness)?; + let fold = prove_level( + &witness, + &_statement, + &cfg, + comkey_seed, + jl_seed, + backend, + level_idx, + transcript, + )?; + let after_size = witness_size_bits::(&fold.next_witness); + if after_size >= before_size { + break; + } + levels.push(fold.level_proof); + _statement = fold.statement; + witness = fold.next_witness; + level_idx += 1; + } + + if level_idx + 1 < LABRADOR_MAX_LEVELS { + let mut tail_cfg = select_config(&witness)?; + tail_cfg = LabradorReductionConfig { + tail: true, + kappa1: 0, + fu: 1, + bu: ESTIMATED_LOGQ_BITS, + ..tail_cfg + }; + + let baseline_bits = witness_size_bits::(&witness) + + levels + .iter() + .map(level_payload_size_bits::) + .sum::(); + + // Clone transcript so we can roll back if tail doesn't help. + let mut tail_transcript = transcript.clone(); + if let Ok(tail) = prove_level( + &witness, + &_statement, + &tail_cfg, + comkey_seed, + jl_seed, + backend, + level_idx, + &mut tail_transcript, + ) { + let candidate_bits = witness_size_bits::(&tail.next_witness) + + levels + .iter() + .map(level_payload_size_bits::) + .sum::() + + level_payload_size_bits::(&tail.level_proof); + if candidate_bits < baseline_bits { + levels.push(tail.level_proof); + _statement = tail.statement; + witness = tail.next_witness; + *transcript = tail_transcript; + } + } + } + + Ok(LabradorProof { + levels, + final_opening_witness: witness, + }) +} + +/// Build a recursive Labrador proof using a caller-supplied initial config. +/// +/// Falls back to the provided config if `select_config` fails for a level. +/// +/// # Errors +/// +/// Returns [`HachiError`] if any fold level fails (e.g. empty witness, +/// invalid config, or transcript errors). +pub fn prove_with_config( + initial_witness: LabradorWitness, + initial_statement: &LabradorStatement, + initial_config: &LabradorReductionConfig, + comkey_seed: &LabradorComKeySeed, + jl_seed: &[u8; 16], + backend: MatrixPrgBackendChoice, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + T: Transcript, +{ + if initial_witness.rows().is_empty() { + return Err(HachiError::InvalidInput( + "cannot prove with empty Labrador witness".to_string(), + )); + } + + let mut levels = Vec::new(); + let mut witness = initial_witness; + let mut statement = initial_statement.clone(); + let mut level_idx = 0usize; + let mut fallback_cfg = *initial_config; + let mut force_first_level = true; + + while level_idx + 1 < LABRADOR_MAX_LEVELS { + let before_size = witness_size_bits::(&witness); + if before_size == 0 || witness.rows().len() <= 1 { + break; + } + + let cfg = select_config(&witness).unwrap_or(fallback_cfg); + let fold = prove_level( + &witness, + &statement, + &cfg, + comkey_seed, + jl_seed, + backend, + level_idx, + transcript, + )?; + let after_size = witness_size_bits::(&fold.next_witness); + if after_size >= before_size && !force_first_level { + break; + } + + levels.push(fold.level_proof); + statement = fold.statement; + witness = fold.next_witness; + fallback_cfg = cfg; + level_idx += 1; + force_first_level = false; + } + + if level_idx + 1 < LABRADOR_MAX_LEVELS { + let mut tail_cfg = select_config(&witness).unwrap_or(fallback_cfg); + tail_cfg = LabradorReductionConfig { + tail: true, + kappa1: 0, + fu: 1, + bu: ESTIMATED_LOGQ_BITS, + ..tail_cfg + }; + + let baseline_bits = witness_size_bits::(&witness) + + levels + .iter() + .map(level_payload_size_bits::) + .sum::(); + + let mut tail_transcript = transcript.clone(); + if let Ok(tail) = prove_level( + &witness, + &statement, + &tail_cfg, + comkey_seed, + jl_seed, + backend, + level_idx, + &mut tail_transcript, + ) { + let candidate_bits = witness_size_bits::(&tail.next_witness) + + levels + .iter() + .map(level_payload_size_bits::) + .sum::() + + level_payload_size_bits::(&tail.level_proof); + if candidate_bits < baseline_bits { + levels.push(tail.level_proof); + witness = tail.next_witness; + *transcript = tail_transcript; + } + } + } + + Ok(LabradorProof { + levels, + final_opening_witness: witness, + }) +} + +fn witness_size_bits(witness: &LabradorWitness) -> usize { + witness + .rows() + .iter() + .map(|row| row.len() * D * ESTIMATED_LOGQ_BITS) + .sum() +} + +fn level_payload_size_bits( + level: &crate::protocol::labrador::LabradorLevelProof, +) -> usize { + let ring_elems = level.u1.len() + level.u2.len() + level.bb.len(); + let ring_bits = ring_elems * D * ESTIMATED_LOGQ_BITS; + let jl_bits = level.jl_projection.len() * 32; + ring_bits + jl_bits + 64 +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::algebra::ring::CyclotomicRing; + use crate::protocol::transcript::labels::DOMAIN_LABRADOR_PROTOCOL; + use crate::protocol::transcript::Blake2bTranscript; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + fn sample_witness() -> LabradorWitness { + let row = |len: usize| -> Vec> { + (0..len) + .map(|i| { + CyclotomicRing::from_coefficients(std::array::from_fn(|j| { + F::from_i64(((i + j) as i64 % 7) - 3) + })) + }) + .collect() + }; + LabradorWitness::new(vec![row(6), row(6), row(6)]) + } + + #[test] + fn prover_loop_returns_final_opening_witness() { + let statement = crate::protocol::labrador::types::LabradorStatement { + u1: Vec::new(), + u2: Vec::new(), + challenges: Vec::new(), + constraints: Vec::new(), + beta_sq: 1024, + hash: [0u8; 16], + }; + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + let proof = prove( + sample_witness(), + &statement, + &[1u8; 32], + &[2u8; 16], + MatrixPrgBackendChoice::Shake256, + &mut transcript, + ) + .unwrap(); + assert!(!proof.final_opening_witness.rows().is_empty()); + assert!(proof.levels.len() <= LABRADOR_MAX_LEVELS); + } + + #[test] + fn prover_proof_verifies() { + let statement = crate::protocol::labrador::types::LabradorStatement { + u1: Vec::new(), + u2: Vec::new(), + challenges: Vec::new(), + constraints: Vec::new(), + beta_sq: 1 << 30, + hash: [0u8; 16], + }; + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + let proof = prove( + sample_witness(), + &statement, + &[1u8; 32], + &[2u8; 16], + MatrixPrgBackendChoice::Shake256, + &mut transcript, + ) + .unwrap(); + + let mut verify_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + crate::protocol::labrador::verify( + &statement, + &proof, + &[1u8; 32], + &[2u8; 16], + MatrixPrgBackendChoice::Shake256, + &mut verify_transcript, + ) + .unwrap(); + } +} diff --git a/src/protocol/labrador/transcript.rs b/src/protocol/labrador/transcript.rs new file mode 100644 index 00000000..6dc9fcb5 --- /dev/null +++ b/src/protocol/labrador/transcript.rs @@ -0,0 +1,359 @@ +//! Canonical transcript schedule helpers for Greyhound/Labrador. +//! +//! These helpers centralize byte-level encoding for prover/verifier replay: +//! dimension binding, backend binding, and nonce encoding. + +use crate::error::HachiError; +use crate::protocol::labrador::guardrails::checked_usize_to_u64; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, HachiSerialize}; + +/// Greyhound evaluation transcript context. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct GreyhoundEvalTranscriptContext { + /// Matrix rows for reshaped witness. + pub m_rows: usize, + /// Matrix columns for reshaped witness. + pub n_cols: usize, + /// Number of "inner" multilinear variables. + pub inner_vars: usize, + /// Length of the evaluation point vector. + pub eval_point_len: usize, + /// Matrix-PRG backend id bound into Fiat-Shamir. + pub prg_backend_id: u8, +} + +/// Labrador level transcript context. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorLevelTranscriptContext { + /// Zero-based recursion level index. + pub level_index: usize, + /// Whether this level is in tail mode. + pub tail: bool, + /// Input witness row lengths (`n[i]` in the C reference). + pub input_row_lengths: Vec, + /// Input row chunk counts (`nu[i]` in the C reference). + pub input_row_chunks: Vec, + /// Witness decomposition parts. + pub f: usize, + /// Witness decomposition basis log2. + pub b: usize, + /// Commitment decomposition parts. + pub fu: usize, + /// Commitment decomposition basis log2. + pub bu: usize, + /// Inner commitment rank. + pub kappa: usize, + /// Outer commitment rank. + pub kappa1: usize, + /// Matrix-PRG backend id bound into Fiat-Shamir. + pub prg_backend_id: u8, +} + +fn append_u64_le(buf: &mut Vec, value: u64) { + buf.extend_from_slice(&value.to_le_bytes()); +} + +fn encode_usize_slice(buf: &mut Vec, values: &[usize]) -> Result<(), HachiError> { + append_u64_le(buf, checked_usize_to_u64(values.len(), "slice length")?); + for &v in values { + append_u64_le(buf, checked_usize_to_u64(v, "slice element")?); + } + Ok(()) +} + +fn encode_greyhound_eval_context( + ctx: &GreyhoundEvalTranscriptContext, +) -> Result, HachiError> { + let mut bytes = Vec::with_capacity(2 + 8 * 4); + // Versioned payload for deterministic replay stability. + bytes.push(1u8); + bytes.push(ctx.prg_backend_id); + append_u64_le(&mut bytes, checked_usize_to_u64(ctx.m_rows, "m_rows")?); + append_u64_le(&mut bytes, checked_usize_to_u64(ctx.n_cols, "n_cols")?); + append_u64_le( + &mut bytes, + checked_usize_to_u64(ctx.inner_vars, "inner_vars")?, + ); + append_u64_le( + &mut bytes, + checked_usize_to_u64(ctx.eval_point_len, "eval_point_len")?, + ); + Ok(bytes) +} + +fn encode_labrador_level_context( + ctx: &LabradorLevelTranscriptContext, +) -> Result, HachiError> { + let mut bytes = + Vec::with_capacity(4 + 8 * (8 + ctx.input_row_lengths.len() + ctx.input_row_chunks.len())); + // Versioned payload for deterministic replay stability. + bytes.push(1u8); + bytes.push(u8::from(ctx.tail)); + bytes.push(ctx.prg_backend_id); + bytes.push(0u8); // reserved + append_u64_le( + &mut bytes, + checked_usize_to_u64(ctx.level_index, "level_index")?, + ); + append_u64_le(&mut bytes, checked_usize_to_u64(ctx.f, "f")?); + append_u64_le(&mut bytes, checked_usize_to_u64(ctx.b, "b")?); + append_u64_le(&mut bytes, checked_usize_to_u64(ctx.fu, "fu")?); + append_u64_le(&mut bytes, checked_usize_to_u64(ctx.bu, "bu")?); + append_u64_le(&mut bytes, checked_usize_to_u64(ctx.kappa, "kappa")?); + append_u64_le(&mut bytes, checked_usize_to_u64(ctx.kappa1, "kappa1")?); + encode_usize_slice(&mut bytes, &ctx.input_row_lengths)?; + encode_usize_slice(&mut bytes, &ctx.input_row_chunks)?; + Ok(bytes) +} + +/// Absorb canonical Greyhound evaluation context bytes. +/// +/// # Errors +/// +/// Returns an error if any dimension does not fit in `u64`. +pub fn absorb_greyhound_eval_context( + transcript: &mut T, + ctx: &GreyhoundEvalTranscriptContext, +) -> Result<(), HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + let bytes = encode_greyhound_eval_context(ctx)?; + transcript.append_bytes(labels::ABSORB_GREYHOUND_EVAL_CONTEXT, &bytes); + Ok(()) +} + +/// Absorb canonical Greyhound evaluation claim bytes (`r` and `v`). +pub fn absorb_greyhound_eval_claim(transcript: &mut T, eval_point: &[F], eval_value: &F) +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + for coord in eval_point { + transcript.append_field(labels::ABSORB_GREYHOUND_EVAL_POINT, coord); + } + transcript.append_field(labels::ABSORB_GREYHOUND_EVAL_VALUE, eval_value); +} + +/// Absorb Greyhound commitment payload `u2`. +pub fn absorb_greyhound_u2(transcript: &mut T, u2: &S) +where + F: FieldCore + CanonicalField, + T: Transcript, + S: HachiSerialize, +{ + transcript.append_serde(labels::ABSORB_GREYHOUND_U2, u2); +} + +/// Sample a Greyhound fold challenge. +pub fn sample_greyhound_fold_challenge(transcript: &mut T) -> F +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + transcript.challenge_scalar(labels::CHALLENGE_GREYHOUND_FOLD) +} + +/// Absorb canonical Labrador level context bytes. +/// +/// # Errors +/// +/// Returns an error if any dimension does not fit in `u64`. +pub fn absorb_labrador_level_context( + transcript: &mut T, + ctx: &LabradorLevelTranscriptContext, +) -> Result<(), HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + let bytes = encode_labrador_level_context(ctx)?; + transcript.append_bytes(labels::ABSORB_LABRADOR_LEVEL_CONTEXT, &bytes); + Ok(()) +} + +/// Absorb Labrador JL projection vector bytes (`i32` little-endian). +pub fn absorb_labrador_jl_projection(transcript: &mut T, projection: &[i32; 256]) +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + let mut bytes = Vec::with_capacity(256 * std::mem::size_of::()); + for coeff in projection { + bytes.extend_from_slice(&coeff.to_le_bytes()); + } + transcript.append_bytes(labels::ABSORB_LABRADOR_JL_PROJECTION, &bytes); +} + +/// Absorb Labrador JL nonce (`u64` little-endian). +pub fn absorb_labrador_jl_nonce(transcript: &mut T, nonce: u64) +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + transcript.append_bytes(labels::ABSORB_LABRADOR_JL_NONCE, &nonce.to_le_bytes()); +} + +/// Sample a Labrador aggregation challenge. +pub fn sample_labrador_aggregation_challenge(transcript: &mut T) -> F +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + transcript.challenge_scalar(labels::CHALLENGE_LABRADOR_AGGREGATION) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::protocol::transcript::Blake2bTranscript; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + + // Fixed test nonces for deterministic replay. + const TEST_NONCE_LOW: u64 = 1; + const TEST_NONCE_HIGH: u64 = 2; + const TEST_NONCE_REPLAY: u64 = 42; + + #[test] + fn greyhound_context_replay_is_deterministic() { + let ctx = GreyhoundEvalTranscriptContext { + m_rows: 64, + n_cols: 128, + inner_vars: 6, + eval_point_len: 13, + prg_backend_id: 1, + }; + let eval_point: Vec = (0..13).map(|i| F::from_u64((i + 3) as u64)).collect(); + let eval_value = F::from_u64(77); + let u2 = vec![F::from_u64(9), F::from_u64(11), F::from_u64(13)]; + + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_GREYHOUND_EVAL); + absorb_greyhound_eval_context::(&mut t1, &ctx).unwrap(); + absorb_greyhound_eval_claim::(&mut t1, &eval_point, &eval_value); + absorb_greyhound_u2::(&mut t1, &u2); + let c1 = sample_greyhound_fold_challenge::(&mut t1); + + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_GREYHOUND_EVAL); + absorb_greyhound_eval_context::(&mut t2, &ctx).unwrap(); + absorb_greyhound_eval_claim::(&mut t2, &eval_point, &eval_value); + absorb_greyhound_u2::(&mut t2, &u2); + let c2 = sample_greyhound_fold_challenge::(&mut t2); + + assert_eq!(c1, c2, "same transcript schedule must replay identically"); + } + + #[test] + fn greyhound_context_binds_dimensions() { + let eval_point: Vec = (0..10).map(|i| F::from_u64((i + 5) as u64)).collect(); + let eval_value = F::from_u64(17); + let u2 = vec![F::from_u64(1), F::from_u64(2)]; + + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_GREYHOUND_EVAL); + absorb_greyhound_eval_context::( + &mut t1, + &GreyhoundEvalTranscriptContext { + m_rows: 32, + n_cols: 32, + inner_vars: 5, + eval_point_len: 10, + prg_backend_id: 1, + }, + ) + .unwrap(); + absorb_greyhound_eval_claim::(&mut t1, &eval_point, &eval_value); + absorb_greyhound_u2::(&mut t1, &u2); + let c1 = sample_greyhound_fold_challenge::(&mut t1); + + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_GREYHOUND_EVAL); + absorb_greyhound_eval_context::( + &mut t2, + &GreyhoundEvalTranscriptContext { + m_rows: 32, + n_cols: 64, // dimension changed + inner_vars: 5, + eval_point_len: 10, + prg_backend_id: 1, + }, + ) + .unwrap(); + absorb_greyhound_eval_claim::(&mut t2, &eval_point, &eval_value); + absorb_greyhound_u2::(&mut t2, &u2); + let c2 = sample_greyhound_fold_challenge::(&mut t2); + + assert_ne!( + c1, c2, + "dimension changes must affect transcript challenges" + ); + } + + #[test] + fn labrador_context_and_nonce_replay_is_deterministic() { + let ctx = LabradorLevelTranscriptContext { + level_index: 2, + tail: false, + input_row_lengths: vec![1024, 2048, 128, 64], + input_row_chunks: vec![16, 32, 4, 2], + f: 2, + b: 8, + fu: 3, + bu: 10, + kappa: 12, + kappa1: 6, + prg_backend_id: 1, + }; + let projection = std::array::from_fn(|i| i as i32 - 127); + let nonce = TEST_NONCE_REPLAY; + + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_LABRADOR_PROTOCOL); + absorb_labrador_level_context::(&mut t1, &ctx).unwrap(); + absorb_labrador_jl_projection::(&mut t1, &projection); + absorb_labrador_jl_nonce::(&mut t1, nonce); + let c1 = sample_labrador_aggregation_challenge::(&mut t1); + + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_LABRADOR_PROTOCOL); + absorb_labrador_level_context::(&mut t2, &ctx).unwrap(); + absorb_labrador_jl_projection::(&mut t2, &projection); + absorb_labrador_jl_nonce::(&mut t2, nonce); + let c2 = sample_labrador_aggregation_challenge::(&mut t2); + + assert_eq!(c1, c2, "identical schedule must be replay deterministic"); + } + + #[test] + fn labrador_nonce_binding_changes_challenge() { + let ctx = LabradorLevelTranscriptContext { + level_index: 0, + tail: true, + input_row_lengths: vec![64, 32], + input_row_chunks: vec![2, 1], + f: 1, + b: 8, + fu: 2, + bu: 10, + kappa: 4, + kappa1: 0, + prg_backend_id: 0, + }; + let projection = [0i32; 256]; + + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_LABRADOR_PROTOCOL); + absorb_labrador_level_context::(&mut t1, &ctx).unwrap(); + absorb_labrador_jl_projection::(&mut t1, &projection); + absorb_labrador_jl_nonce::(&mut t1, TEST_NONCE_LOW); + let c1 = sample_labrador_aggregation_challenge::(&mut t1); + + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_LABRADOR_PROTOCOL); + absorb_labrador_level_context::(&mut t2, &ctx).unwrap(); + absorb_labrador_jl_projection::(&mut t2, &projection); + absorb_labrador_jl_nonce::(&mut t2, TEST_NONCE_HIGH); + let c2 = sample_labrador_aggregation_challenge::(&mut t2); + + assert_ne!(c1, c2, "nonce must be transcript-binding"); + } +} diff --git a/src/protocol/labrador/types.rs b/src/protocol/labrador/types.rs new file mode 100644 index 00000000..a4e49f5a --- /dev/null +++ b/src/protocol/labrador/types.rs @@ -0,0 +1,171 @@ +//! Core Labrador witness/statement/proof types. + +use crate::algebra::ring::CyclotomicRing; +use crate::{CanonicalField, FieldCore}; + +/// Witness object for a Labrador statement, holding the `s_i` row vectors. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct LabradorWitness { + rows: Vec>>, +} + +impl LabradorWitness { + /// Build a witness from row vectors, all of which must share the same length. + /// + /// # Panics + /// + /// Panics if any two rows differ in length. + pub fn new(rows: Vec>>) -> Self { + if let Some(first_len) = rows.first().map(|r| r.len()) { + assert!( + rows.iter().all(|r| r.len() == first_len), + "all witness rows must have the same length" + ); + } + Self { rows } + } + + /// Build a witness without asserting uniform row length. + /// + /// Use only where the protocol produces rows of mixed length + /// (e.g. z-decomposition rows plus an auxiliary row). + pub(crate) fn new_unchecked(rows: Vec>>) -> Self { + Self { rows } + } + + /// Borrow the underlying row slices. + pub(crate) fn rows(&self) -> &[Vec>] { + &self.rows + } +} + +impl LabradorWitness { + /// Squared coefficient norm summed over every ring element in the witness. + pub fn norm(&self) -> u128 { + self.rows + .iter() + .flat_map(|row| row.iter()) + .map(|ring| ring.coeff_norm_sq()) + .fold(0u128, |acc, v| acc.saturating_add(v)) + } +} + +/// Linear constraint: `sum_i = target`. +/// +/// `coefficients[i]` holds the φ_i vector for witness row `i`. +/// For multi-output constraints (`target.len() > 1`), each coefficient +/// vector is `outputs * row_len` long: output `k` occupies +/// `coefficients[i][k*row_len..(k+1)*row_len]`. +/// An empty inner vec means the row does not participate. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorConstraint { + /// Per-row coefficient vectors (one per witness row). + pub coefficients: Vec>>, + /// Right-hand side target vector. + pub target: Vec>, +} + +/// Public statement reduced to Labrador recursion. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorStatement { + /// Outer commitment for opening relation. + pub u1: Vec>, + /// Outer commitment for linear-garbage relation. + pub u2: Vec>, + /// Amortization challenges (per input witness row). + pub challenges: Vec>, + /// Sparse constraints checked by reducer/verifier. + pub constraints: Vec>, + /// Squared norm bound. + pub beta_sq: u128, + /// Statement hash binding. + pub hash: [u8; 16], +} + +/// Per-level reduction parameters. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct LabradorReductionConfig { + /// Witness decomposition parts. + pub f: usize, + /// Witness decomposition basis log2. + pub b: usize, + /// Commitment decomposition parts. + pub fu: usize, + /// Commitment decomposition basis log2. + pub bu: usize, + /// Inner commitment rank. + pub kappa: usize, + /// Outer commitment rank (`0` in tail mode). + pub kappa1: usize, + /// Tail-mode marker. + pub tail: bool, +} + +/// One recursive Labrador level proof payload. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorLevelProof { + /// Whether this level uses tail semantics. + pub tail: bool, + /// Input row lengths (`n[i]` in C). + pub input_row_lengths: Vec, + /// Input row chunk counts (`nu[i]` in C). + pub input_row_chunks: Vec, + /// Configuration selected for this level. + pub config: LabradorReductionConfig, + /// First outer commitment. + pub u1: Vec>, + /// Second outer commitment. + pub u2: Vec>, + /// JL projection vector. + pub jl_projection: [i32; 256], + /// JL nonce used to regenerate projection matrix. + pub jl_nonce: u64, + /// Lift polynomials (constant term zeroed in proof). + pub bb: Vec>, + /// Output witness norm bound after reduction. + pub norm_sq: u128, +} + +/// Full recursive Labrador proof plus final clear opening witness. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorProof { + /// Recursive level payloads. + pub levels: Vec>, + /// Final clear witness opened at recursion termination. + pub final_opening_witness: LabradorWitness, +} + +impl LabradorLevelProof { + /// Serialized size of this level in bytes. + pub fn size(&self) -> usize { + let ring_bytes = std::mem::size_of::>(); + let ring_count = self.u1.len() + self.u2.len() + self.bb.len(); + ring_count * ring_bytes + + self.jl_projection.len() * std::mem::size_of::() + + std::mem::size_of::() // jl_nonce + + std::mem::size_of::() // norm_sq + } +} + +impl LabradorProof { + /// Construct an empty proof (used when Labrador is disabled). + pub fn empty() -> Self { + Self { + levels: Vec::new(), + final_opening_witness: LabradorWitness { rows: Vec::new() }, + } + } + + /// Total serialized size of the proof in bytes. + pub fn size(&self) -> usize { + let ring_bytes = std::mem::size_of::>(); + let levels_size: usize = self.levels.iter().map(|l| l.size()).sum(); + let witness_rings: usize = self + .final_opening_witness + .rows + .iter() + .map(|r| r.len()) + .sum(); + levels_size + witness_rings * ring_bytes + } +} diff --git a/src/protocol/labrador/utils.rs b/src/protocol/labrador/utils.rs new file mode 100644 index 00000000..b84cedc8 --- /dev/null +++ b/src/protocol/labrador/utils.rs @@ -0,0 +1,20 @@ +//! Shared utility helpers for the Labrador sub-protocol. + +use crate::algebra::ring::CyclotomicRing; +use crate::FieldCore; + +pub(crate) fn mat_vec_mul( + mat: &[Vec>], + vec: &[CyclotomicRing], +) -> Vec> { + mat.iter() + .map(|row| { + debug_assert_eq!(row.len(), vec.len()); + let mut acc = CyclotomicRing::::zero(); + for (a, x) in row.iter().zip(vec.iter()) { + acc += *a * *x; + } + acc + }) + .collect() +} diff --git a/src/protocol/labrador/verifier.rs b/src/protocol/labrador/verifier.rs new file mode 100644 index 00000000..ceb3f625 --- /dev/null +++ b/src/protocol/labrador/verifier.rs @@ -0,0 +1,1137 @@ +//! Labrador verifier/reducer loop. + +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +use crate::protocol::labrador::comkey::{derive_extendable_comkey_matrix, LabradorComKeySeed}; +use crate::protocol::labrador::guardrails::LABRADOR_MAX_LEVELS; +use crate::protocol::labrador::johnson_lindenstrauss::{ + collapse, restore_constant_term, LabradorJlMatrix, +}; +use crate::protocol::labrador::transcript::{ + absorb_labrador_jl_nonce, absorb_labrador_jl_projection, absorb_labrador_level_context, + LabradorLevelTranscriptContext, +}; +use crate::protocol::labrador::types::{ + LabradorConstraint, LabradorLevelProof, LabradorProof, LabradorReductionConfig, + LabradorStatement, LabradorWitness, +}; +use crate::protocol::labrador::utils::mat_vec_mul; +use crate::protocol::prg::MatrixPrgBackendChoice; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::{challenge_ring_element_rejection_sampled, Transcript}; +use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt}; + +/// Output of verifier-side Labrador reduction. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorVerifyResult { + /// Statement after replaying all reduction levels. + pub terminal_statement: LabradorStatement, + /// Final clear opening witness from the proof payload. + pub final_opening_witness: LabradorWitness, +} + +use crate::protocol::labrador::config::JL_LIFTS; + +/// Verify Labrador proof and return terminal reduction state. +/// +/// Currently supports a single Labrador level; recursive reduction is +/// intentionally deferred until the folding statement update is implemented. +/// +/// # Errors +/// +/// Returns [`HachiError::InvalidProof`] on structural inconsistencies, +/// norm bound violations, or constraint failures. +pub fn verify( + initial_statement: &LabradorStatement, + proof: &LabradorProof, + comkey_seed: &LabradorComKeySeed, + jl_seed: &[u8; 16], + backend: MatrixPrgBackendChoice, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + T: Transcript, +{ + if proof.levels.len() > LABRADOR_MAX_LEVELS || proof.final_opening_witness.rows().is_empty() { + return Err(HachiError::InvalidProof); + } + + if proof.levels.is_empty() { + let final_norm = proof.final_opening_witness.norm(); + if final_norm > initial_statement.beta_sq { + return Err(HachiError::InvalidProof); + } + verify_constraints(&initial_statement.constraints, &proof.final_opening_witness)?; + return Ok(LabradorVerifyResult { + terminal_statement: initial_statement.clone(), + final_opening_witness: proof.final_opening_witness.clone(), + }); + } + + let mut statement = initial_statement.clone(); + let last_idx = proof.levels.len() - 1; + for (idx, level) in proof.levels.iter().enumerate() { + if level.tail { + if idx != last_idx { + return Err(HachiError::InvalidProof); + } + verify_tail_level( + &statement, + level, + &proof.final_opening_witness, + comkey_seed, + jl_seed, + backend, + transcript, + idx, + )?; + return Ok(LabradorVerifyResult { + terminal_statement: statement, + final_opening_witness: proof.final_opening_witness.clone(), + }); + } + statement = reduce_statement( + &statement, + level, + comkey_seed, + jl_seed, + backend, + transcript, + idx, + )?; + } + + let final_norm = proof.final_opening_witness.norm(); + if final_norm > statement.beta_sq { + return Err(HachiError::InvalidProof); + } + verify_constraints(&statement.constraints, &proof.final_opening_witness)?; + + Ok(LabradorVerifyResult { + terminal_statement: statement, + final_opening_witness: proof.final_opening_witness.clone(), + }) +} + +fn reduce_statement( + statement: &LabradorStatement, + level: &LabradorLevelProof, + comkey_seed: &LabradorComKeySeed, + jl_seed: &[u8; 16], + backend: MatrixPrgBackendChoice, + transcript: &mut T, + level_index: usize, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + T: Transcript, +{ + if level.tail { + return Err(HachiError::InvalidProof); + } + let r = level.input_row_lengths.len(); + if r == 0 || level.input_row_chunks.len() != r { + return Err(HachiError::InvalidProof); + } + if level.config.f == 0 || level.config.fu == 0 { + return Err(HachiError::InvalidProof); + } + let max_len = level.input_row_lengths.iter().copied().max().unwrap_or(0); + + absorb_labrador_level_context( + transcript, + &LabradorLevelTranscriptContext { + level_index, + tail: level.tail, + input_row_lengths: level.input_row_lengths.clone(), + input_row_chunks: level.input_row_chunks.clone(), + f: level.config.f, + b: level.config.b, + fu: level.config.fu, + bu: level.config.bu, + kappa: level.config.kappa, + kappa1: level.config.kappa1, + prg_backend_id: backend as u8, + }, + )?; + transcript.append_serde(labels::ABSORB_LABRADOR_U1, &level.u1); + absorb_labrador_jl_projection(transcript, &level.jl_projection); + absorb_labrador_jl_nonce(transcript, level.jl_nonce); + + let (phi_jl, b_jl) = aggregate_jl_constraints_verifier( + &level.input_row_lengths, + &level.jl_projection, + jl_seed, + level.jl_nonce, + &level.bb, + backend, + transcript, + )?; + let (phi_stmt, b_stmt) = aggregate_statement_constraints( + &statement.constraints, + &level.input_row_lengths, + transcript, + )?; + + let mut phi_total = phi_stmt; + add_phi_in_place(&mut phi_total, &phi_jl)?; + let b_total = b_stmt + b_jl; + + transcript.append_serde(labels::ABSORB_LABRADOR_U2, &level.u2); + let mut challenges = Vec::with_capacity(r); + for _ in 0..r { + challenges.push(challenge_ring_element_rejection_sampled( + transcript, + labels::CHALLENGE_LABRADOR_AMORTIZE, + )?); + } + + let next_constraints = build_next_constraints( + &phi_total, + &b_total, + &challenges, + &level.input_row_lengths, + max_len, + &level.config, + &level.u1, + &level.u2, + comkey_seed, + backend, + )?; + + Ok(LabradorStatement { + u1: level.u1.clone(), + u2: level.u2.clone(), + challenges, + constraints: next_constraints, + beta_sq: level.norm_sq, + hash: [0u8; 16], + }) +} + +#[allow(clippy::too_many_lines, clippy::too_many_arguments)] +fn verify_tail_level( + statement: &LabradorStatement, + level: &LabradorLevelProof, + witness: &LabradorWitness, + comkey_seed: &LabradorComKeySeed, + jl_seed: &[u8; 16], + backend: MatrixPrgBackendChoice, + transcript: &mut T, + level_index: usize, +) -> Result<(), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + T: Transcript, +{ + if !level.tail { + return Err(HachiError::InvalidProof); + } + let r = level.input_row_lengths.len(); + if r == 0 || level.input_row_chunks.len() != r { + return Err(HachiError::InvalidProof); + } + if level.config.f == 0 || level.config.fu == 0 { + return Err(HachiError::InvalidProof); + } + let max_len = level.input_row_lengths.iter().copied().max().unwrap_or(0); + if witness.rows().len() != level.config.f { + return Err(HachiError::InvalidProof); + } + for row in witness.rows() { + if row.len() != max_len { + return Err(HachiError::InvalidProof); + } + } + + let t_hat_len = r * level.config.kappa * level.config.fu; + let h_hat_len = r * (r + 1) / 2 * level.config.fu; + if level.u1.len() != t_hat_len || level.u2.len() != h_hat_len { + return Err(HachiError::InvalidProof); + } + let t_hat = &level.u1; + let h_hat = &level.u2; + + absorb_labrador_level_context( + transcript, + &LabradorLevelTranscriptContext { + level_index, + tail: level.tail, + input_row_lengths: level.input_row_lengths.clone(), + input_row_chunks: level.input_row_chunks.clone(), + f: level.config.f, + b: level.config.b, + fu: level.config.fu, + bu: level.config.bu, + kappa: level.config.kappa, + kappa1: level.config.kappa1, + prg_backend_id: backend as u8, + }, + )?; + transcript.append_serde(labels::ABSORB_LABRADOR_U1, &level.u1); + absorb_labrador_jl_projection(transcript, &level.jl_projection); + absorb_labrador_jl_nonce(transcript, level.jl_nonce); + + let (phi_jl, b_jl) = aggregate_jl_constraints_verifier( + &level.input_row_lengths, + &level.jl_projection, + jl_seed, + level.jl_nonce, + &level.bb, + backend, + transcript, + )?; + let (phi_stmt, b_stmt) = aggregate_statement_constraints( + &statement.constraints, + &level.input_row_lengths, + transcript, + )?; + let mut phi_total = phi_stmt; + add_phi_in_place(&mut phi_total, &phi_jl)?; + let b_total = b_stmt + b_jl; + + transcript.append_serde(labels::ABSORB_LABRADOR_U2, &level.u2); + let mut challenges = Vec::with_capacity(r); + for _ in 0..r { + challenges.push(challenge_ring_element_rejection_sampled( + transcript, + labels::CHALLENGE_LABRADOR_AMORTIZE, + )?); + } + + let z_parts: Vec>> = witness.rows().to_vec(); + let z = recompose_from_parts(&z_parts, level.config.b as u32)?; + let t_flat = recompose_flat(t_hat, level.config.fu, level.config.bu as u32)?; + let h_flat = recompose_flat(h_hat, level.config.fu, level.config.bu as u32)?; + if t_flat.len() != r * level.config.kappa || h_flat.len() != r * (r + 1) / 2 { + return Err(HachiError::InvalidProof); + } + + let computed_norm = witness.norm(); + if computed_norm > level.norm_sq { + return Err(HachiError::InvalidProof); + } + if projection_norm_sq(&level.jl_projection) > 128u128.saturating_mul(statement.beta_sq) { + return Err(HachiError::InvalidProof); + } + + let a = derive_extendable_comkey_matrix::( + level.config.kappa, + z.len(), + comkey_seed, + b"labrador/comkey/A", + backend, + ); + let az = mat_vec_mul(&a, &z); + let mut rhs = vec![CyclotomicRing::::zero(); level.config.kappa]; + for (i, t_row) in t_flat.chunks(level.config.kappa).enumerate() { + let c = challenges[i]; + for k in 0..level.config.kappa { + rhs[k] += c * t_row[k]; + } + } + if az != rhs { + return Err(HachiError::InvalidProof); + } + + let mut combined_phi = vec![CyclotomicRing::::zero(); max_len]; + for (i, phi_row) in phi_total.iter().enumerate() { + let c = challenges[i]; + for (j, elem) in phi_row.iter().enumerate() { + combined_phi[j] += c * *elem; + } + } + let lhs = dot_product(&combined_phi, &z); + let mut rhs = CyclotomicRing::::zero(); + let mut idx = 0usize; + for i in 0..r { + for j in i..r { + rhs += challenges[i] * challenges[j] * h_flat[idx]; + idx += 1; + } + } + if lhs != rhs { + return Err(HachiError::InvalidProof); + } + + let mut diag_sum = CyclotomicRing::::zero(); + for i in 0..r { + let idx = diag_index(i, r); + diag_sum += h_flat[idx]; + } + if diag_sum - b_total != CyclotomicRing::::zero() { + return Err(HachiError::InvalidProof); + } + + Ok(()) +} + +#[allow(clippy::too_many_lines)] +#[allow(dead_code)] +fn verify_single_level( + statement: &LabradorStatement, + level: &LabradorLevelProof, + witness: &LabradorWitness, + comkey_seed: &LabradorComKeySeed, + jl_seed: &[u8; 16], + backend: MatrixPrgBackendChoice, + transcript: &mut T, +) -> Result<(), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + T: Transcript, +{ + if level.tail { + return Err(HachiError::InvalidProof); + } + let r = level.input_row_lengths.len(); + if r == 0 || level.input_row_chunks.len() != r { + return Err(HachiError::InvalidProof); + } + if level.config.f == 0 || level.config.fu == 0 { + return Err(HachiError::InvalidProof); + } + + let max_len = level.input_row_lengths.iter().copied().max().unwrap_or(0); + let expected_rows = level.config.f + 1; + if witness.rows().len() != expected_rows { + return Err(HachiError::InvalidProof); + } + for row in witness.rows().iter().take(level.config.f) { + if row.len() != max_len { + return Err(HachiError::InvalidProof); + } + } + + let t_hat_len = r * level.config.kappa * level.config.fu; + let h_hat_len = r * (r + 1) / 2 * level.config.fu; + let aux = &witness.rows()[level.config.f]; + if aux.len() != t_hat_len + h_hat_len { + return Err(HachiError::InvalidProof); + } + let (t_hat, h_hat) = aux.split_at(t_hat_len); + + // Transcript: absorb level context, commitments, JL. + absorb_labrador_level_context( + transcript, + &LabradorLevelTranscriptContext { + level_index: 0, + tail: level.tail, + input_row_lengths: level.input_row_lengths.clone(), + input_row_chunks: level.input_row_chunks.clone(), + f: level.config.f, + b: level.config.b, + fu: level.config.fu, + bu: level.config.bu, + kappa: level.config.kappa, + kappa1: level.config.kappa1, + prg_backend_id: backend as u8, + }, + )?; + transcript.append_serde(labels::ABSORB_LABRADOR_U1, &level.u1); + absorb_labrador_jl_projection(transcript, &level.jl_projection); + absorb_labrador_jl_nonce(transcript, level.jl_nonce); + + let (phi_jl, b_jl) = aggregate_jl_constraints_verifier( + &level.input_row_lengths, + &level.jl_projection, + jl_seed, + level.jl_nonce, + &level.bb, + backend, + transcript, + )?; + let (phi_stmt, b_stmt) = aggregate_statement_constraints( + &statement.constraints, + &level.input_row_lengths, + transcript, + )?; + + let mut phi_total = phi_stmt; + add_phi_in_place(&mut phi_total, &phi_jl)?; + let b_total = b_stmt + b_jl; + + transcript.append_serde(labels::ABSORB_LABRADOR_U2, &level.u2); + + let mut challenges = Vec::with_capacity(r); + for _ in 0..r { + challenges.push(challenge_ring_element_rejection_sampled( + transcript, + labels::CHALLENGE_LABRADOR_AMORTIZE, + )?); + } + + let z_parts: Vec>> = witness + .rows() + .iter() + .take(level.config.f) + .cloned() + .collect(); + let z = recompose_from_parts(&z_parts, level.config.b as u32)?; + + let t_flat = recompose_flat(t_hat, level.config.fu, level.config.bu as u32)?; + let h_flat = recompose_flat(h_hat, level.config.fu, level.config.bu as u32)?; + if t_flat.len() != r * level.config.kappa { + return Err(HachiError::InvalidProof); + } + if h_flat.len() != r * (r + 1) / 2 { + return Err(HachiError::InvalidProof); + } + let mut t_by_row = Vec::with_capacity(r); + for chunk in t_flat.chunks(level.config.kappa) { + t_by_row.push(chunk.to_vec()); + } + + if !statement.u1.is_empty() && statement.u1 != level.u1 { + return Err(HachiError::InvalidProof); + } + if !statement.u2.is_empty() && statement.u2 != level.u2 { + return Err(HachiError::InvalidProof); + } + + if level.config.kappa1 > 0 { + let b = derive_extendable_comkey_matrix::( + level.config.kappa1, + t_hat.len(), + comkey_seed, + b"labrador/comkey/B", + backend, + ); + let u1_check = mat_vec_mul(&b, t_hat); + if u1_check != level.u1 { + return Err(HachiError::InvalidProof); + } + let b2 = derive_extendable_comkey_matrix::( + level.config.kappa1, + h_hat.len(), + comkey_seed, + b"labrador/comkey/U2", + backend, + ); + let u2_check = mat_vec_mul(&b2, h_hat); + if u2_check != level.u2 { + return Err(HachiError::InvalidProof); + } + } else { + if level.u1 != t_hat { + return Err(HachiError::InvalidProof); + } + if level.u2 != h_hat { + return Err(HachiError::InvalidProof); + } + } + + let computed_norm = witness.norm(); + if computed_norm > level.norm_sq { + return Err(HachiError::InvalidProof); + } + + if projection_norm_sq(&level.jl_projection) > 128u128.saturating_mul(statement.beta_sq) { + return Err(HachiError::InvalidProof); + } + + let a = derive_extendable_comkey_matrix::( + level.config.kappa, + z.len(), + comkey_seed, + b"labrador/comkey/A", + backend, + ); + let az = mat_vec_mul(&a, &z); + let mut rhs = vec![CyclotomicRing::::zero(); level.config.kappa]; + for (i, t_row) in t_by_row.iter().enumerate() { + let c = challenges[i]; + for k in 0..level.config.kappa { + rhs[k] += c * t_row[k]; + } + } + if az != rhs { + return Err(HachiError::InvalidProof); + } + + let mut combined_phi = vec![CyclotomicRing::::zero(); max_len]; + for (i, phi_row) in phi_total.iter().enumerate() { + let c = challenges[i]; + for (j, elem) in phi_row.iter().enumerate() { + combined_phi[j] += c * *elem; + } + } + let lhs = dot_product(&combined_phi, &z); + let mut rhs = CyclotomicRing::::zero(); + let mut idx = 0usize; + for i in 0..r { + for j in i..r { + rhs += challenges[i] * challenges[j] * h_flat[idx]; + idx += 1; + } + } + if lhs != rhs { + return Err(HachiError::InvalidProof); + } + + let mut diag_sum = CyclotomicRing::::zero(); + for i in 0..r { + let idx = diag_index(i, r); + diag_sum += h_flat[idx]; + } + if diag_sum - b_total != CyclotomicRing::::zero() { + return Err(HachiError::InvalidProof); + } + + Ok(()) +} + +fn diag_index(i: usize, r: usize) -> usize { + i * (2 * r - i + 1) / 2 +} + +fn projection_norm_sq(projection: &[i32; 256]) -> u128 { + projection.iter().fold(0u128, |acc, &v| { + let x = v as i128; + let sq = x * x; + acc.saturating_add(sq as u128) + }) +} + +fn recompose_from_parts( + parts: &[Vec>], + log_basis: u32, +) -> Result>, HachiError> { + if parts.is_empty() { + return Err(HachiError::InvalidProof); + } + let len = parts[0].len(); + for row in parts.iter().skip(1) { + if row.len() != len { + return Err(HachiError::InvalidProof); + } + } + let mut out = Vec::with_capacity(len); + for idx in 0..len { + let mut slice = Vec::with_capacity(parts.len()); + for part in parts { + slice.push(part[idx]); + } + out.push(CyclotomicRing::gadget_recompose_pow2(&slice, log_basis)); + } + Ok(out) +} + +fn recompose_flat( + flat: &[CyclotomicRing], + parts: usize, + log_basis: u32, +) -> Result>, HachiError> { + if parts == 0 || flat.len() % parts != 0 { + return Err(HachiError::InvalidProof); + } + let mut out = Vec::with_capacity(flat.len() / parts); + for chunk in flat.chunks(parts) { + out.push(CyclotomicRing::gadget_recompose_pow2(chunk, log_basis)); + } + Ok(out) +} + +#[allow(clippy::too_many_arguments)] +fn build_next_constraints< + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + const D: usize, +>( + phi_total: &[Vec>], + b_total: &CyclotomicRing, + challenges: &[CyclotomicRing], + row_lengths: &[usize], + max_len: usize, + config: &LabradorReductionConfig, + u1: &[CyclotomicRing], + u2: &[CyclotomicRing], + comkey_seed: &LabradorComKeySeed, + backend: MatrixPrgBackendChoice, +) -> Result>, HachiError> { + let r = row_lengths.len(); + if r == 0 || challenges.len() != r { + return Err(HachiError::InvalidProof); + } + if config.f == 0 { + return Err(HachiError::InvalidProof); + } + + let pow_b: Vec = (0..config.f) + .map(|idx| pow2_field::(config.b * idx)) + .collect(); + let pow_bu: Vec = (0..config.fu) + .map(|idx| pow2_field::(config.bu * idx)) + .collect(); + + let mut combined_phi = vec![CyclotomicRing::::zero(); max_len]; + for (row_idx, row_phi) in phi_total.iter().enumerate() { + let c = challenges[row_idx]; + for (j, elem) in row_phi.iter().enumerate() { + combined_phi[j] += c * *elem; + } + } + + let mut constraints = Vec::new(); + let t_hat_len = r * config.kappa * config.fu; + let h_len = r * (r + 1) / 2; + let h_hat_len = h_len * config.fu; + let aux_row = config.f; + let aux_row_len = t_hat_len + h_hat_len; + let num_rows = config.f + 1; + + if config.kappa1 > 0 { + if u1.len() != config.kappa1 || u2.len() != config.kappa1 { + return Err(HachiError::InvalidProof); + } + + // B · t_hat = u1 + let b = derive_extendable_comkey_matrix::( + config.kappa1, + t_hat_len, + comkey_seed, + b"labrador/comkey/B", + backend, + ); + let mut aux_coeffs = vec![CyclotomicRing::::zero(); config.kappa1 * aux_row_len]; + for (out_idx, b_row) in b.iter().enumerate() { + let start = out_idx * aux_row_len; + for (j, val) in b_row.iter().enumerate() { + aux_coeffs[start + j] = *val; + } + } + let mut coefficients = vec![vec![]; num_rows]; + coefficients[aux_row] = aux_coeffs; + constraints.push(LabradorConstraint { + coefficients, + target: u1.to_vec(), + }); + + // B2 · h_hat = u2 + let b2 = derive_extendable_comkey_matrix::( + config.kappa1, + h_hat_len, + comkey_seed, + b"labrador/comkey/U2", + backend, + ); + let mut aux_coeffs = vec![CyclotomicRing::::zero(); config.kappa1 * aux_row_len]; + for (out_idx, b2_row) in b2.iter().enumerate() { + let start = out_idx * aux_row_len + t_hat_len; + for (j, val) in b2_row.iter().enumerate() { + aux_coeffs[start + j] = *val; + } + } + let mut coefficients = vec![vec![]; num_rows]; + coefficients[aux_row] = aux_coeffs; + constraints.push(LabradorConstraint { + coefficients, + target: u2.to_vec(), + }); + } + + // A·z - c·t = 0 + let a = derive_extendable_comkey_matrix::( + config.kappa, + max_len, + comkey_seed, + b"labrador/comkey/A", + backend, + ); + let mut az_coefficients = vec![vec![]; num_rows]; + for part_idx in 0..config.f { + let scale = pow_b[part_idx]; + let mut coeffs = Vec::with_capacity(config.kappa * max_len); + for a_row in &a { + for elem in a_row.iter() { + coeffs.push(elem.scale(&scale)); + } + } + az_coefficients[part_idx] = coeffs; + } + + let mut t_coeffs = vec![CyclotomicRing::::zero(); config.kappa * t_hat_len]; + for (row_idx, challenge) in challenges.iter().enumerate() { + for (part_idx, &scale) in pow_bu.iter().enumerate() { + let scaled = challenge.scale(&scale); + for k in 0..config.kappa { + let idx = row_idx * config.kappa * config.fu + k * config.fu + part_idx; + let slot = k * t_hat_len + idx; + t_coeffs[slot] = -scaled; + } + } + } + let mut aux_az = vec![CyclotomicRing::::zero(); config.kappa * aux_row_len]; + for k in 0..config.kappa { + let src_start = k * t_hat_len; + let dst_start = k * aux_row_len; + aux_az[dst_start..dst_start + t_hat_len] + .copy_from_slice(&t_coeffs[src_start..src_start + t_hat_len]); + } + az_coefficients[aux_row] = aux_az; + constraints.push(LabradorConstraint { + coefficients: az_coefficients, + target: vec![CyclotomicRing::::zero(); config.kappa], + }); + + // linear garbage constraint + let mut lg_coefficients = vec![vec![]; num_rows]; + for part_idx in 0..config.f { + let scale = pow_b[part_idx]; + let coeffs: Vec> = + combined_phi.iter().map(|elem| elem.scale(&scale)).collect(); + lg_coefficients[part_idx] = coeffs; + } + let mut h_coeffs = vec![CyclotomicRing::::zero(); h_hat_len]; + for i in 0..r { + for j in i..r { + let coeff = challenges[i] * challenges[j]; + let pair = pair_index(i, j, r); + for (part_idx, &scale) in pow_bu.iter().enumerate() { + let idx = pair * config.fu + part_idx; + h_coeffs[idx] = -(coeff.scale(&scale)); + } + } + } + let mut aux_lg = vec![CyclotomicRing::::zero(); aux_row_len]; + aux_lg[t_hat_len..t_hat_len + h_hat_len].copy_from_slice(&h_coeffs); + lg_coefficients[aux_row] = aux_lg; + constraints.push(LabradorConstraint { + coefficients: lg_coefficients, + target: vec![CyclotomicRing::::zero()], + }); + + // diagonal (norm) constraint + let mut diag_coeffs = vec![CyclotomicRing::::zero(); aux_row_len]; + for i in 0..r { + let pair = pair_index(i, i, r); + for (part_idx, &scale) in pow_bu.iter().enumerate() { + let idx = pair * config.fu + part_idx; + diag_coeffs[t_hat_len + idx] = constant_poly(scale); + } + } + let mut diag_coefficients = vec![vec![]; num_rows]; + diag_coefficients[aux_row] = diag_coeffs; + constraints.push(LabradorConstraint { + coefficients: diag_coefficients, + target: vec![*b_total], + }); + + Ok(constraints) +} + +fn pow2_field(exp: usize) -> F { + let two = F::from_u64(2); + let mut acc = F::one(); + for _ in 0..exp { + acc = acc * two; + } + acc +} + +fn constant_poly(value: F) -> CyclotomicRing { + CyclotomicRing::from_coefficients(std::array::from_fn( + |i| { + if i == 0 { + value + } else { + F::zero() + } + }, + )) +} + +fn pair_index(i: usize, j: usize, r: usize) -> usize { + debug_assert!(i <= j && j < r); + i * (2 * r - i + 1) / 2 + (j - i) +} + +fn add_phi_in_place( + acc: &mut [Vec>], + other: &[Vec>], +) -> Result<(), HachiError> { + if acc.len() != other.len() { + return Err(HachiError::InvalidProof); + } + for (row_acc, row_other) in acc.iter_mut().zip(other.iter()) { + if row_acc.len() != row_other.len() { + return Err(HachiError::InvalidProof); + } + for (a, b) in row_acc.iter_mut().zip(row_other.iter()) { + *a += *b; + } + } + Ok(()) +} + +fn dot_product( + lhs: &[CyclotomicRing], + rhs: &[CyclotomicRing], +) -> CyclotomicRing { + let mut acc = CyclotomicRing::::zero(); + let len = lhs.len().min(rhs.len()); + for i in 0..len { + acc += lhs[i] * rhs[i]; + } + acc +} + +#[allow(clippy::type_complexity)] +fn aggregate_statement_constraints( + constraints: &[LabradorConstraint], + row_lengths: &[usize], + transcript: &mut T, +) -> Result<(Vec>>, CyclotomicRing), HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + let mut phi_total: Vec>> = row_lengths + .iter() + .map(|&len| vec![CyclotomicRing::zero(); len]) + .collect(); + let mut b_total = CyclotomicRing::::zero(); + + if constraints.is_empty() { + return Ok((phi_total, b_total)); + } + + for cnst in constraints { + let outputs = cnst.target.len().max(1); + for out_idx in 0..outputs { + let alpha = challenge_ring_element_rejection_sampled( + transcript, + labels::CHALLENGE_LABRADOR_AGGREGATION, + )?; + let target = cnst + .target + .get(out_idx) + .copied() + .unwrap_or_else(CyclotomicRing::::zero); + b_total += alpha * target; + + for (row_idx, coeffs) in cnst.coefficients.iter().enumerate() { + if coeffs.is_empty() { + continue; + } + if row_idx >= phi_total.len() { + return Err(HachiError::InvalidProof); + } + let row_len = coeffs.len() / outputs; + let coeff_start = out_idx * row_len; + let coeff_slice = &coeffs[coeff_start..coeff_start + row_len]; + for (j, coeff) in coeff_slice.iter().enumerate() { + phi_total[row_idx][j] += alpha * *coeff; + } + } + } + } + + Ok((phi_total, b_total)) +} + +fn sample_jl_collapse_challenge(transcript: &mut T) -> [i64; 256] +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + std::array::from_fn(|_| { + let s = transcript.challenge_scalar(labels::CHALLENGE_LABRADOR_JL_COLLAPSE); + let c = s.to_canonical_u128(); + if c > half_q { + -((q - c) as i64) + } else { + c as i64 + } + }) +} + +fn jl_collapse_phi_from_weights( + matrix: &LabradorJlMatrix, + omega: &[i64; 256], +) -> Result>, HachiError> { + if matrix.cols % D != 0 { + return Err(HachiError::InvalidProof); + } + let mut weights = vec![0i64; matrix.cols]; + for (row_idx, row) in matrix.signs.iter().enumerate() { + let alpha = omega[row_idx]; + for (col_idx, &sign) in row.iter().enumerate() { + weights[col_idx] += alpha * (sign as i64); + } + } + + let ring_elems = matrix.cols / D; + let mut phi = Vec::with_capacity(ring_elems); + for idx in 0..ring_elems { + let coeffs = std::array::from_fn(|k| { + let w = weights[idx * D + k]; + F::from_i64(w) + }); + phi.push(CyclotomicRing::from_coefficients(coeffs).sigma_m1()); + } + Ok(phi) +} + +#[allow(clippy::type_complexity)] +fn aggregate_jl_constraints_verifier( + row_lengths: &[usize], + jl_projection: &[i32; 256], + jl_seed: &[u8; 16], + jl_nonce: u64, + bb: &[CyclotomicRing], + backend: MatrixPrgBackendChoice, + transcript: &mut T, +) -> Result<(Vec>>, CyclotomicRing), HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + if bb.len() != JL_LIFTS { + return Err(HachiError::InvalidProof); + } + let total_len: usize = row_lengths.iter().sum(); + let cols = total_len.checked_mul(D).ok_or(HachiError::InvalidProof)?; + if cols == 0 { + return Err(HachiError::InvalidProof); + } + let mut ranges = Vec::with_capacity(row_lengths.len()); + let mut cursor = 0usize; + for &len in row_lengths { + let start = cursor; + cursor += len; + ranges.push((start, cursor)); + } + + let matrix = LabradorJlMatrix::generate(jl_seed, jl_nonce, cols, backend)?; + + let mut phi_total: Vec>> = row_lengths + .iter() + .map(|&len| vec![CyclotomicRing::zero(); len]) + .collect(); + let mut b_total = CyclotomicRing::::zero(); + + for bb_lift in bb.iter() { + let omega = sample_jl_collapse_challenge::(transcript); + let phi_flat = jl_collapse_phi_from_weights::(&matrix, &omega)?; + let target = collapse(jl_projection, &omega); + let b_full = restore_constant_term(*bb_lift, F::from_i64(target)); + transcript.append_serde(labels::ABSORB_LABRADOR_BB, bb_lift); + let beta = challenge_ring_element_rejection_sampled( + transcript, + labels::CHALLENGE_LABRADOR_AGGREGATION, + )?; + b_total += beta * b_full; + for (row_idx, (start, end)) in ranges.iter().enumerate() { + let row = &phi_flat[*start..*end]; + for (j, elem) in row.iter().enumerate() { + phi_total[row_idx][j] += beta * *elem; + } + } + } + + Ok((phi_total, b_total)) +} + +fn verify_constraints( + constraints: &[LabradorConstraint], + witness: &LabradorWitness, +) -> Result<(), HachiError> { + for (idx, cnst) in constraints.iter().enumerate() { + let outputs = cnst.target.len().max(1); + let mut lhs = vec![CyclotomicRing::::zero(); outputs]; + + for (row_idx, coeffs) in cnst.coefficients.iter().enumerate() { + if coeffs.is_empty() { + continue; + } + if row_idx >= witness.rows().len() { + return Err(HachiError::InvalidProof); + } + let row = &witness.rows()[row_idx]; + let row_len = coeffs.len() / outputs; + for (out_idx, lhs_elem) in lhs.iter_mut().enumerate() { + let coeff_start = out_idx * row_len; + let coeff_slice = &coeffs[coeff_start..coeff_start + row_len]; + let mut inner = CyclotomicRing::::zero(); + for (j, coeff) in coeff_slice.iter().enumerate() { + let w_elem = row + .get(j) + .copied() + .unwrap_or_else(CyclotomicRing::::zero); + inner += *coeff * w_elem; + } + *lhs_elem += inner; + } + } + + for (out_idx, lhs_elem) in lhs.iter().enumerate() { + let target = cnst + .target + .get(out_idx) + .copied() + .unwrap_or_else(CyclotomicRing::::zero); + if *lhs_elem != target { + return Err(HachiError::InvalidInput(format!( + "Labrador constraint {idx} not satisfied" + ))); + } + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::algebra::ring::CyclotomicRing; + use crate::protocol::labrador::types::LabradorConstraint; + use crate::protocol::transcript::labels::DOMAIN_LABRADOR_PROTOCOL; + use crate::protocol::transcript::Blake2bTranscript; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + #[test] + fn verify_accepts_basic_linear_constraint() { + let row = vec![CyclotomicRing::::from_coefficients( + std::array::from_fn(|i| if i == 0 { F::from_i64(3) } else { F::zero() }), + )]; + let witness = LabradorWitness::new(vec![row.clone()]); + let coeff = vec![CyclotomicRing::one()]; + let target = vec![CyclotomicRing::::from_coefficients( + std::array::from_fn(|i| if i == 0 { F::from_i64(3) } else { F::zero() }), + )]; + let constraint = LabradorConstraint { + coefficients: vec![coeff], + target, + }; + let statement = LabradorStatement { + u1: Vec::new(), + u2: Vec::new(), + challenges: Vec::new(), + constraints: vec![constraint], + beta_sq: 1000, + hash: [0u8; 16], + }; + let proof = LabradorProof { + levels: Vec::new(), + final_opening_witness: witness.clone(), + }; + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_PROTOCOL); + let out = verify( + &statement, + &proof, + &[1u8; 32], + &[2u8; 16], + MatrixPrgBackendChoice::Shake256, + &mut transcript, + ) + .unwrap(); + assert_eq!(out.final_opening_witness, witness); + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs new file mode 100644 index 00000000..d4081280 --- /dev/null +++ b/src/protocol/mod.rs @@ -0,0 +1,40 @@ +//! Protocol-layer transcript and commitment abstractions. +//! +//! This module defines the Hachi-native protocol interfaces used by higher-level +//! proof logic. It intentionally stays independent from external integration +//! details (for example, Jolt wiring). + +pub mod challenges; +pub mod commitment; +pub mod commitment_scheme; +pub mod dispatch; +pub mod greyhound; +pub mod hachi_poly_ops; +pub mod labrador; +pub mod opening_point; +pub mod prg; +pub mod proof; +pub mod quadratic_equation; +pub mod ring_switch; +pub mod sumcheck; +pub mod transcript; + +pub use commitment::{ + optimal_m_r_split, AppendToTranscript, CommitmentConfig, CommitmentScheme, DummyProof, + DynamicSmallTestCommitmentConfig, Fp128BoundedCommitmentConfig, Fp128CommitmentConfig, + Fp128FullCommitmentConfig, Fp128HalvingDCommitmentConfig, Fp128LogBasisCommitmentConfig, + Fp128OneHotCommitmentConfig, HachiCommitment, HachiCommitmentCore, HachiCommitmentLayout, + HachiExpandedSetup, HachiOpeningClaim, HachiOpeningPoint, HachiProverSetup, HachiSetupSeed, + HachiVerifierSetup, RingCommitment, RingCommitmentScheme, SmallTestCommitmentConfig, +}; +pub use commitment_scheme::HachiCommitmentScheme; +pub use hachi_poly_ops::{DensePoly, HachiPolyOps, OneHotIndex, OneHotPoly}; +pub use opening_point::{BasisMode, RingOpeningPoint}; +pub use proof::{FlatCommitmentHint, FlatRingVec, HachiLevelProof, HachiProof, PackedDigits}; +pub use quadratic_equation::QuadraticEquation; +pub use sumcheck::batched_sumcheck::{prove_batched_sumcheck, verify_batched_sumcheck}; +pub use sumcheck::{ + prove_sumcheck, verify_sumcheck, CompressedUniPoly, SumcheckInstanceProver, + SumcheckInstanceVerifier, SumcheckProof, UniPoly, +}; +pub use transcript::{sample_ext_challenge, Blake2bTranscript, KeccakTranscript, Transcript}; diff --git a/src/protocol/opening_point.rs b/src/protocol/opening_point.rs new file mode 100644 index 00000000..4d80563d --- /dev/null +++ b/src/protocol/opening_point.rs @@ -0,0 +1,41 @@ +//! Ring-native opening point for the Hachi protocol. + +use crate::FieldCore; + +/// Polynomial basis mode for the evaluation relation. +/// +/// Determines how the polynomial's values are interpreted during an opening +/// proof. The commitment itself is basis-agnostic; the basis only affects +/// the tensor-product weights used in `prove` and `verify`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BasisMode { + /// Evaluations over the boolean hypercube. + /// + /// The weight vector is `⊗ᵢ (1 − xᵢ, xᵢ)` (multilinear Lagrange basis). + /// Use when the committed values are `f(b)` for `b ∈ {0,1}^n`. + Lagrange, + + /// Coefficients of multilinear monomials. + /// + /// The weight vector is `⊗ᵢ (1, xᵢ)`. + /// Use when the committed values are the coefficients `c_S` such that + /// `f(x) = Σ_S c_S · ∏_{i ∈ S} x_i`. + Monomial, +} + +/// Ring-native opening point storing field scalars. +/// +/// Contains the two vectors used by the §4.2 prover: +/// - `a`: evaluation vector of length `2^m` (inner-block coordinates). +/// - `b`: block-select vector of length `2^r` (outer coordinates). +/// +/// These are raw field scalars, not ring elements — they originate from +/// basis weight evaluations (Lagrange or monomial) and are always constant +/// (scalar) ring elements when embedded into the ring. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RingOpeningPoint { + /// Evaluation vector of length `2^m` (field scalars). + pub a: Vec, + /// Block-select vector of length `2^r` (field scalars). + pub b: Vec, +} diff --git a/src/protocol/prg.rs b/src/protocol/prg.rs new file mode 100644 index 00000000..f1eb6581 --- /dev/null +++ b/src/protocol/prg.rs @@ -0,0 +1,361 @@ +//! Matrix PRG backends shared by commitment/JL derivation. +//! +//! The PRG is keyed per matrix entry using domain-separated context bytes. + +use aes::Aes128; +use ctr::cipher::{KeyIvInit, StreamCipher}; +use rand_core::{CryptoRng, RngCore}; +use sha3::digest::{ExtendableOutput, Update, XofReader}; +use sha3::Shake256; + +const MATRIX_PRG_DOMAIN: &[u8] = b"hachi/matrix-prg"; +const MATRIX_PRG_SHAKE_DOMAIN: &[u8] = b"hachi/matrix-prg/shake256"; +const MATRIX_PRG_AES_DOMAIN: &[u8] = b"hachi/matrix-prg/aes128-ctr"; + +/// Stable backend identifiers for transcript/context binding. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum MatrixPrgBackendId { + /// SHAKE256 XOF backend. + Shake256 = 0, + /// AES-128-CTR backend. + Aes128Ctr = 1, +} + +impl TryFrom for MatrixPrgBackendId { + type Error = crate::error::HachiError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(Self::Shake256), + 1 => Ok(Self::Aes128Ctr), + _ => Err(crate::error::HachiError::InvalidInput(format!( + "unknown matrix PRG backend id: {value}" + ))), + } + } +} + +impl From for u8 { + fn from(value: MatrixPrgBackendId) -> Self { + value as u8 + } +} + +/// Input context used for deterministic matrix-entry sampling. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MatrixPrgContext<'a> { + /// Public seed. + pub seed: &'a [u8; 32], + /// Matrix label (`A`, `B`, `D`, etc.). + pub matrix_label: &'a [u8], + /// Matrix row count. + pub rows: usize, + /// Matrix column count. + pub cols: usize, + /// Matrix-entry row index. + pub row: usize, + /// Matrix-entry column index. + pub col: usize, +} + +/// Backend trait for matrix-entry PRG streams. +pub trait MatrixPrgBackend: Clone + Send + Sync + 'static { + /// Stable backend identifier. + fn backend_id(&self) -> MatrixPrgBackendId; + /// Construct a stream RNG for one matrix entry. + fn entry_rng(&self, context: &MatrixPrgContext<'_>) -> MatrixPrgRng; +} + +/// Runtime backend selector. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MatrixPrgBackendChoice { + /// SHAKE256 XOF stream. + Shake256, + /// AES-128-CTR stream. + Aes128Ctr, +} + +impl MatrixPrgBackendChoice { + /// Return the stable backend id. + pub fn backend_id(self) -> MatrixPrgBackendId { + match self { + Self::Shake256 => MatrixPrgBackendId::Shake256, + Self::Aes128Ctr => MatrixPrgBackendId::Aes128Ctr, + } + } + + /// Construct a stream RNG for one matrix entry. + pub fn entry_rng(self, context: &MatrixPrgContext<'_>) -> MatrixPrgRng { + match self { + Self::Shake256 => Shake256Backend.entry_rng(context), + Self::Aes128Ctr => Aes128CtrBackend.entry_rng(context), + } + } +} + +impl Default for MatrixPrgBackendChoice { + fn default() -> Self { + Self::Shake256 + } +} + +/// SHAKE256 backend implementation. +#[derive(Debug, Clone, Copy, Default)] +pub struct Shake256Backend; + +impl MatrixPrgBackend for Shake256Backend { + fn backend_id(&self) -> MatrixPrgBackendId { + MatrixPrgBackendId::Shake256 + } + + fn entry_rng(&self, context: &MatrixPrgContext<'_>) -> MatrixPrgRng { + MatrixPrgRng::Shake(ShakeEntryRng::new(context)) + } +} + +/// AES-128-CTR backend implementation. +#[derive(Debug, Clone, Copy, Default)] +pub struct Aes128CtrBackend; + +impl MatrixPrgBackend for Aes128CtrBackend { + fn backend_id(&self) -> MatrixPrgBackendId { + MatrixPrgBackendId::Aes128Ctr + } + + fn entry_rng(&self, context: &MatrixPrgContext<'_>) -> MatrixPrgRng { + let (key, iv) = derive_aes_key_iv(context); + // On aarch64, the `aes` crate uses target-feature intrinsics when + // available; we still gate this branch for explicit architecture intent. + #[cfg(target_arch = "aarch64")] + { + if std::arch::is_aarch64_feature_detected!("aes") { + return MatrixPrgRng::AesCtr(Aes128CtrEntryRng::new(&key, &iv)); + } + } + // TODO(x86_64): add explicit AES-NI runtime path selection once CI has + // dedicated hardware coverage. Today we use the `aes` crate default. + #[cfg(target_arch = "x86_64")] + { + let _ = std::arch::is_x86_feature_detected!("aes"); + } + MatrixPrgRng::AesCtr(Aes128CtrEntryRng::new(&key, &iv)) + } +} + +/// Matrix-entry RNG wrapper over supported PRG backends. +#[allow(clippy::large_enum_variant)] +pub enum MatrixPrgRng { + /// SHAKE256 XOF-backed RNG. + Shake(ShakeEntryRng), + /// AES-128-CTR-backed RNG. + AesCtr(Aes128CtrEntryRng), +} + +impl RngCore for MatrixPrgRng { + fn next_u32(&mut self) -> u32 { + let mut buf = [0u8; 4]; + self.fill_bytes(&mut buf); + u32::from_le_bytes(buf) + } + + fn next_u64(&mut self) -> u64 { + let mut buf = [0u8; 8]; + self.fill_bytes(&mut buf); + u64::from_le_bytes(buf) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + match self { + Self::Shake(rng) => rng.fill_bytes(dest), + Self::AesCtr(rng) => rng.fill_bytes(dest), + } + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +impl CryptoRng for MatrixPrgRng {} + +/// SHAKE256-backed matrix-entry RNG. +pub struct ShakeEntryRng { + reader: Box, +} + +impl ShakeEntryRng { + fn new(context: &MatrixPrgContext<'_>) -> Self { + let mut xof = Shake256::default(); + absorb_matrix_context(&mut xof, MATRIX_PRG_SHAKE_DOMAIN, context); + Self { + reader: Box::new(xof.finalize_xof()), + } + } +} + +impl RngCore for ShakeEntryRng { + fn next_u32(&mut self) -> u32 { + let mut buf = [0u8; 4]; + self.fill_bytes(&mut buf); + u32::from_le_bytes(buf) + } + + fn next_u64(&mut self) -> u64 { + let mut buf = [0u8; 8]; + self.fill_bytes(&mut buf); + u64::from_le_bytes(buf) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.reader.read(dest); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +impl CryptoRng for ShakeEntryRng {} + +type AesCtrCipher = ctr::Ctr128BE; + +/// AES-128-CTR-backed matrix-entry RNG. +pub struct Aes128CtrEntryRng { + cipher: AesCtrCipher, +} + +impl Aes128CtrEntryRng { + fn new(key: &[u8; 16], iv: &[u8; 16]) -> Self { + Self { + cipher: AesCtrCipher::new(key.into(), iv.into()), + } + } +} + +impl RngCore for Aes128CtrEntryRng { + fn next_u32(&mut self) -> u32 { + let mut buf = [0u8; 4]; + self.fill_bytes(&mut buf); + u32::from_le_bytes(buf) + } + + fn next_u64(&mut self) -> u64 { + let mut buf = [0u8; 8]; + self.fill_bytes(&mut buf); + u64::from_le_bytes(buf) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + dest.fill(0u8); + self.cipher.apply_keystream(dest); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +impl CryptoRng for Aes128CtrEntryRng {} + +fn derive_aes_key_iv(context: &MatrixPrgContext<'_>) -> ([u8; 16], [u8; 16]) { + let mut xof = Shake256::default(); + absorb_matrix_context(&mut xof, MATRIX_PRG_AES_DOMAIN, context); + let mut out = [0u8; 32]; + xof.finalize_xof().read(&mut out); + let key: [u8; 16] = out[..16].try_into().expect("XOF produced 32 bytes"); + let iv: [u8; 16] = out[16..].try_into().expect("XOF produced 32 bytes"); + (key, iv) +} + +fn absorb_matrix_context( + xof: &mut Shake256, + backend_domain: &[u8], + context: &MatrixPrgContext<'_>, +) { + absorb_len_prefixed(xof, b"domain", MATRIX_PRG_DOMAIN); + absorb_len_prefixed(xof, b"backend", backend_domain); + absorb_len_prefixed(xof, b"seed", context.seed); + absorb_len_prefixed(xof, b"matrix", context.matrix_label); + absorb_len_prefixed(xof, b"rows", &(context.rows as u64).to_le_bytes()); + absorb_len_prefixed(xof, b"cols", &(context.cols as u64).to_le_bytes()); + absorb_len_prefixed(xof, b"row", &(context.row as u64).to_le_bytes()); + absorb_len_prefixed(xof, b"col", &(context.col as u64).to_le_bytes()); +} + +fn absorb_len_prefixed(xof: &mut Shake256, label: &[u8], data: &[u8]) { + xof.update(&(label.len() as u64).to_le_bytes()); + xof.update(label); + xof.update(&(data.len() as u64).to_le_bytes()); + xof.update(data); +} + +#[cfg(test)] +mod tests { + use super::*; + + fn context<'a>(seed: &'a [u8; 32], row: usize, col: usize) -> MatrixPrgContext<'a> { + MatrixPrgContext { + seed, + matrix_label: b"A", + rows: 4, + cols: 5, + row, + col, + } + } + + #[test] + fn shake_backend_is_deterministic() { + let seed = [42u8; 32]; + let ctx = context(&seed, 1, 3); + let mut rng1 = Shake256Backend.entry_rng(&ctx); + let mut rng2 = Shake256Backend.entry_rng(&ctx); + let mut a = [0u8; 96]; + let mut b = [0u8; 96]; + rng1.fill_bytes(&mut a); + rng2.fill_bytes(&mut b); + assert_eq!(a, b); + } + + #[test] + fn aes_backend_is_deterministic() { + let seed = [7u8; 32]; + let ctx = context(&seed, 0, 2); + let mut rng1 = Aes128CtrBackend.entry_rng(&ctx); + let mut rng2 = Aes128CtrBackend.entry_rng(&ctx); + let mut a = [0u8; 96]; + let mut b = [0u8; 96]; + rng1.fill_bytes(&mut a); + rng2.fill_bytes(&mut b); + assert_eq!(a, b); + } + + #[test] + fn row_col_changes_separate_streams() { + let seed = [9u8; 32]; + let mut rng_a = Shake256Backend.entry_rng(&context(&seed, 0, 0)); + let mut rng_b = Shake256Backend.entry_rng(&context(&seed, 0, 1)); + let mut a = [0u8; 64]; + let mut b = [0u8; 64]; + rng_a.fill_bytes(&mut a); + rng_b.fill_bytes(&mut b); + assert_ne!(a, b); + } + + #[test] + fn backend_choice_changes_stream() { + let seed = [5u8; 32]; + let ctx = context(&seed, 2, 4); + let mut shake = MatrixPrgBackendChoice::Shake256.entry_rng(&ctx); + let mut aes = MatrixPrgBackendChoice::Aes128Ctr.entry_rng(&ctx); + let mut a = [0u8; 64]; + let mut b = [0u8; 64]; + shake.fill_bytes(&mut a); + aes.fill_bytes(&mut b); + assert_ne!(a, b); + } +} diff --git a/src/protocol/proof.rs b/src/protocol/proof.rs new file mode 100644 index 00000000..ea085175 --- /dev/null +++ b/src/protocol/proof.rs @@ -0,0 +1,677 @@ +//! Proof structures for the Hachi protocol. + +use crate::algebra::CyclotomicRing; +use crate::primitives::serialization::{Compress, SerializationError}; +use crate::primitives::serialization::{Valid, Validate}; +use crate::protocol::commitment::RingCommitment; +use crate::protocol::sumcheck::SumcheckProof; +use crate::{FieldCore, FromSmallInt, HachiDeserialize, HachiSerialize}; +use std::io::{Read, Write}; +use std::marker::PhantomData; + +/// Bit-packed balanced digits for the final-level witness vector. +/// +/// Each element is a signed value in `[-b/2, b/2)` where `b = 2^bits_per_elem`, +/// stored in two's-complement using exactly `bits_per_elem` bits per value. +/// This reduces proof size by ~32x compared to storing full field elements. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PackedDigits { + /// Number of logical elements. + pub num_elems: usize, + /// Bits per element (= `log_basis` from the commitment config). + pub bits_per_elem: u32, + /// Bit-packed two's-complement data. + pub data: Vec, +} + +/// Precomputed lookup table mapping balanced digit index → field element. +/// +/// Wraps `FromSmallInt::digit_lut` with convenient signed-digit indexing. +/// Index a digit `d ∈ [-b/2, b/2)` via [`get`](DigitLut::get). +pub(crate) struct DigitLut { + table: [F; 16], + half_b: i8, +} + +impl DigitLut { + #[inline] + pub(crate) fn new(log_basis: u32) -> Self { + let half_b = 1i8 << (log_basis - 1); + Self { + table: F::digit_lut(log_basis), + half_b, + } + } + + #[inline(always)] + pub(crate) fn get(&self, d: i8) -> F { + self.table[(d + self.half_b) as usize] + } +} + +impl PackedDigits { + /// Pack balanced i8 digits into bit-packed form. + /// + /// Each element must be in `[-b/2, b/2)` where `b = 2^log_basis`. + /// + /// # Panics + /// + /// Panics (in debug) if any element does not fit in `log_basis` bits. + pub fn from_i8_digits(w: &[i8], log_basis: u32) -> Self { + assert!(log_basis > 0 && log_basis <= 7, "log_basis out of range"); + let half_b = 1i8 << (log_basis - 1); + + let bits = log_basis as usize; + let total_bits = w.len() * bits; + let num_bytes = total_bits.div_ceil(8); + let mut data = vec![0u8; num_bytes]; + + for (i, &signed) in w.iter().enumerate() { + debug_assert!( + signed >= -half_b && signed < half_b, + "digit {signed} out of range for log_basis={log_basis}" + ); + let unsigned = (signed as u8) & ((1u8 << bits) - 1); + let bit_offset = i * bits; + let byte_idx = bit_offset / 8; + let bit_idx = bit_offset % 8; + data[byte_idx] |= unsigned << bit_idx; + if bit_idx + bits > 8 { + data[byte_idx + 1] |= unsigned >> (8 - bit_idx); + } + } + + Self { + num_elems: w.len(), + bits_per_elem: log_basis, + data, + } + } + + /// Unpack to field elements using a precomputed lookup table. + pub fn to_field_elems(&self) -> Vec { + let bits = self.bits_per_elem as usize; + let mask = (1u8 << bits) - 1; + let sign_bit = 1u8 << (bits - 1); + let lut = DigitLut::::new(self.bits_per_elem); + + let mut out = Vec::with_capacity(self.num_elems); + for i in 0..self.num_elems { + let bit_offset = i * bits; + let byte_idx = bit_offset / 8; + let bit_idx = bit_offset % 8; + let mut raw = (self.data[byte_idx] >> bit_idx) & mask; + if bit_idx + bits > 8 { + raw |= (self.data[byte_idx + 1] << (8 - bit_idx)) & mask; + } + let signed = if raw & sign_bit != 0 { + raw as i8 | !(mask as i8) + } else { + raw as i8 + }; + out.push(lut.get(signed)); + } + out + } + + /// Number of packed data bytes. + pub fn packed_byte_len(&self) -> usize { + self.data.len() + } +} + +impl HachiSerialize for PackedDigits { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (self.num_elems as u64).serialize_with_mode(&mut writer, compress)?; + (self.bits_per_elem as u8).serialize_with_mode(&mut writer, compress)?; + writer.write_all(&self.data)?; + Ok(()) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 8 + 1 + self.data.len() + } +} + +impl Valid for PackedDigits { + fn check(&self) -> Result<(), SerializationError> { + if self.bits_per_elem == 0 || self.bits_per_elem > 7 { + return Err(SerializationError::InvalidData( + "bits_per_elem out of range".to_string(), + )); + } + let expected_bytes = (self.num_elems * self.bits_per_elem as usize).div_ceil(8); + if self.data.len() != expected_bytes { + return Err(SerializationError::InvalidData( + "packed data length mismatch".to_string(), + )); + } + Ok(()) + } +} + +impl HachiDeserialize for PackedDigits { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let num_elems = u64::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let bits_per_elem = u8::deserialize_with_mode(&mut reader, compress, validate)? as u32; + let num_bytes = (num_elems * bits_per_elem as usize).div_ceil(8); + let mut data = vec![0u8; num_bytes]; + reader.read_exact(&mut data)?; + let out = Self { + num_elems, + bits_per_elem, + data, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// D-erased storage for a sequence of ring elements as raw field-element +/// coefficients. +/// +/// Each ring element of dimension `ring_dim` is stored as `ring_dim` +/// contiguous field elements in `coeffs`. The total number of ring elements +/// is `coeffs.len() / ring_dim`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FlatRingVec { + coeffs: Vec, + ring_dim: usize, +} + +impl FlatRingVec { + /// Wrap a single ring element. + pub fn from_single(r: &CyclotomicRing) -> Self { + Self { + coeffs: r.coefficients().to_vec(), + ring_dim: D, + } + } + + /// Wrap a slice of ring elements. + pub fn from_ring_elems(elems: &[CyclotomicRing]) -> Self { + let mut coeffs = Vec::with_capacity(elems.len() * D); + for e in elems { + coeffs.extend_from_slice(e.coefficients()); + } + Self { + coeffs, + ring_dim: D, + } + } + + /// Wrap a `RingCommitment`. + pub fn from_commitment(c: &RingCommitment) -> Self { + Self::from_ring_elems(&c.u) + } + + /// Ring dimension (number of field-element coefficients per ring element). + pub fn ring_dim(&self) -> usize { + self.ring_dim + } + + /// Number of ring elements stored. + pub fn count(&self) -> usize { + if self.ring_dim == 0 { + 0 + } else { + self.coeffs.len() / self.ring_dim + } + } + + /// Raw coefficient slice. + pub fn coeffs(&self) -> &[F] { + &self.coeffs + } + + /// Reconstruct a single ring element. + /// + /// # Panics + /// + /// Panics if `D != ring_dim` or `count() != 1`. + pub fn to_single(&self) -> CyclotomicRing { + assert_eq!(D, self.ring_dim, "D mismatch in to_single"); + assert_eq!(self.count(), 1, "expected exactly one ring element"); + CyclotomicRing::from_slice(&self.coeffs) + } + + /// Reconstruct a vector of ring elements. + /// + /// # Panics + /// + /// Panics if `D != ring_dim`. + pub fn to_vec(&self) -> Vec> { + assert_eq!(D, self.ring_dim, "D mismatch in to_vec"); + self.coeffs + .chunks_exact(D) + .map(CyclotomicRing::from_slice) + .collect() + } + + /// Reconstruct a `RingCommitment`. + /// + /// # Panics + /// + /// Panics if `D != ring_dim`. + pub fn to_ring_commitment(&self) -> RingCommitment { + RingCommitment { u: self.to_vec() } + } +} + +impl HachiSerialize for FlatRingVec { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (self.ring_dim as u32).serialize_with_mode(&mut writer, compress)?; + self.coeffs.serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + 4 + self.coeffs.serialized_size(compress) + } +} + +impl Valid for FlatRingVec { + fn check(&self) -> Result<(), SerializationError> { + if self.ring_dim == 0 { + return Err(SerializationError::InvalidData( + "ring_dim must be > 0".to_string(), + )); + } + if self.coeffs.len() % self.ring_dim != 0 { + return Err(SerializationError::InvalidData( + "coeffs length not a multiple of ring_dim".to_string(), + )); + } + Ok(()) + } +} + +impl HachiDeserialize for FlatRingVec { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let ring_dim = u32::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let coeffs = Vec::::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { coeffs, ring_dim }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// D-erased commitment hint for cross-level storage. +/// +/// Stores the decomposed `t̂_i` blocks as a flat `Vec` with metadata +/// about block sizes and ring dimension. Convert to/from the typed +/// [`HachiCommitmentHint`] via [`from_typed`](Self::from_typed) and +/// [`to_typed`](Self::to_typed). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FlatCommitmentHint { + data: Vec, + block_sizes: Vec, + ring_dim: usize, +} + +impl FlatCommitmentHint { + /// Convert from a typed hint, consuming it. + pub fn from_typed(hint: HachiCommitmentHint) -> Self { + let block_sizes: Vec = hint.t_hat.iter().map(|b| b.len()).collect(); + let total_planes: usize = block_sizes.iter().sum(); + let mut data = Vec::with_capacity(total_planes * D); + for block in &hint.t_hat { + for plane in block { + data.extend_from_slice(plane); + } + } + Self { + data, + block_sizes, + ring_dim: D, + } + } + + /// Reconstruct a typed hint. + /// + /// # Panics + /// + /// Panics if `D != ring_dim`. + pub fn to_typed(&self) -> HachiCommitmentHint { + assert_eq!(D, self.ring_dim, "D mismatch in to_typed"); + let mut t_hat = Vec::with_capacity(self.block_sizes.len()); + let mut offset = 0; + for &block_size in &self.block_sizes { + let mut block = Vec::with_capacity(block_size); + for _ in 0..block_size { + let mut plane = [0i8; D]; + plane.copy_from_slice(&self.data[offset..offset + D]); + offset += D; + block.push(plane); + } + t_hat.push(block); + } + HachiCommitmentHint::new(t_hat) + } + + /// Ring dimension stored in this hint. + pub fn ring_dim(&self) -> usize { + self.ring_dim + } + + /// Empty hint (verifier side, where hint data is not available). + pub fn empty() -> Self { + Self { + data: Vec::new(), + block_sizes: Vec::new(), + ring_dim: 0, + } + } +} + +/// Prover-side hint produced at commitment time. +/// +/// Contains the decomposed inner-Ajtai outputs `t̂_i` needed by the +/// ring-switch step of the prover. The polynomial itself (ring coefficients) +/// is passed separately to `prove` via `HachiPolyOps`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiCommitmentHint { + /// Decomposed `t̂_i` blocks from the commitment phase as i8 digit planes. + pub t_hat: Vec>, + _marker: PhantomData, +} + +impl HachiCommitmentHint { + /// Construct a new hint from i8 digit plane blocks. + pub fn new(t_hat: Vec>) -> Self { + Self { + t_hat, + _marker: PhantomData, + } + } +} + +/// Proof for a single fold level (quad_eq + ring_switch + sumcheck). +/// +/// D-agnostic: ring elements are stored as [`FlatRingVec`] with their +/// ring dimension recorded. Use [`y_ring_typed`](Self::y_ring_typed), +/// [`v_typed`](Self::v_typed), and +/// [`w_commitment_typed`](Self::w_commitment_typed) to reconstruct +/// typed ring elements. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiLevelProof { + /// `y_ring` from the §3.1 reduction (ring dim = current level's D). + pub y_ring: FlatRingVec, + /// `v = D · ŵ` (ring dim = current level's D). + pub v: FlatRingVec, + /// Batched sumcheck proof (F_0 norm + F_α relation, §4.3). + pub sumcheck_proof: SumcheckProof, + /// Commitment to the sumcheck witness `w` + /// (ring dim = next level's D, may differ from y_ring/v). + pub w_commitment: FlatRingVec, + /// Claimed evaluation of w at the sumcheck challenge point. + pub w_eval: F, +} + +impl HachiLevelProof { + /// Construct from typed ring elements for the current level and a + /// pre-erased `FlatRingVec` for the w-commitment (which may be at a + /// different D). + pub fn new( + y_ring: CyclotomicRing, + v: Vec>, + sumcheck_proof: SumcheckProof, + w_commitment: FlatRingVec, + w_eval: F, + ) -> Self { + Self { + y_ring: FlatRingVec::from_single(&y_ring), + v: FlatRingVec::from_ring_elems(&v), + sumcheck_proof, + w_commitment, + w_eval, + } + } + + /// Ring dimension of y_ring and v (current level). + pub fn level_d(&self) -> usize { + self.y_ring.ring_dim() + } + + /// Ring dimension of the w_commitment (next level). + pub fn w_commit_d(&self) -> usize { + self.w_commitment.ring_dim() + } + + /// Reconstruct typed `y_ring`. + /// + /// # Panics + /// + /// Panics if `D` does not match the stored ring dimension. + pub fn y_ring_typed(&self) -> CyclotomicRing { + self.y_ring.to_single() + } + + /// Reconstruct typed `v`. + /// + /// # Panics + /// + /// Panics if `D` does not match the stored ring dimension. + pub fn v_typed(&self) -> Vec> { + self.v.to_vec() + } + + /// Reconstruct typed `w_commitment`. + /// + /// # Panics + /// + /// Panics if `D` does not match the stored ring dimension. + pub fn w_commitment_typed(&self) -> RingCommitment { + self.w_commitment.to_ring_commitment() + } +} + +/// Hachi PCS proof with multi-level folding. +/// +/// Each level runs the full protocol (quadratic equation, ring switch, +/// sumcheck) on the previous level's witness `w`. The final level sends +/// `w` directly for the verifier to check, packed as balanced digits. +/// +/// D-agnostic: per-level ring dimensions are recorded in each +/// [`HachiLevelProof`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiProof { + /// Per-level proofs, from the original polynomial (level 0) through + /// recursive w-openings. + pub levels: Vec>, + /// The witness vector at the deepest fold level, bit-packed as balanced + /// digits in `[-b/2, b/2)`. Use [`PackedDigits::to_field_elems`] to + /// reconstruct `Vec`. + pub final_w: PackedDigits, +} + +impl HachiProof { + /// Returns the proof size in bytes (uncompressed). + pub fn size(&self) -> usize { + let levels_size: usize = self + .levels + .iter() + .map(|lp| { + lp.y_ring.serialized_size(Compress::No) + + lp.v.serialized_size(Compress::No) + + lp.sumcheck_proof.serialized_size(Compress::No) + + lp.w_commitment.serialized_size(Compress::No) + + lp.w_eval.serialized_size(Compress::No) + }) + .sum(); + levels_size + self.final_w.serialized_size(Compress::No) + } +} + +impl HachiSerialize for HachiCommitmentHint { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (self.t_hat.len() as u64).serialize_with_mode(&mut writer, compress)?; + for block in &self.t_hat { + (block.len() as u64).serialize_with_mode(&mut writer, compress)?; + for plane in block { + let bytes: &[u8] = + unsafe { std::slice::from_raw_parts(plane.as_ptr().cast::(), D) }; + writer.write_all(bytes)?; + } + } + Ok(()) + } + fn serialized_size(&self, _compress: Compress) -> usize { + 8 + self + .t_hat + .iter() + .map(|block| 8 + block.len() * D) + .sum::() + } +} + +impl Valid for HachiCommitmentHint { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl HachiDeserialize for HachiCommitmentHint { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let num_blocks = u64::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let mut t_hat = Vec::with_capacity(num_blocks); + for _ in 0..num_blocks { + let block_len = u64::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let mut block = Vec::with_capacity(block_len); + for _ in 0..block_len { + let mut plane = [0i8; D]; + let bytes: &mut [u8] = + unsafe { std::slice::from_raw_parts_mut(plane.as_mut_ptr().cast::(), D) }; + reader.read_exact(bytes)?; + block.push(plane); + } + t_hat.push(block); + } + Ok(Self::new(t_hat)) + } +} + +impl HachiSerialize for HachiLevelProof { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.y_ring.serialize_with_mode(&mut writer, compress)?; + self.v.serialize_with_mode(&mut writer, compress)?; + self.sumcheck_proof + .serialize_with_mode(&mut writer, compress)?; + self.w_commitment + .serialize_with_mode(&mut writer, compress)?; + self.w_eval.serialize_with_mode(&mut writer, compress) + } + fn serialized_size(&self, compress: Compress) -> usize { + self.y_ring.serialized_size(compress) + + self.v.serialized_size(compress) + + self.sumcheck_proof.serialized_size(compress) + + self.w_commitment.serialized_size(compress) + + self.w_eval.serialized_size(compress) + } +} + +impl Valid for HachiLevelProof { + fn check(&self) -> Result<(), SerializationError> { + self.y_ring.check()?; + self.v.check()?; + self.w_commitment.check() + } +} + +impl HachiDeserialize for HachiLevelProof { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + Ok(Self { + y_ring: FlatRingVec::deserialize_with_mode(&mut reader, compress, validate)?, + v: FlatRingVec::deserialize_with_mode(&mut reader, compress, validate)?, + sumcheck_proof: SumcheckProof::deserialize_with_mode(&mut reader, compress, validate)?, + w_commitment: FlatRingVec::deserialize_with_mode(&mut reader, compress, validate)?, + w_eval: F::deserialize_with_mode(&mut reader, compress, validate)?, + }) + } +} + +impl HachiSerialize for HachiProof { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (self.levels.len() as u32).serialize_with_mode(&mut writer, compress)?; + for level in &self.levels { + level.serialize_with_mode(&mut writer, compress)?; + } + self.final_w.serialize_with_mode(&mut writer, compress) + } + fn serialized_size(&self, compress: Compress) -> usize { + 4 + self + .levels + .iter() + .map(|l| l.serialized_size(compress)) + .sum::() + + self.final_w.serialized_size(compress) + } +} + +impl Valid for HachiProof { + fn check(&self) -> Result<(), SerializationError> { + for lp in &self.levels { + lp.check()?; + } + self.final_w.check() + } +} + +impl HachiDeserialize for HachiProof { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let num_levels = u32::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let mut levels = Vec::with_capacity(num_levels); + for _ in 0..num_levels { + levels.push(HachiLevelProof::deserialize_with_mode( + &mut reader, + compress, + validate, + )?); + } + let final_w = PackedDigits::deserialize_with_mode(&mut reader, compress, validate)?; + Ok(Self { levels, final_w }) + } +} diff --git a/src/protocol/quadratic_equation.rs b/src/protocol/quadratic_equation.rs new file mode 100644 index 00000000..ff00b01e --- /dev/null +++ b/src/protocol/quadratic_equation.rs @@ -0,0 +1,944 @@ +//! Quadratic equation builder for the Hachi PCS (§4.2). +//! +//! This module encapsulates the stage-1 prover logic and the generation of +//! the quadratic equation components M, y, z, and v. + +use crate::algebra::{CyclotomicRing, SparseChallenge, SparseChallengeConfig}; +#[cfg(any(test, debug_assertions))] +use crate::cfg_into_iter; +use crate::error::HachiError; +#[cfg(all(feature = "parallel", any(test, debug_assertions)))] +use crate::parallel::*; +use crate::protocol::challenges::sparse::sample_sparse_challenges; +use crate::protocol::commitment::utils::crt_ntt::NttSlotCache; +use crate::protocol::commitment::utils::linear::{ + flatten_i8_blocks, mat_vec_mul_ntt_single_i8, unreduced_quotient_rows_ntt_cached, + unreduced_quotient_rows_ntt_cached_i8, +}; +use crate::protocol::commitment::utils::norm::{detect_field_modulus, vec_inf_norm}; +use crate::protocol::commitment::{ + CommitmentConfig, HachiCommitmentLayout, HachiExpandedSetup, RingCommitment, +}; +use crate::protocol::hachi_poly_ops::HachiPolyOps; +use crate::protocol::opening_point::RingOpeningPoint; +use crate::protocol::proof::HachiCommitmentHint; +#[cfg(any(test, debug_assertions))] +use crate::protocol::ring_switch::eval_ring_at; +use crate::protocol::transcript::labels::{ABSORB_PROVER_V, CHALLENGE_STAGE1_FOLD}; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; +use std::iter::repeat_n; +use std::marker::PhantomData; +use std::time::Instant; + +/// **Step 4.** Compute `v = D · ŵ` (first prover message). +fn compute_v( + ntt_d: &NttSlotCache, + w_hat_flat: &[[i8; D]], +) -> Vec> { + mat_vec_mul_ntt_single_i8(ntt_d, w_hat_flat) +} + +fn flatten_w_hat(w_hat: &[Vec<[i8; D]>]) -> Vec<[i8; D]> { + w_hat.iter().flat_map(|v| v.iter().copied()).collect() +} + +/// **Steps 7–9.** Fold `z_pre = Σ c_i · s_i` and check `‖z_pre‖_∞ ≤ β`. +/// +/// Uses `HachiPolyOps::decompose_fold` to carry out the decompose + fold +/// in whatever way the polynomial implementation prefers. +fn compute_z_pre( + poly: &P, + challenges: &[SparseChallenge], + layout: HachiCommitmentLayout, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, + P: HachiPolyOps, +{ + let z = poly.decompose_fold( + challenges, + layout.block_len, + layout.num_digits_commit, + layout.log_basis, + ); + + let modulus = detect_field_modulus::(); + let norm = vec_inf_norm(&z, modulus); + let beta = Cfg::beta_bound(layout)?; + if norm > beta { + return Err(HachiError::InvalidInput(format!( + "prover abort: ||z||_inf = {norm} > beta = {beta}" + ))); + } + + Ok(z) +} + +/// Stage-1 quadratic equation state for the Hachi protocol. +/// +/// Encapsulates the relation $M(x) \cdot z = y(x) + (X^D + 1) \cdot r(x)$ +/// along with intermediate prover witness data (`w_hat`, `z_pre`, `hint`). +/// +/// M and z are never materialized on the hot path — split-eq factoring computes +/// their products on-the-fly via `compute_r_split_eq`, while debug/test code +/// can reconstruct reference `M_a` rows when needed. +pub struct QuadraticEquation { + /// Stage-1 proof vector `v = D · ŵ`. + pub v: Vec>, + /// Stage-1 folding challenges (sparse representation). + pub challenges: Vec, + /// Vector `y`. + y: Vec>, + /// Opening point (a, b) Lagrange weights. + opening_point: RingOpeningPoint, + /// Pre-decomposition folded witness `z_pre = Σ c_i · s_i` (prover only). + /// Replaces both `z_hat` and `z`: `z_hat = J^{-1}(z_pre)`. + z_pre: Option>>, + /// Decomposed `ŵ_i = G_1^{-1}(w_i)` as i8 digit planes (prover only). + w_hat: Option>>, + /// Flattened `w_hat` as i8 digit planes (prover only, computed once and reused). + w_hat_flat: Option>, + /// Pre-decomposition folded ring elements (prover only, avoids recompose roundtrip). + w_folded: Option>>, + /// Commitment hint (prover only). + hint: Option>, + + _marker: PhantomData, +} + +impl QuadraticEquation +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, +{ + /// Prover constructor: runs §4.2 stage 1 and builds all equation components. + /// + /// `poly` provides the ring-level polynomial data for fold/decompose ops. + /// `hint` carries `t_hat` from the commitment phase. + /// + /// # Errors + /// + /// Returns an error if the norm check, challenge sampling, or matrix + /// generation fails. + #[allow(clippy::too_many_arguments)] + #[tracing::instrument(skip_all, name = "QuadraticEquation::new_prover")] + #[inline(never)] + pub fn new_prover, P: HachiPolyOps>( + ntt_d: &NttSlotCache, + ring_opening_point: RingOpeningPoint, + poly: &P, + pre_folded: Vec>, + hint: HachiCommitmentHint, + transcript: &mut T, + commitment: &RingCommitment, + y_ring: &CyclotomicRing, + layout: HachiCommitmentLayout, + ) -> Result { + { + let x: u8 = 0; + eprintln!( + " [QuadraticEquation::new_prover] stack ~= {:#x}", + &x as *const u8 as usize + ); + } + let t_wh = Instant::now(); + let (w_hat, w_hat_flat) = { + let _span = tracing::info_span!("decompose_w_hat").entered(); + let depth_open = layout.num_digits_open; + let log_basis = layout.log_basis; + let w_hat: Vec> = pre_folded + .iter() + .map(|w_i| w_i.balanced_decompose_pow2_i8(depth_open, log_basis)) + .collect(); + let w_hat_flat = flatten_w_hat(&w_hat); + (w_hat, w_hat_flat) + }; + eprintln!( + " [quad_eq] decompose_w_hat+flatten: {:.2}s (blocks={}, depth={})", + t_wh.elapsed().as_secs_f64(), + w_hat.len(), + w_hat.first().map_or(0, |v| v.len()) + ); + + let t_v = Instant::now(); + let v = { + let _span = tracing::info_span!("compute_v").entered(); + compute_v(ntt_d, &w_hat_flat) + }; + eprintln!( + " [quad_eq] compute_v (D*w_hat): {:.2}s (w_hat_flat_len={})", + t_v.elapsed().as_secs_f64(), + w_hat_flat.len() + ); + + transcript.append_serde(ABSORB_PROVER_V, &v); + + let challenge_cfg = SparseChallengeConfig { + weight: Cfg::challenge_weight_for_ring_dim(D), + nonzero_coeffs: vec![-1, 1], + }; + let challenges = sample_sparse_challenges::( + transcript, + CHALLENGE_STAGE1_FOLD, + layout.num_blocks, + &challenge_cfg, + )?; + + let t_zp = Instant::now(); + let z_pre = { + let _span = tracing::info_span!("compute_z_pre").entered(); + compute_z_pre::(poly, &challenges, layout)? + }; + eprintln!( + " [quad_eq] compute_z_pre: {:.2}s (z_pre_len={})", + t_zp.elapsed().as_secs_f64(), + z_pre.len() + ); + + let y = generate_y::(&v, &commitment.u, y_ring, Cfg::N_D, Cfg::N_B, Cfg::N_A)?; + + Ok(Self { + v, + challenges, + y, + opening_point: ring_opening_point, + z_pre: Some(z_pre), + w_hat: Some(w_hat), + w_hat_flat: Some(w_hat_flat), + w_folded: Some(pre_folded), + hint: Some(hint), + _marker: PhantomData, + }) + } + + /// Verifier constructor: Derives challenges and computes M and y. + /// + /// # Errors + /// + /// Returns an error if challenge derivation fails. + #[tracing::instrument(skip_all, name = "QuadraticEquation::new_verifier")] + #[inline(never)] + pub fn new_verifier>( + ring_opening_point: RingOpeningPoint, + v: Vec>, + transcript: &mut T, + commitment: &RingCommitment, + y_ring: &CyclotomicRing, + layout: HachiCommitmentLayout, + ) -> Result { + let challenges = + derive_stage1_challenges::(transcript, &v, layout.num_blocks)?; + let y = generate_y::(&v, &commitment.u, y_ring, Cfg::N_D, Cfg::N_B, Cfg::N_A)?; + + Ok(Self { + v, + challenges, + y, + opening_point: ring_opening_point, + z_pre: None, + w_hat: None, + w_hat_flat: None, + w_folded: None, + hint: None, + _marker: PhantomData, + }) + } + + /// Get the vector y. + pub fn y(&self) -> &[CyclotomicRing] { + &self.y + } + + /// Get the vector v. + pub fn v(&self) -> &[CyclotomicRing] { + &self.v + } + + /// Get the opening point (a, b) Lagrange weights. + pub fn opening_point(&self) -> &RingOpeningPoint { + &self.opening_point + } + + /// Get the pre-decomposition folded witness `z_pre` (prover only). + pub fn z_pre(&self) -> Option<&[CyclotomicRing]> { + self.z_pre.as_deref() + } + + /// Take ownership of `z_pre`, leaving `None` in its place. + pub fn take_z_pre(&mut self) -> Option>> { + self.z_pre.take() + } + + /// Get the decomposed witness `ŵ` as i8 digit planes (prover only). + pub fn w_hat(&self) -> Option<&[Vec<[i8; D]>]> { + self.w_hat.as_deref() + } + + /// Get the pre-flattened `w_hat` as i8 digit planes (prover only). + pub fn w_hat_flat(&self) -> Option<&[[i8; D]]> { + self.w_hat_flat.as_deref() + } + + /// Take ownership of `w_hat`, leaving `None` in its place. + pub fn take_w_hat(&mut self) -> Option>> { + self.w_hat.take() + } + + /// Get the pre-decomposition folded ring elements (prover only). + pub fn w_folded(&self) -> Option<&[CyclotomicRing]> { + self.w_folded.as_deref() + } + + /// Get the commitment hint (prover only). + pub fn hint(&self) -> Option<&HachiCommitmentHint> { + self.hint.as_ref() + } + + /// Take ownership of the hint, leaving `None` in its place. + pub fn take_hint(&mut self) -> Option> { + self.hint.take() + } +} + +pub(crate) fn derive_stage1_challenges( + transcript: &mut T, + v: &Vec>, + num_blocks: usize, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + let challenge_cfg = SparseChallengeConfig { + weight: Cfg::challenge_weight_for_ring_dim(D), + nonzero_coeffs: vec![-1, 1], + }; + transcript.append_serde(ABSORB_PROVER_V, v); + sample_sparse_challenges::( + transcript, + CHALLENGE_STAGE1_FOLD, + num_blocks, + &challenge_cfg, + ) +} + +#[cfg(any(test, debug_assertions))] +fn gadget_row_scalars(levels: usize, log_basis: u32) -> Vec { + let base = F::from_canonical_u128_reduced(1u128 << log_basis); + let mut out = Vec::with_capacity(levels); + let mut power = F::one(); + for _ in 0..levels { + out.push(power); + power = power * base; + } + out +} + +/// Add scalar * ring_element into the low-D coefficients of `poly`. +/// scalar * ring produces degree D-1, so no high-half contribution. +fn add_scalar_ring_product( + poly: &mut [F], + scalar: &F, + ring: &CyclotomicRing, +) { + for (k, coeff) in ring.coefficients().iter().enumerate() { + poly[k] += *scalar * *coeff; + } +} + +/// Subtract scalar * ring_element from the low-D coefficients of `poly`. +fn sub_scalar_ring_product( + poly: &mut [F], + scalar: &F, + ring: &CyclotomicRing, +) { + for (k, coeff) in ring.coefficients().iter().enumerate() { + poly[k] -= *scalar * *coeff; + } +} + +/// Add sparse_challenge * ring_element as unreduced product into `poly`. +/// +/// Exploits sparsity: O(weight * D) instead of O(D^2) schoolbook. +fn add_sparse_ring_product( + poly: &mut [F], + challenge: &SparseChallenge, + ring: &CyclotomicRing, +) { + let rc = ring.coefficients(); + for (&pos, &coeff) in challenge.positions.iter().zip(challenge.coeffs.iter()) { + let c = F::from_i64(coeff as i64); + let p = pos as usize; + for (s, &r_s) in rc.iter().enumerate() { + poly[p + s] += c * r_s; + } + } +} + +/// Split-eq replacement for `generate_m` + `compute_r_via_poly_division`. +/// +/// Computes `r` such that `M·z = y + (X^D+1)·r` without materializing M or z. +/// Uses split-eq factoring: `kron(left, gadget) · decomposed = left · pre_decomp`. +#[allow(clippy::too_many_arguments, clippy::needless_borrow)] +#[tracing::instrument(skip_all, name = "compute_r_split_eq")] +#[allow(clippy::too_many_arguments)] +pub(crate) fn compute_r_split_eq( + _setup: &HachiExpandedSetup, + opening_point: &RingOpeningPoint, + challenges: &[SparseChallenge], + w_hat_flat: &[[i8; D]], + t_hat: &[Vec<[i8; D]>], + w_folded: &[CyclotomicRing], + z_pre: &[CyclotomicRing], + y: &[CyclotomicRing], + ntt_a: &NttSlotCache, + ntt_b: &NttSlotCache, + ntt_d: &NttSlotCache, + layout: HachiCommitmentLayout, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, +{ + { + let x: u8 = 0; + eprintln!( + " [compute_r_split_eq] stack ~= {:#x}", + &x as *const u8 as usize + ); + } + let decomp_commit = layout.num_digits_commit; + let decomp_open = layout.num_digits_open; + let log_basis = layout.log_basis; + let poly_len = 2 * D - 1; + let num_rows = Cfg::N_D + Cfg::N_B + 1 + 1 + Cfg::N_A; + + let t_hat_flat = flatten_i8_blocks(t_hat); + + // NTT-accelerated D, B, and A rows: compute quotient = (cyc - neg) / 2 + let t_d = Instant::now(); + let d_quotients = { + let _span = tracing::info_span!("D_rows_ntt").entered(); + unreduced_quotient_rows_ntt_cached_i8(ntt_d, w_hat_flat) + }; + let d_time = t_d.elapsed().as_secs_f64(); + + let t_b = Instant::now(); + let b_quotients = { + let _span = tracing::info_span!("B_rows_ntt").entered(); + unreduced_quotient_rows_ntt_cached_i8(ntt_b, &t_hat_flat) + }; + let b_time = t_b.elapsed().as_secs_f64(); + + let t_a = Instant::now(); + let a_quotients = { + let _span = tracing::info_span!("A_rows_ntt").entered(); + unreduced_quotient_rows_ntt_cached(ntt_a, z_pre) + }; + let a_time = t_a.elapsed().as_secs_f64(); + + let mut result = Vec::with_capacity(num_rows); + let mut other_time = 0.0f64; + let mut poly_buf = vec![F::zero(); poly_len]; + let mut quotient_buf = vec![F::zero(); D]; + + for (row_idx, _y_i) in y.iter().enumerate().take(num_rows) { + if row_idx < Cfg::N_D { + result.push(d_quotients[row_idx]); + } else if row_idx < Cfg::N_D + Cfg::N_B { + result.push(b_quotients[row_idx - Cfg::N_D]); + } else if row_idx >= Cfg::N_D + Cfg::N_B + 2 { + // A-rows: NTT-accelerated A*z_pre + sparse challenge terms + let t_row = Instant::now(); + let _span = tracing::info_span!("A_row").entered(); + let a_idx = row_idx - (Cfg::N_D + Cfg::N_B + 2); + + poly_buf.fill(F::zero()); + for (i, t_hat_i) in t_hat.iter().enumerate() { + let start = a_idx * decomp_open; + let end = start + decomp_open; + if end <= t_hat_i.len() { + let t_recomp = + CyclotomicRing::gadget_recompose_pow2_i8(&t_hat_i[start..end], log_basis); + add_sparse_ring_product(&mut poly_buf, &challenges[i], &t_recomp); + } + } + + let a_q = a_quotients[a_idx].coefficients(); + quotient_buf.fill(F::zero()); + quotient_buf[..(poly_len - D)].copy_from_slice(&poly_buf[D..poly_len]); + for k in 0..D { + quotient_buf[k] -= a_q[k]; + } + result.push(CyclotomicRing::from_slice("ient_buf)); + other_time += t_row.elapsed().as_secs_f64(); + } else { + // bTw_row and challenge_fold_row: schoolbook (cheap) + let t_row = Instant::now(); + poly_buf.fill(F::zero()); + + if row_idx == Cfg::N_D + Cfg::N_B { + let _span = tracing::info_span!("bTw_row").entered(); + for (i, w_f) in w_folded.iter().enumerate() { + add_scalar_ring_product(&mut poly_buf, &opening_point.b[i], w_f); + } + } else { + let _span = tracing::info_span!("challenge_fold_row").entered(); + for (i, w_f) in w_folded.iter().enumerate() { + add_sparse_ring_product(&mut poly_buf, &challenges[i], w_f); + } + let block_len = opening_point.a.len(); + for i in 0..block_len { + let start = i * decomp_commit; + let end = start + decomp_commit; + if end <= z_pre.len() { + let z_pre_recomp = + CyclotomicRing::gadget_recompose_pow2(&z_pre[start..end], log_basis); + sub_scalar_ring_product(&mut poly_buf, &opening_point.a[i], &z_pre_recomp); + } + } + } + + let y_coeffs = _y_i.coefficients(); + for k in 0..D { + poly_buf[k] -= y_coeffs[k]; + } + + quotient_buf.fill(F::zero()); + for k in (D..poly_len).rev() { + let q = poly_buf[k]; + quotient_buf[k - D] = q; + poly_buf[k - D] -= q; + } + result.push(CyclotomicRing::from_slice("ient_buf)); + other_time += t_row.elapsed().as_secs_f64(); + } + } + + eprintln!( + " [compute_r] D(NTT): {d_time:.2}s, B(NTT): {b_time:.2}s, A(NTT): {a_time:.2}s, other: {other_time:.2}s", + ); + + Ok(result) +} + +/// Reference helper for tests/debug diagnostics: split-eq replacement for +/// `generate_m` + `eval_ring_matrix_at`. +/// +/// Computes the field-element evaluations of each M entry at `alpha`, +/// organized as rows of field elements, without materializing ring-valued `M`. +#[cfg(any(test, debug_assertions))] +#[tracing::instrument(skip_all, name = "compute_m_a_reference")] +pub(crate) fn compute_m_a_reference( + setup: &HachiExpandedSetup, + opening_point: &RingOpeningPoint, + challenges: &[SparseChallenge], + alpha: &F, + layout: HachiCommitmentLayout, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, +{ + let depth_commit = layout.num_digits_commit; + let depth_open = layout.num_digits_open; + let depth_fold = layout.num_digits_fold; + let log_basis = layout.log_basis; + let num_blocks = opening_point.b.len(); + let block_len = layout.block_len; + let w_len = depth_open * num_blocks; + let t_len = depth_open * Cfg::N_A * num_blocks; + let z_len = depth_fold * depth_commit * block_len; + let total_cols = w_len + t_len + z_len; + + let g1_open = gadget_row_scalars::(depth_open, log_basis); + let g1_commit = gadget_row_scalars::(depth_commit, log_basis); + let j1 = gadget_row_scalars::(depth_fold, log_basis); + + let c_alphas: Vec = challenges + .iter() + .map(|c| eval_ring_at(&c.to_dense::().expect("valid challenge"), alpha)) + .collect(); + + let d_view = setup.D_mat.view::(); + let b_view = setup.B.view::(); + + let d_rows: Vec> = cfg_into_iter!(0..d_view.num_rows()) + .map(|i| { + let d_row = d_view.row(i); + let mut full = vec![F::zero(); total_cols]; + for (j, ring) in d_row.iter().take(w_len).enumerate() { + full[j] = eval_ring_at(ring, alpha); + } + full + }) + .collect(); + + let b_rows: Vec> = cfg_into_iter!(0..b_view.num_rows()) + .map(|i| { + let b_row = b_view.row(i); + let mut full = vec![F::zero(); total_cols]; + for (j, ring) in b_row.iter().take(t_len).enumerate() { + full[w_len + j] = eval_ring_at(ring, alpha); + } + full + }) + .collect(); + + let mut rows = Vec::with_capacity(Cfg::N_D + Cfg::N_B + 1 + 1 + Cfg::N_A); + rows.extend(d_rows); + rows.extend(b_rows); + + // Row 3: b^T · G · ŵ = y_ring (ŵ uses delta_open) + { + let mut full = vec![F::zero(); total_cols]; + for (i, &b_i) in opening_point.b.iter().enumerate() { + for (d, &g) in g1_open.iter().enumerate() { + full[i * depth_open + d] = b_i * g; + } + } + rows.push(full); + } + + // Row 4: (c^T ⊗ G) · ŵ = a^T · G · J · ẑ + { + let mut full = vec![F::zero(); total_cols]; + for (i, &c_alpha) in c_alphas.iter().enumerate() { + for (d, &g) in g1_open.iter().enumerate() { + full[i * depth_open + d] = c_alpha * g; + } + } + let z_offset = w_len + t_len; + for (i, &a_i) in opening_point.a.iter().enumerate() { + for (d, &g) in g1_commit.iter().enumerate() { + let ag = a_i * g; + for (t, &j) in j1.iter().enumerate() { + let idx = (i * depth_commit + d) * depth_fold + t; + full[z_offset + idx] = -(ag * j); + } + } + } + rows.push(full); + } + + // Row 5: (c^T ⊗ G_open) · t̂ = A · J · ẑ + // t̂ uses delta_open (t = A*s has full-field coefficients); ẑ uses delta_commit + for a_idx in 0..Cfg::N_A { + let mut full = vec![F::zero(); total_cols]; + for (i, &c_alpha) in c_alphas.iter().enumerate() { + for (d, &g) in g1_open.iter().enumerate() { + let t_idx = i * (Cfg::N_A * depth_open) + a_idx * depth_open + d; + full[w_len + t_idx] = c_alpha * g; + } + } + let z_offset = w_len + t_len; + let a_view = setup.A.view::(); + let a_row = a_view.row(a_idx); + let inner_width = block_len * depth_commit; + for (k, ring) in a_row.iter().take(inner_width).enumerate() { + let ring_alpha = eval_ring_at(ring, alpha); + for (t, &j) in j1.iter().enumerate() { + full[z_offset + k * depth_fold + t] = -(ring_alpha * j); + } + } + rows.push(full); + } + + Ok(rows) +} + +pub(crate) fn generate_y( + v: &[CyclotomicRing], + u: &[CyclotomicRing], + u_eval: &CyclotomicRing, + n_d: usize, + n_b: usize, + n_a: usize, +) -> Result>, HachiError> +where + F: FieldCore, +{ + if v.len() != n_d { + return Err(HachiError::InvalidSize { + expected: n_d, + actual: v.len(), + }); + } + if u.len() != n_b { + return Err(HachiError::InvalidSize { + expected: n_b, + actual: u.len(), + }); + } + let mut out = Vec::with_capacity(n_d + n_b + 1 + 1 + n_a); + out.extend_from_slice(v); + out.extend_from_slice(u); + out.push(*u_eval); + out.push(CyclotomicRing::::zero()); + out.extend(repeat_n(CyclotomicRing::::zero(), n_a)); + Ok(out) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::array::from_fn; + + use crate::algebra::{CyclotomicRing, SparseChallengeConfig}; + use crate::protocol::challenges::sparse::sample_sparse_challenges; + use crate::protocol::commitment::HachiProverSetup; + use crate::protocol::commitment::{HachiCommitmentCore, RingCommitmentScheme}; + use crate::protocol::hachi_poly_ops::DensePoly; + use crate::protocol::proof::HachiCommitmentHint; + use crate::protocol::transcript::Blake2bTranscript; + use crate::test_utils::*; + use crate::FromSmallInt; + use crate::Transcript; + + const TRANSCRIPT_SEED: &[u8] = b"test/prover-relation"; + + fn replay_challenges(v: &Vec>) -> Vec> { + let mut transcript = Blake2bTranscript::::new(TRANSCRIPT_SEED); + transcript.append_serde(ABSORB_PROVER_V, v); + + let challenge_cfg = SparseChallengeConfig { + weight: TinyConfig::CHALLENGE_WEIGHT, + nonzero_coeffs: vec![-1, 1], + }; + let sparse = sample_sparse_challenges::, D>( + &mut transcript, + CHALLENGE_STAGE1_FOLD, + NUM_BLOCKS, + &challenge_cfg, + ) + .unwrap(); + sparse + .iter() + .map(|c| c.to_dense::().unwrap()) + .collect() + } + + struct Fixture { + setup: HachiProverSetup, + commitment_u: Vec>, + point: RingOpeningPoint, + blocks: Vec>>, + quad_eq: QuadraticEquation, + /// Challenges re-derived via transcript replay (cross-check). + challenges: Vec>, + } + + fn build_fixture() -> Fixture { + let (setup, _) = + >::setup(16).unwrap(); + + let blocks = sample_blocks(); + let w = + >::commit_ring_blocks( + &blocks, &setup, + ) + .unwrap(); + + let point = RingOpeningPoint { + a: sample_a(), + b: sample_b(), + }; + + let ring_coeffs: Vec> = + blocks.iter().flat_map(|b| b.iter().copied()).collect(); + let poly = DensePoly::from_ring_coeffs(ring_coeffs); + let hint = HachiCommitmentHint::new(w.t_hat); + let mut transcript = Blake2bTranscript::::new(TRANSCRIPT_SEED); + let y_ring = CyclotomicRing::::zero(); + let layout = setup.layout(); + let w_folded = poly.fold_blocks(&point.a, layout.block_len); + let quad_eq = QuadraticEquation::::new_prover( + &setup.ntt_D, + point.clone(), + &poly, + w_folded, + hint, + &mut transcript, + &w.commitment, + &y_ring, + layout, + ) + .unwrap(); + + let challenges = replay_challenges(&quad_eq.v); + + Fixture { + setup, + commitment_u: w.commitment.u.clone(), + point, + blocks, + quad_eq, + challenges, + } + } + + fn i8_to_ring(digits: &[[i8; D]]) -> Vec> { + digits + .iter() + .map(|d| { + let coeffs: [F; D] = from_fn(|i| F::from_i64(d[i] as i64)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + } + + /// Row 1: D · ŵ = v + #[test] + fn row1_d_times_w_hat_equals_v() { + let f = build_fixture(); + + let w_hat = f.quad_eq.w_hat().unwrap(); + let w_hat_flat: Vec> = i8_to_ring( + &w_hat + .iter() + .flat_map(|v| v.iter().copied()) + .collect::>(), + ); + let lhs = mat_vec_mul(&f.setup.expanded.D_mat, &w_hat_flat); + + assert_eq!(lhs, f.quad_eq.v(), "Row 1 failed: D · ŵ ≠ v"); + } + + /// Row 2: B · t̂ = u (commitment vector) + #[test] + fn row2_b_times_t_hat_equals_u_commitment() { + let f = build_fixture(); + + let hint = f.quad_eq.hint().unwrap(); + let t_hat_flat_ring: Vec> = hint + .t_hat + .iter() + .flat_map(|v| v.iter()) + .map(|plane| { + let coeffs: [F; D] = from_fn(|k| F::from_i64(plane[k] as i64)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let lhs = mat_vec_mul(&f.setup.expanded.B, &t_hat_flat_ring); + + assert_eq!(lhs, f.commitment_u, "Row 2 failed: B · t̂ ≠ u"); + } + + /// Row 3: b^T · G_{2^r} · ŵ = u_eval + #[test] + fn row3_bt_gadget_w_hat_equals_u_eval() { + let f = build_fixture(); + + let w_hat = f.quad_eq.w_hat().unwrap(); + let w_recomposed: Vec> = w_hat + .iter() + .map(|w_hat_i| CyclotomicRing::gadget_recompose_pow2_i8(w_hat_i, log_basis())) + .collect(); + + let u_eval = w_recomposed + .iter() + .zip(f.point.b.iter()) + .fold(CyclotomicRing::::zero(), |acc, (w_i, b_i)| { + acc + w_i.scale(b_i) + }); + + let u_eval_direct = f.blocks.iter().zip(f.point.b.iter()).fold( + CyclotomicRing::::zero(), + |acc, (block_i, b_i)| { + let inner: CyclotomicRing = block_i + .iter() + .zip(f.point.a.iter()) + .fold(CyclotomicRing::::zero(), |acc2, (f_ij, a_j)| { + acc2 + f_ij.scale(a_j) + }); + acc + inner.scale(b_i) + }, + ); + + assert_eq!( + u_eval, u_eval_direct, + "Row 3 failed: b^T G ŵ ≠ Σ b_i (a^T f_i)" + ); + } + + /// Derive z_hat from z_pre for test assertions. + fn derive_z_hat(z_pre: &[CyclotomicRing]) -> Vec> { + z_pre + .iter() + .flat_map(|z_j| z_j.balanced_decompose_pow2(num_digits_fold(), log_basis())) + .collect() + } + + /// Row 4: (c^T ⊗ G_1) · ŵ = a^T · G_{2^m} · J · ẑ + #[test] + fn row4_challenge_fold_w_equals_a_gadget_j_z_hat() { + let f = build_fixture(); + + let w_hat = f.quad_eq.w_hat().unwrap(); + let w: Vec> = w_hat + .iter() + .map(|w_hat_i| CyclotomicRing::gadget_recompose_pow2_i8(w_hat_i, log_basis())) + .collect(); + + let lhs = f + .challenges + .iter() + .zip(w.iter()) + .fold(CyclotomicRing::::zero(), |acc, (c_i, w_i)| { + acc + (*c_i * *w_i) + }); + + let z_hat = derive_z_hat(f.quad_eq.z_pre().unwrap()); + let z_recovered = recompose_z_hat(&z_hat); + let rhs = a_transpose_gadget_times_vec(&f.point.a, &z_recovered); + + assert_eq!(lhs, rhs, "Row 4 failed: (c^T ⊗ G_1)ŵ ≠ a^T G J ẑ"); + } + + /// Row 5: (c^T ⊗ G_{n_A}) · t̂ = A · J · ẑ + #[test] + fn row5_challenge_fold_t_equals_a_j_z_hat() { + let f = build_fixture(); + + let hint = f.quad_eq.hint().unwrap(); + let mut lhs = vec![CyclotomicRing::::zero(); N_A]; + for (c_i, t_hat_i) in f.challenges.iter().zip(hint.t_hat.iter()) { + let t_i = gadget_recompose_vec_i8(t_hat_i); + assert_eq!(t_i.len(), N_A); + for (lhs_j, t_ij) in lhs.iter_mut().zip(t_i.iter()) { + *lhs_j += *c_i * *t_ij; + } + } + + let z_hat = derive_z_hat(f.quad_eq.z_pre().unwrap()); + let z_recovered = recompose_z_hat(&z_hat); + let rhs = mat_vec_mul(&f.setup.expanded.A, &z_recovered); + + assert_eq!(lhs, rhs, "Row 5 failed: (c^T ⊗ G_nA)t̂ ≠ A · J · ẑ"); + } + + #[test] + fn prove_output_shapes_are_correct() { + let f = build_fixture(); + + assert_eq!(f.quad_eq.v().len(), TinyConfig::N_D); + + let w_hat = f.quad_eq.w_hat().unwrap(); + assert_eq!(w_hat.len(), NUM_BLOCKS); + assert!(w_hat.iter().all(|v| v.len() == num_digits_open())); + + let hint = f.quad_eq.hint().unwrap(); + assert_eq!(hint.t_hat.len(), NUM_BLOCKS); + assert!(hint + .t_hat + .iter() + .all(|v| v.len() == N_A * num_digits_open())); + + assert_eq!( + f.quad_eq.z_pre().unwrap().len(), + BLOCK_LEN * num_digits_commit() + ); + } +} diff --git a/src/protocol/ring_switch.rs b/src/protocol/ring_switch.rs new file mode 100644 index 00000000..e087f766 --- /dev/null +++ b/src/protocol/ring_switch.rs @@ -0,0 +1,1125 @@ +//! Ring switching logic for the Hachi PCS (Section 4.3). +//! +//! Handles the transition from the ring-based quadratic equation to field-based +//! sumcheck instances by expanding the ring elements into their coefficient +//! vectors and setting up the evaluation tables. + +use crate::algebra::{CyclotomicRing, SparseChallenge}; +use crate::cfg_into_iter; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::commitment::utils::crt_ntt::NttSlotCache; +use crate::protocol::commitment::utils::linear::{ + decompose_rows_i8, flatten_i8_blocks, mat_vec_mul_ntt_digits_i8, mat_vec_mul_ntt_i8, + mat_vec_mul_ntt_single_i8, +}; +use crate::protocol::commitment::utils::norm::detect_field_modulus; +use crate::protocol::commitment::{ + optimal_m_r_split, CommitmentConfig, DecompositionParams, HachiCommitmentLayout, + HachiExpandedSetup, RingCommitment, +}; +use crate::protocol::opening_point::RingOpeningPoint; +use crate::protocol::proof::{DigitLut, FlatCommitmentHint, FlatRingVec, HachiCommitmentHint}; +use crate::protocol::quadratic_equation::{compute_r_split_eq, QuadraticEquation}; +use crate::protocol::sumcheck::eq_poly::EqPolynomial; +use crate::protocol::transcript::labels::{ + ABSORB_SUMCHECK_W, CHALLENGE_RING_SWITCH, CHALLENGE_TAU0, CHALLENGE_TAU1, +}; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt}; +#[cfg(test)] +use std::array::from_fn; +use std::marker::PhantomData; +use std::time::Instant; + +/// D-agnostic output of the ring switch protocol, containing everything +/// needed for sumchecks and level chaining. +pub struct RingSwitchOutput { + /// The witness vector w as balanced digits in `[-b/2, b/2)`. + pub w: Vec, + /// D-erased commitment to w. + pub w_commitment: FlatRingVec, + /// D-erased prover hint for the w-commitment. + pub w_hint: FlatCommitmentHint, + /// Compact evaluation table of w (all entries in [-b/2, b/2), reordered for sumcheck). + /// Populated by the prover; empty on the verifier side. + pub w_evals: Vec, + /// Field-element evaluation table of w (same reordering as `w_evals`). + /// Produced alongside `w_evals` in a single pass to avoid a duplicate scan. + pub w_evals_field: Vec, + /// Evaluation table of M_alpha(x) (tau1-weighted). + pub m_evals_x: Vec, + /// Evaluation table of alpha powers (y dimension). + pub alpha_evals_y: Vec, + /// Number of upper variable bits. + pub num_u: usize, + /// Number of lower variable bits. + pub num_l: usize, + /// Challenge tau0 for F_0 sumcheck. + pub tau0: Vec, + /// Challenge tau1 for F_alpha sumcheck. + pub tau1: Vec, + /// Basis size b = 2^LOG_BASIS. + pub b: usize, + /// Ring-switch challenge alpha. + pub alpha: F, +} + +/// Build the witness vector `w` from the quadratic equation state. +/// +/// This is the first half of the ring switch: it computes `r` and assembles +/// `w` as a flat `Vec`. The resulting `w` is D-agnostic and can be +/// committed at any ring dimension via [`commit_w`]. +/// +/// # Errors +/// +/// Returns an error if the quadratic equation is missing prover-side data. +#[tracing::instrument(skip_all, name = "ring_switch_build_w")] +#[allow(clippy::too_many_arguments)] +#[inline(never)] +pub fn ring_switch_build_w( + quad_eq: &mut QuadraticEquation, + setup: &HachiExpandedSetup, + ntt_a: &NttSlotCache, + ntt_b: &NttSlotCache, + ntt_d: &NttSlotCache, + layout: HachiCommitmentLayout, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, +{ + { + let x: u8 = 0; + eprintln!( + " [ring_switch_build_w] stack ~= {:#x}", + &x as *const u8 as usize + ); + } + let w_hat = quad_eq + .w_hat() + .ok_or_else(|| HachiError::InvalidInput("missing w_hat in prover".to_string()))?; + let w_hat_flat = quad_eq + .w_hat_flat() + .ok_or_else(|| HachiError::InvalidInput("missing w_hat_flat in prover".to_string()))?; + let z_pre = quad_eq + .z_pre() + .ok_or_else(|| HachiError::InvalidInput("missing z_pre in prover".to_string()))?; + let hint = quad_eq + .hint() + .ok_or_else(|| HachiError::InvalidInput("missing hint in prover".to_string()))?; + let t_hat = &hint.t_hat; + let w_folded = quad_eq + .w_folded() + .ok_or_else(|| HachiError::InvalidInput("missing w_folded in prover".to_string()))?; + + let t_rs = Instant::now(); + let r = compute_r_split_eq::( + setup, + quad_eq.opening_point(), + &quad_eq.challenges, + w_hat_flat, + t_hat, + w_folded, + z_pre, + quad_eq.y(), + ntt_a, + ntt_b, + ntt_d, + layout, + )?; + eprintln!( + " [ring_switch] compute_r_split_eq: {:.2}s", + t_rs.elapsed().as_secs_f64() + ); + let t_wc = Instant::now(); + let w = { + let _span = tracing::info_span!("build_w_coeffs").entered(); + build_w_coeffs::(w_hat, t_hat, z_pre, &r, layout) + }; + eprintln!( + " [ring_switch] build_w_coeffs: {:.2}s", + t_wc.elapsed().as_secs_f64() + ); + Ok(w) +} + +/// Complete the ring switch after `w` has been committed. +/// +/// Takes the already-committed `w` (with its D-erased commitment and hint) +/// and finishes the protocol: absorbs the commitment into the transcript, +/// samples challenges, and builds the evaluation tables for the fused sumcheck. +/// +/// Only the current level's `D` is needed (for M_alpha expansion and +/// alpha_evals_y). The commitment's ring dimension is encoded in the +/// `FlatRingVec` and does not require a separate const generic. +/// +/// # Errors +/// +/// Returns an error if matrix expansion or evaluation-table construction fails. +#[tracing::instrument(skip_all, name = "ring_switch_finalize")] +#[allow(clippy::too_many_arguments)] +#[inline(never)] +pub fn ring_switch_finalize( + quad_eq: &QuadraticEquation, + setup: &HachiExpandedSetup, + transcript: &mut T, + w: Vec, + w_commitment: FlatRingVec, + w_hint: FlatCommitmentHint, + layout: HachiCommitmentLayout, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, + Cfg: CommitmentConfig, +{ + transcript.append_serde(ABSORB_SUMCHECK_W, &w_commitment); + + let alpha: F = transcript.challenge_scalar(CHALLENGE_RING_SWITCH); + + let num_l = D.trailing_zeros() as usize; + let num_ring_elems = w.len() / D; + let num_u = num_ring_elems.next_power_of_two().trailing_zeros() as usize; + let m_rows = m_row_count::(); + let num_sc_vars = num_u + num_l; + let num_i = m_rows.next_power_of_two().trailing_zeros() as usize; + + let tau0 = sample_tau::(transcript, CHALLENGE_TAU0, num_sc_vars); + let tau1 = sample_tau::(transcript, CHALLENGE_TAU1, num_i); + let alpha_evals_y = build_alpha_evals_y(alpha, D); + + let t_par = Instant::now(); + let opening_point = quad_eq.opening_point(); + let challenges = &quad_eq.challenges; + + #[cfg(feature = "parallel")] + let (m_evals_x_result, w_result) = rayon::join( + || { + compute_m_evals_x::( + setup, + opening_point, + challenges, + alpha, + &alpha_evals_y, + layout, + &tau1, + ) + }, + || build_w_evals_dual::(&w, D, layout.log_basis), + ); + #[cfg(not(feature = "parallel"))] + let (m_evals_x_result, w_result) = { + let m_evals_x = compute_m_evals_x::( + setup, + opening_point, + challenges, + alpha, + &alpha_evals_y, + layout, + &tau1, + )?; + let w_dual = build_w_evals_dual::(&w, D, layout.log_basis); + (Ok(m_evals_x), w_dual) + }; + + let m_evals_x = m_evals_x_result?; + let (w_evals, w_evals_field, _, _) = w_result?; + eprintln!( + " [ring_switch] m_evals_x+w_evals parallel: {:.2}s", + t_par.elapsed().as_secs_f64() + ); + + Ok(RingSwitchOutput { + w, + w_commitment, + w_hint, + w_evals, + w_evals_field, + m_evals_x, + alpha_evals_y, + num_u, + num_l, + tau0, + tau1, + b: 1usize << layout.log_basis, + alpha, + }) +} + +/// Execute the prover side of the ring switching protocol (Section 4.3). +/// +/// Convenience wrapper that calls [`ring_switch_build_w`], [`commit_w`], and +/// [`ring_switch_finalize`] in sequence, all at the same ring dimension `D`. +/// +/// # Errors +/// +/// Returns an error if z_pre/w_hat is missing, commitment fails, or matrix expansion fails. +#[tracing::instrument(skip_all, name = "ring_switch_prover")] +#[allow(clippy::too_many_arguments)] +#[inline(never)] +pub fn ring_switch_prover( + quad_eq: &mut QuadraticEquation, + setup: &HachiExpandedSetup, + transcript: &mut T, + ntt_a: &NttSlotCache, + ntt_b: &NttSlotCache, + ntt_d: &NttSlotCache, + layout: HachiCommitmentLayout, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, + Cfg: CommitmentConfig, +{ + let w = ring_switch_build_w::(quad_eq, setup, ntt_a, ntt_b, ntt_d, layout)?; + + let t_cw = Instant::now(); + let (w_commitment, w_hint) = commit_w::(&w, ntt_a, ntt_b)?; + eprintln!( + " [ring_switch] commit_w: {:.2}s (w_len={})", + t_cw.elapsed().as_secs_f64(), + w.len() + ); + + let w_commitment_flat = FlatRingVec::from_commitment(&w_commitment); + let w_hint_flat = FlatCommitmentHint::from_typed(w_hint); + + ring_switch_finalize::( + quad_eq, + setup, + transcript, + w, + w_commitment_flat, + w_hint_flat, + layout, + ) +} + +/// Replay the verifier side of ring switching to reconstruct evaluation tables. +/// +/// Takes the w-commitment as a [`FlatRingVec`] so the verifier does not need +/// to know D_COMMIT (the commitment's ring dimension). +/// +/// # Errors +/// +/// Returns an error if matrix expansion fails. +#[tracing::instrument(skip_all, name = "ring_switch_verifier")] +#[inline(never)] +pub fn ring_switch_verifier( + quad_eq: &QuadraticEquation, + setup: &HachiExpandedSetup, + w_len: usize, + w_commitment: &FlatRingVec, + transcript: &mut T, + layout: HachiCommitmentLayout, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, + Cfg: CommitmentConfig, +{ + transcript.append_serde(ABSORB_SUMCHECK_W, w_commitment); + + let alpha: F = transcript.challenge_scalar(CHALLENGE_RING_SWITCH); + + let num_ring_elems = w_len / D; + let num_u = num_ring_elems.next_power_of_two().trailing_zeros() as usize; + let num_l = D.trailing_zeros() as usize; + let m_rows = m_row_count::(); + let num_sc_vars = num_u + num_l; + let num_i = m_rows.next_power_of_two().trailing_zeros() as usize; + + let tau0 = sample_tau::(transcript, CHALLENGE_TAU0, num_sc_vars); + let tau1 = sample_tau::(transcript, CHALLENGE_TAU1, num_i); + let alpha_evals_y = build_alpha_evals_y(alpha, D); + + let m_evals_x = compute_m_evals_x::( + setup, + quad_eq.opening_point(), + &quad_eq.challenges, + alpha, + &alpha_evals_y, + layout, + &tau1, + )?; + + Ok(RingSwitchOutput { + w: Vec::new(), + w_commitment: w_commitment.clone(), + w_hint: FlatCommitmentHint::empty(), + w_evals: Vec::new(), + w_evals_field: Vec::new(), + m_evals_x, + alpha_evals_y, + num_u, + num_l, + tau0, + tau1, + b: 1usize << layout.log_basis, + alpha, + }) +} + +#[cfg(test)] +pub(crate) fn compute_r_via_poly_division( + m: &[Vec>], + z: &[CyclotomicRing], + y: &[CyclotomicRing], +) -> Result>, HachiError> { + let poly_len = 2 * D - 1; + let out = m + .iter() + .zip(y.iter()) + .map(|(row, y_i)| { + let column_contribution = + |m_ij: &CyclotomicRing, z_j: &CyclotomicRing| -> Vec { + let mut local = vec![F::zero(); poly_len]; + if m_ij.is_zero() { + return local; + } + let a = m_ij.coefficients(); + let b = z_j.coefficients(); + let is_scalar = a[1..].iter().all(|c| c.is_zero()); + if is_scalar { + let scalar = a[0]; + for s in 0..D { + local[s] = scalar * b[s]; + } + } else { + for t in 0..D { + for s in 0..D { + local[t + s] += a[t] * b[s]; + } + } + } + local + }; + + let pointwise_add = |mut a: Vec, b: Vec| -> Vec { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + *ai += *bi; + } + a + }; + + #[cfg(feature = "parallel")] + let mut poly = row + .par_iter() + .zip(z.par_iter()) + .fold( + || vec![F::zero(); poly_len], + |acc, (m_ij, z_j)| pointwise_add(acc, column_contribution(m_ij, z_j)), + ) + .reduce(|| vec![F::zero(); poly_len], pointwise_add); + + #[cfg(not(feature = "parallel"))] + let mut poly = row + .iter() + .zip(z.iter()) + .fold(vec![F::zero(); poly_len], |acc, (m_ij, z_j)| { + pointwise_add(acc, column_contribution(m_ij, z_j)) + }); + let y_coeffs = y_i.coefficients(); + for k in 0..D { + poly[k] -= y_coeffs[k]; + } + let mut quotient = vec![F::zero(); D]; + for k in (D..poly_len).rev() { + let q = poly[k]; + quotient[k - D] = q; + poly[k - D] -= q; + } + let coeffs: [F; D] = from_fn(|k| quotient[k]); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + Ok(out) +} + +/// Derived commitment config for recursive w-openings. +/// +/// Sets `log_commit_bound = log_basis` (w's entries are balanced digits) and +/// `log_open_bound = parent's open bound` (opening folds produce full-field +/// coefficients). +/// +/// For `D=512, Cfg=Fp128FullCommitmentConfig`, this is equivalent to +/// [`Fp128LogBasisCommitmentConfig`](super::commitment::Fp128LogBasisCommitmentConfig). +#[derive(Clone, Copy, Debug)] +pub(crate) struct WCommitmentConfig { + _cfg: PhantomData, +} + +impl CommitmentConfig for WCommitmentConfig { + const D: usize = D; + const N_A: usize = Cfg::N_A; + const N_B: usize = Cfg::N_B; + const N_D: usize = Cfg::N_D; + const CHALLENGE_WEIGHT: usize = Cfg::CHALLENGE_WEIGHT; + + fn challenge_weight_for_ring_dim(d: usize) -> usize { + Cfg::challenge_weight_for_ring_dim(d) + } + + fn decomposition() -> DecompositionParams { + let parent = Cfg::decomposition(); + let parent_open = parent.log_open_bound.unwrap_or(parent.log_commit_bound); + DecompositionParams { + log_basis: parent.log_basis, + // w's entries are balanced digits in [-b/2, b/2), so commitment + // decomposition needs only one level. + log_commit_bound: parent.log_basis, + // Opening folds w with arbitrary field-element weights, producing + // full-field-size coefficients that need the same decomposition + // depth as the parent's opening bound. + log_open_bound: Some(parent_open), + } + } + + fn commitment_layout(max_num_vars: usize) -> Result { + let alpha = D.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + if reduced_vars == 0 { + return Err(HachiError::InvalidSetup( + "max_num_vars must leave at least one outer variable".to_string(), + )); + } + let (m_vars, r_vars) = optimal_m_r_split::(reduced_vars); + HachiCommitmentLayout::new::(m_vars, r_vars, &Self::decomposition()) + } +} + +/// Total ring elements in the w polynomial, computed from the main layout. +/// +/// Components: w_hat + t_hat + decomposed z_pre + decomposed r. +pub(crate) fn w_ring_element_count( + layout: HachiCommitmentLayout, +) -> usize { + let w_hat_count = layout.num_blocks * layout.num_digits_open; + let t_hat_count = layout.num_blocks * Cfg::N_A * layout.num_digits_open; + let z_pre_count = layout.inner_width * layout.num_digits_fold; + let r_count = m_row_count::() * r_decomp_levels::(layout.log_basis); + w_hat_count + t_hat_count + z_pre_count + r_count +} + +/// Compute the w-commitment layout from the main layout. +pub(crate) fn w_commitment_layout( + main_layout: HachiCommitmentLayout, +) -> Result { + let total = w_ring_element_count::(main_layout) + .next_power_of_two() + .max(1); + let alpha = D.trailing_zeros() as usize; + let m_vars = total.trailing_zeros() as usize; + let max_num_vars = m_vars + alpha; + WCommitmentConfig::::commitment_layout(max_num_vars) +} + +/// Commit the witness vector `w` (D-agnostic `Vec`) into `D`-sized ring +/// elements and compute the ring commitment. +/// +/// This is the **D-boundary** in the protocol: the ring switch at level k +/// produces `w` using D_k operations, but `commit_w` re-chunks `w` into +/// D_{k+1}-sized ring elements and commits using D_{k+1} NTT caches. +/// +/// For constant-D configs, D_k = D_{k+1} = D and the distinction is moot. +/// +/// # Errors +/// +/// Returns an error if the commitment layout derivation or NTT mat-vec fails. +#[tracing::instrument(skip_all, name = "commit_w")] +#[inline(never)] +pub fn commit_w( + w: &[i8], + ntt_a: &NttSlotCache, + ntt_b: &NttSlotCache, +) -> Result<(RingCommitment, HachiCommitmentHint), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, +{ + let (w_digits, remainder) = w.as_chunks::(); + if !remainder.is_empty() { + return Err(HachiError::InvalidSize { + expected: D, + actual: w.len(), + }); + } + + let total = w_digits.len().next_power_of_two().max(1); + let alpha = D.trailing_zeros() as usize; + let m_vars_total = total.trailing_zeros() as usize; + let max_num_vars = m_vars_total + alpha; + let w_layout = WCommitmentConfig::::commitment_layout(max_num_vars)?; + + let num_blocks = w_layout.num_blocks; + let block_len = w_layout.block_len; + let depth_commit = w_layout.num_digits_commit; + let depth_open = w_layout.num_digits_open; + let log_basis = w_layout.log_basis; + let coeff_len = w_digits.len(); + + let t_all = if depth_commit == 1 { + // `build_w_coeffs` already emits balanced base-`2^log_basis` digits, so + // the recursive w-commitment can skip the field conversion and feed those + // planes directly into the tiled NTT mat-vec. + let block_slices: Vec<&[[i8; D]]> = (0..num_blocks) + .map(|i| { + let start = i * block_len; + if start >= coeff_len { + &[] as &[[i8; D]] + } else { + &w_digits[start..(start + block_len).min(coeff_len)] + } + }) + .collect(); + mat_vec_mul_ntt_digits_i8(ntt_a, &block_slices) + } else { + let lut = DigitLut::::new(log_basis); + let ring_elems: Vec> = w_digits + .iter() + .map(|digit| { + let coeffs = std::array::from_fn(|k| lut.get(digit[k])); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let block_slices: Vec<&[CyclotomicRing]> = (0..num_blocks) + .map(|i| { + let start = i * block_len; + if start >= coeff_len { + &[] as &[CyclotomicRing] + } else { + &ring_elems[start..(start + block_len).min(coeff_len)] + } + }) + .collect(); + mat_vec_mul_ntt_i8(ntt_a, &block_slices, depth_commit, log_basis) + }; + let t_hat_per_block: Vec> = cfg_into_iter!(t_all) + .map(|t_i| decompose_rows_i8(&t_i, depth_open, log_basis)) + .collect(); + + let t_hat_flat = flatten_i8_blocks(&t_hat_per_block); + let u: Vec> = mat_vec_mul_ntt_single_i8(ntt_b, &t_hat_flat); + let hint = HachiCommitmentHint::new(t_hat_per_block); + Ok((RingCommitment { u }, hint)) +} + +pub(crate) fn eval_ring_at(r: &CyclotomicRing, alpha: &F) -> F { + let mut acc = F::zero(); + let mut power = F::one(); + for coeff in r.coefficients() { + acc += *coeff * power; + power = power * *alpha; + } + acc +} + +#[inline] +fn eval_ring_at_pows( + r: &CyclotomicRing, + alpha_pows: &[F], +) -> F { + debug_assert_eq!(alpha_pows.len(), D); + r.coefficients() + .iter() + .zip(alpha_pows.iter()) + .fold(F::zero(), |acc, (coeff, alpha_pow)| { + acc + *coeff * *alpha_pow + }) +} + +#[inline] +fn eval_sparse_challenge_at_pows( + challenge: &SparseChallenge, + alpha_pows: &[F], +) -> Result { + if alpha_pows.len() != D { + return Err(HachiError::InvalidSize { + expected: D, + actual: alpha_pows.len(), + }); + } + + debug_assert_eq!(challenge.positions.len(), challenge.coeffs.len()); + + let mut acc = F::zero(); + for (&pos, &coeff) in challenge.positions.iter().zip(challenge.coeffs.iter()) { + let idx = pos as usize; + debug_assert!(idx < D); + debug_assert_ne!(coeff, 0); + acc += F::from_i64(coeff as i64) * alpha_pows[idx]; + } + Ok(acc) +} + +#[inline] +fn gadget_row_scalars(levels: usize, log_basis: u32) -> Vec { + let base = F::from_canonical_u128_reduced(1u128 << log_basis); + let mut out = Vec::with_capacity(levels); + let mut power = F::one(); + for _ in 0..levels { + out.push(power); + power = power * base; + } + out +} + +pub(crate) fn r_decomp_levels(log_basis: u32) -> usize { + let modulus = detect_field_modulus::(); + let bits = 128 - (modulus.saturating_sub(1)).leading_zeros() as usize; + let lb = log_basis as usize; + let mut levels = (bits + lb.saturating_sub(1)) / lb.max(1); + if levels == 0 { + levels = 1; + } + + let total_bits = levels * lb; + if total_bits <= bits { + let b = 1u128 << log_basis; + let half_q = modulus / 2; + let half_b_minus_1 = b / 2 - 1; + let b_minus_1 = b - 1; + let mut b_pow = 1u128; + for _ in 0..levels { + b_pow = b_pow.saturating_mul(b); + } + let max_positive = half_b_minus_1.saturating_mul((b_pow - 1) / b_minus_1); + if max_positive < half_q { + levels += 1; + } + } + + levels +} + +#[cfg(test)] +pub(crate) fn expand_m_a( + m_a: &[Vec], + alpha: F, + log_basis: u32, +) -> Result, HachiError> { + if m_a.is_empty() { + return Ok(Vec::new()); + } + let rows = m_a.len(); + let cols = m_a[0].len(); + if cols == 0 { + return Ok(vec![F::zero(); rows]); + } + for row in m_a.iter() { + if row.len() != cols { + return Err(HachiError::InvalidSize { + expected: cols, + actual: row.len(), + }); + } + } + + let levels = r_decomp_levels::(log_basis); + let total_cols = cols + .checked_add( + rows.checked_mul(levels) + .ok_or_else(|| HachiError::InvalidSetup("expanded M width overflow".to_string()))?, + ) + .ok_or_else(|| HachiError::InvalidSetup("expanded M width overflow".to_string()))?; + + let base = F::from_canonical_u128_reduced(1u128 << log_basis); + let mut gadget_row = Vec::with_capacity(levels); + let mut power = F::one(); + for _ in 0..levels { + gadget_row.push(power); + power = power * base; + } + + let mut alpha_pow = F::one(); + for _ in 0..D { + alpha_pow = alpha_pow * alpha; + } + let denom = alpha_pow + F::one(); + + let mut out = vec![F::zero(); rows * total_cols]; + for (i, m_a_row) in m_a.iter().enumerate() { + let row_start = i * total_cols; + out[row_start..row_start + cols].copy_from_slice(m_a_row); + let r_start = row_start + cols + i * levels; + for (j, g) in gadget_row.iter().enumerate() { + out[r_start + j] = -denom * *g; + } + } + Ok(out) +} + +/// # Errors +/// +/// Returns an error if `w.len()` is not a multiple of `d`. +pub(crate) fn build_w_evals( + w: &[F], + d: usize, +) -> Result<(Vec, usize, usize), HachiError> { + if d == 0 || w.len() % d != 0 { + return Err(HachiError::InvalidSize { + expected: d, + actual: w.len(), + }); + } + let num_l = d.trailing_zeros() as usize; + let num_ring_elems = w.len() / d; + let num_u = num_ring_elems.next_power_of_two().trailing_zeros() as usize; + let x_len = 1usize << num_u; + let n = x_len << num_l; + + let evals: Vec = cfg_into_iter!(0..n) + .map(|dst| { + let x = dst & (x_len - 1); + let y = dst >> num_u; + let src = y + (x << num_l); + if src < w.len() { + w[src] + } else { + F::zero() + } + }) + .collect(); + Ok((evals, num_u, num_l)) +} + +/// Produce both compact `Vec` and field `Vec` eval tables in one pass +/// over `w`, sharing the index computation. +pub(crate) fn build_w_evals_dual( + w: &[i8], + d: usize, + log_basis: u32, +) -> Result<(Vec, Vec, usize, usize), HachiError> { + if d == 0 || w.len() % d != 0 { + return Err(HachiError::InvalidSize { + expected: d, + actual: w.len(), + }); + } + let num_l = d.trailing_zeros() as usize; + let num_ring_elems = w.len() / d; + let num_u = num_ring_elems.next_power_of_two().trailing_zeros() as usize; + let x_len = 1usize << num_u; + let n = x_len << num_l; + + let lut = DigitLut::::new(log_basis); + let (compact, field): (Vec, Vec) = cfg_into_iter!(0..n) + .map(|dst| { + let x = dst & (x_len - 1); + let y = dst >> num_u; + let src = y + (x << num_l); + if src < w.len() { + let d = w[src]; + (d, lut.get(d)) + } else { + (0i8, F::zero()) + } + }) + .unzip(); + Ok((compact, field, num_u, num_l)) +} + +pub(crate) fn m_row_count() -> usize { + Cfg::N_D + Cfg::N_B + 1 + 1 + Cfg::N_A +} + +pub(crate) fn compute_m_evals_x( + setup: &HachiExpandedSetup, + opening_point: &RingOpeningPoint, + challenges: &[SparseChallenge], + alpha: F, + alpha_pows: &[F], + layout: HachiCommitmentLayout, + tau1: &[F], +) -> Result, HachiError> +where + Cfg: CommitmentConfig, +{ + if alpha_pows.len() != D { + return Err(HachiError::InvalidSize { + expected: D, + actual: alpha_pows.len(), + }); + } + + let depth_commit = layout.num_digits_commit; + let depth_open = layout.num_digits_open; + let depth_fold = layout.num_digits_fold; + let log_basis = layout.log_basis; + let num_blocks = opening_point.b.len(); + let block_len = layout.block_len; + let w_len = depth_open * num_blocks; + let t_len = depth_open * Cfg::N_A * num_blocks; + let inner_width = block_len * depth_commit; + let z_len = depth_fold * inner_width; + let rows = m_row_count::(); + let levels = r_decomp_levels::(log_basis); + let total_cols = w_len + .checked_add(t_len) + .and_then(|cols| cols.checked_add(z_len)) + .and_then(|cols| cols.checked_add(rows.checked_mul(levels)?)) + .ok_or_else(|| HachiError::InvalidSetup("expanded M width overflow".to_string()))?; + + let eq_tau1 = EqPolynomial::evals(tau1); + if eq_tau1.len() < rows { + return Err(HachiError::InvalidSize { + expected: rows, + actual: eq_tau1.len(), + }); + } + + let g1_open = gadget_row_scalars::(depth_open, log_basis); + let g1_commit = gadget_row_scalars::(depth_commit, log_basis); + let fold_gadget = gadget_row_scalars::(depth_fold, log_basis); + let r_gadget = gadget_row_scalars::(levels, log_basis); + let x_len = total_cols.next_power_of_two(); + let mut out = Vec::with_capacity(x_len); + + let c_alphas: Vec = challenges + .iter() + .map(|challenge| eval_sparse_challenge_at_pows::(challenge, alpha_pows)) + .collect::>()?; + + let d_view = setup.D_mat.view::(); + let b_view = setup.B.view::(); + let a_view = setup.A.view::(); + + let row3_weight = eq_tau1[Cfg::N_D + Cfg::N_B]; + let row4_weight = eq_tau1[Cfg::N_D + Cfg::N_B + 1]; + let a_weights = &eq_tau1[(Cfg::N_D + Cfg::N_B + 2)..rows]; + + let w_segment: Vec = cfg_into_iter!(0..w_len) + .map(|x| { + let block_idx = x / depth_open; + let digit_idx = x % depth_open; + let mut acc = (row3_weight * opening_point.b[block_idx] + + row4_weight * c_alphas[block_idx]) + * g1_open[digit_idx]; + for row_idx in 0..Cfg::N_D { + let eq_i = eq_tau1[row_idx]; + if !eq_i.is_zero() { + acc += eq_i * eval_ring_at_pows(&d_view.row(row_idx)[x], alpha_pows); + } + } + acc + }) + .collect(); + out.extend(w_segment); + + let t_segment: Vec = cfg_into_iter!(0..t_len) + .map(|x| { + let block_idx = x / (Cfg::N_A * depth_open); + let rem = x % (Cfg::N_A * depth_open); + let a_idx = rem / depth_open; + let digit_idx = rem % depth_open; + let mut acc = a_weights[a_idx] * c_alphas[block_idx] * g1_open[digit_idx]; + for row_idx in 0..Cfg::N_B { + let eq_i = eq_tau1[Cfg::N_D + row_idx]; + if !eq_i.is_zero() { + acc += eq_i * eval_ring_at_pows(&b_view.row(row_idx)[x], alpha_pows); + } + } + acc + }) + .collect(); + out.extend(t_segment); + + let z_base: Vec = cfg_into_iter!(0..inner_width) + .map(|k| { + let block_idx = k / depth_commit; + let digit_idx = k % depth_commit; + let mut acc = row4_weight * opening_point.a[block_idx] * g1_commit[digit_idx]; + for (a_idx, eq_i) in a_weights.iter().enumerate() { + if !eq_i.is_zero() { + acc += *eq_i * eval_ring_at_pows(&a_view.row(a_idx)[k], alpha_pows); + } + } + acc + }) + .collect(); + + let z_segment: Vec = cfg_into_iter!(0..z_len) + .map(|idx| { + let k = idx / depth_fold; + let fold_idx = idx % depth_fold; + -(z_base[k] * fold_gadget[fold_idx]) + }) + .collect(); + out.extend(z_segment); + + let alpha_pow_d = alpha_pows[D - 1] * alpha; + let denom = alpha_pow_d + F::one(); + let r_tail_len = rows * levels; + let r_tail: Vec = cfg_into_iter!(0..r_tail_len) + .map(|idx| { + let row_idx = idx / levels; + let level_idx = idx % levels; + -(eq_tau1[row_idx] * denom * r_gadget[level_idx]) + }) + .collect(); + out.extend(r_tail); + out.resize(x_len, F::zero()); + Ok(out) +} + +pub(crate) fn build_alpha_evals_y(alpha: F, d: usize) -> Vec { + let mut out = vec![F::zero(); d]; + let mut power = F::one(); + for val in out.iter_mut() { + *val = power; + power = power * alpha; + } + out +} + +pub(crate) fn sample_tau>( + transcript: &mut T, + label: &[u8], + n: usize, +) -> Vec { + (0..n).map(|_| transcript.challenge_scalar(label)).collect() +} + +pub(crate) fn build_w_coeffs( + w_hat: &[Vec<[i8; D]>], + t_hat: &[Vec<[i8; D]>], + z_pre: &[CyclotomicRing], + r: &[CyclotomicRing], + layout: HachiCommitmentLayout, +) -> Vec { + let log_basis = layout.log_basis; + let num_digits_fold = layout.num_digits_fold; + let levels = r_decomp_levels::(log_basis); + + let t_hat_flat = t_hat.iter().flat_map(|v| v.iter()); + + let w_hat_planes: usize = w_hat.iter().map(|v| v.len()).sum(); + let t_hat_planes: usize = t_hat.iter().map(|v| v.len()).sum(); + let z_count = w_hat_planes + t_hat_planes + z_pre.len() * num_digits_fold; + let r_hat_count = r.len() * levels; + eprintln!( + " [build_w_coeffs] w_hat_planes={w_hat_planes}, t_hat_planes={t_hat_planes}, z_pre_elems={}, z_pre_planes={}, r_elems={}, r_planes={r_hat_count}, total_ring={}, total_field={}", + z_pre.len(), z_pre.len() * num_digits_fold, r.len(), z_count + r_hat_count, (z_count + r_hat_count) * D, + ); + let mut out = Vec::with_capacity((z_count + r_hat_count) * D); + for block in w_hat { + for digits in block { + out.extend_from_slice(digits); + } + } + for digits in t_hat_flat { + out.extend_from_slice(digits); + } + for z_j in z_pre { + for plane in z_j.balanced_decompose_pow2_i8(num_digits_fold, log_basis) { + out.extend_from_slice(&plane); + } + } + for ri in r { + for plane in ri.balanced_decompose_pow2_i8(levels, log_basis) { + out.extend_from_slice(&plane); + } + } + out +} + +#[cfg(test)] +mod tests { + use super::compute_r_via_poly_division; + use crate::algebra::{CyclotomicRing, Fp64}; + use std::array::from_fn; + + use crate::{FieldCore, FromSmallInt}; + + fn compute_r_schoolbook( + m: &[Vec>], + z: &[CyclotomicRing], + y: &[CyclotomicRing], + ) -> Vec> { + let poly_len = 2 * D - 1; + m.iter() + .zip(y.iter()) + .map(|(row, y_i)| { + let mut poly = vec![F::zero(); poly_len]; + for (m_ij, z_j) in row.iter().zip(z.iter()) { + if m_ij.is_zero() { + continue; + } + let a = m_ij.coefficients(); + let b = z_j.coefficients(); + let is_scalar = a[1..].iter().all(|c| c.is_zero()); + if is_scalar { + let scalar = a[0]; + for s in 0..D { + poly[s] += scalar * b[s]; + } + } else { + for t in 0..D { + for s in 0..D { + poly[t + s] += a[t] * b[s]; + } + } + } + } + let y_coeffs = y_i.coefficients(); + for k in 0..D { + poly[k] -= y_coeffs[k]; + } + let mut quotient = vec![F::zero(); D]; + for k in (D..poly_len).rev() { + let q = poly[k]; + quotient[k - D] = q; + poly[k - D] -= q; + } + let coeffs: [F; D] = from_fn(|k| quotient[k]); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + } + + #[test] + fn compute_r_matches_schoolbook_reference() { + type F = Fp64<4294967197>; + const D: usize = 64; + + let m: Vec>> = (0..3) + .map(|i| { + (0..4) + .map(|j| { + if (i + j) % 3 == 0 { + let mut coeffs = [F::zero(); D]; + coeffs[0] = F::from_u64((i * 5 + j + 1) as u64); + CyclotomicRing::from_coefficients(coeffs) + } else { + let coeffs = from_fn(|k| { + F::from_u64((i as u64 * 1000 + j as u64 * 100 + k as u64 + 1) % 97) + }); + CyclotomicRing::from_coefficients(coeffs) + } + }) + .collect() + }) + .collect(); + let z: Vec> = (0..4) + .map(|j| { + let coeffs = from_fn(|k| F::from_u64((j as u64 * 37 + k as u64 + 5) % 89)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let y: Vec> = (0..3) + .map(|i| { + let coeffs = from_fn(|k| F::from_u64((i as u64 * 29 + k as u64 + 7) % 83)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + + let expected = compute_r_schoolbook(&m, &z, &y); + let got = compute_r_via_poly_division::(&m, &z, &y) + .expect("ring-switch CRT+NTT path should dispatch for D=64"); + assert_eq!(got, expected); + } +} diff --git a/src/protocol/sumcheck/batched_sumcheck.rs b/src/protocol/sumcheck/batched_sumcheck.rs new file mode 100644 index 00000000..e3ee7fc0 --- /dev/null +++ b/src/protocol/sumcheck/batched_sumcheck.rs @@ -0,0 +1,347 @@ +//! Batched sumcheck protocol. +//! +//! Implements the standard technique for batching parallel sumchecks to reduce +//! verifier cost and proof size. +//! +//! For details, refer to Jim Posen's ["Perspectives on Sumcheck Batching"](https://hackmd.io/s/HyxaupAAA). +//! We do what they describe as "front-loaded" batch sumcheck. +//! +//! Adapted from Jolt's `BatchedSumcheck` implementation. + +use super::{SumcheckInstanceProver, SumcheckInstanceVerifier, SumcheckProof, UniPoly}; +use crate::error::HachiError; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, FromSmallInt}; + +fn mul_pow_2(x: E, k: usize) -> E { + let mut result = x; + for _ in 0..k { + result = result + result; + } + result +} + +fn linear_combination(polys: &[UniPoly], coeffs: &[E]) -> UniPoly { + let max_len = polys.iter().map(|p| p.coeffs.len()).max().unwrap_or(0); + let mut result = vec![E::zero(); max_len]; + for (poly, coeff) in polys.iter().zip(coeffs.iter()) { + for (i, c) in poly.coeffs.iter().enumerate() { + result[i] += *c * *coeff; + } + } + UniPoly::from_coeffs(result) +} + +/// Verifier-side output of the batched sumcheck round replay. +/// +/// This carries all transcript-derived values needed for the final oracle check, +/// which is intentionally split out so callers can compute the expected output +/// claim through an external reduction (e.g. Greyhound) before enforcing +/// equality. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BatchedSumcheckRoundResult { + /// Final claim produced by replaying all sumcheck rounds. + pub output_claim: E, + /// Challenge vector sampled during replay. + pub r_sumcheck: Vec, + /// Front-loaded batching coefficient per verifier instance. + pub batching_coeffs: Vec, + /// Maximum number of rounds among batched instances. + pub max_num_rounds: usize, +} + +/// Produce a batched sumcheck proof for multiple instances sharing the same +/// variable space, driving the Fiat–Shamir transcript. +/// +/// This function: +/// - absorbs each instance's initial claim, +/// - samples batching coefficients (one per instance), +/// - computes a single batched round polynomial per round as a linear +/// combination of the individual round polynomials, +/// - returns a single [`SumcheckProof`] and the derived challenge vector. +/// +/// Instances with fewer rounds than the maximum are padded with constant +/// "dummy" round polynomials (the Jolt "front-loaded" approach). +/// +/// # Panics +/// +/// Panics if `instances` is empty or if 2 is not invertible in the field. +/// +/// # Errors +/// +/// Returns an error if the field inverse of 2 does not exist. +#[tracing::instrument(skip_all, name = "prove_batched_sumcheck")] +pub fn prove_batched_sumcheck( + mut instances: Vec<&mut dyn SumcheckInstanceProver>, + transcript: &mut T, + mut sample_challenge: S, +) -> Result<(SumcheckProof, Vec), HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, + E: FieldCore + FromSmallInt, + S: FnMut(&mut T) -> E, +{ + if instances.is_empty() { + return Err(HachiError::InvalidInput( + "no sumcheck instances provided".into(), + )); + } + + let max_num_rounds = instances + .iter() + .map(|inst| inst.num_rounds()) + .max() + .unwrap(); // safe: non-empty checked above + + // Absorb individual input claims. + for inst in instances.iter() { + let claim = inst.input_claim(); + transcript.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &claim); + } + + // Sample one batching coefficient per instance. + let batching_coeffs: Vec = (0..instances.len()) + .map(|_| sample_challenge(transcript)) + .collect(); + + // To see why we may need to scale by a power of two, consider a batch of + // two sumchecks: + // claim_a = \sum_x P(x) where x \in {0, 1}^M + // claim_b = \sum_{x, y} Q(x, y) where x \in {0, 1}^M, y \in {0, 1}^N + // Then the batched sumcheck is: + // \sum_{x, y} A * P(x) + B * Q(x, y) where A and B are batching coefficients + // = A * \sum_y \sum_x P(x) + B * \sum_{x, y} Q(x, y) + // = A * \sum_y claim_a + B * claim_b + // = A * 2^N * claim_a + B * claim_b + let mut individual_claims: Vec = instances + .iter() + .map(|inst| { + let n = inst.num_rounds(); + let claim = inst.input_claim(); + mul_pow_2(claim, max_num_rounds - n) + }) + .collect(); + + let mut round_polys = Vec::with_capacity(max_num_rounds); + let mut challenges = Vec::with_capacity(max_num_rounds); + + for round in 0..max_num_rounds { + let univariate_polys: Vec> = instances + .iter_mut() + .zip(individual_claims.iter()) + .map(|(inst, previous_claim)| { + let n = inst.num_rounds(); + let offset = max_num_rounds - n; + let active = round >= offset && round < offset + n; + if active { + inst.compute_round_univariate(round - offset, *previous_claim) + } else { + UniPoly::from_coeffs(vec![*previous_claim * E::TWO_INV]) + } + }) + .collect(); + + let batched_poly = linear_combination(&univariate_polys, &batching_coeffs); + + #[cfg(debug_assertions)] + { + let g0 = batched_poly.evaluate(&E::zero()); + let g1 = batched_poly.evaluate(&E::one()); + let batched_claim: E = individual_claims + .iter() + .zip(batching_coeffs.iter()) + .map(|(c, b)| *c * *b) + .fold(E::zero(), |a, v| a + v); + debug_assert!( + g0 + g1 == batched_claim, + "round {round}: H(0) + H(1) != batched claim" + ); + } + + let compressed = batched_poly.compress(); + transcript.append_serde(labels::ABSORB_SUMCHECK_ROUND, &compressed); + let r_j = sample_challenge(transcript); + challenges.push(r_j); + + // Update individual claims from each instance's own univariate. + for (claim, poly) in individual_claims.iter_mut().zip(univariate_polys.iter()) { + *claim = poly.evaluate(&r_j); + } + + // Ingest challenge into each active instance. + for inst in instances.iter_mut() { + let n = inst.num_rounds(); + let offset = max_num_rounds - n; + let active = round >= offset && round < offset + n; + if active { + inst.ingest_challenge(round - offset, r_j); + } + } + + round_polys.push(compressed); + } + + for inst in instances.iter_mut() { + inst.finalize(); + } + + Ok((SumcheckProof { round_polys }, challenges)) +} + +/// Verify a batched sumcheck proof. +/// +/// This function: +/// - absorbs each verifier instance's initial claim, +/// - re-derives the batching coefficients, +/// - computes the batched initial claim, +/// - verifies the proof against the batched claim. +/// +/// Returns transcript-derived verifier data for the caller to perform the final +/// expected-output equality check. +/// +/// # Panics +/// +/// Panics if `verifiers` is empty. +/// +/// # Errors +/// +/// Propagates per-round verification errors. +pub fn verify_batched_sumcheck_rounds( + proof: &SumcheckProof, + verifiers: Vec<&dyn SumcheckInstanceVerifier>, + transcript: &mut T, + mut sample_challenge: S, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, +{ + if verifiers.is_empty() { + return Err(HachiError::InvalidInput( + "no sumcheck instances provided".into(), + )); + } + + let max_degree = verifiers.iter().map(|v| v.degree_bound()).max().unwrap(); // safe: non-empty + let max_num_rounds = verifiers.iter().map(|v| v.num_rounds()).max().unwrap(); // safe: non-empty + + // Absorb individual input claims. + for v in verifiers.iter() { + let claim = v.input_claim(); + transcript.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &claim); + } + + // Re-derive batching coefficients. + let batching_coeffs: Vec = (0..verifiers.len()) + .map(|_| sample_challenge(transcript)) + .collect(); + + // Compute the combined initial claim with power-of-two scaling. + let batched_claim: E = verifiers + .iter() + .zip(batching_coeffs.iter()) + .map(|(v, coeff)| { + let n = v.num_rounds(); + let claim = v.input_claim(); + mul_pow_2(claim, max_num_rounds - n) * *coeff + }) + .fold(E::zero(), |a, v| a + v); + + let (output_claim, r_sumcheck) = proof.verify::( + batched_claim, + max_num_rounds, + max_degree, + transcript, + &mut sample_challenge, + )?; + + Ok(BatchedSumcheckRoundResult { + output_claim, + r_sumcheck, + batching_coeffs, + max_num_rounds, + }) +} + +/// Compute the expected batched output claim from verifier instances and +/// transcript-derived batching data. +/// +/// # Errors +/// +/// Propagates errors from verifier `expected_output_claim` calls. +pub fn compute_batched_expected_output_claim( + verifiers: Vec<&dyn SumcheckInstanceVerifier>, + batching_coeffs: &[E], + max_num_rounds: usize, + r_sumcheck: &[E], +) -> Result { + let expected_output_claim: E = verifiers + .iter() + .zip(batching_coeffs.iter()) + .map(|(v, coeff)| { + let offset = max_num_rounds - v.num_rounds(); + let r_slice = &r_sumcheck[offset..offset + v.num_rounds()]; + v.expected_output_claim(r_slice).map(|val| val * *coeff) + }) + .try_fold(E::zero(), |a, v| v.map(|val| a + val))?; + + Ok(expected_output_claim) +} + +/// Enforce final batched output-claim equality. +/// +/// # Errors +/// +/// Returns an error if `output_claim != expected_output_claim`. +pub fn check_batched_output_claim( + output_claim: E, + expected_output_claim: E, +) -> Result<(), HachiError> { + if output_claim != expected_output_claim { + return Err(HachiError::InvalidProof); + } + + Ok(()) +} + +/// Verify a batched sumcheck proof, including final expected-output equality. +/// +/// This convenience wrapper preserves the previous behavior. Callers that need +/// to inject an external reduction should use [`verify_batched_sumcheck_rounds`] +/// and [`check_batched_output_claim`] directly. +/// +/// # Errors +/// +/// Propagates errors from round verification and output-claim equality check. +#[tracing::instrument(skip_all, name = "verify_batched_sumcheck")] +pub fn verify_batched_sumcheck( + proof: &SumcheckProof, + verifiers: Vec<&dyn SumcheckInstanceVerifier>, + transcript: &mut T, + mut sample_challenge: S, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, +{ + let round_result = verify_batched_sumcheck_rounds::( + proof, + verifiers.clone(), + transcript, + &mut sample_challenge, + )?; + let expected_output_claim = compute_batched_expected_output_claim( + verifiers, + &round_result.batching_coeffs, + round_result.max_num_rounds, + &round_result.r_sumcheck, + )?; + check_batched_output_claim(round_result.output_claim, expected_output_claim)?; + Ok(round_result.r_sumcheck) +} diff --git a/src/protocol/sumcheck/eq_poly.rs b/src/protocol/sumcheck/eq_poly.rs new file mode 100644 index 00000000..8b14b248 --- /dev/null +++ b/src/protocol/sumcheck/eq_poly.rs @@ -0,0 +1,229 @@ +//! Utilities for the equality polynomial `eq(x, y) = Πᵢ (xᵢ yᵢ + (1 − xᵢ)(1 − yᵢ))`. +//! +//! The equality polynomial evaluates to 1 when `x = y` (over the boolean hypercube) +//! and 0 otherwise. Its multilinear extension (MLE) is used throughout sumcheck +//! protocols. +//! +//! Adapted from Jolt's `EqPolynomial` implementation. +//! +//! ## Bit / index order: Little-endian +//! +//! The evaluation tables produced by this module use **little-endian** bit order: +//! entry `b` (as an integer index) corresponds to the boolean vector where +//! bit `k` of `b` equals `x[k]`. In other words, `r[0]` corresponds to the +//! **least-significant bit** (bit 0) and `r[n-1]` to the MSB. + +use crate::FieldCore; +use std::marker::PhantomData; + +/// Utilities for the equality polynomial `eq(x, y) = Πᵢ (xᵢ yᵢ + (1 − xᵢ)(1 − yᵢ))`. +pub struct EqPolynomial(PhantomData); + +impl EqPolynomial { + /// Compute the MLE of the equality polynomial at two points: + /// `eq(x, y) = Πᵢ (xᵢ yᵢ + (1 − xᵢ)(1 − yᵢ))`. + /// + /// # Panics + /// + /// Panics if `x.len() != y.len()`. + pub fn mle(x: &[E], y: &[E]) -> E { + assert_eq!(x.len(), y.len()); + x.iter() + .zip(y.iter()) + .map(|(&x_i, &y_i)| x_i * y_i + (E::one() - x_i) * (E::one() - y_i)) + .fold(E::one(), |acc, v| acc * v) + } + + /// Compute the zero selector: `eq(r, 0) = Πᵢ (1 − rᵢ)`. + pub fn zero_selector(r: &[E]) -> E { + r.iter().fold(E::one(), |acc, &r_i| acc * (E::one() - r_i)) + } + + /// Compute the full evaluation table `{ eq(r, x) : x ∈ {0,1}^n }`. + /// + /// Uses **little-endian** bit order: entry `b` has bit `k` of `b` + /// corresponding to `r[k]`. + /// + /// For a scaled table, use [`Self::evals_with_scaling`]. + pub fn evals(r: &[E]) -> Vec { + Self::evals_with_scaling(r, None) + } + + /// Compute the full evaluation table with optional scaling: + /// `scaling_factor · eq(r, x)` for all `x ∈ {0,1}^n`. + /// + /// Uses the same **little-endian** index order as [`Self::evals`]. + /// If `scaling_factor` is `None`, defaults to 1 (no scaling). + pub fn evals_with_scaling(r: &[E], scaling_factor: Option) -> Vec { + #[cfg(feature = "parallel")] + { + const PARALLEL_THRESHOLD: usize = 16; + if r.len() > PARALLEL_THRESHOLD { + return Self::evals_parallel(r, scaling_factor); + } + } + Self::evals_serial(r, scaling_factor) + } + + /// Serial (single-threaded) version of [`Self::evals_with_scaling`]. + /// + /// Uses **little-endian** index order. + pub fn evals_serial(r: &[E], scaling_factor: Option) -> Vec { + let size = 1usize << r.len(); + let mut evals = vec![E::zero(); size]; + evals[0] = scaling_factor.unwrap_or(E::one()); + let mut len = 1usize; + for &t in r.iter().rev() { + let one_minus_t = E::one() - t; + for j in (0..len).rev() { + evals[2 * j + 1] = evals[j] * t; + evals[2 * j] = evals[j] * one_minus_t; + } + len *= 2; + } + evals + } + + /// Compute eq evaluations and cache intermediate tables. + /// + /// Returns `result` where `result[j]` contains evaluations for the prefix + /// `r[..j]`: `result[j][x] = eq(r[..j], x)` for `x ∈ {0,1}^j`. + /// + /// So `result[0] = [1]`, `result[1]` has 2 entries, ..., and `result[n]` + /// equals [`Self::evals(r)`]. + pub fn evals_cached(r: &[E]) -> Vec> { + Self::evals_cached_with_scaling(r, None) + } + + /// Like [`Self::evals_cached`], but with optional scaling. + pub fn evals_cached_with_scaling(r: &[E], scaling_factor: Option) -> Vec> { + let mut result: Vec> = (0..r.len() + 1).map(|i| vec![E::zero(); 1 << i]).collect(); + result[0][0] = scaling_factor.unwrap_or(E::one()); + for j in 0..r.len() { + let idx = r.len() - 1 - j; + let t = r[idx]; + let one_minus_t = E::one() - t; + let prev_len = 1 << j; + for i in (0..prev_len).rev() { + result[j + 1][2 * i + 1] = result[j][i] * t; + result[j + 1][2 * i] = result[j][i] * one_minus_t; + } + } + result + } + + /// Parallel version of [`Self::evals_with_scaling`]. + /// + /// Uses rayon to compute the largest layers of the DP tree in parallel. + /// Uses the same **little-endian** index order as [`Self::evals`]. + #[cfg(feature = "parallel")] + pub fn evals_parallel(r: &[E], scaling_factor: Option) -> Vec { + use rayon::prelude::*; + + let final_size = 1usize << r.len(); + let mut evals = vec![E::zero(); final_size]; + evals[0] = scaling_factor.unwrap_or(E::one()); + let mut size = 1; + + // Forward iteration (r[0] first) produces little-endian ordering. + for &r_i in r.iter() { + let (evals_left, evals_right) = evals.split_at_mut(size); + let (evals_right, _) = evals_right.split_at_mut(size); + + evals_left + .par_iter_mut() + .zip(evals_right.par_iter_mut()) + .for_each(|(x, y)| { + *y = *x * r_i; + *x -= *y; + }); + + size *= 2; + } + + evals + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Fp64; + use crate::{FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp64<4294967197>; + + #[test] + fn evals_matches_mle_pointwise() { + let mut rng = StdRng::seed_from_u64(0xEE); + for n in 1..8 { + let r: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let table = EqPolynomial::evals(&r); + assert_eq!(table.len(), 1 << n); + for (idx, &val) in table.iter().enumerate() { + let bits: Vec = (0..n) + .map(|k| { + if (idx >> k) & 1 == 1 { + F::one() + } else { + F::zero() + } + }) + .collect(); + let expected = EqPolynomial::mle(&r, &bits); + assert_eq!(val, expected, "n={n} idx={idx}"); + } + } + } + + #[test] + fn evals_with_scaling_scales_uniformly() { + let mut rng = StdRng::seed_from_u64(0xAB); + let r: Vec = (0..5).map(|_| F::sample(&mut rng)).collect(); + let scale = F::from_u64(7); + let unscaled = EqPolynomial::evals(&r); + let scaled = EqPolynomial::evals_with_scaling(&r, Some(scale)); + for (u, s) in unscaled.iter().zip(scaled.iter()) { + assert_eq!(*s, *u * scale); + } + } + + #[test] + fn evals_cached_last_matches_evals() { + let mut rng = StdRng::seed_from_u64(0xCD); + for n in 1..8 { + let r: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let table = EqPolynomial::evals(&r); + let cached = EqPolynomial::evals_cached(&r); + assert_eq!(cached.len(), n + 1); + assert_eq!(cached[0], vec![F::one()]); + assert_eq!(*cached.last().unwrap(), table); + } + } + + #[test] + fn zero_selector_matches_mle_at_origin() { + let mut rng = StdRng::seed_from_u64(0x00); + for n in 1..8 { + let r: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let zeros = vec![F::zero(); n]; + let expected = EqPolynomial::mle(&r, &zeros); + let actual = EqPolynomial::zero_selector(&r); + assert_eq!(actual, expected, "n={n}"); + } + } + + #[cfg(feature = "parallel")] + #[test] + fn evals_parallel_matches_serial() { + let mut rng = StdRng::seed_from_u64(0xFF); + for n in 1..20 { + let r: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let serial = EqPolynomial::evals_serial(&r, None); + let parallel = EqPolynomial::evals_parallel(&r, None); + assert_eq!(serial, parallel, "n={n}"); + } + } +} diff --git a/src/protocol/sumcheck/hachi_sumcheck.rs b/src/protocol/sumcheck/hachi_sumcheck.rs new file mode 100644 index 00000000..7d448cf9 --- /dev/null +++ b/src/protocol/sumcheck/hachi_sumcheck.rs @@ -0,0 +1,734 @@ +//! Fused norm+relation sumcheck prover/verifier for the Hachi PCS. +//! +//! Eliminates the redundant `w_evals` clone by sharing a single `w_table` +//! across both the norm (F_0) and relation (F_α) sumcheck computations. +//! Supports compact `Vec` storage for round 0 (all entries in [-b/2, b/2)), +//! transitioning to `Vec` at half size after the first fold. + +use super::eq_poly::EqPolynomial; +use super::norm_sumcheck::{ + choose_round_kernel, compute_entry_coeffs, compute_entry_coeffs_x4, field_from_i128, + range_check_eval_i128, range_check_eval_precomputed, trim_trailing_zeros, NormRoundKernel, + PointEvalPrecomp, RangeAffinePrecomp, MAX_AFFINE_COEFFS, +}; +use super::split_eq::GruenSplitEq; +use super::{fold_evals_in_place, multilinear_eval, range_check_eval}; +use super::{SumcheckInstanceProver, SumcheckInstanceVerifier, UniPoly}; +use crate::algebra::fields::HasUnreducedOps; +use crate::algebra::CyclotomicRing; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::ring_switch::eval_ring_at; +use std::marker::PhantomData; + +use crate::{cfg_fold_reduce, cfg_into_iter}; +use std::iter; +use std::mem; +use std::time::Instant; + +use crate::{AdditiveGroup, CanonicalField, FieldCore, FromSmallInt}; + +enum WTable { + Compact(Vec), + Full(Vec), +} + +/// Fused norm+relation sumcheck prover. +/// +/// Holds a single `w_table` shared by both sumcheck instances, weighted +/// by `batching_coeff`. The round polynomial is +/// `batching_coeff * norm_round(t) + relation_round(t)`. +/// +/// Alpha and m are stored in compact form (sizes `2^num_l` and `2^num_u` +/// respectively) and folded only during rounds where their variables are active. +pub struct HachiSumcheckProver { + w_table: WTable, + batching_coeff: E, + + // Norm state + split_eq: GruenSplitEq, + round_kernel: NormRoundKernel, + point_precomp: Option>, + range_precomp: Option>, + b: usize, + + // Relation state (compact — not expanded to full domain) + alpha_compact: Vec, + m_compact: Vec, + num_u: usize, + + num_vars: usize, + relation_claim: E, + + norm_time_total: f64, + relation_time_total: f64, + fold_time_total: f64, + rounds_completed: usize, +} + +impl HachiSumcheckProver { + /// Create a fused norm+relation sumcheck prover. + /// + /// # Panics + /// + /// Panics if table sizes are inconsistent with `num_u` and `num_l`. + #[allow(clippy::too_many_arguments)] + #[tracing::instrument(skip_all, name = "HachiSumcheckProver::new")] + pub fn new( + batching_coeff: E, + w_evals_compact: Vec, + tau0: &[E], + b: usize, + alpha_evals_y: Vec, + m_evals_x: Vec, + num_u: usize, + num_l: usize, + relation_claim: E, + ) -> Self { + assert!(b >= 1, "b must be at least 1"); + let num_vars = num_u + num_l; + let n = 1usize << num_vars; + assert_eq!(w_evals_compact.len(), n); + assert_eq!(tau0.len(), num_vars); + assert_eq!(alpha_evals_y.len(), 1 << num_l); + assert_eq!(m_evals_x.len(), 1 << num_u); + + let round_kernel = choose_round_kernel(b); + let point_precomp = match round_kernel { + NormRoundKernel::PointEvalInterpolation => Some(PointEvalPrecomp::new(b)), + NormRoundKernel::AffineCoeffComposition => None, + }; + let range_precomp = match round_kernel { + NormRoundKernel::PointEvalInterpolation => None, + NormRoundKernel::AffineCoeffComposition => Some(RangeAffinePrecomp::new(b)), + }; + + Self { + w_table: WTable::Compact(w_evals_compact), + batching_coeff, + split_eq: GruenSplitEq::new(tau0), + round_kernel, + point_precomp, + range_precomp, + b, + alpha_compact: alpha_evals_y, + m_compact: m_evals_x, + num_u, + num_vars, + relation_claim, + norm_time_total: 0.0, + relation_time_total: 0.0, + fold_time_total: 0.0, + rounds_completed: 0, + } + } + + /// Accumulate `am * w_int` into split pos/neg accumulators. + /// `accum[pos_idx]` gets the product when `w_int >= 0`, + /// `accum[pos_idx + 1]` gets it when `w_int < 0`. + #[inline] + fn accum_signed_mul(accum: &mut [E::MulU64Accum], pos_idx: usize, am: E, w_int: i32) { + let prod = am.mul_u64_unreduced(w_int.unsigned_abs() as u64); + if w_int < 0 { + accum[pos_idx + 1] += prod; + } else { + accum[pos_idx] += prod; + } + } + + /// Reduce a (positive, negative) accumulator pair to a single field element. + #[inline] + fn reduce_signed_accum(pos: E::MulU64Accum, neg: E::MulU64Accum) -> E { + E::reduce_mul_u64_accum(pos) - E::reduce_mul_u64_accum(neg) + } + + /// Fused compact round 0: computes both the norm and relation round + /// polynomials in a single pass over `w_compact`, using i128/LUT + /// arithmetic for the norm and unreduced small-int multiplies for the + /// relation. Relation uses split pos/neg accumulators to avoid + /// wrapping-neg overflow in the unsigned limbed accumulators. + #[tracing::instrument(skip_all, name = "HachiSumcheckProver::compute_round_compact_fused")] + fn compute_round_compact_fused(&self, w_compact: &[i8]) -> (UniPoly, UniPoly) { + let half = w_compact.len() / 2; + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + let current_x_width = self.num_u.saturating_sub(self.rounds_completed); + let current_x_mask = (1usize << current_x_width).wrapping_sub(1); + let alpha_compact = &self.alpha_compact; + let m_compact = &self.m_compact; + let b = self.b; + + // 6-element array: [pos0, neg0, pos1, neg1, pos2, neg2] + type RelAccum = [::MulU64Accum; 6]; + let rel_zero = || -> RelAccum { [E::MulU64Accum::ZERO; 6] }; + #[allow(unused_variables)] + let rel_combine = |a: &mut RelAccum, b: &RelAccum| { + for i in 0..6 { + a[i] += b[i]; + } + }; + let rel_reduce = |r: RelAccum| -> [E; 3] { + [ + Self::reduce_signed_accum(r[0], r[1]), + Self::reduce_signed_accum(r[2], r[3]), + Self::reduce_signed_accum(r[4], r[5]), + ] + }; + + match self.round_kernel { + NormRoundKernel::PointEvalInterpolation if b <= 10 => { + let degree_q = 2 * b - 1; + let num_points_q = degree_q + 1; + + let _span = tracing::info_span!("fused_compact_point_eval").entered(); + let (q_evals, rel_accum) = cfg_fold_reduce!( + 0..half, + || (vec![E::zero(); num_points_q], rel_zero()), + |(mut norm_evals, mut rel), j| { + let w0_i = w_compact[2 * j] as i32; + let w1_i = w_compact[2 * j + 1] as i32; + let delta_i = w1_i - w0_i; + + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let mut w_t_i = w0_i; + for eval in norm_evals.iter_mut() { + let rc = range_check_eval_i128(w_t_i, b); + *eval += eq_rem * field_from_i128::(rc); + w_t_i += delta_i; + } + + let a_0 = alpha_compact[(2 * j) >> current_x_width]; + let a_1 = alpha_compact[(2 * j + 1) >> current_x_width]; + let m_0 = m_compact[(2 * j) & current_x_mask]; + let m_1 = m_compact[(2 * j + 1) & current_x_mask]; + let am_0 = a_0 * m_0; + let am_1 = a_1 * m_1; + let w2_i = 2 * w1_i - w0_i; + let am_2 = (a_1 + a_1 - a_0) * (m_1 + m_1 - m_0); + + Self::accum_signed_mul(&mut rel, 0, am_0, w0_i); + Self::accum_signed_mul(&mut rel, 2, am_1, w1_i); + Self::accum_signed_mul(&mut rel, 4, am_2, w2_i); + + (norm_evals, rel) + }, + |(mut na, mut ra), (nb, rb)| { + for (ai, bi) in na.iter_mut().zip(nb.iter()) { + *ai += *bi; + } + rel_combine(&mut ra, &rb); + (na, ra) + } + ); + + let q_poly = UniPoly::from_evals(&q_evals); + let norm_poly = self.split_eq.gruen_mul(&q_poly); + let rel_evals = rel_reduce(rel_accum); + (norm_poly, UniPoly::from_evals(&rel_evals)) + } + NormRoundKernel::AffineCoeffComposition => { + let rp = self.range_precomp.as_ref().unwrap(); + let num_coeffs_q = rp.degree_q + 1; + + let _span = tracing::info_span!("fused_compact_affine_coeff").entered(); + let (mut q_coeffs, rel_accum) = cfg_fold_reduce!( + 0..e_second.len(), + || (vec![E::ProductAccum::ZERO; num_coeffs_q], rel_zero()), + |(mut outer_accum, mut rel), j_high| { + debug_assert!(num_coeffs_q <= MAX_AFFINE_COEFFS); + let mut inner_accum = [E::ProductAccum::ZERO; MAX_AFFINE_COEFFS]; + for (j_low, &e_in) in e_first.iter().enumerate() { + let j = j_high * num_first + j_low; + let w0_int = w_compact[2 * j]; + let w1_int = w_compact[2 * j + 1]; + + let w_1 = E::from_i64(w1_int as i64); + let a = w_1 - E::from_i64(w0_int as i64); + let mut a_pow = E::one(); + for (i, acc) in inner_accum[..num_coeffs_q].iter_mut().enumerate() { + let h_i_w0 = rp.h_i_lut(w0_int, i); + let val = a_pow * h_i_w0; + *acc += e_in.mul_to_product_accum(val); + a_pow = a_pow * a; + } + + let a_0 = alpha_compact[(2 * j) >> current_x_width]; + let a_1 = alpha_compact[(2 * j + 1) >> current_x_width]; + let m_0 = m_compact[(2 * j) & current_x_mask]; + let m_1 = m_compact[(2 * j + 1) & current_x_mask]; + let am_0 = a_0 * m_0; + let am_1 = a_1 * m_1; + let w2_i = 2 * w1_int as i32 - w0_int as i32; + let am_2 = (a_1 + a_1 - a_0) * (m_1 + m_1 - m_0); + + Self::accum_signed_mul(&mut rel, 0, am_0, w0_int as i32); + Self::accum_signed_mul(&mut rel, 2, am_1, w1_int as i32); + Self::accum_signed_mul(&mut rel, 4, am_2, w2_i); + } + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = E::reduce_product_accum(inner_accum[k]); + outer_accum[k] += e_out.mul_to_product_accum(inner_reduced); + } + (outer_accum, rel) + }, + |(mut ca, mut ra), (cb, rb)| { + for (ai, bi) in ca.iter_mut().zip(cb.iter()) { + *ai += *bi; + } + rel_combine(&mut ra, &rb); + (ca, ra) + } + ); + + let q_coeffs_reduced: Vec = + q_coeffs.drain(..).map(E::reduce_product_accum).collect(); + let mut q_coeffs = q_coeffs_reduced; + trim_trailing_zeros(&mut q_coeffs); + let q_poly = UniPoly::from_coeffs(q_coeffs); + let norm_poly = self.split_eq.gruen_mul(&q_poly); + let rel_evals = rel_reduce(rel_accum); + (norm_poly, UniPoly::from_evals(&rel_evals)) + } + _ => { + // b > 10 with point-eval: fall back to separate passes + let _span = tracing::info_span!("compact_fallback").entered(); + use super::norm_sumcheck::compute_norm_round_poly_compact; + let np = compute_norm_round_poly_compact( + &self.split_eq, + w_compact, + b, + self.round_kernel, + self.point_precomp.as_ref(), + self.range_precomp.as_ref(), + ); + let pair = |j: usize| { + ( + E::from_i64(w_compact[2 * j] as i64), + E::from_i64(w_compact[2 * j + 1] as i64), + ) + }; + let rel_evals = cfg_fold_reduce!( + 0..half, + || [E::zero(); 3], + |mut evals, j| { + let (w_0, w_1) = pair(j); + let a_0 = alpha_compact[(2 * j) >> current_x_width]; + let a_1 = alpha_compact[(2 * j + 1) >> current_x_width]; + let m_0 = m_compact[(2 * j) & current_x_mask]; + let m_1 = m_compact[(2 * j + 1) & current_x_mask]; + evals[0] += w_0 * a_0 * m_0; + evals[1] += w_1 * a_1 * m_1; + let w_2 = w_1 + w_1 - w_0; + let a_2 = a_1 + a_1 - a_0; + let m_2 = m_1 + m_1 - m_0; + evals[2] += w_2 * a_2 * m_2; + evals + }, + |mut a, b| { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + *ai += *bi; + } + a + } + ); + (np, UniPoly::from_evals(&rel_evals)) + } + } + } + + fn fold_compact_to_full(w_compact: &[i8], r: E) -> Vec { + cfg_into_iter!(0..w_compact.len() / 2) + .map(|j| { + let w_0 = E::from_i64(w_compact[2 * j] as i64); + let delta = w_compact[2 * j + 1] as i32 - w_compact[2 * j] as i32; + let delta_abs = delta.unsigned_abs() as u64; + let r_delta = E::reduce_mul_u64_accum(r.mul_u64_unreduced(delta_abs)); + if delta < 0 { + w_0 - r_delta + } else { + w_0 + r_delta + } + }) + .collect() + } +} + +impl SumcheckInstanceProver + for HachiSumcheckProver +{ + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 2 * self.b + } + + fn input_claim(&self) -> E { + self.relation_claim + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + let t_norm = Instant::now(); + let (norm_poly, relation_poly) = match &self.w_table { + WTable::Compact(w_compact) => { + let result = self.compute_round_compact_fused(w_compact); + self.norm_time_total += t_norm.elapsed().as_secs_f64(); + result + } + WTable::Full(w_full) => { + let half = w_full.len() / 2; + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + let current_x_width = self.num_u.saturating_sub(self.rounds_completed); + let current_x_mask = (1usize << current_x_width).wrapping_sub(1); + let alpha_compact = &self.alpha_compact; + let m_compact = &self.m_compact; + + let _span = tracing::info_span!("fused_norm_relation").entered(); + + let (np, rp) = match self.round_kernel { + NormRoundKernel::PointEvalInterpolation => { + let degree_q = 2 * self.b - 1; + let num_points_q = degree_q + 1; + let offsets_sq = &self.point_precomp.as_ref().unwrap().range_offsets_sq; + + let (q_evals, rel_evals) = cfg_fold_reduce!( + 0..half, + || (vec![E::zero(); num_points_q], [E::zero(); 3]), + |(mut norm_evals, mut rel_evals), j| { + let w_0 = w_full[2 * j]; + let w_1 = w_full[2 * j + 1]; + + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let delta = w_1 - w_0; + let mut w_t = w_0; + for eval in norm_evals.iter_mut() { + *eval += eq_rem * range_check_eval_precomputed(w_t, offsets_sq); + w_t += delta; + } + + let a_0 = alpha_compact[(2 * j) >> current_x_width]; + let a_1 = alpha_compact[(2 * j + 1) >> current_x_width]; + let m_0 = m_compact[(2 * j) & current_x_mask]; + let m_1 = m_compact[(2 * j + 1) & current_x_mask]; + rel_evals[0] += w_0 * a_0 * m_0; + rel_evals[1] += w_1 * a_1 * m_1; + let w_2 = w_1 + w_1 - w_0; + let a_2 = a_1 + a_1 - a_0; + let m_2 = m_1 + m_1 - m_0; + rel_evals[2] += w_2 * a_2 * m_2; + + (norm_evals, rel_evals) + }, + |(mut na, mut ra), (nb, rb)| { + for (ai, bi) in na.iter_mut().zip(nb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (na, ra) + } + ); + + let q_poly = UniPoly::from_evals(&q_evals); + ( + self.split_eq.gruen_mul(&q_poly), + UniPoly::from_evals(&rel_evals), + ) + } + NormRoundKernel::AffineCoeffComposition => { + let range_pc = self.range_precomp.as_ref().unwrap(); + let num_coeffs_q = range_pc.degree_q + 1; + debug_assert!(num_coeffs_q <= MAX_AFFINE_COEFFS); + + let (mut q_coeffs, rel_evals) = cfg_fold_reduce!( + 0..e_second.len(), + || (vec![E::ProductAccum::ZERO; num_coeffs_q], [E::zero(); 3]), + |(mut outer_accum, mut rel_evals), j_high| { + let mut inner_accum = [E::ProductAccum::ZERO; MAX_AFFINE_COEFFS]; + let base_j = j_high * num_first; + let full_chunks = num_first / 4; + let mut batch_out = [[E::zero(); MAX_AFFINE_COEFFS]; 4]; + + for chunk in 0..full_chunks { + let jl = chunk * 4; + let w = [ + (w_full[2 * (base_j + jl)], w_full[2 * (base_j + jl) + 1]), + ( + w_full[2 * (base_j + jl + 1)], + w_full[2 * (base_j + jl + 1) + 1], + ), + ( + w_full[2 * (base_j + jl + 2)], + w_full[2 * (base_j + jl + 2) + 1], + ), + ( + w_full[2 * (base_j + jl + 3)], + w_full[2 * (base_j + jl + 3) + 1], + ), + ]; + compute_entry_coeffs_x4( + &mut batch_out, + range_pc, + [w[0].0, w[1].0, w[2].0, w[3].0], + [ + w[0].1 - w[0].0, + w[1].1 - w[1].0, + w[2].1 - w[2].0, + w[3].1 - w[3].0, + ], + ); + for (b, bo) in batch_out.iter().enumerate() { + let e_in = e_first[jl + b]; + for (acc, &entry) in inner_accum[..num_coeffs_q] + .iter_mut() + .zip(bo[..num_coeffs_q].iter()) + { + *acc += e_in.mul_to_product_accum(entry); + } + } + for (b, &(w_0, w_1)) in w.iter().enumerate() { + let j = base_j + jl + b; + let a_0 = alpha_compact[(2 * j) >> current_x_width]; + let a_1 = alpha_compact[(2 * j + 1) >> current_x_width]; + let m_0 = m_compact[(2 * j) & current_x_mask]; + let m_1 = m_compact[(2 * j + 1) & current_x_mask]; + rel_evals[0] += w_0 * a_0 * m_0; + rel_evals[1] += w_1 * a_1 * m_1; + let w_2 = w_1 + w_1 - w_0; + let a_2 = a_1 + a_1 - a_0; + let m_2 = m_1 + m_1 - m_0; + rel_evals[2] += w_2 * a_2 * m_2; + } + } + + let mut entry_buf = [E::zero(); MAX_AFFINE_COEFFS]; + let mut w_pows_buf = [E::zero(); MAX_AFFINE_COEFFS]; + for (tail_idx, &e_in) in + e_first[full_chunks * 4..].iter().enumerate() + { + let j = base_j + full_chunks * 4 + tail_idx; + let w_0 = w_full[2 * j]; + let w_1 = w_full[2 * j + 1]; + compute_entry_coeffs( + &mut entry_buf, + &mut w_pows_buf, + range_pc, + w_0, + w_1 - w_0, + ); + for (acc, &entry) in inner_accum[..num_coeffs_q] + .iter_mut() + .zip(entry_buf[..num_coeffs_q].iter()) + { + *acc += e_in.mul_to_product_accum(entry); + } + let a_0 = alpha_compact[(2 * j) >> current_x_width]; + let a_1 = alpha_compact[(2 * j + 1) >> current_x_width]; + let m_0 = m_compact[(2 * j) & current_x_mask]; + let m_1 = m_compact[(2 * j + 1) & current_x_mask]; + rel_evals[0] += w_0 * a_0 * m_0; + rel_evals[1] += w_1 * a_1 * m_1; + let w_2 = w_1 + w_1 - w_0; + let a_2 = a_1 + a_1 - a_0; + let m_2 = m_1 + m_1 - m_0; + rel_evals[2] += w_2 * a_2 * m_2; + } + + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = E::reduce_product_accum(inner_accum[k]); + outer_accum[k] += e_out.mul_to_product_accum(inner_reduced); + } + (outer_accum, rel_evals) + }, + |(mut ca, mut ra), (cb, rb)| { + for (ai, bi) in ca.iter_mut().zip(cb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (ca, ra) + } + ); + + let mut q_coeffs: Vec = + q_coeffs.drain(..).map(E::reduce_product_accum).collect(); + trim_trailing_zeros(&mut q_coeffs); + let q_poly = UniPoly::from_coeffs(q_coeffs); + ( + self.split_eq.gruen_mul(&q_poly), + UniPoly::from_evals(&rel_evals), + ) + } + }; + + self.norm_time_total += t_norm.elapsed().as_secs_f64(); + (np, rp) + } + }; + + let max_len = norm_poly.coeffs.len().max(relation_poly.coeffs.len()); + let mut combined = vec![E::zero(); max_len]; + for (i, c) in norm_poly.coeffs.iter().enumerate() { + combined[i] += self.batching_coeff * *c; + } + for (i, c) in relation_poly.coeffs.iter().enumerate() { + combined[i] += *c; + } + UniPoly::from_coeffs(combined) + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + let t_fold = Instant::now(); + let _span = tracing::info_span!("fold_round").entered(); + self.split_eq.bind(r); + + self.w_table = match mem::replace(&mut self.w_table, WTable::Full(Vec::new())) { + WTable::Compact(w_compact) => WTable::Full(Self::fold_compact_to_full(&w_compact, r)), + WTable::Full(mut w_full) => { + fold_evals_in_place(&mut w_full, r); + WTable::Full(w_full) + } + }; + + if self.rounds_completed < self.num_u { + fold_evals_in_place(&mut self.m_compact, r); + } else { + fold_evals_in_place(&mut self.alpha_compact, r); + } + + drop(_span); + self.fold_time_total += t_fold.elapsed().as_secs_f64(); + self.rounds_completed += 1; + + if self.rounds_completed == self.num_vars { + eprintln!( + " [fused_sc] {} rounds: norm={:.2}s, relation={:.2}s, fold={:.2}s", + self.num_vars, self.norm_time_total, self.relation_time_total, self.fold_time_total + ); + } + } +} + +/// Fused norm+relation sumcheck verifier. +pub struct HachiSumcheckVerifier { + batching_coeff: F, + w_evals: Vec, + /// When set, overrides the `w_val` computed from `w_evals` in + /// `expected_output_claim`. Used at intermediate fold levels where + /// the full w vector is not available. + w_val_override: Option, + tau0: Vec, + b: usize, + alpha_evals_y: Vec, + m_evals_x: Vec, + num_u: usize, + num_l: usize, + relation_claim: F, + _marker: PhantomData<[F; D]>, +} + +impl HachiSumcheckVerifier { + /// Create a fused verifier for the norm + relation sumcheck. + #[allow(clippy::too_many_arguments)] + #[tracing::instrument(skip_all, name = "HachiSumcheckVerifier::new")] + pub fn new( + batching_coeff: F, + w_evals: Vec, + tau0: Vec, + b: usize, + alpha_evals_y: Vec, + m_evals_x: Vec, + tau1: Vec, + v: Vec>, + u: Vec>, + y_ring: CyclotomicRing, + alpha: F, + num_u: usize, + num_l: usize, + ) -> Self { + let y_a: Vec = v + .iter() + .chain(u.iter()) + .chain(iter::once(&y_ring)) + .map(|r| eval_ring_at(r, &alpha)) + .collect(); + let eq_tau1 = EqPolynomial::evals(&tau1); + let mut relation_claim = F::zero(); + for (i, eq_i) in eq_tau1.iter().enumerate() { + let y_i = if i < y_a.len() { y_a[i] } else { F::zero() }; + relation_claim += *eq_i * y_i; + } + + Self { + batching_coeff, + w_evals, + w_val_override: None, + tau0, + b, + alpha_evals_y, + m_evals_x, + num_u, + num_l, + relation_claim, + _marker: PhantomData, + } + } + + /// Set the w_val override for intermediate fold levels where the + /// full w vector is not available. + pub fn with_w_val_override(mut self, w_val: F) -> Self { + self.w_val_override = Some(w_val); + self + } +} + +impl SumcheckInstanceVerifier + for HachiSumcheckVerifier +{ + fn num_rounds(&self) -> usize { + self.num_u + self.num_l + } + + fn degree_bound(&self) -> usize { + 2 * self.b + } + + fn input_claim(&self) -> F { + self.relation_claim + } + + fn expected_output_claim(&self, challenges: &[F]) -> Result { + let eq_val = EqPolynomial::mle(&self.tau0, challenges); + let w_val = match self.w_val_override { + Some(v) => v, + None => multilinear_eval(&self.w_evals, challenges)?, + }; + let norm_oracle = eq_val * range_check_eval(w_val, self.b); + + let (x_challenges, y_challenges) = challenges.split_at(self.num_u); + let alpha_val = multilinear_eval(&self.alpha_evals_y, y_challenges)?; + let m_val = multilinear_eval(&self.m_evals_x, x_challenges)?; + let relation_oracle = w_val * alpha_val * m_val; + + eprintln!( + " [expected_output] num_u={}, num_l={}, w_override={}, b={}, tau0.len={}, m_evals_x.len={}, alpha_evals_y.len={}", + self.num_u, self.num_l, self.w_val_override.is_some(), self.b, + self.tau0.len(), self.m_evals_x.len(), self.alpha_evals_y.len() + ); + + Ok(self.batching_coeff * norm_oracle + relation_oracle) + } +} diff --git a/src/protocol/sumcheck/mod.rs b/src/protocol/sumcheck/mod.rs new file mode 100644 index 00000000..bd3b88e5 --- /dev/null +++ b/src/protocol/sumcheck/mod.rs @@ -0,0 +1,229 @@ +//! Sumcheck protocol: traits, proof driver, and concrete instances. +//! +//! Types (`UniPoly`, `CompressedUniPoly`, `SumcheckProof`) live in the +//! [`types`] submodule. Polynomial evaluation utilities (`multilinear_eval`, +//! `fold_evals_in_place`, `range_check_eval`) live in [`crate::algebra::poly`]. +//! +//! ## Temporary duplication notice (Jolt integration) +//! +//! Jolt already has a mature, streaming-aware sumcheck implementation. Long-term, we +//! expect to extract the common sumcheck machinery into a dedicated crate and depend +//! on it from both Hachi and Jolt. Until that exists, this module intentionally +//! duplicates the essential sumcheck data types and transcript-driving logic as a +//! pragmatic workaround. + +pub mod batched_sumcheck; +pub mod eq_poly; +pub mod hachi_sumcheck; +pub mod norm_sumcheck; +pub mod relation_sumcheck; +pub mod split_eq; +pub mod types; + +use crate::error::HachiError; +use crate::primitives::serialization::Compress; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; + +pub use crate::algebra::poly::{ + fold_evals_in_place, multilinear_eval, multilinear_eval_small, range_check_eval, +}; +pub use types::{CompressedUniPoly, SumcheckProof, UniPoly}; + +/// Prover-side sumcheck instance interface. +/// +/// This trait encapsulates the protocol-specific logic required to compute each +/// per-round univariate polynomial `g_j(X)` and to update (fold) internal state +/// after receiving the verifier challenge `r_j`. +/// +/// Hachi §4.3 will implement concrete instances for `H_0` and `H_α`. +pub trait SumcheckInstanceProver: Send + Sync { + /// Number of rounds (i.e. number of variables bound by sumcheck). + fn num_rounds(&self) -> usize; + + /// Maximum allowed degree for any round univariate polynomial. + fn degree_bound(&self) -> usize; + + /// The initial claimed sum that this sumcheck instance is proving. + fn input_claim(&self) -> E; + + /// Compute the prover message `g_round(X)` given the previous running claim. + /// + /// In standard sumcheck, `previous_claim` is the expected value of the + /// remaining sum after binding previous challenges, and must satisfy: + /// + /// `g_round(0) + g_round(1) == previous_claim`. + fn compute_round_univariate(&mut self, round: usize, previous_claim: E) -> UniPoly; + + /// Ingest the verifier challenge `r_round` to fold/bind the current variable. + fn ingest_challenge(&mut self, round: usize, r_round: E); + + /// Optional end-of-protocol hook after the last challenge has been ingested. + fn finalize(&mut self) {} +} + +/// Verifier-side sumcheck instance interface. +/// +/// Implementations provide the initial claim and the oracle evaluation at the +/// challenge point, enabling the verifier to perform the final consistency check. +pub trait SumcheckInstanceVerifier: Send + Sync { + /// Number of rounds (i.e. number of variables bound by sumcheck). + fn num_rounds(&self) -> usize; + + /// Maximum allowed degree for any round univariate polynomial. + fn degree_bound(&self) -> usize; + + /// The initial claimed sum that this sumcheck instance is proving. + fn input_claim(&self) -> E; + + /// Compute the expected final evaluation `f(r_0, ..., r_{n-1})` at the + /// challenge point derived during the protocol. + /// + /// # Errors + /// + /// May return an error if internal evaluations fail (e.g., malformed + /// evaluation tables from untrusted proof data). + fn expected_output_claim(&self, challenges: &[E]) -> Result; +} + +/// Produce a sumcheck proof for a single instance, driving the Fiat-Shamir transcript. +/// +/// This method: +/// - does **not** absorb the initial claim into the transcript (callers should do so), +/// - appends each round message under `labels::ABSORB_SUMCHECK_ROUND`, +/// - samples one challenge per round via `sample_challenge`, +/// - updates the running claim using the per-round hint (`g(0)+g(1)`). +/// +/// It returns the proof, the derived point `r`, and the final claimed value at `r`. +/// +/// # Errors +/// +/// Returns an error if any per-round polynomial exceeds the instance's degree bound. +#[tracing::instrument(skip_all, name = "prove_sumcheck")] +#[inline(never)] +pub fn prove_sumcheck( + instance: &mut Inst, + transcript: &mut T, + mut sample_challenge: S, +) -> Result<(SumcheckProof, Vec, E), HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, + Inst: SumcheckInstanceProver, +{ + let mut claim = instance.input_claim(); + { + let mut buf = Vec::new(); + claim.serialize_with_mode(&mut buf, Compress::No).ok(); + eprintln!( + " [prove_sumcheck] input_claim is_zero={}, bytes={:02x?}, num_rounds={}", + claim.is_zero(), + &buf[..buf.len().min(16)], + instance.num_rounds() + ); + } + transcript.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &claim); + + let num_rounds = instance.num_rounds(); + let degree_bound = instance.degree_bound(); + + let mut round_polys = Vec::with_capacity(num_rounds); + let mut r = Vec::with_capacity(num_rounds); + + for round in 0..num_rounds { + let g = instance.compute_round_univariate(round, claim); + debug_assert!( + g.evaluate(&E::zero()) + g.evaluate(&E::one()) == claim, + "sumcheck round univariate does not match previous claim hint" + ); + + let compressed = g.compress(); + if compressed.degree() > degree_bound { + return Err(HachiError::InvalidInput(format!( + "sumcheck round poly degree {} exceeds bound {}", + compressed.degree(), + degree_bound + ))); + } + + transcript.append_serde(labels::ABSORB_SUMCHECK_ROUND, &compressed); + let r_i = sample_challenge(transcript); + r.push(r_i); + + claim = compressed.eval_from_hint(&claim, &r_i); + + instance.ingest_challenge(round, r_i); + round_polys.push(compressed); + } + + instance.finalize(); + + let proof = SumcheckProof { round_polys }; + Ok((proof, r, claim)) +} + +/// Verify a single-instance sumcheck proof. +/// +/// This function: +/// - absorbs the initial claim into the transcript, +/// - delegates round-by-round verification to [`SumcheckProof::verify`], +/// - performs the final oracle check: `final_claim == verifier.expected_output_claim(r)`. +/// +/// Returns the challenge point `r` on success. +/// +/// # Errors +/// +/// Returns [`HachiError::InvalidProof`] if the final sumcheck claim does not +/// match the oracle evaluation, or propagates any error from the per-round +/// verification (e.g. degree-bound violation, round-count mismatch). +#[tracing::instrument(skip_all, name = "verify_sumcheck")] +#[inline(never)] +pub fn verify_sumcheck( + proof: &SumcheckProof, + verifier: &V, + transcript: &mut T, + sample_challenge: S, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, + V: SumcheckInstanceVerifier, +{ + let claim = verifier.input_claim(); + { + let mut buf = Vec::new(); + claim.serialize_with_mode(&mut buf, Compress::No).ok(); + eprintln!( + " [verify_sumcheck] input_claim is_zero={}, bytes={:02x?}, num_rounds={}", + claim.is_zero(), + &buf[..buf.len().min(16)], + verifier.num_rounds() + ); + } + transcript.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &claim); + let (final_claim, challenges) = proof.verify::( + claim, + verifier.num_rounds(), + verifier.degree_bound(), + transcript, + sample_challenge, + )?; + + let expected = verifier.expected_output_claim(&challenges)?; + if final_claim != expected { + eprintln!( + "[verify_sumcheck] MISMATCH: rounds={}, degree_bound={}", + verifier.num_rounds(), + verifier.degree_bound(), + ); + eprintln!(" diff_is_zero = {}", (final_claim - expected).is_zero()); + return Err(HachiError::InvalidProof); + } + + Ok(challenges) +} diff --git a/src/protocol/sumcheck/norm_sumcheck.rs b/src/protocol/sumcheck/norm_sumcheck.rs new file mode 100644 index 00000000..4e2ff2f5 --- /dev/null +++ b/src/protocol/sumcheck/norm_sumcheck.rs @@ -0,0 +1,1136 @@ +//! Norm (range-check) sumcheck instance (F_0). +//! +//! **F_{0,τ₀}(x, y)** = ẽq(τ₀,(x,y)) · w̃(x,y) · (w̃−1)(w̃+1)···(w̃−b+1)(w̃+b−1) +//! +//! Proves that all entries of w̃ lie in {−(b−1), …, b−1}; the sum over the +//! boolean hypercube should equal zero when the range constraint holds. + +use super::eq_poly::EqPolynomial; +use super::split_eq::GruenSplitEq; +use super::{fold_evals_in_place, multilinear_eval, range_check_eval}; +use super::{SumcheckInstanceProver, SumcheckInstanceVerifier, UniPoly}; +use crate::algebra::fields::HasUnreducedOps; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::{cfg_fold_reduce, AdditiveGroup, CanonicalField, FieldCore, FromSmallInt}; + +/// Max number of affine coefficient rows (degree_q + 1) for `b <= 8`. +pub(crate) const MAX_AFFINE_COEFFS: usize = 16; + +/// Which kernel to use for the norm sumcheck accumulation loop. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum NormRoundKernel { + /// Evaluate the range-check polynomial at `degree_q+1` points, then interpolate. + PointEvalInterpolation, + /// Directly accumulate polynomial coefficients via affine substitution. + AffineCoeffComposition, +} + +/// Select the norm kernel for a given `b`. +/// +/// Override with env var `HACHI_NORM_KERNEL=point_eval` or `affine_coeff`. +pub fn choose_round_kernel(b: usize) -> NormRoundKernel { + if let Ok(v) = std::env::var("HACHI_NORM_KERNEL") { + match v.as_str() { + "point_eval" => return NormRoundKernel::PointEvalInterpolation, + "affine_coeff" => return NormRoundKernel::AffineCoeffComposition, + _ => {} + } + } + if b <= 8 { + NormRoundKernel::AffineCoeffComposition + } else { + NormRoundKernel::PointEvalInterpolation + } +} + +/// A nonzero coefficient entry in the affine decomposition polynomial. +#[derive(Clone, Copy, Debug)] +pub(crate) struct SparseCoeffEntry { + /// Power index: which `w_0^k` this coefficient multiplies. + pub k: u8, + /// Absolute value of the mixed coefficient (fits u64 for b <= 8). + pub abs_coeff: u64, + /// Sign: true if the coefficient is negative. + pub is_neg: bool, +} + +#[derive(Clone)] +pub(crate) struct RangeAffinePrecomp { + /// Flat storage of nonzero `coeff_mix[i][k]` entries. + sparse_entries: Vec, + /// `sparse_row_offsets[i]..sparse_row_offsets[i+1]` indexes into `sparse_entries`. + sparse_row_offsets: Vec, + pub(crate) degree_q: usize, + /// Precomputed `h_i(w_0)` for all small-integer `w_0 ∈ {-(b-1),...,b-1}`. + /// Indexed as `small_w_lut[(w_0 + b - 1) * num_rows + i]`. + small_w_lut: Vec, + b: usize, +} + +/// Integer version of `range_check_coeffs`: returns the polynomial coefficients +/// of `R(w) = w * Π_{k=1}^{b-1}(w² - k²)` as exact i64 values. +fn range_check_coeffs_int(b: usize) -> Vec { + assert!(b >= 1, "b must be at least 1"); + let mut coeffs: Vec = vec![0, 1]; + for k in 1..b as i64 { + let k_sq = k * k; + let mut next = vec![0i64; coeffs.len() + 2]; + for (idx, &c) in coeffs.iter().enumerate() { + next[idx] -= c * k_sq; + next[idx + 2] += c; + } + coeffs = next; + } + coeffs +} + +impl RangeAffinePrecomp { + pub(crate) fn new(b: usize) -> Self { + assert!(b >= 1, "b must be at least 1"); + + let range_coeffs = range_check_coeffs_int(b); + let degree_q = range_coeffs.len() - 1; + let num_rows = degree_q + 1; + + // Build dense integer coeff_mix and sparse entries simultaneously. + let total_elems = num_rows * (num_rows + 1) / 2; + let mut dense_int = Vec::with_capacity(total_elems); + let mut dense_row_offsets = Vec::with_capacity(num_rows + 1); + let mut sparse_entries = Vec::new(); + let mut sparse_row_offsets = Vec::with_capacity(num_rows + 1); + + for i in 0..num_rows { + dense_row_offsets.push(dense_int.len()); + sparse_row_offsets.push(sparse_entries.len()); + let row_len = degree_q - i + 1; + let mut binom: i64 = 1; // binom(i, i) = 1 + for k in 0..row_len { + let m = i + k; + let coeff = range_coeffs[m] * binom; + dense_int.push(coeff); + if coeff != 0 { + sparse_entries.push(SparseCoeffEntry { + k: k as u8, + abs_coeff: coeff.unsigned_abs(), + is_neg: coeff < 0, + }); + } + if k + 1 < row_len { + binom = binom * (m as i64 + 1) / (k as i64 + 1); + } + } + } + dense_row_offsets.push(dense_int.len()); + sparse_row_offsets.push(sparse_entries.len()); + + // Precompute LUT using i128 integer Horner. + let num_w_vals = 2 * b - 1; + let mut small_w_lut = vec![E::zero(); num_w_vals * num_rows]; + for (w_idx, w_0_int) in (-(b as i64 - 1)..=(b as i64 - 1)).enumerate() { + for i in 0..num_rows { + let row = &dense_int[dense_row_offsets[i]..dense_row_offsets[i + 1]]; + let mut h: i128 = 0; + for &c in row.iter().rev() { + h = h * w_0_int as i128 + c as i128; + } + small_w_lut[w_idx * num_rows + i] = E::from_i128(h); + } + } + + Self { + sparse_entries, + sparse_row_offsets, + degree_q, + small_w_lut, + b, + } + } +} + +impl RangeAffinePrecomp { + #[inline] + pub(crate) fn sparse_row(&self, i: usize) -> &[SparseCoeffEntry] { + &self.sparse_entries[self.sparse_row_offsets[i]..self.sparse_row_offsets[i + 1]] + } + + pub(crate) fn num_rows(&self) -> usize { + self.degree_q + 1 + } + + #[inline] + pub(crate) fn h_i_lut(&self, w_0_int: i8, i: usize) -> E { + let w_idx = (w_0_int as i16 + self.b as i16 - 1) as usize; + self.small_w_lut[w_idx * self.num_rows() + i] + } +} + +#[derive(Clone)] +pub(crate) struct PointEvalPrecomp { + pub(crate) range_offsets_sq: Vec, +} + +impl PointEvalPrecomp { + pub(crate) fn new(b: usize) -> Self { + assert!(b >= 1, "b must be at least 1"); + let range_offsets_sq = (1..b) + .map(|k| { + let k_e = E::from_u64(k as u64); + k_e * k_e + }) + .collect(); + Self { range_offsets_sq } + } +} + +/// Evaluate `R(w) = w · Π_{k=1}^{b-1}(w² - k²)` in native `i128` arithmetic. +/// +/// Only valid for `b <= 10` (intermediates fit i128; verified up to ~2^117 for b=8). +/// Panics in debug mode if an intermediate overflows. +#[inline] +pub(crate) fn range_check_eval_i128(w: i32, b: usize) -> i128 { + debug_assert!(b <= 10, "i128 range-check only valid for b <= 10"); + let s = (w as i128) * (w as i128); + let mut acc = w as i128; + for k in 1..b as i128 { + acc = acc + .checked_mul(s - k * k) + .expect("i128 overflow in range-check"); + } + acc +} + +/// Convert an `i128` to a field element via `CanonicalField::from_canonical_u128_reduced`. +#[inline] +pub(crate) fn field_from_i128(val: i128) -> E { + if val >= 0 { + E::from_canonical_u128_reduced(val as u128) + } else { + -E::from_canonical_u128_reduced(val.unsigned_abs()) + } +} + +pub(crate) fn range_check_eval_precomputed(w: E, offsets_sq: &[E]) -> E { + let s = w * w; + let mut acc = w; + for &k_sq in offsets_sq { + acc = acc * (s - k_sq); + } + acc +} + +/// Compute per-entry affine range-check coefficients using power-table + +/// sparse unreduced dot product. Writes `a^i · h_i(w_0)` into `out[i]` +/// for `i ∈ 0..precomp.num_rows()`. +/// +/// `w_pows` is a caller-provided scratch buffer of length >= `degree_q + 1`. +#[inline] +pub(crate) fn compute_entry_coeffs( + out: &mut [E], + w_pows: &mut [E], + precomp: &RangeAffinePrecomp, + w_0: E, + a: E, +) { + let deg = precomp.degree_q; + let num_rows = precomp.num_rows(); + debug_assert!(out.len() >= num_rows); + debug_assert!(w_pows.len() > deg); + + w_pows[0] = E::one(); + for k in 1..=deg { + w_pows[k] = w_pows[k - 1] * w_0; + } + + let mut a_pow = E::one(); + for (i, out_i) in out.iter_mut().enumerate().take(num_rows) { + let entries = precomp.sparse_row(i); + let mut pos = E::MulU64Accum::ZERO; + let mut neg = E::MulU64Accum::ZERO; + for entry in entries { + let prod = w_pows[entry.k as usize].mul_u64_unreduced(entry.abs_coeff); + if entry.is_neg { + neg += prod; + } else { + pos += prod; + } + } + let h_i = E::reduce_mul_u64_accum(pos) - E::reduce_mul_u64_accum(neg); + *out_i = a_pow * h_i; + a_pow = a_pow * a; + } +} + +/// Batched version: processes 4 entries simultaneously to expose ILP across +/// independent power-table and sparse-dot-product chains. +#[inline] +pub(crate) fn compute_entry_coeffs_x4( + out: &mut [[E; MAX_AFFINE_COEFFS]; 4], + precomp: &RangeAffinePrecomp, + w_0: [E; 4], + a: [E; 4], +) { + let deg = precomp.degree_q; + let num_rows = precomp.num_rows(); + + let mut pw = [[E::zero(); MAX_AFFINE_COEFFS]; 4]; + for p in &mut pw { + p[0] = E::one(); + } + for k in 1..=deg { + pw[0][k] = pw[0][k - 1] * w_0[0]; + pw[1][k] = pw[1][k - 1] * w_0[1]; + pw[2][k] = pw[2][k - 1] * w_0[2]; + pw[3][k] = pw[3][k - 1] * w_0[3]; + } + + let mut ap = [E::one(); 4]; + for i in 0..num_rows { + let entries = precomp.sparse_row(i); + + let mut pos0 = E::MulU64Accum::ZERO; + let mut neg0 = E::MulU64Accum::ZERO; + let mut pos1 = E::MulU64Accum::ZERO; + let mut neg1 = E::MulU64Accum::ZERO; + let mut pos2 = E::MulU64Accum::ZERO; + let mut neg2 = E::MulU64Accum::ZERO; + let mut pos3 = E::MulU64Accum::ZERO; + let mut neg3 = E::MulU64Accum::ZERO; + + for entry in entries { + let k = entry.k as usize; + let c = entry.abs_coeff; + let p0 = pw[0][k].mul_u64_unreduced(c); + let p1 = pw[1][k].mul_u64_unreduced(c); + let p2 = pw[2][k].mul_u64_unreduced(c); + let p3 = pw[3][k].mul_u64_unreduced(c); + if entry.is_neg { + neg0 += p0; + neg1 += p1; + neg2 += p2; + neg3 += p3; + } else { + pos0 += p0; + pos1 += p1; + pos2 += p2; + pos3 += p3; + } + } + + let h0 = E::reduce_mul_u64_accum(pos0) - E::reduce_mul_u64_accum(neg0); + let h1 = E::reduce_mul_u64_accum(pos1) - E::reduce_mul_u64_accum(neg1); + let h2 = E::reduce_mul_u64_accum(pos2) - E::reduce_mul_u64_accum(neg2); + let h3 = E::reduce_mul_u64_accum(pos3) - E::reduce_mul_u64_accum(neg3); + + out[0][i] = ap[0] * h0; + out[1][i] = ap[1] * h1; + out[2][i] = ap[2] * h2; + out[3][i] = ap[3] * h3; + + ap[0] = ap[0] * a[0]; + ap[1] = ap[1] * a[1]; + ap[2] = ap[2] * a[2]; + ap[3] = ap[3] * a[3]; + } +} + +pub(crate) fn trim_trailing_zeros(coeffs: &mut Vec) { + while coeffs.len() > 1 && coeffs.last().is_some_and(|c| c.is_zero()) { + coeffs.pop(); + } +} + +/// Centralized norm round polynomial computation (full field-element path). +/// +/// Both `NormSumcheckProver` and `HachiSumcheckProver` delegate here. +pub(crate) fn compute_norm_round_poly( + split_eq: &GruenSplitEq, + half: usize, + b: usize, + round_kernel: NormRoundKernel, + point_precomp: Option<&PointEvalPrecomp>, + range_precomp: Option<&RangeAffinePrecomp>, + w_pair: impl Fn(usize) -> (E, E) + Sync, +) -> UniPoly { + let (e_first, e_second) = split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + + match round_kernel { + NormRoundKernel::PointEvalInterpolation => { + let degree_q = 2 * b - 1; + let num_points_q = degree_q + 1; + let offsets_sq = &point_precomp.unwrap().range_offsets_sq; + + let q_evals = { + let _span = tracing::info_span!("norm_accumulate", kernel = "point_eval").entered(); + cfg_fold_reduce!( + 0..half, + || vec![E::zero(); num_points_q], + |mut evals, j| { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let (w_0, w_1) = w_pair(j); + let delta = w_1 - w_0; + let mut w_t = w_0; + for eval in evals.iter_mut() { + *eval += eq_rem * range_check_eval_precomputed(w_t, offsets_sq); + w_t += delta; + } + evals + }, + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai += *bi; + } + a + } + ) + }; + + let q_poly = UniPoly::from_evals(&q_evals); + split_eq.gruen_mul(&q_poly) + } + NormRoundKernel::AffineCoeffComposition => { + let rp = range_precomp.unwrap(); + let num_coeffs_q = rp.degree_q + 1; + + let mut q_coeffs = { + let _span = + tracing::info_span!("norm_accumulate", kernel = "affine_coeff").entered(); + + cfg_fold_reduce!( + 0..e_second.len(), + || vec![E::ProductAccum::ZERO; num_coeffs_q], + |mut outer_accum, j_high| { + debug_assert!(num_coeffs_q <= MAX_AFFINE_COEFFS); + let mut inner_accum = [E::ProductAccum::ZERO; MAX_AFFINE_COEFFS]; + let base_j = j_high * num_first; + let full_chunks = e_first.len() / 4; + let mut batch_out = [[E::zero(); MAX_AFFINE_COEFFS]; 4]; + + for chunk in 0..full_chunks { + let jl = chunk * 4; + let pairs = [ + w_pair(base_j + jl), + w_pair(base_j + jl + 1), + w_pair(base_j + jl + 2), + w_pair(base_j + jl + 3), + ]; + compute_entry_coeffs_x4( + &mut batch_out, + rp, + [pairs[0].0, pairs[1].0, pairs[2].0, pairs[3].0], + [ + pairs[0].1 - pairs[0].0, + pairs[1].1 - pairs[1].0, + pairs[2].1 - pairs[2].0, + pairs[3].1 - pairs[3].0, + ], + ); + for (b, bo) in batch_out.iter().enumerate() { + let e_in = e_first[jl + b]; + for (acc, &entry) in inner_accum[..num_coeffs_q] + .iter_mut() + .zip(bo[..num_coeffs_q].iter()) + { + *acc += e_in.mul_to_product_accum(entry); + } + } + } + + let mut entry_buf = [E::zero(); MAX_AFFINE_COEFFS]; + let mut w_pows_buf = [E::zero(); MAX_AFFINE_COEFFS]; + for (tail_idx, &e_in) in e_first[full_chunks * 4..].iter().enumerate() { + let j = base_j + full_chunks * 4 + tail_idx; + let (w_0, w_1) = w_pair(j); + compute_entry_coeffs( + &mut entry_buf, + &mut w_pows_buf, + rp, + w_0, + w_1 - w_0, + ); + for (acc, &entry) in inner_accum[..num_coeffs_q] + .iter_mut() + .zip(entry_buf[..num_coeffs_q].iter()) + { + *acc += e_in.mul_to_product_accum(entry); + } + } + + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = E::reduce_product_accum(inner_accum[k]); + outer_accum[k] += e_out.mul_to_product_accum(inner_reduced); + } + outer_accum + }, + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai += *bi; + } + a + } + ) + } + .into_iter() + .map(E::reduce_product_accum) + .collect::>(); + + trim_trailing_zeros(&mut q_coeffs); + let q_poly = UniPoly::from_coeffs(q_coeffs); + split_eq.gruen_mul(&q_poly) + } + } +} + +/// Compact round-0 variant: uses native i128 arithmetic (point-eval, b<=10) +/// or precomputed LUT (affine-coeff) when w values are small integers. +pub(crate) fn compute_norm_round_poly_compact< + E: FieldCore + FromSmallInt + CanonicalField + HasUnreducedOps, +>( + split_eq: &GruenSplitEq, + w_compact: &[i8], + b: usize, + round_kernel: NormRoundKernel, + point_precomp: Option<&PointEvalPrecomp>, + range_precomp: Option<&RangeAffinePrecomp>, +) -> UniPoly { + let half = w_compact.len() / 2; + let (e_first, e_second) = split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + + match round_kernel { + NormRoundKernel::PointEvalInterpolation if b <= 10 => { + let degree_q = 2 * b - 1; + let num_points_q = degree_q + 1; + + let q_evals = { + let _span = + tracing::info_span!("norm_accumulate", kernel = "point_eval_i128").entered(); + cfg_fold_reduce!( + 0..half, + || vec![E::zero(); num_points_q], + |mut evals, j| { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w0_i = w_compact[2 * j] as i32; + let delta_i = w_compact[2 * j + 1] as i32 - w0_i; + let mut w_t_i = w0_i; + for eval in evals.iter_mut() { + let rc = range_check_eval_i128(w_t_i, b); + *eval += eq_rem * field_from_i128::(rc); + w_t_i += delta_i; + } + evals + }, + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai += *bi; + } + a + } + ) + }; + + let q_poly = UniPoly::from_evals(&q_evals); + split_eq.gruen_mul(&q_poly) + } + NormRoundKernel::AffineCoeffComposition => { + let rp = range_precomp.unwrap(); + let num_coeffs_q = rp.degree_q + 1; + + let mut q_coeffs = { + let _span = + tracing::info_span!("norm_accumulate", kernel = "affine_coeff_lut").entered(); + + cfg_fold_reduce!( + 0..e_second.len(), + || vec![E::ProductAccum::ZERO; num_coeffs_q], + |mut outer_accum, j_high| { + debug_assert!(num_coeffs_q <= MAX_AFFINE_COEFFS); + let mut inner_accum = [E::ProductAccum::ZERO; MAX_AFFINE_COEFFS]; + for (j_low, &e_in) in e_first.iter().enumerate() { + let j = j_high * num_first + j_low; + let w_0_int = w_compact[2 * j]; + let w_1 = E::from_i64(w_compact[2 * j + 1] as i64); + let a = w_1 - E::from_i64(w_0_int as i64); + let mut a_pow = E::one(); + for (i, acc) in inner_accum[..num_coeffs_q].iter_mut().enumerate() { + let h_i_w0 = rp.h_i_lut(w_0_int, i); + let val = a_pow * h_i_w0; + *acc += e_in.mul_to_product_accum(val); + a_pow = a_pow * a; + } + } + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = E::reduce_product_accum(inner_accum[k]); + outer_accum[k] += e_out.mul_to_product_accum(inner_reduced); + } + outer_accum + }, + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai += *bi; + } + a + } + ) + } + .into_iter() + .map(E::reduce_product_accum) + .collect::>(); + + trim_trailing_zeros(&mut q_coeffs); + let q_poly = UniPoly::from_coeffs(q_coeffs); + split_eq.gruen_mul(&q_poly) + } + _ => { + // b > 10 with point-eval: fall back to field-element path + let pair = |j: usize| { + ( + E::from_i64(w_compact[2 * j] as i64), + E::from_i64(w_compact[2 * j + 1] as i64), + ) + }; + compute_norm_round_poly( + split_eq, + half, + b, + round_kernel, + point_precomp, + range_precomp, + pair, + ) + } + } +} + +/// Prover for `F_{0,τ₀}(x,y) = ẽq(τ₀,(x,y)) · w̃(x,y) · range_check(w̃(x,y), b)`. +/// +/// Uses the Gruen/Dao-Thaler optimization: the eq polynomial is factored into +/// a running scalar and split tables instead of being stored as a full table +/// and folded each round. The round polynomial is computed as `l(X) · q(X)` +/// where `l(X)` is the linear eq factor and `q(X)` is the inner sum without +/// the current-variable eq contribution. +pub struct NormSumcheckProver { + split_eq: GruenSplitEq, + w_table: Vec, + round_kernel: NormRoundKernel, + point_precomp: Option>, + range_precomp: Option>, + num_vars: usize, + b: usize, +} + +impl NormSumcheckProver { + /// Create a new norm (range-check) sumcheck prover. + /// + /// # Panics + /// + /// Panics if `w_evals.len() != 2^tau.len()`. + pub fn new(tau: &[E], w_evals: Vec, b: usize) -> Self { + Self::new_with_kernel(tau, w_evals, b, choose_round_kernel(b)) + } + + fn new_with_kernel( + tau: &[E], + w_evals: Vec, + b: usize, + round_kernel: NormRoundKernel, + ) -> Self { + assert!(b >= 1, "b must be at least 1"); + let num_vars = tau.len(); + assert_eq!(w_evals.len(), 1 << num_vars); + let point_precomp = match round_kernel { + NormRoundKernel::PointEvalInterpolation => Some(PointEvalPrecomp::new(b)), + NormRoundKernel::AffineCoeffComposition => None, + }; + let range_precomp = match round_kernel { + NormRoundKernel::PointEvalInterpolation => None, + NormRoundKernel::AffineCoeffComposition => Some(RangeAffinePrecomp::new(b)), + }; + Self { + split_eq: GruenSplitEq::new(tau), + w_table: w_evals, + round_kernel, + point_precomp, + range_precomp, + num_vars, + b, + } + } +} + +impl SumcheckInstanceProver + for NormSumcheckProver +{ + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 2 * self.b + } + + fn input_claim(&self) -> E { + E::zero() + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + let half = self.w_table.len() / 2; + let w_table = &self.w_table; + compute_norm_round_poly( + &self.split_eq, + half, + self.b, + self.round_kernel, + self.point_precomp.as_ref(), + self.range_precomp.as_ref(), + |j| (w_table[2 * j], w_table[2 * j + 1]), + ) + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + self.split_eq.bind(r); + fold_evals_in_place(&mut self.w_table, r); + } +} + +/// Verifier for the norm (range-check) sumcheck `F_{0,τ₀}`. +pub struct NormSumcheckVerifier { + tau: Vec, + w_evals: Vec, + num_vars: usize, + b: usize, +} + +impl NormSumcheckVerifier { + /// Create a new norm (range-check) sumcheck verifier. + /// + /// # Panics + /// + /// Panics if `w_evals.len() != 2^tau.len()`. + pub fn new(tau: Vec, w_evals: Vec, b: usize) -> Self { + let num_vars = tau.len(); + assert_eq!(w_evals.len(), 1 << num_vars); + Self { + tau, + w_evals, + num_vars, + b, + } + } +} + +impl SumcheckInstanceVerifier for NormSumcheckVerifier { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 2 * self.b + } + + fn input_claim(&self) -> E { + E::zero() + } + + fn expected_output_claim(&self, challenges: &[E]) -> Result { + let eq_val = EqPolynomial::mle(&self.tau, challenges); + let w_val = multilinear_eval(&self.w_evals, challenges)?; + Ok(eq_val * range_check_eval(w_val, self.b)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::ext::Ext2; + use crate::algebra::fields::lift::LiftBase; + use crate::algebra::ring::CyclotomicRing; + use crate::algebra::Fp64; + use crate::protocol::hachi_poly_ops::DensePoly; + use crate::protocol::opening_point::BasisMode; + use crate::protocol::ring_switch::r_decomp_levels; + use crate::protocol::sumcheck::eq_poly::EqPolynomial; + use crate::protocol::sumcheck::multilinear_eval; + use crate::protocol::transcript::labels; + use crate::protocol::{ + prove_sumcheck, verify_sumcheck, Blake2bTranscript, CommitmentConfig, CommitmentScheme, + HachiCommitmentScheme, SmallTestCommitmentConfig, Transcript, + }; + use crate::{FieldCore, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + use std::array::from_fn; + + type F = Fp64<4294967197>; + type Cfg = SmallTestCommitmentConfig; + const D: usize = Cfg::D; + type Scheme = HachiCommitmentScheme; + + struct PointEvalReferenceNormSumcheckProver { + split_eq: GruenSplitEq, + w_table: Vec, + num_vars: usize, + b: usize, + } + + impl PointEvalReferenceNormSumcheckProver { + fn new(tau: &[E], w_evals: Vec, b: usize) -> Self { + let num_vars = tau.len(); + assert_eq!(w_evals.len(), 1 << num_vars); + Self { + split_eq: GruenSplitEq::new(tau), + w_table: w_evals, + num_vars, + b, + } + } + } + + impl SumcheckInstanceProver + for PointEvalReferenceNormSumcheckProver + { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 2 * self.b + } + + fn input_claim(&self) -> E { + E::zero() + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + let half = self.w_table.len() / 2; + let degree_q = 2 * self.b - 1; + let num_points_q = degree_q + 1; + + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + let b = self.b; + + let mut q_evals = vec![E::zero(); num_points_q]; + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + for (t, eval) in q_evals.iter_mut().enumerate() { + let t_e = E::from_u64(t as u64); + let w_t = w_0 + t_e * (w_1 - w_0); + *eval += eq_rem * range_check_eval(w_t, b); + } + } + + let q_poly = UniPoly::from_evals(&q_evals); + self.split_eq.gruen_mul(&q_poly) + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + self.split_eq.bind(r); + fold_evals_in_place(&mut self.w_table, r); + } + } + + fn ring_with_small_coeff(value: u64) -> CyclotomicRing { + let coeffs = from_fn(|_| F::from_u64(value)); + CyclotomicRing::from_coefficients(coeffs) + } + + #[test] + fn norm_sumcheck_runtime_dispatch_matches_reference_kernels() { + let mut rng = StdRng::seed_from_u64(0xC0FFEE); + for (case_idx, b) in [4usize, 8, 16].into_iter().enumerate() { + let case_idx = case_idx as u64; + let num_vars = 6; + let n = 1usize << num_vars; + let w_evals: Vec = (0..n) + .map(|i| F::from_u64((i as u64 * 31 + case_idx * 17) % b as u64)) + .collect(); + let tau: Vec = (0..num_vars) + .map(|_| F::from_u64(rand::Rng::gen_range(&mut rng, 1u64..=257))) + .collect(); + + let mut dispatched = NormSumcheckProver::new(&tau, w_evals.clone(), b); + let mut point_eval = NormSumcheckProver::new_with_kernel( + &tau, + w_evals.clone(), + b, + NormRoundKernel::PointEvalInterpolation, + ); + let use_affine = b <= 8; + let mut affine_coeff = if use_affine { + Some(NormSumcheckProver::new_with_kernel( + &tau, + w_evals.clone(), + b, + NormRoundKernel::AffineCoeffComposition, + )) + } else { + None + }; + let mut reference = PointEvalReferenceNormSumcheckProver::new(&tau, w_evals, b); + + let mut claim_dispatched = F::zero(); + let mut claim_point = F::zero(); + let mut claim_affine = F::zero(); + let mut claim_reference = F::zero(); + for round in 0..num_vars { + let g_dispatch = dispatched.compute_round_univariate(round, claim_dispatched); + let g_point = point_eval.compute_round_univariate(round, claim_point); + let g_affine = affine_coeff + .as_mut() + .map(|p| p.compute_round_univariate(round, claim_affine)); + let g_ref = reference.compute_round_univariate(round, claim_reference); + + assert_eq!( + g_point, g_ref, + "point-eval mismatch for case {case_idx} round {round}" + ); + if let Some(ref ga) = g_affine { + assert_eq!( + *ga, g_ref, + "affine-coeff mismatch for case {case_idx} round {round}" + ); + } + match choose_round_kernel(b) { + NormRoundKernel::PointEvalInterpolation => { + assert_eq!( + g_dispatch, g_point, + "dispatch mismatch for case {case_idx} round {round}" + ); + } + NormRoundKernel::AffineCoeffComposition => { + assert_eq!( + g_dispatch, + g_affine.as_ref().unwrap().clone(), + "dispatch mismatch for case {case_idx} round {round}" + ); + } + } + + assert_eq!( + g_dispatch.evaluate(&F::zero()) + g_dispatch.evaluate(&F::one()), + claim_dispatched, + "dispatched hint mismatch for case {case_idx} round {round}" + ); + assert_eq!( + g_ref.evaluate(&F::zero()) + g_ref.evaluate(&F::one()), + claim_reference, + "reference hint mismatch for case {case_idx} round {round}" + ); + + let r = F::from_u64(rand::Rng::gen_range(&mut rng, 1u64..=257)); + claim_dispatched = g_dispatch.evaluate(&r); + claim_point = g_point.evaluate(&r); + if let Some(ref ga) = g_affine { + claim_affine = ga.evaluate(&r); + } + claim_reference = g_ref.evaluate(&r); + dispatched.ingest_challenge(round, r); + point_eval.ingest_challenge(round, r); + if let Some(ref mut p) = affine_coeff { + p.ingest_challenge(round, r); + } + reference.ingest_challenge(round, r); + } + assert_eq!( + claim_dispatched, claim_reference, + "final dispatched claim mismatch for case {case_idx}" + ); + assert_eq!( + claim_point, claim_reference, + "final point claim mismatch for case {case_idx}" + ); + if use_affine { + assert_eq!( + claim_affine, claim_reference, + "final affine claim mismatch for case {case_idx}" + ); + } + } + } + + #[test] + fn norm_sumcheck_uses_commitment_w_evals() { + let z = [ + ring_with_small_coeff(1), + ring_with_small_coeff(2), + ring_with_small_coeff(3), + ]; + let r = [ring_with_small_coeff(0), ring_with_small_coeff(1)]; + let log_basis = SmallTestCommitmentConfig::decomposition().log_basis; + let levels = r_decomp_levels::(log_basis); + let r_hat: Vec> = r + .iter() + .flat_map(|ri| ri.balanced_decompose_pow2(levels, log_basis)) + .collect(); + let mut w_evals: Vec = z + .iter() + .chain(r_hat.iter()) + .flat_map(|elem| elem.coefficients().iter().copied()) + .collect(); + + let target_len = w_evals.len().next_power_of_two(); + w_evals.resize(target_len, F::zero()); + let num_vars = target_len.trailing_zeros() as usize; + let tau: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + let b = 1usize << SmallTestCommitmentConfig::decomposition().log_basis; + + let eq_table = EqPolynomial::evals(&tau); + let _claim: F = (0..w_evals.len()) + .map(|i| eq_table[i] * range_check_eval(w_evals[i], b)) + .fold(F::zero(), |a, v| a + v); + + let mut prover = NormSumcheckProver::new(&tau, w_evals.clone(), b); + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (proof, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut pt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let oracle = EqPolynomial::mle(&tau, &prover_challenges) + * range_check_eval(multilinear_eval(&w_evals, &prover_challenges).unwrap(), b); + assert_eq!(final_claim, oracle, "prover final claim != oracle eval"); + + let verifier = NormSumcheckVerifier::new(tau, w_evals, b); + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let verifier_challenges = + verify_sumcheck::(&proof, &verifier, &mut vt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + assert_eq!(prover_challenges, verifier_challenges); + } + + #[test] + fn norm_sumcheck_uses_prove_w_evals() { + let alpha = SmallTestCommitmentConfig::D.trailing_zeros() as usize; + let layout = SmallTestCommitmentConfig::commitment_layout(8).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha; + let len = 1usize << num_vars; + let evals: Vec = (0..len).map(|i| F::from_u64(i as u64)).collect(); + let poly = DensePoly::::from_field_evals(num_vars, &evals).unwrap(); + + let setup = Scheme::setup_prover(num_vars); + let (commitment, hint) = Scheme::commit(&poly, &setup, &layout).unwrap(); + + let opening_point: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + let mut prover_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let proof = Scheme::prove( + &setup, + &poly, + &opening_point, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + let mut w_evals: Vec = proof.final_w.to_field_elems(); + let target_len = w_evals.len().next_power_of_two(); + w_evals.resize(target_len, F::zero()); + let num_sumcheck_vars = target_len.trailing_zeros() as usize; + let tau: Vec = (0..num_sumcheck_vars) + .map(|i| F::from_u64((i + 3) as u64)) + .collect(); + let b = 1usize << SmallTestCommitmentConfig::decomposition().log_basis; + + let eq_table = EqPolynomial::evals(&tau); + let _claim: F = (0..w_evals.len()) + .map(|i| eq_table[i] * range_check_eval(w_evals[i], b)) + .fold(F::zero(), |a, v| a + v); + + let mut prover = NormSumcheckProver::new(&tau, w_evals.clone(), b); + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (proof_sc, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut pt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let oracle = EqPolynomial::mle(&tau, &prover_challenges) + * range_check_eval(multilinear_eval(&w_evals, &prover_challenges).unwrap(), b); + assert_eq!(final_claim, oracle, "prover final claim != oracle eval"); + + let verifier = NormSumcheckVerifier::new(tau, w_evals, b); + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let verifier_challenges = + verify_sumcheck::(&proof_sc, &verifier, &mut vt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + assert_eq!(prover_challenges, verifier_challenges); + } + + #[test] + fn norm_sumcheck_over_ext2() { + type E2 = Ext2; + + let num_vars = 3; + let n = 1usize << num_vars; + let b = 2; + let w_evals_f: Vec = (0..n).map(|i| F::from_u64(i as u64 % b as u64)).collect(); + let tau_f: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + + let w_evals_e: Vec = w_evals_f.iter().map(|&f| E2::lift_base(f)).collect(); + let tau_e: Vec = tau_f.iter().map(|&f| E2::lift_base(f)).collect(); + + let mut prover = NormSumcheckProver::new(&tau_e, w_evals_e.clone(), b); + + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (proof, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut pt, |tr| { + E2::lift_base(tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND)) + }) + .unwrap(); + + let oracle = EqPolynomial::mle(&tau_e, &prover_challenges) + * range_check_eval(multilinear_eval(&w_evals_e, &prover_challenges).unwrap(), b); + assert_eq!(final_claim, oracle, "E2 prover final claim != oracle eval"); + + let verifier = NormSumcheckVerifier::new(tau_e, w_evals_e, b); + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let verifier_challenges = + verify_sumcheck::(&proof, &verifier, &mut vt, |tr| { + E2::lift_base(tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND)) + }) + .unwrap(); + + assert_eq!(prover_challenges, verifier_challenges); + } + + #[test] + fn range_check_eval_i128_matches_field() { + for b in [2, 4, 8, 10] { + for w in -(b as i32 - 1)..=(b as i32 - 1) { + let i128_val = range_check_eval_i128(w, b); + let field_val: F = range_check_eval(F::from_i64(w as i64), b); + let field_from_i128_val: F = field_from_i128(i128_val); + assert_eq!( + field_from_i128_val, field_val, + "i128 range-check mismatch for b={b}, w={w}: \ + i128={i128_val}, field_from_i128={field_from_i128_val:?}, field={field_val:?}" + ); + } + } + } +} diff --git a/src/protocol/sumcheck/relation_sumcheck.rs b/src/protocol/sumcheck/relation_sumcheck.rs new file mode 100644 index 00000000..09c38f31 --- /dev/null +++ b/src/protocol/sumcheck/relation_sumcheck.rs @@ -0,0 +1,440 @@ +//! Evaluation-relation sumcheck instance (F_α). +//! +//! **F_{α,τ₁}(x, y)** = w̃(x,y) · α̃(y) · m(x) +//! where m(x) = Σ_i ẽq(τ₁,i) · M̃_α(i,x). +//! +//! Proves the evaluation relation; sum equals `a = Σ_i ẽq(τ₁,i) · y_i(α)`. + +use super::eq_poly::EqPolynomial; +use super::{fold_evals_in_place, multilinear_eval}; +use super::{SumcheckInstanceProver, SumcheckInstanceVerifier, UniPoly}; +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::ring_switch::eval_ring_at; +use crate::{FieldCore, FromSmallInt}; +use std::iter; + +/// Prover for `F_{α,τ₁}(x,y) = w̃(x,y) · α̃(y) · m(x)`. +/// +/// Alpha and m are stored in compact form (sizes `2^num_l` and `2^num_u`) +/// and folded only during rounds where their variables are active. +/// +/// Round polynomial degree is 2 (product of at most two multilinear factors +/// depending on any single variable). +pub struct RelationSumcheckProver { + w_table: Vec, + alpha_compact: Vec, + m_compact: Vec, + num_u: usize, + num_vars: usize, + rounds_completed: usize, +} + +impl RelationSumcheckProver { + /// Construct from the three constituent evaluation tables. + /// + /// - `w_evals`: evaluations of `w̃` over `{0,1}^{num_u + num_l}` (full domain). + /// - `alpha_evals_y`: evaluations of `α̃` over `{0,1}^{num_l}` (compact). + /// - `m_evals_x`: evaluations of `m` over `{0,1}^{num_u}` (compact). + /// + /// # Panics + /// + /// Panics if table sizes don't match `2^num_u`, `2^num_l`, or `2^(num_u+num_l)`. + pub fn new( + w_evals: Vec, + alpha_evals_y: &[E], + m_evals_x: &[E], + num_u: usize, + num_l: usize, + ) -> Self { + let num_vars = num_u + num_l; + let n = 1usize << num_vars; + assert_eq!(w_evals.len(), n); + assert_eq!(alpha_evals_y.len(), 1 << num_l); + assert_eq!(m_evals_x.len(), 1 << num_u); + + Self { + w_table: w_evals, + alpha_compact: alpha_evals_y.to_vec(), + m_compact: m_evals_x.to_vec(), + num_u, + num_vars, + rounds_completed: 0, + } + } +} + +impl SumcheckInstanceProver for RelationSumcheckProver { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 2 + } + + fn input_claim(&self) -> E { + let x_mask = (1usize << self.num_u) - 1; + let alpha_compact = &self.alpha_compact; + let m_compact = &self.m_compact; + let num_u = self.num_u; + + #[cfg(feature = "parallel")] + { + self.w_table + .par_iter() + .enumerate() + .fold( + || E::zero(), + |acc, (idx, &w)| { + acc + w * alpha_compact[idx >> num_u] * m_compact[idx & x_mask] + }, + ) + .reduce(|| E::zero(), |a, b| a + b) + } + #[cfg(not(feature = "parallel"))] + { + self.w_table + .iter() + .enumerate() + .fold(E::zero(), |acc, (idx, &w)| { + acc + w * alpha_compact[idx >> num_u] * m_compact[idx & x_mask] + }) + } + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + let half = self.w_table.len() / 2; + let num_points = 3; + let current_x_width = self.num_u.saturating_sub(self.rounds_completed); + let current_x_mask = (1usize << current_x_width).wrapping_sub(1); + let alpha_compact = &self.alpha_compact; + let m_compact = &self.m_compact; + + #[cfg(feature = "parallel")] + let round_evals = { + (0..half) + .into_par_iter() + .fold( + || vec![E::zero(); num_points], + |mut evals, j| { + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + let a_0 = alpha_compact[(2 * j) >> current_x_width]; + let a_1 = alpha_compact[(2 * j + 1) >> current_x_width]; + let m_0 = m_compact[(2 * j) & current_x_mask]; + let m_1 = m_compact[(2 * j + 1) & current_x_mask]; + for (t, eval) in evals.iter_mut().enumerate() { + let t_e = E::from_u64(t as u64); + let w_t = w_0 + t_e * (w_1 - w_0); + let a_t = a_0 + t_e * (a_1 - a_0); + let m_t = m_0 + t_e * (m_1 - m_0); + *eval += w_t * a_t * m_t; + } + evals + }, + ) + .reduce( + || vec![E::zero(); num_points], + |mut a, b| { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + *ai += *bi; + } + a + }, + ) + }; + #[cfg(not(feature = "parallel"))] + let round_evals = { + let mut evals = vec![E::zero(); num_points]; + for j in 0..half { + let w_0 = self.w_table[2 * j]; + let w_1 = self.w_table[2 * j + 1]; + let a_0 = alpha_compact[(2 * j) >> current_x_width]; + let a_1 = alpha_compact[(2 * j + 1) >> current_x_width]; + let m_0 = m_compact[(2 * j) & current_x_mask]; + let m_1 = m_compact[(2 * j + 1) & current_x_mask]; + for (t, eval) in evals.iter_mut().enumerate() { + let t_e = E::from_u64(t as u64); + let w_t = w_0 + t_e * (w_1 - w_0); + let a_t = a_0 + t_e * (a_1 - a_0); + let m_t = m_0 + t_e * (m_1 - m_0); + *eval += w_t * a_t * m_t; + } + } + evals + }; + + UniPoly::from_evals(&round_evals) + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + fold_evals_in_place(&mut self.w_table, r); + if self.rounds_completed < self.num_u { + fold_evals_in_place(&mut self.m_compact, r); + } else { + fold_evals_in_place(&mut self.alpha_compact, r); + } + self.rounds_completed += 1; + } +} + +/// Verifier for the evaluation-relation sumcheck `F_{α,τ₁}`. +pub struct RelationSumcheckVerifier { + w_evals: Vec, + alpha_evals_y: Vec, + m_evals_x: Vec, + tau: Vec, + v: Vec>, + u: Vec>, + y_ring: CyclotomicRing, + alpha: F, + num_u: usize, + num_l: usize, +} + +impl RelationSumcheckVerifier { + /// Create a new evaluation-relation sumcheck verifier. + /// + /// # Panics + /// + /// Panics if table sizes don't match `2^num_u`, `2^num_l`, or `2^(num_u+num_l)`. + #[allow(clippy::too_many_arguments)] + pub fn new( + w_evals: Vec, + alpha_evals_y: Vec, + m_evals_x: Vec, + tau: Vec, + v: Vec>, + u: Vec>, + y_ring: CyclotomicRing, + alpha: F, + num_u: usize, + num_l: usize, + ) -> Self { + assert_eq!(w_evals.len(), 1 << (num_u + num_l)); + assert_eq!(alpha_evals_y.len(), 1 << num_l); + assert_eq!(m_evals_x.len(), 1 << num_u); + Self { + w_evals, + alpha_evals_y, + m_evals_x, + tau, + v, + u, + y_ring, + alpha, + num_u, + num_l, + } + } +} + +impl SumcheckInstanceVerifier for RelationSumcheckVerifier { + fn num_rounds(&self) -> usize { + self.num_u + self.num_l + } + + fn degree_bound(&self) -> usize { + 2 + } + + fn input_claim(&self) -> F { + let y_a: Vec = self + .v + .iter() + .chain(self.u.iter()) + .chain(iter::once(&self.y_ring)) + .map(|r| eval_ring_at(r, &self.alpha)) + .collect(); + + let eq_tau = EqPolynomial::evals(&self.tau); + let mut acc = F::zero(); + for (i, eq_i) in eq_tau.iter().enumerate() { + let y_i = if i < y_a.len() { y_a[i] } else { F::zero() }; + acc += *eq_i * y_i; + } + acc + } + + fn expected_output_claim(&self, challenges: &[F]) -> Result { + let (x_challenges, y_challenges) = challenges.split_at(self.num_u); + let w_val = multilinear_eval(&self.w_evals, challenges)?; + let alpha_val = multilinear_eval(&self.alpha_evals_y, y_challenges)?; + let m_val = multilinear_eval(&self.m_evals_x, x_challenges)?; + Ok(w_val * alpha_val * m_val) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Fp64; + use crate::protocol::commitment_scheme::{ + rederive_alpha_and_m_a, rederive_alpha_and_m_evals_x, + }; + use crate::protocol::hachi_poly_ops::DensePoly; + use crate::protocol::opening_point::BasisMode; + use crate::protocol::sumcheck::eq_poly::EqPolynomial; + use crate::protocol::transcript::labels; + use crate::protocol::{ + prove_sumcheck, verify_sumcheck, Blake2bTranscript, CommitmentConfig, CommitmentScheme, + HachiCommitmentScheme, SmallTestCommitmentConfig, Transcript, + }; + use crate::{FieldCore, FromSmallInt}; + + type F = Fp64<4294967197>; + type Cfg = SmallTestCommitmentConfig; + const D: usize = Cfg::D; + type Scheme = HachiCommitmentScheme; + + #[test] + fn relation_sumcheck_uses_prove_w_evals() { + let alpha_bits = D.trailing_zeros() as usize; + let layout = SmallTestCommitmentConfig::commitment_layout(8).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha_bits; + let len = 1usize << num_vars; + let evals: Vec = (0..len).map(|i| F::from_u64(i as u64)).collect(); + let poly = DensePoly::::from_field_evals(num_vars, &evals).unwrap(); + + let setup = Scheme::setup_prover(num_vars); + let (commitment, hint) = Scheme::commit(&poly, &setup, &layout).unwrap(); + + let opening_point: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + let mut prover_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let proof = Scheme::prove( + &setup, + &poly, + &opening_point, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + let (alpha, m_a_vec) = rederive_alpha_and_m_a::( + &proof, + &Scheme::setup_verifier(&setup), + &opening_point, + &commitment, + ) + .unwrap(); + + let final_w: Vec = proof.final_w.to_field_elems(); + let d = SmallTestCommitmentConfig::D; + assert_eq!(final_w.len() % d, 0); + let w_u = final_w.len() / d; + let rows = SmallTestCommitmentConfig::N_D + + SmallTestCommitmentConfig::N_B + + 1 + + 1 + + SmallTestCommitmentConfig::N_A; + assert!(rows > 0); + assert_eq!(m_a_vec.len() % rows, 0); + let cols = m_a_vec.len() / rows; + assert_eq!(w_u, cols); + + let num_u = cols.next_power_of_two().trailing_zeros() as usize; + let num_l = alpha_bits; + let n = 1usize << (num_u + num_l); + + let mut w_evals = vec![F::zero(); n]; + let y_len = 1usize << num_l; + let x_len = 1usize << num_u; + for x in 0..x_len { + for y in 0..y_len { + let src = y + (x << num_l); + if src < final_w.len() { + let dst = x + (y << num_u); + w_evals[dst] = final_w[src]; + } + } + } + + let num_i = rows.next_power_of_two().trailing_zeros() as usize; + let tau1: Vec = (0..num_i).map(|i| F::from_u64((i + 5) as u64)).collect(); + let eq_tau1 = EqPolynomial::evals(&tau1); + + let mut m_evals_x_reference = vec![F::zero(); x_len]; + for x in 0..x_len { + let mut acc = F::zero(); + for i in 0..(1usize << num_i) { + let row_val = if i < rows && x < cols { + m_a_vec[i * cols + x] + } else { + F::zero() + }; + acc += eq_tau1[i] * row_val; + } + m_evals_x_reference[x] = acc; + } + + let (alpha_check, m_evals_x) = rederive_alpha_and_m_evals_x::( + &proof, + &Scheme::setup_verifier(&setup), + &opening_point, + &commitment, + &tau1, + ) + .unwrap(); + assert_eq!(alpha_check, alpha); + assert_eq!( + m_evals_x, m_evals_x_reference, + "fused m_evals_x should match expanded-matrix contraction", + ); + + let mut alpha_evals_y = vec![F::zero(); y_len]; + let mut power = F::one(); + for val in alpha_evals_y.iter_mut() { + *val = power; + power *= alpha; + } + + let x_mask = x_len - 1; + let alpha_full: Vec = (0..n).map(|idx| alpha_evals_y[idx >> num_u]).collect(); + let m_full: Vec = (0..n).map(|idx| m_evals_x[idx & x_mask]).collect(); + let _claim: F = (0..n) + .map(|i| w_evals[i] * alpha_full[i] * m_full[i]) + .fold(F::zero(), |a, v| a + v); + + let mut prover = + RelationSumcheckProver::new(w_evals.clone(), &alpha_evals_y, &m_evals_x, num_u, num_l); + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (proof_sc, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut pt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let (x_ch, y_ch) = prover_challenges.split_at(num_u); + let oracle = multilinear_eval(&w_evals, &prover_challenges).unwrap() + * multilinear_eval(&alpha_evals_y, y_ch).unwrap() + * multilinear_eval(&m_evals_x, x_ch).unwrap(); + assert_eq!(final_claim, oracle, "prover final claim != oracle eval"); + + let verifier = RelationSumcheckVerifier::new( + w_evals, + alpha_evals_y, + m_evals_x, + tau1, + proof.levels[0].v_typed::(), + commitment.u.clone(), + proof.levels[0].y_ring_typed::(), + alpha, + num_u, + num_l, + ); + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let verifier_challenges = + verify_sumcheck::(&proof_sc, &verifier, &mut vt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + assert_eq!(prover_challenges, verifier_challenges); + } +} diff --git a/src/protocol/sumcheck/split_eq.rs b/src/protocol/sumcheck/split_eq.rs new file mode 100644 index 00000000..254cc67c --- /dev/null +++ b/src/protocol/sumcheck/split_eq.rs @@ -0,0 +1,213 @@ +//! Gruen/Dao-Thaler split equality polynomial for efficient sumcheck. +//! +//! Factors `eq(τ, x)` into a running scalar, a linear factor for the current +//! variable, and precomputed split tables for the remaining variables. This +//! avoids maintaining and folding a full-size eq table during sumcheck. +//! +//! For details, see . +//! +//! Adapted from Jolt's `GruenSplitEqPolynomial`. +//! +//! ## Variable Layout (forward binding, little-endian) +//! +//! ```text +//! τ = [τ_current, τ_first_half, τ_second_half] +//! 1 var m vars (n-1-m) vars +//! ``` +//! +//! where `m = (n-1) / 2` and `n = τ.len()`. +//! +//! After binding `τ_current`, the next variable comes from `τ_first_half`, +//! then from `τ_second_half`. Suffix-cached eq tables for each half enable +//! O(1) pops per round instead of an O(2^n) fold. + +use super::eq_poly::EqPolynomial; +use super::UniPoly; +use crate::FieldCore; + +/// Split equality polynomial with Gruen scalar accumulation. +/// +/// Instead of storing and folding a full eq table each round, this struct +/// maintains: +/// - `current_scalar`: accumulated `eq(τ_bound, r_bound)` from already-bound +/// variables +/// - `E_first` / `E_second`: suffix-cached eq tables for two halves of the +/// remaining (unbound, non-current) variables +/// +/// The eq contribution for a pair index `j` in the inner sum is: +/// ```text +/// eq_remaining(j) = E_first[j_low] · E_second[j_high] +/// ``` +/// and the full round polynomial is `l(X) · q(X)` where `l(X)` is the linear +/// eq factor for the current variable. +#[allow(non_snake_case)] +pub struct GruenSplitEq { + tau: Vec, + current_round: usize, + current_scalar: E, + /// Suffix-cached eq tables for the first half of remaining variables. + /// `E_first[k]` = `eq(τ[split-k..split], ·)` with `2^k` entries. + /// Invariant: never empty; `E_first[0] = [1]`. + E_first: Vec>, + /// Suffix-cached eq tables for the second half of remaining variables. + /// `E_second[k]` = `eq(τ[n-k..n], ·)` with `2^k` entries. + /// Invariant: never empty; `E_second[0] = [1]`. + E_second: Vec>, +} + +#[allow(non_snake_case)] +impl GruenSplitEq { + /// Create a new split-eq from the full challenge vector `τ`. + /// + /// Precomputes suffix-cached eq tables for two halves of `τ[1..n]`. + /// + /// # Panics + /// + /// Panics if `tau` is empty. + pub fn new(tau: &[E]) -> Self { + let n = tau.len(); + assert!(n >= 1); + let m = (n - 1) / 2; + let split = 1 + m; + let first_half = &tau[1..split]; + let second_half = &tau[split..n]; + let E_first = EqPolynomial::evals_cached(first_half); + let E_second = EqPolynomial::evals_cached(second_half); + Self { + tau: tau.to_vec(), + current_round: 0, + current_scalar: E::one(), + E_first, + E_second, + } + } + + /// The accumulated scalar `Π_{k < current_round} eq(τ[k], r[k])`. + pub fn current_scalar(&self) -> E { + self.current_scalar + } + + /// The τ value for the variable about to be bound. + pub fn current_tau(&self) -> E { + self.tau[self.current_round] + } + + /// Return the current top-level split-eq tables `(E_first, E_second)`. + /// + /// For a pair index `j` in the inner sum, the eq factor for the + /// remaining (non-current) variables is: + /// ```text + /// eq_remaining(j) = E_first[j & (E_first.len()-1)] + /// · E_second[j >> E_first.len().trailing_zeros()] + /// ``` + /// + /// # Panics + /// + /// Panics if either `E_first` or `E_second` is empty (invariant violation). + pub fn remaining_eq_tables(&self) -> (&[E], &[E]) { + ( + self.E_first.last().expect("E_first is never empty"), + self.E_second.last().expect("E_second is never empty"), + ) + } + + /// Bind the current variable to challenge `r`, advancing to the next round. + /// + /// Updates `current_scalar` with `eq(τ[current_round], r)` and pops the + /// appropriate split table level. + pub fn bind(&mut self, r: E) { + let tau_k = self.tau[self.current_round]; + self.current_scalar = + self.current_scalar * (tau_k * r + (E::one() - tau_k) * (E::one() - r)); + self.current_round += 1; + if self.E_first.len() > 1 { + self.E_first.pop(); + } else if self.E_second.len() > 1 { + self.E_second.pop(); + } + } + + /// Compute the round polynomial `s(X) = l(X) · q(X)` from the inner + /// polynomial `q` (given as evaluations at integer points `0, 1, ..., d`). + /// + /// `l(X) = current_scalar · eq(τ_current, X)` is the linear eq factor + /// for the current variable. The result has degree `d + 1`. + pub fn gruen_mul(&self, q_poly: &UniPoly) -> UniPoly { + let tau_k = self.current_tau(); + let scalar = self.current_scalar(); + let l_0 = scalar * (E::one() - tau_k); + let l_1 = scalar * tau_k; + let slope = l_1 - l_0; + let mut coeffs = vec![E::zero(); q_poly.coeffs.len() + 1]; + for (i, &c) in q_poly.coeffs.iter().enumerate() { + coeffs[i] += c * l_0; + coeffs[i + 1] += c * slope; + } + UniPoly::from_coeffs(coeffs) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Fp64; + use crate::protocol::sumcheck::fold_evals_in_place; + use crate::{FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp64<4294967197>; + + #[test] + fn gruen_eq_matches_full_eq_table() { + let mut rng = StdRng::seed_from_u64(0xBB); + for n in 1..10 { + let tau: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let mut full_eq = EqPolynomial::evals(&tau); + let mut split_eq = GruenSplitEq::new(&tau); + + for _round in 0..n { + let half = full_eq.len() / 2; + let (e_first, e_second) = split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> num_first.trailing_zeros(); + let eq_rem = e_first[j_low] * e_second[j_high]; + + let tau_k = split_eq.current_tau(); + let scalar = split_eq.current_scalar(); + let eq_0 = scalar * (F::one() - tau_k) * eq_rem; + let eq_1 = scalar * tau_k * eq_rem; + + assert_eq!(eq_0, full_eq[2 * j], "n={n} round={_round} j={j} eq_0"); + assert_eq!(eq_1, full_eq[2 * j + 1], "n={n} round={_round} j={j} eq_1"); + } + + let r = F::sample(&mut rng); + fold_evals_in_place(&mut full_eq, r); + split_eq.bind(r); + } + } + } + + #[test] + fn gruen_mul_matches_direct_product() { + let mut rng = StdRng::seed_from_u64(0xCC); + let tau: Vec = (0..5).map(|_| F::sample(&mut rng)).collect(); + let split_eq = GruenSplitEq::new(&tau); + + let q = UniPoly::from_coeffs(vec![F::from_u64(3), F::from_u64(7), F::from_u64(2)]); + let s = split_eq.gruen_mul(&q); + + let tau_k = split_eq.current_tau(); + let scalar = split_eq.current_scalar(); + for t in 0..10u64 { + let x = F::from_u64(t); + let l_x = scalar * (tau_k * x + (F::one() - tau_k) * (F::one() - x)); + let q_x = q.evaluate(&x); + assert_eq!(s.evaluate(&x), l_x * q_x, "t={t}"); + } + } +} diff --git a/src/protocol/sumcheck/types.rs b/src/protocol/sumcheck/types.rs new file mode 100644 index 00000000..7b4955a2 --- /dev/null +++ b/src/protocol/sumcheck/types.rs @@ -0,0 +1,378 @@ +//! Sumcheck data types: univariate polynomials, compressed representation, and proof container. + +use crate::error::HachiError; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::Transcript; +use crate::FieldCore; +use crate::FromSmallInt; +use std::io::{Read, Write}; + +/// Univariate polynomial in coefficient form: `p(X) = Σ_{i=0}^d coeffs[i] * X^i`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UniPoly { + /// Coefficients from low degree to high degree. + pub coeffs: Vec, +} + +impl UniPoly { + /// Construct from coefficients in increasing-degree order. + pub fn from_coeffs(coeffs: Vec) -> Self { + Self { coeffs } + } + + /// Degree of the polynomial (0 for empty or constant). + pub fn degree(&self) -> usize { + self.coeffs.len().saturating_sub(1) + } + + /// Evaluate at `x` via Horner's method. + pub fn evaluate(&self, x: &E) -> E { + let mut acc = E::zero(); + for c in self.coeffs.iter().rev() { + acc = acc * *x + *c; + } + acc + } + + /// Compress this polynomial by omitting the linear coefficient. + /// + /// The verifier can reconstruct/evaluate the missing linear coefficient using + /// the per-round hint `g(0)+g(1)` from the sumcheck protocol. + /// + /// This matches the technique used by Jolt's sumcheck (`CompressedUniPoly`). + pub fn compress(&self) -> CompressedUniPoly { + let coeffs = &self.coeffs; + if coeffs.is_empty() { + return CompressedUniPoly { + coeffs_except_linear_term: Vec::new(), + }; + } + if coeffs.len() == 1 { + return CompressedUniPoly { + coeffs_except_linear_term: vec![coeffs[0]], + }; + } + let mut out = Vec::with_capacity(coeffs.len().saturating_sub(1)); + out.push(coeffs[0]); + out.extend_from_slice(&coeffs[2..]); + CompressedUniPoly { + coeffs_except_linear_term: out, + } + } +} + +impl UniPoly { + /// Interpolate from evaluations at equispaced integer points `x = 0, 1, ..., d`. + /// + /// Uses Newton forward-difference interpolation: compute divided differences, + /// then expand via Horner on the nested Newton form. + /// + /// # Panics + /// + /// Panics if any required factorial inverse does not exist (field characteristic + /// must exceed the number of evaluation points). This is a prover-only + /// function and the condition always holds for Hachi's fields. + pub fn from_evals(evals: &[E]) -> Self { + let n = evals.len(); + if n == 0 { + return Self::from_coeffs(vec![]); + } + if n == 1 { + return Self::from_coeffs(vec![evals[0]]); + } + + let mut table = evals.to_vec(); + let mut deltas = vec![table[0]]; + for _ in 1..n { + for j in 0..table.len() - 1 { + table[j] = table[j + 1] - table[j]; + } + table.pop(); + deltas.push(table[0]); + } + + let mut factorial = E::one(); + let mut divided_diffs = vec![deltas[0]]; + for (k, delta_k) in deltas.iter().enumerate().skip(1) { + factorial = factorial * E::from_u64(k as u64); + divided_diffs.push( + *delta_k + * factorial + .inv() + .expect("field characteristic too small for interpolation"), + ); + } + + let mut coeffs = vec![divided_diffs[n - 1]]; + + for k in (0..n - 1).rev() { + let shift = E::from_u64(k as u64); + let old_len = coeffs.len(); + let mut new_coeffs = vec![E::zero(); old_len + 1]; + + new_coeffs[0] = divided_diffs[k]; + for i in 0..old_len { + new_coeffs[i + 1] += coeffs[i]; + new_coeffs[i] -= shift * coeffs[i]; + } + + coeffs = new_coeffs; + } + + while coeffs.len() > 1 && coeffs.last().is_some_and(|c| c.is_zero()) { + coeffs.pop(); + } + + Self::from_coeffs(coeffs) + } +} + +impl Valid for UniPoly { + fn check(&self) -> Result<(), SerializationError> { + self.coeffs.check() + } +} + +impl HachiSerialize for UniPoly { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.coeffs.serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.coeffs.serialized_size(compress) + } +} + +impl HachiDeserialize for UniPoly { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let coeffs = Vec::::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { coeffs }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Compressed univariate polynomial representation omitting the linear term. +/// +/// We store `[c0, c2, c3, ..., cd]`. Given the sumcheck hint `hint = g(0)+g(1)`, +/// the missing linear coefficient is: +/// +/// `c1 = hint - 2*c0 - Σ_{i=2..d} ci`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CompressedUniPoly { + /// Coefficients excluding the linear term: `[c0, c2, c3, ..., cd]`. + pub coeffs_except_linear_term: Vec, +} + +impl CompressedUniPoly { + /// Degree of the underlying uncompressed polynomial. + /// + /// `compress()` stores `[c0, c2, ..., cd]` — exactly `d` entries for + /// degree `d >= 2`. For `len <= 1` (degree 0 or 1, which are ambiguous + /// in compressed form) we report 0; this is conservative for the + /// verifier's degree-bound check since `degree_bound >= 2` in practice. + pub fn degree(&self) -> usize { + let len = self.coeffs_except_linear_term.len(); + if len <= 1 { + 0 + } else { + len + } + } + + fn recover_linear_term(&self, hint: &E) -> E { + if self.coeffs_except_linear_term.is_empty() { + return E::zero(); + } + + let c0 = self.coeffs_except_linear_term[0]; + let mut linear = *hint - c0 - c0; + for c in self.coeffs_except_linear_term.iter().skip(1) { + linear -= *c; + } + linear + } + + /// Decompress using `hint = g(0)+g(1)`. + pub fn decompress(&self, hint: &E) -> UniPoly { + if self.coeffs_except_linear_term.is_empty() { + return UniPoly::from_coeffs(Vec::new()); + } + let linear = self.recover_linear_term(hint); + let mut coeffs = Vec::with_capacity(self.coeffs_except_linear_term.len() + 1); + coeffs.push(self.coeffs_except_linear_term[0]); + coeffs.push(linear); + coeffs.extend_from_slice(&self.coeffs_except_linear_term[1..]); + UniPoly::from_coeffs(coeffs) + } + + /// Evaluate the uncompressed polynomial at `x`, using `hint = g(0)+g(1)`. + /// + /// This avoids materializing the full coefficient list. + pub fn eval_from_hint(&self, hint: &E, x: &E) -> E { + if self.coeffs_except_linear_term.is_empty() { + return E::zero(); + } + + let linear = self.recover_linear_term(hint); + let mut acc = self.coeffs_except_linear_term[0] + (*x * linear); + + let mut pow = *x * *x; + for c in self.coeffs_except_linear_term.iter().skip(1) { + acc += *c * pow; + pow = pow * *x; + } + acc + } +} + +impl Valid for CompressedUniPoly { + fn check(&self) -> Result<(), SerializationError> { + self.coeffs_except_linear_term.check() + } +} + +impl HachiSerialize for CompressedUniPoly { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.coeffs_except_linear_term + .serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.coeffs_except_linear_term.serialized_size(compress) + } +} + +impl HachiDeserialize for CompressedUniPoly { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let coeffs_except_linear_term = + Vec::::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { + coeffs_except_linear_term, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Sumcheck proof containing one compressed univariate polynomial per round. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SumcheckProof { + /// One compressed univariate polynomial per sumcheck round. + pub round_polys: Vec>, +} + +impl Valid for SumcheckProof { + fn check(&self) -> Result<(), SerializationError> { + self.round_polys.check() + } +} + +impl HachiSerialize for SumcheckProof { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.round_polys.serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.round_polys.serialized_size(compress) + } +} + +impl HachiDeserialize for SumcheckProof { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let round_polys = + Vec::>::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { round_polys }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl SumcheckProof { + /// Verifier-side sumcheck transcript driver. + /// + /// This method: + /// - absorbs the per-round prover message (compressed univariate), + /// - samples one challenge per round via `sample_challenge`, + /// - updates the running claim using `eval_from_hint`. + /// + /// It does **not** perform the final oracle check `final_claim == f(r*)`. + /// Callers (e.g. ring-switching) must compute `f(r*)` themselves and compare. + /// + /// # Errors + /// + /// Returns an error if the proof length does not match `num_rounds` or if any + /// per-round polynomial exceeds `degree_bound`. + pub fn verify( + &self, + mut claim: E, + num_rounds: usize, + degree_bound: usize, + transcript: &mut T, + mut sample_challenge: S, + ) -> Result<(E, Vec), HachiError> + where + F: crate::FieldCore + crate::CanonicalField, + T: Transcript, + S: FnMut(&mut T) -> E, + { + if self.round_polys.len() != num_rounds { + return Err(HachiError::InvalidSize { + expected: num_rounds, + actual: self.round_polys.len(), + }); + } + + let mut r = Vec::with_capacity(num_rounds); + for poly in &self.round_polys { + if poly.degree() > degree_bound { + return Err(HachiError::InvalidInput(format!( + "sumcheck round poly degree {} exceeds bound {}", + poly.degree(), + degree_bound + ))); + } + + transcript.append_serde(labels::ABSORB_SUMCHECK_ROUND, poly); + let r_i = sample_challenge(transcript); + r.push(r_i); + + claim = poly.eval_from_hint(&claim, &r_i); + } + + Ok((claim, r)) + } +} diff --git a/src/protocol/transcript/hash.rs b/src/protocol/transcript/hash.rs new file mode 100644 index 00000000..cfae7b30 --- /dev/null +++ b/src/protocol/transcript/hash.rs @@ -0,0 +1,103 @@ +//! Generic hash-based transcript for protocol-layer Fiat-Shamir. +//! +//! Parameterised over any `Digest + Clone` hasher, eliminating the +//! near-identical Blake2b and Keccak implementations. + +use super::Transcript; +use crate::primitives::serialization::HachiSerialize; +use crate::{CanonicalField, FieldCore}; +use blake2::{Blake2b512, Digest}; +use sha3::Keccak256; +use std::marker::PhantomData; + +/// Hash-based transcript with labeled framing. +/// +/// Works with any cryptographic hash that implements `Digest + Clone`. +#[derive(Clone)] +pub struct HashTranscript +where + F: FieldCore + CanonicalField + 'static, +{ + hasher: D, + _field: PhantomData, +} + +impl HashTranscript +where + F: FieldCore + CanonicalField + 'static, +{ + #[inline] + fn append_bytes_impl(&mut self, label: &[u8], bytes: &[u8]) { + self.hasher.update(label); + self.hasher.update((bytes.len() as u64).to_le_bytes()); + self.hasher.update(bytes); + } + + #[inline] + fn challenge_and_chain(&mut self, label: &[u8]) -> Vec { + self.hasher.update(label); + let digest = self.hasher.clone().finalize(); + let out = digest.to_vec(); + self.hasher.update(&out); + out + } +} + +impl Transcript for HashTranscript +where + F: FieldCore + CanonicalField + 'static, +{ + fn new(domain_label: &[u8]) -> Self { + let mut hasher = D::new(); + hasher.update(domain_label); + Self { + hasher, + _field: PhantomData, + } + } + + fn append_bytes(&mut self, label: &[u8], bytes: &[u8]) { + self.append_bytes_impl(label, bytes); + } + + fn append_field(&mut self, label: &[u8], x: &F) { + self.append_bytes_impl(label, &x.to_canonical_u128().to_le_bytes()); + } + + fn append_serde(&mut self, label: &[u8], s: &S) { + let mut bytes = Vec::new(); + s.serialize_compressed(&mut bytes) + .expect("HachiSerialize should not fail"); + self.append_bytes_impl(label, &bytes); + } + + fn challenge_scalar(&mut self, label: &[u8]) -> F { + let bytes = self.challenge_and_chain(label); + let mut lo = [0u8; 16]; + lo.copy_from_slice(&bytes[..16]); + let sampled = u128::from_le_bytes(lo); + F::from_canonical_u128_reduced(sampled) + } +} + +impl HashTranscript +where + F: FieldCore + CanonicalField + 'static, +{ + /// Reset transcript state under a new domain label. + /// + /// This is an inherent method (not part of the `Transcript` trait) to + /// discourage use in production protocol code where resetting the + /// Fiat-Shamir chain would be unsound. + pub fn reset(&mut self, domain_label: &[u8]) { + let mut hasher = D::new(); + hasher.update(domain_label); + self.hasher = hasher; + } +} + +/// Blake2b512 transcript with labeled framing. +pub type Blake2bTranscript = HashTranscript; + +/// Keccak256 transcript with labeled framing. +pub type KeccakTranscript = HashTranscript; diff --git a/src/protocol/transcript/labels.rs b/src/protocol/transcript/labels.rs new file mode 100644 index 00000000..4cc059be --- /dev/null +++ b/src/protocol/transcript/labels.rs @@ -0,0 +1,124 @@ +//! Hachi-native transcript labels. +//! +//! These constants are the single source of truth for protocol transcript +//! labels in Hachi core. External integrations should translate at adapter +//! boundaries instead of introducing foreign label names here. + +/// Top-level protocol domain label. +pub const DOMAIN_HACHI_PROTOCOL: &[u8] = b"hachi/protocol"; + +/// Absorb commitment object(s) (paper §4.1). +pub const ABSORB_COMMITMENT: &[u8] = b"hachi/absorb/commitment"; +/// Absorb claimed openings/evaluations before relation reduction (paper §4.2). +pub const ABSORB_EVALUATION_CLAIMS: &[u8] = b"hachi/absorb/evaluation-claims"; +/// Challenge for the evaluation-to-linear-relation reduction (paper §4.2). +pub const CHALLENGE_LINEAR_RELATION: &[u8] = b"hachi/challenge/linear-relation"; +/// Absorb ring-switch relation messages (paper §4.3). +pub const ABSORB_RING_SWITCH_MESSAGE: &[u8] = b"hachi/absorb/ring-switch-message"; +/// Challenge used by ring-switching relation checks (paper §4.3). +pub const CHALLENGE_RING_SWITCH: &[u8] = b"hachi/challenge/ring-switch"; +/// Absorb sparse-challenge sampling context (e.g. for short/sparse ring `c`). +pub const ABSORB_SPARSE_CHALLENGE: &[u8] = b"hachi/absorb/sparse-challenge"; +/// Challenge bytes used to sample sparse challenges (e.g. ring `c` with weight ω). +pub const CHALLENGE_SPARSE_CHALLENGE: &[u8] = b"hachi/challenge/sparse-challenge"; +/// Absorb the initial sumcheck claim before round messages begin. +pub const ABSORB_SUMCHECK_CLAIM: &[u8] = b"hachi/absorb/sumcheck-claim"; +/// Absorb per-round sumcheck messages (paper §4.3). +pub const ABSORB_SUMCHECK_ROUND: &[u8] = b"hachi/absorb/sumcheck-round"; +/// Challenge sampled per sumcheck round (paper §4.3). +pub const CHALLENGE_SUMCHECK_ROUND: &[u8] = b"hachi/challenge/sumcheck-round"; +/// Challenge for batched sumcheck coefficient sampling. +pub const CHALLENGE_SUMCHECK_BATCH: &[u8] = b"hachi/challenge/sumcheck-batch"; +/// Absorb recursion/stop-condition message payloads (paper §4.5). +pub const ABSORB_STOP_CONDITION: &[u8] = b"hachi/absorb/stop-condition"; +/// Challenge sampled for recursion stop-condition checks (paper §4.5). +pub const CHALLENGE_STOP_CONDITION: &[u8] = b"hachi/challenge/stop-condition"; + +/// Absorb the prover's stage-1 message `v = D · ŵ` (paper §4.2, Figure 3). +pub const ABSORB_PROVER_V: &[u8] = b"hachi/absorb/prover-stage1-v"; +/// Challenge label for stage-1 fold (sampling sparse `c_i`). +pub const CHALLENGE_STAGE1_FOLD: &[u8] = b"hachi/challenge/stage1-fold"; + +/// Absorb the `w` coefficient vector before sumcheck (paper §4.3). +pub const ABSORB_SUMCHECK_W: &[u8] = b"hachi/absorb/sumcheck-w"; +/// Challenge for sampling `τ₀` (F_0 range-check batching point, paper §4.3). +pub const CHALLENGE_TAU0: &[u8] = b"hachi/challenge/tau0"; +/// Challenge for sampling `τ₁` (F_α evaluation-relation batching point, paper §4.3). +pub const CHALLENGE_TAU1: &[u8] = b"hachi/challenge/tau1"; + +/// Labrador protocol domain label (used for recursive reduction stages). +pub const DOMAIN_LABRADOR_PROTOCOL: &[u8] = b"hachi/labrador/protocol"; +/// Greyhound evaluation-reduction domain label. +pub const DOMAIN_GREYHOUND_EVAL: &[u8] = b"hachi/greyhound/eval"; +/// Absorb canonical Greyhound evaluation context bytes (dimensions/backend id). +pub const ABSORB_GREYHOUND_EVAL_CONTEXT: &[u8] = b"hachi/absorb/greyhound-eval-context"; +/// Absorb canonicalized evaluation-point coordinates for Greyhound reduction. +pub const ABSORB_GREYHOUND_EVAL_POINT: &[u8] = b"hachi/absorb/greyhound-eval-point"; +/// Absorb the claimed evaluation value for Greyhound reduction. +pub const ABSORB_GREYHOUND_EVAL_VALUE: &[u8] = b"hachi/absorb/greyhound-eval-value"; +/// Absorb the Greyhound second outer commitment `u2`. +pub const ABSORB_GREYHOUND_U2: &[u8] = b"hachi/absorb/greyhound-u2"; +/// Challenge for Greyhound column-fold coefficients. +pub const CHALLENGE_GREYHOUND_FOLD: &[u8] = b"hachi/challenge/greyhound-fold"; +/// Absorb canonical Labrador level metadata (shape/config/tail/backend id). +pub const ABSORB_LABRADOR_LEVEL_CONTEXT: &[u8] = b"hachi/absorb/labrador-level-context"; +/// Absorb Labrador JL projection vector `p`. +pub const ABSORB_LABRADOR_JL_PROJECTION: &[u8] = b"hachi/absorb/labrador-jl-projection"; +/// Absorb Labrador JL nonce. +pub const ABSORB_LABRADOR_JL_NONCE: &[u8] = b"hachi/absorb/labrador-jl-nonce"; +/// Challenge for Labrador aggregation/lift stage. +pub const CHALLENGE_LABRADOR_AGGREGATION: &[u8] = b"hachi/challenge/labrador-aggregation"; +/// Challenge for Labrador JL collapse coefficients. +pub const CHALLENGE_LABRADOR_JL_COLLAPSE: &[u8] = b"hachi/challenge/labrador-jl-collapse"; +/// Absorb Labrador inner commitment u1 at each recursion level. +pub const ABSORB_LABRADOR_U1: &[u8] = b"hachi/absorb/labrador-u1"; +/// Absorb Labrador outer commitment u2 at each recursion level. +pub const ABSORB_LABRADOR_U2: &[u8] = b"hachi/absorb/labrador-u2"; +/// Absorb Labrador lift polynomials (constant-term-removed). +pub const ABSORB_LABRADOR_BB: &[u8] = b"hachi/absorb/labrador-bb"; +/// Absorb Labrador squared norm bound at each level. +pub const ABSORB_LABRADOR_NORM: &[u8] = b"hachi/absorb/labrador-norm"; +/// Challenge for Labrador amortization fold (ring-element challenges). +pub const CHALLENGE_LABRADOR_AMORTIZE: &[u8] = b"hachi/challenge/labrador-amortize"; + +/// Return all Hachi-core transcript labels. +pub fn all_labels() -> &'static [&'static [u8]] { + &[ + DOMAIN_HACHI_PROTOCOL, + ABSORB_COMMITMENT, + ABSORB_EVALUATION_CLAIMS, + CHALLENGE_LINEAR_RELATION, + ABSORB_RING_SWITCH_MESSAGE, + CHALLENGE_RING_SWITCH, + ABSORB_SPARSE_CHALLENGE, + CHALLENGE_SPARSE_CHALLENGE, + ABSORB_SUMCHECK_CLAIM, + ABSORB_SUMCHECK_ROUND, + CHALLENGE_SUMCHECK_ROUND, + CHALLENGE_SUMCHECK_BATCH, + ABSORB_STOP_CONDITION, + CHALLENGE_STOP_CONDITION, + ABSORB_PROVER_V, + CHALLENGE_STAGE1_FOLD, + ABSORB_SUMCHECK_W, + CHALLENGE_TAU0, + CHALLENGE_TAU1, + DOMAIN_LABRADOR_PROTOCOL, + DOMAIN_GREYHOUND_EVAL, + ABSORB_GREYHOUND_EVAL_CONTEXT, + ABSORB_GREYHOUND_EVAL_POINT, + ABSORB_GREYHOUND_EVAL_VALUE, + ABSORB_GREYHOUND_U2, + CHALLENGE_GREYHOUND_FOLD, + ABSORB_LABRADOR_LEVEL_CONTEXT, + ABSORB_LABRADOR_JL_PROJECTION, + ABSORB_LABRADOR_JL_NONCE, + CHALLENGE_LABRADOR_AGGREGATION, + CHALLENGE_LABRADOR_JL_COLLAPSE, + ABSORB_LABRADOR_U1, + ABSORB_LABRADOR_U2, + ABSORB_LABRADOR_BB, + ABSORB_LABRADOR_NORM, + CHALLENGE_LABRADOR_AMORTIZE, + ] +} diff --git a/src/protocol/transcript/mod.rs b/src/protocol/transcript/mod.rs new file mode 100644 index 00000000..e56e673f --- /dev/null +++ b/src/protocol/transcript/mod.rs @@ -0,0 +1,102 @@ +//! Protocol transcript contracts and implementations. + +mod hash; +pub mod labels; + +use crate::algebra::fields::lift::ExtField; +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +use crate::protocol::labrador::challenge::sample_labrador_challenge_coeffs; +use crate::{CanonicalField, FieldCore, FromSmallInt, HachiSerialize}; + +pub use hash::{Blake2bTranscript, HashTranscript, KeccakTranscript}; + +/// Transcript interface for protocol Fiat-Shamir transforms. +/// +/// The protocol layer is label-aware and uses deterministic byte encoding for +/// all absorbed values. +pub trait Transcript: Clone + Send + Sync + 'static +where + F: FieldCore + CanonicalField, +{ + /// Construct a new transcript under a domain label. + fn new(domain_label: &[u8]) -> Self; + + /// Append labeled raw bytes. + fn append_bytes(&mut self, label: &[u8], bytes: &[u8]); + + /// Append a field element with deterministic encoding. + fn append_field(&mut self, label: &[u8], x: &F); + + /// Append a serializable protocol value. + fn append_serde(&mut self, label: &[u8], s: &S); + + /// Derive a challenge scalar under the provided label. + fn challenge_scalar(&mut self, label: &[u8]) -> F; +} + +/// Sample an extension field challenge by drawing `EXT_DEGREE` base-field +/// challenges and assembling them via `from_base_slice`. +/// +/// When `E = F` (degree 1), this compiles to a single `challenge_scalar` call. +pub fn sample_ext_challenge(tr: &mut T, label: &[u8]) -> E +where + F: FieldCore + CanonicalField, + T: Transcript, + E: ExtField, +{ + E::from_base_slice( + &(0..E::EXT_DEGREE) + .map(|_| tr.challenge_scalar(label)) + .collect::>(), + ) +} + +/// Fixed nonce for single-polynomial rejection sampling. +const REJECTION_SAMPLER_SINGLE_NONCE: u64 = 0; + +/// Sample a dense ring-element challenge by drawing `D` scalar challenges. +pub fn challenge_ring_element( + tr: &mut T, + label: &[u8], +) -> CyclotomicRing +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + CyclotomicRing::from_coefficients(std::array::from_fn(|_| tr.challenge_scalar(label))) +} + +/// Sample a sparse ring-element challenge with operator-norm rejection sampling. +/// +/// Squeezes a 16-byte seed from the transcript, then delegates to the Labrador +/// rejection sampler which produces a polynomial with exactly `TAU1` coefficients +/// in {+/-1} and `TAU2` in {+/-2}, retrying until the operator norm is bounded. +/// +/// # Errors +/// +/// Returns an error if `D` is incompatible with the rejection sampler. +pub fn challenge_ring_element_rejection_sampled( + tr: &mut T, + label: &[u8], +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + let mut seed = [0u8; 16]; + for chunk in seed.chunks_mut(8) { + let s = tr.challenge_scalar(label); + let v = s.to_canonical_u128(); + let len = chunk.len(); + chunk.copy_from_slice(&v.to_le_bytes()[..len]); + } + let coeffs = sample_labrador_challenge_coeffs::(1, &seed, REJECTION_SAMPLER_SINGLE_NONCE)?; + let poly = coeffs + .into_iter() + .next() + .ok_or_else(|| HachiError::InvalidInput("rejection sampler produced no output".into()))?; + Ok(CyclotomicRing::from_coefficients(std::array::from_fn( + |i| F::from_i64(poly[i] as i64), + ))) +} diff --git a/src/test_utils.rs b/src/test_utils.rs new file mode 100644 index 00000000..6222c3fd --- /dev/null +++ b/src/test_utils.rs @@ -0,0 +1,180 @@ +//! Shared test configuration and helpers. +//! +//! This module is only compiled under `#[cfg(test)]` and provides common +//! building blocks for both unit tests (inside `src/`) and integration +//! tests (inside `tests/`). + +use std::array::from_fn; + +use crate::algebra::{CyclotomicRing, Fp64}; +use crate::error::HachiError; +use crate::protocol::commitment::utils::flat_matrix::FlatMatrix; +use crate::protocol::commitment::{ + compute_num_digits, compute_num_digits_fold, CommitmentConfig, DecompositionParams, + HachiCommitmentLayout, +}; +use crate::{FieldCore, FromSmallInt}; + +/// Default test field: a 32-bit prime `p = 4294967197`. +pub type F = Fp64<4294967197>; +/// Ring degree used in tests. +pub const D: usize = 64; + +/// Minimal commitment config for fast unit tests. +#[derive(Clone)] +pub struct TinyConfig; + +impl CommitmentConfig for TinyConfig { + const D: usize = 64; + const N_A: usize = 2; + const N_B: usize = 2; + const N_D: usize = 2; + const CHALLENGE_WEIGHT: usize = 3; + + fn decomposition() -> DecompositionParams { + DecompositionParams { + log_basis: 3, + log_commit_bound: 32, + log_open_bound: None, + } + } + + fn commitment_layout(_max_num_vars: usize) -> Result { + HachiCommitmentLayout::new::(1, 1, &Self::decomposition()) + } +} + +/// Number of ring elements per block (`2^m_vars`). +pub const BLOCK_LEN: usize = 2; +/// Number of blocks (`2^r_vars`). +pub const NUM_BLOCKS: usize = 2; +/// Gadget base exponent (`b = 2^log_basis()`), derived from `TinyConfig`. +pub fn log_basis() -> u32 { + TinyConfig::decomposition().log_basis +} +/// Inner Ajtai row count from `TinyConfig`. +pub const N_A: usize = TinyConfig::N_A; + +/// Decomposition depth for original coefficients under `TinyConfig`. +pub fn num_digits_commit() -> usize { + let d = TinyConfig::decomposition(); + compute_num_digits(d.log_commit_bound, d.log_basis) +} + +/// Decomposition depth for opening / full-field coefficients under `TinyConfig`. +pub fn num_digits_open() -> usize { + let d = TinyConfig::decomposition(); + let log_open = d.log_open_bound.unwrap_or(d.log_commit_bound); + compute_num_digits(log_open, d.log_basis) +} + +/// Decomposition depth for the folded witness `z_pre` under `TinyConfig`. +pub fn num_digits_fold() -> usize { + let d = TinyConfig::decomposition(); + compute_num_digits_fold(1, TinyConfig::CHALLENGE_WEIGHT, d.log_basis) +} + +/// Dense matrix-vector multiply over cyclotomic rings. +/// +/// Matrix rows may be wider than `vec` (e.g. when matrices are widened for +/// multi-level folding); extra columns are treated as multiplying zero. +pub fn mat_vec_mul(mat: &FlatMatrix, vec: &[CyclotomicRing]) -> Vec> { + let view = mat.view::(); + (0..view.num_rows()) + .map(|i| { + let row = view.row(i); + assert!(row.len() >= vec.len()); + row.iter() + .zip(vec.iter()) + .fold(CyclotomicRing::::zero(), |acc, (a, x)| { + acc + (*a * *x) + }) + }) + .collect() +} + +/// Generate deterministic test blocks of ring elements. +pub fn sample_blocks() -> Vec>> { + (0..NUM_BLOCKS) + .map(|bi| { + (0..BLOCK_LEN) + .map(|bj| { + let coeffs = from_fn(|k| F::from_u64((bi * 1_000 + bj * 100 + k) as u64)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect() +} + +/// Generate deterministic inner opening-point scalars. +pub fn sample_a() -> Vec { + (0..BLOCK_LEN) + .map(|j| F::from_u64((j * 10 + 1) as u64)) + .collect() +} + +/// Generate deterministic outer opening-point scalars. +pub fn sample_b() -> Vec { + (0..NUM_BLOCKS) + .map(|i| F::from_u64((i * 7 + 3) as u64)) + .collect() +} + +/// Recompose a gadget-decomposed ring element: `sum_i parts[i] * b^i`. +pub fn field_gadget_recompose( + parts: &[CyclotomicRing], + log_basis: u32, +) -> CyclotomicRing { + let b = F::from_u64(1u64 << log_basis); + let mut result = CyclotomicRing::::zero(); + let mut b_power = F::one(); + for part in parts { + result += part.scale(&b_power); + b_power *= b; + } + result +} + +/// Recompose `z_hat` chunks (num_digits_fold-width) back to `z_pre` elements. +pub fn recompose_z_hat(z_hat: &[CyclotomicRing]) -> Vec> { + z_hat + .chunks(num_digits_fold()) + .map(|chunk| field_gadget_recompose(chunk, log_basis())) + .collect() +} + +/// Recompose a vector of gadget-decomposed elements (num_digits_commit-width chunks). +pub fn gadget_recompose_vec(x_hat: &[CyclotomicRing]) -> Vec> { + x_hat + .chunks(num_digits_commit()) + .map(|chunk| field_gadget_recompose(chunk, log_basis())) + .collect() +} + +/// Recompose a vector of i8 gadget-decomposed digit planes (num_digits_commit-width chunks). +pub fn gadget_recompose_vec_i8(x_hat: &[[i8; D]]) -> Vec> { + x_hat + .chunks(num_digits_commit()) + .map(|chunk| CyclotomicRing::gadget_recompose_pow2_i8(chunk, log_basis())) + .collect() +} + +/// Alias for [`gadget_recompose_vec`] (same num_digits_commit-width recomposition). +pub fn field_gadget_recompose_vec(v: &[CyclotomicRing]) -> Vec> { + v.chunks(num_digits_commit()) + .map(|chunk| field_gadget_recompose(chunk, log_basis())) + .collect() +} + +/// Compute `a^T * G^{-1}(z)`: recompose `z` then inner-product with `a`. +pub fn a_transpose_gadget_times_vec(a: &[F], z: &[CyclotomicRing]) -> CyclotomicRing { + let recomposed = field_gadget_recompose_vec(z); + assert_eq!(recomposed.len(), a.len()); + recomposed + .iter() + .zip(a.iter()) + .fold(CyclotomicRing::::zero(), |acc, (z_j, a_j)| { + acc + z_j.scale(a_j) + }) +} diff --git a/tests/algebra.rs b/tests/algebra.rs new file mode 100644 index 00000000..668cacd0 --- /dev/null +++ b/tests/algebra.rs @@ -0,0 +1,1469 @@ +#![allow(missing_docs)] + +#[cfg(test)] +mod tests { + use num_bigint::BigUint; + use rand::{rngs::StdRng, SeedableRng}; + + use hachi_pcs::algebra::backend::{CrtReconstruct, NttPrimeOps}; + use hachi_pcs::algebra::ntt::butterfly::{forward_ntt, inverse_ntt, NttTwiddles}; + use hachi_pcs::algebra::poly::Poly; + use hachi_pcs::algebra::tables::{ + q128_garner, q128_primes, q32_garner, q64_garner, q64_primes, Q128_MODULUS, + Q128_NUM_PRIMES, Q32_MODULUS, Q32_NUM_PRIMES, Q32_PRIMES, Q64_MODULUS, Q64_NUM_PRIMES, + }; + use hachi_pcs::algebra::{ + pseudo_mersenne_modulus, Pow2Offset128Field, Pow2OffsetPrimeSpec, POW2_OFFSET_MAX, + POW2_OFFSET_PRIMES, POW2_OFFSET_TABLE, + }; + use hachi_pcs::algebra::{ + CyclotomicCrtNtt, CyclotomicRing, Fp128, Fp2, Fp2Config, Fp32, Fp4, Fp4Config, Fp64, LimbQ, + MontCoeff, Prime128M13M4P0, Prime128M37P3P0, Prime128M52M3P0, Prime128M54P4P0, + Prime128M8M4M1M0, ScalarBackend, VectorModule, + }; + use hachi_pcs::primitives::serialization::SerializationError; + use hachi_pcs::{ + CanonicalField, FieldCore, FieldSampling, FromSmallInt, HachiDeserialize, HachiSerialize, + Invertible, Module, PseudoMersenneField, + }; + + const P_159: u128 = 340282366920938463463374607431768211297u128; + + #[test] + fn fp32_basic_arith() { + type F = Fp32<251>; + let a = F::from_u64(17); + let b = F::from_u64(99); + assert_eq!((a + b).to_canonical_u32(), 116); + assert_eq!((a * b).to_canonical_u32(), (17u32 * 99) % 251); + + let inv = a.inv().unwrap(); + assert_eq!(a * inv, F::one()); + } + + #[test] + fn fp64_hachi_q_inv() { + type F = Fp64<4294967197>; + let two = F::from_u64(2); + let inv2 = two.inv().unwrap(); + assert_eq!(two * inv2, F::one()); + } + + #[test] + fn fp128_basic_arith() { + type F = Fp128; + + let a = F::from_u64(123); + let b = F::from_u64(456); + let c = a * b + a - b; + let inv = c.inv().unwrap(); + assert_eq!(c * inv, F::one()); + } + + fn rand_u128(rng: &mut R) -> u128 { + let lo = rng.next_u64() as u128; + let hi = rng.next_u64() as u128; + lo | (hi << 64) + } + + fn biguint_to_u128(x: &num_bigint::BigUint) -> u128 { + let mut bytes = x.to_bytes_le(); + bytes.resize(16, 0); + let mut arr = [0u8; 16]; + arr.copy_from_slice(&bytes[..16]); + u128::from_le_bytes(arr) + } + + fn big_mul_mod_u128(a: u128, b: u128, p: u128) -> u128 { + let n = BigUint::from(a) * BigUint::from(b); + let r = n % BigUint::from(p); + biguint_to_u128(&r) + } + + fn check_solinas_prime< + S: CanonicalField + FieldCore + Invertible + PseudoMersenneField + std::fmt::Debug, + >( + p: u128, + iters: usize, + seed: u64, + ) { + assert_eq!(::MODULUS_BITS, 128); + assert_eq!( + ::MODULUS_OFFSET, + 0u128.wrapping_sub(p) + ); + assert_eq!(std::mem::size_of::(), 16); + + let mut rng = StdRng::seed_from_u64(seed); + + for _ in 0..iters { + let a_raw = rand_u128(&mut rng); + let b_raw = rand_u128(&mut rng); + + let a = S::from_canonical_u128_reduced(a_raw); + let b = S::from_canonical_u128_reduced(b_raw); + + // Canonical range invariant. + assert!(a.to_canonical_u128() < p); + assert!(b.to_canonical_u128() < p); + + // Add/sub/neg identities. + assert_eq!(a + S::zero(), a); + assert_eq!(a - S::zero(), a); + assert_eq!(a + (-a), S::zero()); + + // Multiplicative identity. + assert_eq!(a * S::one(), a); + + // BigUint oracle for mul and sqr (exercises reduction). + let aa = a.to_canonical_u128(); + let bb = b.to_canonical_u128(); + let got_mul = (a * b).to_canonical_u128(); + let exp_mul = big_mul_mod_u128(aa, bb, p); + assert_eq!(got_mul, exp_mul); + + let got_sqr = (a * a).to_canonical_u128(); + let exp_sqr = big_mul_mod_u128(aa, aa, p); + assert_eq!(got_sqr, exp_sqr); + + // Inversion checks (skip explicit inv on zero). + let inv = a.inv_or_zero(); + if a.is_zero() { + assert_eq!(inv, S::zero()); + } else { + assert_eq!(a * inv, S::one()); + assert_eq!(a.inv().unwrap(), inv); + } + } + } + + #[test] + fn fp128_sparse_primes_match_biguint_oracle() { + // These are the five sparse `2^128 - c` primes we care about. + const P13: u128 = 0xffffffffffffffffffffffffffffdff1u128; + const P37: u128 = 0xffffffffffffffffffffffe000000009u128; + const P52: u128 = 0xffffffffffffffffffeffffffffffff9u128; + const P54: u128 = 0xffffffffffffffffffc0000000000011u128; + const P275: u128 = 0xfffffffffffffffffffffffffffffeedu128; + + check_solinas_prime::(P13, 2_000, 13); + check_solinas_prime::(P37, 2_000, 37); + check_solinas_prime::(P52, 2_000, 52); + check_solinas_prime::(P54, 2_000, 54); + check_solinas_prime::(P275, 2_000, 275); + } + + struct NR; + impl Fp2Config> for NR { + fn non_residue() -> Fp32<251> { + -Fp32::<251>::one() + } + } + + struct NR4; + impl Fp4Config, NR> for NR4 { + fn non_residue() -> Fp2, NR> { + Fp2::new(Fp32::<251>::zero(), Fp32::<251>::one()) + } + } + + #[test] + fn fp2_fp4_inversion_smoke() { + type F = Fp32<251>; + type F2 = Fp2; + type F4 = Fp4; + + let x = F2::new(F::from_u64(3), F::from_u64(7)); + let inv = x.inv().unwrap(); + assert!((x * inv) == F2::one()); + + let y = F4::new( + F2::new(F::from_u64(5), F::from_u64(1)), + F2::new(F::from_u64(2), F::from_u64(9)), + ); + let invy = y.inv().unwrap(); + assert!((y * invy) == F4::one()); + } + + #[test] + fn vector_module_ops() { + type F = Fp32<251>; + + let a = VectorModule::([F::from_u64(1), F::from_u64(2), F::from_u64(3)]); + let b = VectorModule::([F::from_u64(3), F::from_u64(4), F::from_u64(5)]); + + let c = a + b; + assert_eq!(c.0[0], F::from_u64(4)); + + let d = a.scale(&F::from_u64(7)); + assert_eq!(d.0[1], F::from_u64(14)); + } + + #[test] + fn inv_zero_returns_none() { + assert!(Fp32::<251>::zero().inv().is_none()); + assert!(Fp64::<4294967197>::zero().inv().is_none()); + assert!(Fp128::::zero().inv().is_none()); + } + + #[test] + fn inv_or_zero_behavior_for_prime_fields() { + type F32 = Fp32<251>; + assert_eq!(F32::zero().inv_or_zero(), F32::zero()); + let x32 = F32::from_u64(17); + let inv32 = x32.inv_or_zero(); + assert_eq!(x32 * inv32, F32::one()); + + type F64 = Fp64<4294967197>; + assert_eq!(F64::zero().inv_or_zero(), F64::zero()); + let x64 = F64::from_u64(2); + let inv64 = x64.inv_or_zero(); + assert_eq!(x64 * inv64, F64::one()); + + type F128 = Fp128; + assert_eq!(F128::zero().inv_or_zero(), F128::zero()); + let x128 = F128::from_u64(12345); + let inv128 = x128.inv_or_zero(); + assert_eq!(x128 * inv128, F128::one()); + } + + #[test] + fn field_identities_fp32() { + type F = Fp32<251>; + let a = F::from_u64(42); + let b = F::from_u64(73); + let c = F::from_u64(11); + + // Additive identity + assert_eq!(a + F::zero(), a); + // Multiplicative identity + assert_eq!(a * F::one(), a); + // Additive inverse + assert_eq!(a + (-a), F::zero()); + // Distributivity + assert_eq!(a * (b + c), a * b + a * c); + // Commutativity + assert_eq!(a * b, b * a); + assert_eq!(a + b, b + a); + } + + #[test] + fn field_identities_fp64() { + type F = Fp64<4294967197>; + let a = F::from_u64(123456); + let b = F::from_u64(789012); + let c = F::from_u64(345678); + + assert_eq!(a + F::zero(), a); + assert_eq!(a * F::one(), a); + assert_eq!(a + (-a), F::zero()); + assert_eq!(a * (b + c), a * b + a * c); + } + + #[test] + fn field_identities_fp128() { + type F = Fp128; + let a = F::from_u64(999999); + let b = F::from_u64(888888); + let c = F::from_u64(777777); + + assert_eq!(a + F::zero(), a); + assert_eq!(a * F::one(), a); + assert_eq!(a + (-a), F::zero()); + assert_eq!(a * (b + c), a * b + a * c); + } + + #[test] + fn serialization_round_trip_fp32() { + type F = Fp32<251>; + let val = F::from_u64(42); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn serialization_round_trip_fp64() { + type F = Fp64<4294967197>; + let val = F::from_u64(123456789); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn serialization_round_trip_fp128() { + type F = Fp128; + let val = F::from_u64(999999999); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn serialization_round_trip_ext() { + type F = Fp32<251>; + type F2 = Fp2; + let val = F2::new(F::from_u64(3), F::from_u64(7)); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F2::deserialize_compressed(&buf[..]).unwrap(); + assert!(val == restored); + } + + #[test] + fn serialization_round_trip_fp4() { + type F = Fp32<251>; + type F2 = Fp2; + type F4 = Fp4; + + let val = F4::new( + F2::new(F::from_u64(5), F::from_u64(1)), + F2::new(F::from_u64(2), F::from_u64(9)), + ); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F4::deserialize_compressed(&buf[..]).unwrap(); + assert!(val == restored); + } + + #[test] + fn serialization_round_trip_vector_module() { + type F = Fp32<251>; + let val = VectorModule::([F::from_u64(1), F::from_u64(2), F::from_u64(3)]); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = VectorModule::::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn serialization_round_trip_poly() { + type F = Fp32<251>; + + let val = Poly::([ + F::from_u64(7), + F::from_u64(11), + F::from_u64(13), + F::from_u64(29), + ]); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = Poly::::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn deserialize_checked_rejects_non_canonical_field_elements() { + type F32 = Fp32<251>; + let bad32 = 251u32.to_le_bytes(); + let err32 = F32::deserialize_compressed(&bad32[..]).unwrap_err(); + assert!(matches!(err32, SerializationError::InvalidData(_))); + let unchecked32 = F32::deserialize_compressed_unchecked(&bad32[..]).unwrap(); + assert_eq!(unchecked32, F32::zero()); + + type F64 = Fp64<4294967197>; + let bad64 = 4294967197u64.to_le_bytes(); + let err64 = F64::deserialize_compressed(&bad64[..]).unwrap_err(); + assert!(matches!(err64, SerializationError::InvalidData(_))); + let unchecked64 = F64::deserialize_compressed_unchecked(&bad64[..]).unwrap(); + assert_eq!(unchecked64, F64::zero()); + + type F128 = Fp128; + let bad128 = P_159.to_le_bytes(); + let err128 = F128::deserialize_compressed(&bad128[..]).unwrap_err(); + assert!(matches!(err128, SerializationError::InvalidData(_))); + let unchecked128 = F128::deserialize_compressed_unchecked(&bad128[..]).unwrap(); + assert_eq!(unchecked128, F128::zero()); + + // Sparse 128-bit prime: same checked/unchecked behavior. + type S13 = Prime128M13M4P0; + const P13: u128 = 0xffffffffffffffffffffffffffffdff1u128; + let bad13 = P13.to_le_bytes(); + let err13 = S13::deserialize_compressed(&bad13[..]).unwrap_err(); + assert!(matches!(err13, SerializationError::InvalidData(_))); + let unchecked13 = S13::deserialize_compressed_unchecked(&bad13[..]).unwrap(); + assert_eq!(unchecked13, S13::zero()); + } + + #[test] + fn fp2_conjugate_and_norm() { + type F = Fp32<251>; + type F2 = Fp2; + let x = F2::new(F::from_u64(3), F::from_u64(7)); + let conj = x.conjugate(); + assert!(conj == F2::new(F::from_u64(3), -F::from_u64(7))); + // For Fp2 with u^2 = -1: norm = c0^2 + c1^2 = 9 + 49 = 58 + assert_eq!(x.norm(), F::from_u64(58)); + // x * conjugate(x) should embed the norm into Fp2 + let prod = x * conj; + assert!(prod == F2::new(F::from_u64(58), F::zero())); + } + + #[test] + fn fp2_distributivity() { + type F = Fp32<251>; + type F2 = Fp2; + let a = F2::new(F::from_u64(3), F::from_u64(7)); + let b = F2::new(F::from_u64(11), F::from_u64(5)); + let c = F2::new(F::from_u64(2), F::from_u64(9)); + assert!(a * (b + c) == a * b + a * c); + } + + #[test] + fn limbq_from_to_u128_round_trip() { + for &val in &[0u128, 1, 12345, 123456789, (1u128 << 28) - 1] { + let limb: LimbQ<3> = LimbQ::from(val); + assert_eq!( + u128::try_from(limb).unwrap(), + val, + "round-trip failed for {val}" + ); + } + } + + #[test] + fn limbq_add_sub_inverse() { + let a: LimbQ<3> = LimbQ::from(12345u128); + let b: LimbQ<3> = LimbQ::from(6789u128); + let sum = a + b; + let diff = sum - b; + assert_eq!(diff, a); + } + + #[test] + fn limbq_ordering() { + let a: LimbQ<3> = LimbQ::from(100u128); + let b: LimbQ<3> = LimbQ::from(200u128); + assert!(a < b); + assert!(b > a); + assert_eq!(a, a); + } + + #[test] + fn ntt_normalize_in_range() { + for prime in &Q32_PRIMES { + for &a in &[0i16, 1, -1, 100, -100, prime.p - 1, -(prime.p - 1)] { + let n = prime.normalize(MontCoeff::from_raw(a)); + assert!( + n.raw() >= 0 && n.raw() < prime.p, + "normalize({a}) = {} for p={}", + n.raw(), + prime.p + ); + } + } + } + + #[test] + fn csubp_widened_handles_large_negative_i16() { + for &prime in &Q32_PRIMES { + let p = prime.p; + // Values in (-2p, -(2^15 - p)) that previously overflowed in narrow i16 + for &raw in &[-20000i16, -(p + p / 2), -(p + 1000)] { + if raw <= -2 * p || raw >= 0 { + continue; + } + let a = MontCoeff::from_raw(raw); + let reduced = prime.reduce_range(a); + let r = reduced.raw(); + assert!( + r > -p && r < p, + "reduce_range({raw}) = {r} not in (-{p}, {p}) for p={p}" + ); + + let norm = prime.normalize(reduced); + let n = norm.raw(); + assert!( + n >= 0 && n < p, + "normalize(reduce_range({raw})) = {n} not in [0, {p}) for p={p}" + ); + } + } + } + + #[test] + fn ntt_mul_commutative() { + let prime = Q32_PRIMES[0]; + let a = MontCoeff::from_raw(1234); + let b = MontCoeff::from_raw(5678); + assert_eq!(prime.mul(a, b), prime.mul(b, a)); + } + + #[test] + fn mont_coeff_round_trip() { + for prime in &Q32_PRIMES { + for &val in &[0i16, 1, 2, 100, prime.p - 1] { + let mont = prime.from_canonical(val); + let back = prime.to_canonical(mont); + assert_eq!(back, val, "round-trip failed for val={val}, p={}", prime.p); + } + } + } + + #[test] + fn poly_add_sub_neg() { + type F = Fp32<251>; + let a = Poly::([F::from_u64(1), F::from_u64(2), F::from_u64(3)]); + let b = Poly::([F::from_u64(10), F::from_u64(20), F::from_u64(30)]); + + let sum = a + b; + assert_eq!(sum.0[0], F::from_u64(11)); + assert_eq!(sum.0[1], F::from_u64(22)); + assert_eq!(sum.0[2], F::from_u64(33)); + + let diff = b - a; + assert_eq!(diff.0[0], F::from_u64(9)); + + let neg_a = -a; + assert_eq!(a + neg_a, Poly::zero()); + } + + #[test] + fn cyclotomic_ring_negacyclic_property() { + type F = Fp32<251>; + type R = CyclotomicRing; + + // X in the ring: [0, 1, 0, 0] + let x = R::x(); + + // X^2 + let x2 = x * x; + let expected_x2 = R::from_coefficients([F::zero(), F::zero(), F::one(), F::zero()]); + assert_eq!(x2, expected_x2); + + // X^4 should equal -1 (because X^4 + 1 = 0 in the ring) + let x4 = x2 * x2; + assert_eq!(x4, -R::one(), "X^D should equal -1 in Z_q[X]/(X^D + 1)"); + } + + #[test] + fn cyclotomic_ring_mul_identity() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + assert_eq!(a * R::one(), a); + assert_eq!(R::one() * a, a); + } + + #[test] + fn cyclotomic_ring_mul_zero() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + assert_eq!(a * R::zero(), R::zero()); + } + + #[test] + fn cyclotomic_ring_commutativity() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + let b = R::from_coefficients([ + F::from_u64(5), + F::from_u64(13), + F::from_u64(99), + F::from_u64(1), + ]); + assert_eq!(a * b, b * a); + } + + #[test] + fn cyclotomic_ring_distributivity() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + let b = R::from_coefficients([ + F::from_u64(5), + F::from_u64(13), + F::from_u64(99), + F::from_u64(1), + ]); + let c = R::from_coefficients([ + F::from_u64(2), + F::from_u64(9), + F::from_u64(50), + F::from_u64(77), + ]); + assert_eq!(a * (b + c), a * b + a * c); + } + + #[test] + fn cyclotomic_ring_associativity() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + let b = R::from_coefficients([ + F::from_u64(5), + F::from_u64(13), + F::from_u64(99), + F::from_u64(1), + ]); + let c = R::from_coefficients([ + F::from_u64(2), + F::from_u64(9), + F::from_u64(50), + F::from_u64(77), + ]); + assert_eq!((a * b) * c, a * (b * c)); + } + + #[test] + fn cyclotomic_ring_additive_inverse() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + assert_eq!(a + (-a), R::zero()); + } + + #[test] + fn cyclotomic_ring_serialization_round_trip() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + let mut buf = Vec::new(); + a.serialize_compressed(&mut buf).unwrap(); + let restored = R::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(a, restored); + } + + #[test] + fn cyclotomic_ring_degree_64() { + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + // X^64 = -1 in Z_q[X]/(X^64 + 1) + let x = R::x(); + let mut power = R::one(); + for _ in 0..64 { + power *= x; + } + assert_eq!(power, -R::one(), "X^64 should equal -1"); + } + + #[test] + fn ntt_forward_inverse_round_trip() { + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + + let original: [MontCoeff; 64] = + std::array::from_fn(|i| prime.from_canonical((i as i16) % prime.p)); + + let mut a = original; + forward_ntt(&mut a, prime, &tw); + inverse_ntt(&mut a, prime, &tw); + + for (i, (got, expected)) in a.iter().zip(original.iter()).enumerate() { + let got_canon = prime.to_canonical(prime.normalize(*got)); + let exp_canon = prime.to_canonical(prime.normalize(*expected)); + assert_eq!( + got_canon, exp_canon, + "NTT round-trip mismatch at index {i}: got {got_canon}, expected {exp_canon}" + ); + } + } + + #[test] + fn ntt_forward_inverse_all_primes() { + for (pi, prime) in Q32_PRIMES.iter().enumerate() { + let tw = NttTwiddles::::compute(*prime); + + let original: [_; 64] = + std::array::from_fn(|i| prime.from_canonical(((i * (pi + 1)) as i16) % prime.p)); + + let mut a = original; + forward_ntt(&mut a, *prime, &tw); + inverse_ntt(&mut a, *prime, &tw); + + for (i, (got, expected)) in a.iter().zip(original.iter()).enumerate() { + let g = prime.to_canonical(prime.normalize(*got)); + let e = prime.to_canonical(prime.normalize(*expected)); + assert_eq!( + g, e, + "prime[{pi}] p={}: round-trip mismatch at index {i}", + prime.p + ); + } + } + } + + #[test] + fn negacyclic_ntt_mul_matches_schoolbook_single_prime_d8() { + const D: usize = 8; + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + + let a_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 7) + 3) % prime.p); + let b_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 5) + 11) % prime.p); + + // Schoolbook negacyclic convolution mod p: X^D = -1. + let mut school = [0i16; D]; + for (i, &ai) in a_canon.iter().enumerate() { + for (j, &bj) in b_canon.iter().enumerate() { + let prod = (ai as i64 * bj as i64) % (prime.p as i64); + let idx = i + j; + if idx < D { + school[idx] = ((school[idx] as i64 + prod) % (prime.p as i64)) as i16; + } else { + let k = idx - D; + school[k] = ((school[k] as i64 - prod) % (prime.p as i64)) as i16; + } + } + } + for x in &mut school { + if *x < 0 { + *x = (*x as i64 + prime.p as i64) as i16; + } + } + + let mut a = std::array::from_fn(|i| prime.from_canonical(a_canon[i])); + let mut b = std::array::from_fn(|i| prime.from_canonical(b_canon[i])); + forward_ntt(&mut a, prime, &tw); + forward_ntt(&mut b, prime, &tw); + + let mut c: [_; D] = std::array::from_fn(|i| prime.mul(a[i], b[i])); + inverse_ntt(&mut c, prime, &tw); + + let got: [i16; D] = std::array::from_fn(|i| prime.to_canonical(prime.normalize(c[i]))); + assert_eq!(got, school); + } + + #[test] + fn negacyclic_ntt_forward_matches_manual_evals_d8() { + const D: usize = 8; + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + let p = prime.p as i64; + + fn pow_mod(mut base: i64, mut exp: i64, modulus: i64) -> i64 { + let mut acc = 1i64; + base %= modulus; + while exp > 0 { + if exp & 1 == 1 { + acc = (acc * base) % modulus; + } + base = (base * base) % modulus; + exp >>= 1; + } + acc + } + + // Compute canonical psi (primitive 2D-th root) directly. + let half = (p - 1) / 2; + let exp = (p - 1) / (2 * D as i64); + let mut psi = None; + for a in 2..p { + if pow_mod(a, half, p) == p - 1 { + let cand = pow_mod(a, exp, p); + if pow_mod(cand, D as i64, p) == p - 1 { + psi = Some(cand); + break; + } + } + } + let psi = psi.expect("psi should exist"); + let a_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 7) + 3) % prime.p); + + let mut expected = Vec::with_capacity(D); + for k in 0..D { + let alpha = pow_mod(psi, (2 * k + 1) as i64, p); + let mut acc = 0i64; + let mut power = 1i64; + for &ai in &a_canon { + acc = (acc + (ai as i64) * power) % p; + power = (power * alpha) % p; + } + expected.push(acc as i16); + } + expected.sort_unstable(); + + let mut a = std::array::from_fn(|i| prime.from_canonical(a_canon[i])); + forward_ntt(&mut a, prime, &tw); + let mut got: Vec = a + .iter() + .map(|x| prime.to_canonical(prime.normalize(*x))) + .collect(); + got.sort_unstable(); + + assert_eq!(got, expected); + } + + #[test] + fn negacyclic_ntt_mul_matches_schoolbook_single_prime_d64() { + const D: usize = 64; + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + let p = prime.p as i64; + + let a_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 7) + 3) % prime.p); + let b_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 5) + 11) % prime.p); + + let mut school = [0i16; D]; + for (i, &ai) in a_canon.iter().enumerate() { + for (j, &bj) in b_canon.iter().enumerate() { + let prod = (ai as i64 * bj as i64) % p; + let idx = i + j; + if idx < D { + school[idx] = ((school[idx] as i64 + prod) % p) as i16; + } else { + let k = idx - D; + school[k] = ((school[k] as i64 - prod) % p) as i16; + } + } + } + for x in &mut school { + if *x < 0 { + *x = (*x as i64 + p) as i16; + } + } + + let mut a = std::array::from_fn(|i| prime.from_canonical(a_canon[i])); + let mut b = std::array::from_fn(|i| prime.from_canonical(b_canon[i])); + forward_ntt(&mut a, prime, &tw); + forward_ntt(&mut b, prime, &tw); + + let mut c: [_; D] = std::array::from_fn(|i| prime.reduce_range(prime.mul(a[i], b[i]))); + inverse_ntt(&mut c, prime, &tw); + + let got: [i16; D] = std::array::from_fn(|i| prime.to_canonical(prime.normalize(c[i]))); + assert_eq!(got, school); + } + + #[test] + fn negacyclic_ntt_mul_matches_schoolbook_all_q32_primes_d64() { + const D: usize = 64; + let a_canon: [i16; D] = std::array::from_fn(|i| (i as i16 * 7 + 3)); + let b_canon: [i16; D] = std::array::from_fn(|i| (i as i16 * 5 + 11)); + + for (pi, &prime) in Q32_PRIMES.iter().enumerate() { + let tw = NttTwiddles::::compute(prime); + let p = prime.p as i64; + + let a_mod: [i16; D] = + std::array::from_fn(|i| ((a_canon[i] as i64).rem_euclid(p)) as i16); + let b_mod: [i16; D] = + std::array::from_fn(|i| ((b_canon[i] as i64).rem_euclid(p)) as i16); + + let mut school = [0i16; D]; + for (i, &ai) in a_mod.iter().enumerate() { + for (j, &bj) in b_mod.iter().enumerate() { + let prod = (ai as i64 * bj as i64) % p; + let idx = i + j; + if idx < D { + school[idx] = ((school[idx] as i64 + prod) % p) as i16; + } else { + let k = idx - D; + school[k] = ((school[k] as i64 - prod) % p) as i16; + } + } + } + for x in &mut school { + if *x < 0 { + *x = (*x as i64 + p) as i16; + } + } + + let mut a = std::array::from_fn(|i| prime.from_canonical(a_mod[i])); + let mut b = std::array::from_fn(|i| prime.from_canonical(b_mod[i])); + forward_ntt(&mut a, prime, &tw); + forward_ntt(&mut b, prime, &tw); + + let mut c = [MontCoeff::from_raw(0i16); D]; + for i in 0..D { + c[i] = prime.reduce_range(prime.mul(a[i], b[i])); + } + inverse_ntt(&mut c, prime, &tw); + + let got: [i16; D] = std::array::from_fn(|i| prime.to_canonical(prime.normalize(c[i]))); + assert_eq!(got, school, "prime[{pi}] p={} mismatch", prime.p); + } + } + + #[test] + fn cyclotomic_ntt_crt_round_trip_q32() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + + let coeffs: [F; 64] = + std::array::from_fn(|i| F::from_u64(((i as u64 * 17) + 5) % Q32_MODULUS)); + let ring = R::from_coefficients(coeffs); + let ntt = N::from_ring(&ring, &Q32_PRIMES, &twiddles); + let garner = q32_garner(); + let round_trip = ntt.to_ring(&Q32_PRIMES, &twiddles, &garner); + + assert_eq!(ring, round_trip); + } + + #[test] + fn cyclotomic_ntt_reduced_ops_are_stable() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + + let a = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 3) + 1) % Q32_MODULUS) + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 11) + 7) % Q32_MODULUS) + })); + + let ntt_a = N::from_ring(&a, &Q32_PRIMES, &twiddles); + let ntt_b = N::from_ring(&b, &Q32_PRIMES, &twiddles); + + let sum = ntt_a.add_reduced(&ntt_b, &Q32_PRIMES); + let back = sum.sub_reduced(&ntt_b, &Q32_PRIMES); + assert_eq!(back, ntt_a); + + let garner = q32_garner(); + let zero_ntt = ntt_a.add_reduced(&ntt_a.neg_reduced(&Q32_PRIMES), &Q32_PRIMES); + let zero_ring = zero_ntt.to_ring(&Q32_PRIMES, &twiddles, &garner); + assert_eq!(zero_ring, R::zero()); + } + + #[test] + fn backend_path_matches_default_scalar_path() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + let ring = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 13) + 9) % Q32_MODULUS) + })); + + let default_ntt = N::from_ring(&ring, &Q32_PRIMES, &twiddles); + let backend_ntt = + N::from_ring_with_backend::(&ring, &Q32_PRIMES, &twiddles); + assert_eq!(default_ntt, backend_ntt); + + let garner = q32_garner(); + let default_back = default_ntt.to_ring(&Q32_PRIMES, &twiddles, &garner); + let backend_back = + backend_ntt.to_ring_with_backend::(&Q32_PRIMES, &twiddles, &garner); + assert_eq!(default_back, backend_back); + } + + #[test] + fn crt_ntt_mul_matches_schoolbook_q32() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + let garner = q32_garner(); + + let a = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 7) + 3) % Q32_MODULUS) + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 5) + 11) % Q32_MODULUS) + })); + + let schoolbook = a * b; + + let ntt_a = N::from_ring(&a, &Q32_PRIMES, &twiddles); + let ntt_b = N::from_ring(&b, &Q32_PRIMES, &twiddles); + let ntt_prod = ntt_a.pointwise_mul(&ntt_b, &Q32_PRIMES); + let ntt_result: R = ntt_prod.to_ring(&Q32_PRIMES, &twiddles, &garner); + + assert_eq!(schoolbook, ntt_result); + } + + #[test] + fn q128_garner_reconstruct_matches_coeffs_no_ntt() { + type F = Fp128<{ Q128_MODULUS }>; + + let primes = q128_primes(); + let garner = q128_garner(); + + let coeffs: [F; 64] = std::array::from_fn(|i| { + if i < 8 { + F::from_u64((i as u64 * 31) + 7) + } else { + F::zero() + } + }); + + let mut canonical = [[0i32; 64]; Q128_NUM_PRIMES]; + for (k, prime) in primes.iter().enumerate() { + let p = prime.p as u32 as u128; + for (i, c) in coeffs.iter().enumerate() { + canonical[k][i] = (c.to_canonical_u128() % p) as i32; + } + } + + let reconstructed: [F; 64] = + >::reconstruct( + &primes, &canonical, &garner, + ); + + assert_eq!(reconstructed, coeffs); + } + + #[test] + fn q128_prime_ntt_round_trip_per_prime() { + let primes = q128_primes(); + let twiddles: [NttTwiddles; Q128_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + + // Use the same sparse coefficient pattern as q128_ntt_round_trip, but test + // the per-prime NTT+Montgomery machinery in isolation (no Garner/Fp128). + let residues: [u32; 64] = + std::array::from_fn(|i| if i < 8 { (i as u32 * 31) + 7 } else { 0 }); + + for k in 0..Q128_NUM_PRIMES { + let prime = primes[k]; + let mut limb = [MontCoeff::from_raw(0i32); 64]; + for (i, r) in residues.iter().enumerate() { + let reduced = (*r as i64 % (prime.p as i64)) as i32; + limb[i] = >::from_canonical(prime, reduced); + } + + forward_ntt(&mut limb, prime, &twiddles[k]); + inverse_ntt(&mut limb, prime, &twiddles[k]); + + for (i, r) in residues.iter().enumerate() { + let expected = (*r as i64 % (prime.p as i64)) as i32; + let got = >::to_canonical(prime, limb[i]); + assert_eq!(got, expected, "prime idx={k} coeff idx={i}"); + } + } + } + + #[test] + fn q128_ntt_round_trip() { + type F = Fp128<{ Q128_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let primes = q128_primes(); + let twiddles: [NttTwiddles; Q128_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = q128_garner(); + + let coeffs: [F; 64] = std::array::from_fn(|i| { + if i < 8 { + F::from_u64((i as u64 * 31) + 7) + } else { + F::zero() + } + }); + let ring = R::from_coefficients(coeffs); + let ntt = N::from_ring(&ring, &primes, &twiddles); + let round_trip: R = ntt.to_ring(&primes, &twiddles, &garner); + + assert_eq!(ring, round_trip); + } + + #[test] + fn crt_ntt_mul_matches_schoolbook_q128() { + type F = Fp128<{ Q128_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let primes = q128_primes(); + let twiddles: [NttTwiddles; Q128_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = q128_garner(); + + let a = R::from_coefficients(std::array::from_fn(|i| { + if i < 8 { + F::from_u64((i as u64 * 7) + 3) + } else { + F::zero() + } + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + if i < 8 { + F::from_u64((i as u64 * 9) + 11) + } else { + F::zero() + } + })); + + let schoolbook = a * b; + + let ntt_a = N::from_ring(&a, &primes, &twiddles); + let ntt_b = N::from_ring(&b, &primes, &twiddles); + let ntt_prod = ntt_a.pointwise_mul(&ntt_b, &primes); + let ntt_result: R = ntt_prod.to_ring(&primes, &twiddles, &garner); + + assert_eq!(schoolbook, ntt_result); + } + + #[test] + fn q64_ntt_round_trip() { + type F = Fp64<{ Q64_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let primes = q64_primes(); + let twiddles: [NttTwiddles; Q64_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = q64_garner(); + + let coeffs: [F; 64] = + std::array::from_fn(|i| F::from_u64(((i as u64 * 19) + 3) % Q64_MODULUS)); + let ring = R::from_coefficients(coeffs); + let ntt = N::from_ring(&ring, &primes, &twiddles); + let round_trip: R = ntt.to_ring(&primes, &twiddles, &garner); + + assert_eq!(ring, round_trip); + } + + #[test] + fn crt_ntt_mul_matches_schoolbook_q64() { + type F = Fp64<{ Q64_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let primes = q64_primes(); + let twiddles: [NttTwiddles; Q64_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = q64_garner(); + + let a = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 5) + 9) % Q64_MODULUS) + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 17) + 13) % Q64_MODULUS) + })); + + let schoolbook = a * b; + + let ntt_a = N::from_ring(&a, &primes, &twiddles); + let ntt_b = N::from_ring(&b, &primes, &twiddles); + let ntt_prod = ntt_a.pointwise_mul(&ntt_b, &primes); + let ntt_result: R = ntt_prod.to_ring(&primes, &twiddles, &garner); + + assert_eq!(schoolbook, ntt_result); + } + + #[test] + fn field_sampling_respects_modulus() { + type F = Fp32<251>; + let mut rng = StdRng::seed_from_u64(42); + for _ in 0..1024 { + let x = F::sample(&mut rng); + assert!(x.to_canonical_u32() < 251); + } + } + + #[test] + fn pow2_offset_registry_is_consistent() { + fn assert_is_pseudo_mersenne() {} + assert_is_pseudo_mersenne::(); + + for Pow2OffsetPrimeSpec { + bits, + offset, + modulus, + .. + } in POW2_OFFSET_PRIMES + { + assert!((offset as u128) <= POW2_OFFSET_MAX); + assert_eq!(POW2_OFFSET_TABLE[bits as usize], offset as i16); + assert_eq!( + Some(modulus), + pseudo_mersenne_modulus(bits, offset as u128), + "2^k-offset modulus mismatch for k={bits}, offset={offset}" + ); + assert_eq!(modulus % 8, 5); + } + + let x = Pow2Offset128Field::from_u64(1234567); + let inv = x.inv().unwrap(); + assert_eq!(x * inv, Pow2Offset128Field::one()); + } + + #[test] + fn cyclotomic_sigma_is_ring_automorphism() { + type F = Fp32<251>; + type R = CyclotomicRing; + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((3 * i + 1) as u64))); + let b = R::from_coefficients(std::array::from_fn(|i| F::from_u64((5 * i + 2) as u64))); + + let k1 = 3usize; + let k2 = 5usize; + let two_d = 16usize; + + assert_eq!(a.sigma(1), a); + assert_eq!(a.sigma_m1().sigma_m1(), a); + assert_eq!(a.sigma(k1).sigma(k2), a.sigma((k1 * k2) % two_d)); + assert_eq!((a * b).sigma(k1), a.sigma(k1) * b.sigma(k1)); + } + + #[test] + fn cyclotomic_balanced_pow2_decompose_recompose_round_trip() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + + let ring = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 73) + 17) % Q32_MODULUS) + })); + + // Q32 balanced base-16: 9 levels absorb the carry-out near q/2. + let digits = ring.balanced_decompose_pow2(9, 4); + let round_trip = R::gadget_recompose_pow2(&digits, 4); + assert_eq!(round_trip, ring); + } + + #[test] + fn sparse_pm1_challenge_has_expected_weight() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let mut rng = StdRng::seed_from_u64(123); + let challenge = R::sample_sparse_pm1(&mut rng, 11); + assert_eq!(challenge.hamming_weight(), 11); + + for c in challenge.coefficients() { + let x = c.to_canonical_u32(); + if x != 0 { + assert!(x == 1 || x == 250, "nonzero coefficient must be +/-1"); + } + } + } + + #[test] + fn negacyclic_shift_equals_mul_by_monomial() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((3 * i + 1) as u64))); + + for k in 0..8 { + let mut monomial_coeffs = [F::zero(); 8]; + monomial_coeffs[k] = F::one(); + let monomial = R::from_coefficients(monomial_coeffs); + assert_eq!( + a.negacyclic_shift(k), + a * monomial, + "negacyclic_shift({k}) != mul by X^{k}" + ); + } + + assert_eq!(a.negacyclic_shift(0), a); + assert_eq!( + a.negacyclic_shift(8), + a, + "shift by D should be identity mod D" + ); + } + + #[test] + fn negacyclic_shift_degree_64() { + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((7 * i + 3) as u64))); + let x = R::x(); + let mut x_pow = R::one(); + for k in 0..64 { + assert_eq!( + a.negacyclic_shift(k), + a * x_pow, + "negacyclic_shift({k}) mismatch at D=64" + ); + x_pow *= x; + } + } + + #[test] + fn mul_by_monomial_sum_matches_ring_mul() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((5 * i + 2) as u64))); + + // Sum of X^1 + X^3 + X^5 + let positions = [1, 3, 5]; + let mut sparse = [F::zero(); 8]; + for &p in &positions { + sparse[p] = F::one(); + } + let sparse_ring = R::from_coefficients(sparse); + + assert_eq!( + a.mul_by_monomial_sum(&positions), + a * sparse_ring, + "mul_by_monomial_sum should equal ring mul by sparse element" + ); + } + + #[test] + fn mul_by_monomial_sum_single_position_equals_shift() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((i + 1) as u64))); + for k in 0..8 { + assert_eq!( + a.mul_by_monomial_sum(&[k]), + a.negacyclic_shift(k), + "single-position monomial_sum should equal negacyclic_shift" + ); + } + } + + #[test] + fn mul_by_monomial_sum_empty_is_zero() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((i + 1) as u64))); + assert_eq!(a.mul_by_monomial_sum(&[]), R::zero()); + } + + #[test] + fn mul_by_sparse_matches_schoolbook() { + use hachi_pcs::algebra::SparseChallenge; + + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((3 * i + 7) as u64))); + + let challenge = SparseChallenge { + positions: vec![2, 17, 41], + coeffs: vec![1, -1, 1], + }; + let dense: R = challenge.to_dense().unwrap(); + + let via_sparse = a.mul_by_sparse(&challenge); + let via_schoolbook = a * dense; + + assert_eq!( + via_sparse, via_schoolbook, + "mul_by_sparse must equal schoolbook multiplication" + ); + } + + #[test] + fn mul_by_sparse_with_all_negative_coeffs() { + use hachi_pcs::algebra::SparseChallenge; + + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((i + 1) as u64))); + + let challenge = SparseChallenge { + positions: vec![0, 5, 63], + coeffs: vec![-1, -1, -1], + }; + let dense: R = challenge.to_dense().unwrap(); + + assert_eq!(a.mul_by_sparse(&challenge), a * dense); + } + + #[test] + fn is_zero_detects_zero_and_nonzero() { + type F = Fp32<251>; + type R = CyclotomicRing; + + assert!(R::zero().is_zero()); + assert!(!R::one().is_zero()); + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64(i as u64))); + assert!(!a.is_zero()); + } + + #[test] + fn kron_scalars_matches_kron_row_constant_rings() { + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + let scalars_a: Vec = (0..4).map(|i| F::from_u64(i * 3 + 1)).collect(); + let scalars_b: Vec = (0..3).map(|i| F::from_u64(i * 7 + 2)).collect(); + + let rings_a: Vec = scalars_a + .iter() + .map(|&s| { + let mut c = [F::zero(); 16]; + c[0] = s; + R::from_coefficients(c) + }) + .collect(); + let rings_b: Vec = scalars_b + .iter() + .map(|&s| { + let mut c = [F::zero(); 16]; + c[0] = s; + R::from_coefficients(c) + }) + .collect(); + + let via_ring: Vec = rings_a + .iter() + .flat_map(|l| rings_b.iter().map(move |r| *l * *r)) + .collect(); + + let via_scalar: Vec = scalars_a + .iter() + .flat_map(|&l| { + scalars_b.iter().map(move |&r| { + let mut c = [F::zero(); 16]; + c[0] = l * r; + R::from_coefficients(c) + }) + }) + .collect(); + + assert_eq!(via_ring, via_scalar); + } +} diff --git a/tests/commitment_contract.rs b/tests/commitment_contract.rs new file mode 100644 index 00000000..bf9b5ff9 --- /dev/null +++ b/tests/commitment_contract.rs @@ -0,0 +1,206 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::CyclotomicRing; +use hachi_pcs::algebra::Fp64; +use hachi_pcs::algebra::SparseChallenge; +use hachi_pcs::protocol::commitment::utils::crt_ntt::NttSlotCache; +use hachi_pcs::protocol::commitment::utils::flat_matrix::FlatMatrix; +use hachi_pcs::protocol::commitment::{DummyProof, HachiCommitment}; +use hachi_pcs::protocol::hachi_poly_ops::HachiPolyOps; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{ + AppendToTranscript, BasisMode, Blake2bTranscript, CommitmentScheme, HachiCommitmentLayout, + Transcript, +}; +use hachi_pcs::{CanonicalField, FieldCore, FromSmallInt, HachiError}; + +type F = Fp64<4294967197>; + +/// Trivial polynomial wrapper that implements `HachiPolyOps`. +#[derive(Debug, Clone)] +struct DummyPoly { + coeffs: Vec, +} + +impl DummyPoly { + fn evaluate(&self, point: &[F]) -> F { + assert_eq!(point.len(), self.num_vars()); + let mut acc = self.coeffs[0]; + for (i, r_i) in point.iter().enumerate() { + acc += self.coeffs[i + 1] * *r_i; + } + acc + } + + fn num_vars(&self) -> usize { + self.coeffs.len().saturating_sub(1) + } +} + +impl HachiPolyOps for DummyPoly { + type CommitCache = NttSlotCache<1>; + + fn num_ring_elems(&self) -> usize { + self.coeffs.len() + } + + fn evaluate_ring(&self, scalars: &[F]) -> CyclotomicRing { + let mut acc = F::zero(); + for (c, &s) in self.coeffs.iter().zip(scalars.iter()) { + acc += *c * s; + } + CyclotomicRing::from_coefficients([acc]) + } + + fn fold_blocks(&self, _scalars: &[F], _block_len: usize) -> Vec> { + vec![] + } + + fn decompose_fold( + &self, + _challenges: &[SparseChallenge], + _block_len: usize, + _num_digits: usize, + _log_basis: u32, + ) -> Vec> { + vec![] + } + + fn commit_inner( + &self, + _a_matrix: &FlatMatrix, + _ntt_a: &NttSlotCache<1>, + _block_len: usize, + _num_digits_commit: usize, + _num_digits_open: usize, + _log_basis: u32, + ) -> Result>, HachiError> { + Ok(vec![]) + } +} + +#[derive(Clone)] +struct DummySetup { + _max_num_vars: usize, +} + +#[derive(Clone)] +struct DummyScheme; + +impl CommitmentScheme for DummyScheme { + type ProverSetup = DummySetup; + type VerifierSetup = DummySetup; + type Commitment = HachiCommitment; + type Proof = DummyProof; + type CommitHint = HachiCommitment; + + fn setup_prover(max_num_vars: usize) -> Self::ProverSetup { + DummySetup { + _max_num_vars: max_num_vars, + } + } + + fn setup_verifier(setup: &Self::ProverSetup) -> Self::VerifierSetup { + setup.clone() + } + + fn commit>( + _poly: &P, + _setup: &Self::ProverSetup, + _layout: &HachiCommitmentLayout, + ) -> Result<(Self::Commitment, Self::CommitHint), HachiError> { + let c = HachiCommitment(0); + Ok((c, c)) + } + + fn prove, P: HachiPolyOps>( + _setup: &Self::ProverSetup, + _poly: &P, + _opening_point: &[F], + _hint: Self::CommitHint, + transcript: &mut T, + commitment: &Self::Commitment, + _basis: BasisMode, + _layout: &HachiCommitmentLayout, + ) -> Result { + commitment.append_to_transcript(labels::ABSORB_COMMITMENT, transcript); + let q = transcript.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + Ok(DummyProof(q.to_canonical_u128())) + } + + fn verify>( + proof: &Self::Proof, + _setup: &Self::VerifierSetup, + transcript: &mut T, + _opening_point: &[F], + _opening: &F, + commitment: &Self::Commitment, + _basis: BasisMode, + _layout: &HachiCommitmentLayout, + ) -> Result<(), HachiError> { + commitment.append_to_transcript(labels::ABSORB_COMMITMENT, transcript); + let q = transcript.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + if proof.0 == q.to_canonical_u128() { + Ok(()) + } else { + Err(HachiError::InvalidProof) + } + } + + fn protocol_name() -> &'static [u8] { + b"HachiDummy" + } +} + +#[test] +fn commitment_scheme_round_trip() { + let poly = DummyPoly { + coeffs: vec![F::from_u64(3), F::from_u64(5), F::from_u64(7)], + }; + let opening_point = [F::from_u64(11), F::from_u64(13)]; + + let psetup = DummyScheme::setup_prover(poly.num_vars()); + let vsetup = DummyScheme::setup_verifier(&psetup); + + let layout = HachiCommitmentLayout { + m_vars: 0, + r_vars: 0, + block_len: 1, + num_blocks: 1, + num_digits_commit: 1, + num_digits_open: 1, + num_digits_fold: 1, + inner_width: 1, + outer_width: 1, + d_matrix_width: 1, + log_basis: 1, + }; + let (commitment, hint) = DummyScheme::commit(&poly, &psetup, &layout).unwrap(); + let opening = poly.evaluate(&opening_point); + + let mut prover_t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let proof = DummyScheme::prove( + &psetup, + &poly, + &opening_point, + hint, + &mut prover_t, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + let mut verifier_t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + DummyScheme::verify( + &proof, + &vsetup, + &mut verifier_t, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); +} diff --git a/tests/hachi_sumcheck.rs b/tests/hachi_sumcheck.rs new file mode 100644 index 00000000..ce33b91b --- /dev/null +++ b/tests/hachi_sumcheck.rs @@ -0,0 +1,225 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::ring::CyclotomicRing; +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::sumcheck::eq_poly::EqPolynomial; +use hachi_pcs::protocol::sumcheck::norm_sumcheck::{NormSumcheckProver, NormSumcheckVerifier}; +use hachi_pcs::protocol::sumcheck::relation_sumcheck::{ + RelationSumcheckProver, RelationSumcheckVerifier, +}; +use hachi_pcs::protocol::sumcheck::{multilinear_eval, range_check_eval}; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{prove_sumcheck, verify_sumcheck, Blake2bTranscript, Transcript}; +use hachi_pcs::{FieldCore, FieldSampling, FromSmallInt}; +use rand::rngs::StdRng; +use rand::SeedableRng; +use std::time::Instant; + +type F = Fp64<4294967197>; + +fn run_f0_e2e(num_u: usize, num_l: usize, b: usize) { + let num_vars = num_u + num_l; + let n = 1usize << num_vars; + let mut rng = StdRng::seed_from_u64(0xF0); + + let w_evals: Vec = (0..n).map(|i| F::from_u64((i % b) as u64)).collect(); + let tau0: Vec = (0..num_vars).map(|_| F::sample(&mut rng)).collect(); + + let t0 = Instant::now(); + let mut prover = NormSumcheckProver::new(&tau0, w_evals.clone(), b); + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let (proof, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut pt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + let prove_time = t0.elapsed(); + + // Sanity: prover's final claim matches oracle evaluation. + let oracle = EqPolynomial::mle(&tau0, &prover_challenges) + * range_check_eval(multilinear_eval(&w_evals, &prover_challenges).unwrap(), b); + assert_eq!(final_claim, oracle, "prover final claim != oracle eval"); + + let t1 = Instant::now(); + let verifier = NormSumcheckVerifier::new(tau0, w_evals, b); + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let verifier_challenges = verify_sumcheck::(&proof, &verifier, &mut vt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + let verify_time = t1.elapsed(); + + assert_eq!(prover_challenges, verifier_challenges); + + eprintln!( + "[F0 e2e] num_u={num_u} num_l={num_l} b={b} n=2^{num_vars}={n} \ + prove={prove_time:.2?} verify={verify_time:.2?} \ + rounds={} degree={}", + proof.round_polys.len(), + 2 * b, + ); +} + +#[test] +fn f0_sumcheck_e2e_small() { + run_f0_e2e(3, 2, 2); +} + +#[test] +fn f0_sumcheck_e2e() { + run_f0_e2e(4, 3, 2); +} + +#[test] +fn f0_sumcheck_e2e_larger_b() { + run_f0_e2e(3, 3, 3); +} + +fn run_f_alpha_e2e(num_u: usize, num_i: usize) { + let num_l = D.trailing_zeros() as usize; + let num_vars = num_u + num_l; + let n = 1usize << num_vars; + let mut rng = StdRng::seed_from_u64(0xFA); + + let w_evals: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let alpha_evals_y: Vec = (0..D).map(|_| F::sample(&mut rng)).collect(); + let m_alpha_evals: Vec = (0..(1usize << (num_i + num_u))) + .map(|_| F::sample(&mut rng)) + .collect(); + let tau1: Vec = (0..num_i).map(|_| F::sample(&mut rng)).collect(); + + // Compute m(x) = Σ_i ẽq(τ₁, i) · M̃_α(i, x) + let eq_tau1 = EqPolynomial::evals(&tau1); + let num_x = 1usize << num_u; + let m_evals_x: Vec = (0..num_x) + .map(|x_idx| { + (0..(1usize << num_i)) + .map(|i_idx| eq_tau1[i_idx] * m_alpha_evals[i_idx * num_x + x_idx]) + .fold(F::zero(), |a, v| a + v) + }) + .collect(); + + // Compute y_a[i] = Σ_x M̃_α(i,x) · w_α(x), where w_α(x) = Σ_y w(x,y) · α̃(y) + let num_y = D; + let num_rows = 1usize << num_i; + let w_alpha: Vec = (0..num_x) + .map(|x| { + (0..num_y) + .map(|y| w_evals[x + y * num_x] * alpha_evals_y[y]) + .fold(F::zero(), |a, v| a + v) + }) + .collect(); + let y_a: Vec = (0..num_rows) + .map(|i| { + (0..num_x) + .map(|x| m_alpha_evals[i * num_x + x] * w_alpha[x]) + .fold(F::zero(), |a, v| a + v) + }) + .collect(); + + // Embed y_a values as constant ring elements for the verifier. + let v_rings: Vec> = y_a + .iter() + .map(|&val| { + let mut coeffs = [F::zero(); D]; + coeffs[0] = val; + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let u_rings: Vec> = vec![]; + let u_eval_ring = CyclotomicRing::::zero(); + let ring_alpha = F::one(); + + let t0 = Instant::now(); + let mut prover = + RelationSumcheckProver::new(w_evals.clone(), &alpha_evals_y, &m_evals_x, num_u, num_l); + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let (proof, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut pt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + let prove_time = t0.elapsed(); + + // Sanity: prover's final claim matches oracle evaluation. + let (x_ch, y_ch) = prover_challenges.split_at(num_u); + let oracle = multilinear_eval(&w_evals, &prover_challenges).unwrap() + * multilinear_eval(&alpha_evals_y, y_ch).unwrap() + * multilinear_eval(&m_evals_x, x_ch).unwrap(); + assert_eq!(final_claim, oracle, "prover final claim != oracle eval"); + + let t1 = Instant::now(); + let verifier = RelationSumcheckVerifier::::new( + w_evals, + alpha_evals_y, + m_evals_x, + tau1, + v_rings, + u_rings, + u_eval_ring, + ring_alpha, + num_u, + num_l, + ); + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let verifier_challenges = verify_sumcheck::(&proof, &verifier, &mut vt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + let verify_time = t1.elapsed(); + + assert_eq!(prover_challenges, verifier_challenges); + + eprintln!( + "[Fα e2e] num_u={num_u} num_l={num_l} num_i={num_i} n=2^{num_vars}={n} \ + prove={prove_time:.2?} verify={verify_time:.2?} \ + rounds={} degree=2", + proof.round_polys.len(), + ); +} + +#[test] +fn f_alpha_sumcheck_e2e_small() { + run_f_alpha_e2e::<4>(3, 2); +} + +#[test] +fn f_alpha_sumcheck_e2e() { + run_f_alpha_e2e::<8>(4, 3); +} + +#[test] +fn f_alpha_sumcheck_e2e_asymmetric() { + run_f_alpha_e2e::<4>(5, 4); +} + +#[test] +fn from_evals_matches_direct_polynomial() { + use hachi_pcs::protocol::UniPoly; + + // Verify that interpolation at integer points reproduces the polynomial. + let mut rng = StdRng::seed_from_u64(0xEE); + + for degree in 0..6usize { + let coeffs: Vec = (0..=degree).map(|_| F::sample(&mut rng)).collect(); + let poly = UniPoly::from_coeffs(coeffs); + + let evals: Vec = (0..=degree) + .map(|t| poly.evaluate(&F::from_u64(t as u64))) + .collect(); + let reconstructed = UniPoly::from_evals(&evals); + + for x_u64 in [0u64, 1, 2, 3, 7, 13] { + let x = F::from_u64(x_u64); + assert_eq!( + poly.evaluate(&x), + reconstructed.evaluate(&x), + "degree {degree}, x={x_u64}" + ); + } + } +} diff --git a/tests/label_schedule.rs b/tests/label_schedule.rs new file mode 100644 index 00000000..c943b83b --- /dev/null +++ b/tests/label_schedule.rs @@ -0,0 +1,63 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{Blake2bTranscript, Transcript}; + +type F = Fp64<4294967197>; + +#[test] +fn label_namespace_does_not_include_dory_literals() { + let banned = ["vmv_", "beta", "alpha", "gamma", "final_e", "dory"]; + for label in labels::all_labels() { + let text = std::str::from_utf8(label).expect("labels must be valid utf8 literals"); + for needle in &banned { + assert!( + !text.contains(needle), + "label `{text}` must not contain banned token `{needle}`" + ); + } + } +} + +fn run_hachi_schedule>(transcript: &mut T) -> (F, F, F) { + transcript.append_bytes(labels::ABSORB_COMMITMENT, b"C"); + transcript.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"O"); + let c_linear_relation = transcript.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + + transcript.append_bytes(labels::ABSORB_RING_SWITCH_MESSAGE, b"RS"); + let c_ring_switch = transcript.challenge_scalar(labels::CHALLENGE_RING_SWITCH); + + transcript.append_bytes(labels::ABSORB_SUMCHECK_ROUND, b"SC1"); + let c_sumcheck = transcript.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + transcript.append_bytes(labels::ABSORB_STOP_CONDITION, b"STOP"); + let _ = transcript.challenge_scalar(labels::CHALLENGE_STOP_CONDITION); + + (c_linear_relation, c_ring_switch, c_sumcheck) +} + +#[test] +fn schedule_is_replayable_with_hachi_labels() { + let mut prover = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut verifier = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + assert_eq!( + run_hachi_schedule(&mut prover), + run_hachi_schedule(&mut verifier) + ); +} + +#[test] +fn schedule_detects_reordered_round_messages() { + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"C"); + t1.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"O"); + let a = t1.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"O"); + t2.append_bytes(labels::ABSORB_COMMITMENT, b"C"); + let b = t2.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + + assert_ne!(a, b); +} diff --git a/tests/onehot_commitment.rs b/tests/onehot_commitment.rs new file mode 100644 index 00000000..eb80586f --- /dev/null +++ b/tests/onehot_commitment.rs @@ -0,0 +1,161 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::CyclotomicRing; +use hachi_pcs::protocol::commitment::{HachiCommitmentCore, RingCommitmentScheme}; +use hachi_pcs::test_utils::*; +use hachi_pcs::{FieldCore, FromSmallInt}; + +type Core = HachiCommitmentCore; + +fn psetup() -> >::ProverSetup { + >::setup(16) + .unwrap() + .0 +} + +/// Compare the optimized one-hot path against the default dense path. +/// +/// The default implementation materializes the full vector and calls +/// `commit_coeffs`. The optimized impl uses sparse inner Ajtai. +/// Both must produce identical (commitment, s_all, t_hat_all). +fn assert_onehot_matches_dense(onehot_k: usize, indices: &[usize]) { + let opt_indices: Vec> = indices.iter().map(|&i| Some(i)).collect(); + let setup = psetup(); + + // Optimized sparse path. + let w_sparse = >::commit_onehot( + onehot_k, + &opt_indices, + &setup, + ) + .unwrap(); + + // Reference: materialize the full one-hot vector, pack into ring elements, + // and commit via the dense path. + let total_field = indices.len() * onehot_k; + let total_ring = total_field / D; + let mut field_elems = vec![F::zero(); total_field]; + for (c, &idx) in indices.iter().enumerate() { + field_elems[c * onehot_k + idx] = F::from_u64(1); + } + let ring_coeffs: Vec> = (0..total_ring) + .map(|r| { + let coeffs: [F; D] = std::array::from_fn(|i| field_elems[r * D + i]); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let w_dense = + >::commit_coeffs(&ring_coeffs, &setup) + .unwrap(); + + assert_eq!( + w_sparse.commitment, w_dense.commitment, + "commitments must match" + ); + assert_eq!( + w_sparse.t_hat, w_dense.t_hat, + "t_hat_all (decomposed inner output) must match" + ); +} + +#[test] +fn onehot_k_gt_d_basic() { + // K=128, D=64 => K/D=2, T=2 => T*K=256 => 4 ring elements + assert_onehot_matches_dense(128, &[0, 64]); +} + +#[test] +fn onehot_k_gt_d_various_positions() { + assert_onehot_matches_dense(128, &[127, 0]); + assert_onehot_matches_dense(128, &[63, 65]); + assert_onehot_matches_dense(128, &[32, 96]); +} + +#[test] +fn onehot_k_much_gt_d() { + // K=256, D=64 => K/D=4, T=1 => T*K=256 => 4 ring elements + assert_onehot_matches_dense(256, &[0]); + assert_onehot_matches_dense(256, &[63]); + assert_onehot_matches_dense(256, &[64]); + assert_onehot_matches_dense(256, &[255]); + assert_onehot_matches_dense(256, &[100]); +} + +#[test] +fn onehot_k_eq_d_basic() { + // K=64=D, T=4 => 4 ring elements, each is a monomial X^{idx}. + assert_onehot_matches_dense(64, &[0, 0, 0, 0]); +} + +#[test] +fn onehot_k_eq_d_varied() { + assert_onehot_matches_dense(64, &[0, 31, 32, 63]); + assert_onehot_matches_dense(64, &[1, 2, 3, 4]); + assert_onehot_matches_dense(64, &[63, 63, 63, 63]); +} + +#[test] +fn onehot_k_lt_d_basic() { + // K=16, D=64 => D/K=4, T=16 => T*K=256 => 4 ring elements. + // Each ring element spans 4 chunks, so has 4 nonzero coefficients. + let indices: Vec = (0..16).map(|i| i % 16).collect(); + assert_onehot_matches_dense(16, &indices); +} + +#[test] +fn onehot_k_lt_d_all_zeros() { + let indices = vec![0; 16]; + assert_onehot_matches_dense(16, &indices); +} + +#[test] +fn onehot_k_lt_d_all_max() { + let indices = vec![15; 16]; + assert_onehot_matches_dense(16, &indices); +} + +#[test] +fn onehot_k_lt_d_mixed() { + let indices = vec![0, 15, 7, 3, 12, 1, 8, 14, 5, 10, 2, 9, 6, 11, 4, 13]; + assert_onehot_matches_dense(16, &indices); +} + +#[test] +fn onehot_k_lt_d_ratio_2() { + // K=32, D=64 => D/K=2, T=8 => T*K=256 => 4 ring elements. + let indices = vec![0, 31, 16, 8, 24, 4, 12, 20]; + assert_onehot_matches_dense(32, &indices); +} + +#[test] +fn onehot_rejects_non_divisible_k_and_d() { + let setup = psetup(); + let result = >::commit_onehot( + 17, + &[Some(0usize); 4], + &setup, + ); + assert!(result.is_err()); +} + +#[test] +fn onehot_rejects_out_of_range_index() { + let setup = psetup(); + let result = >::commit_onehot( + 64, + &[Some(0usize), Some(64), Some(0), Some(0)], + &setup, + ); + assert!(result.is_err()); +} + +#[test] +fn onehot_rejects_wrong_total_size() { + let setup = psetup(); + let result = >::commit_onehot( + 64, + &[Some(0usize), Some(0), Some(0)], + &setup, + ); + assert!(result.is_err()); +} diff --git a/tests/primality.rs b/tests/primality.rs new file mode 100644 index 00000000..0d7a3a8f --- /dev/null +++ b/tests/primality.rs @@ -0,0 +1,124 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::{pseudo_mersenne_modulus, Pow2OffsetPrimeSpec, POW2_OFFSET_PRIMES}; + +// Strong probable-prime test using multiple fixed bases. +// This is not a formal primality certificate, but is sufficient as a +// practical regression guard for the current Pow2Offset profiles. +fn is_probable_prime_miller_rabin(n: u128) -> bool { + if n < 2 { + return false; + } + if n % 2 == 0 { + return n == 2; + } + + const SMALL_PRIMES: [u128; 11] = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]; + for p in SMALL_PRIMES { + if n == p { + return true; + } + if n % p == 0 { + return false; + } + } + + let (d, s) = decompose_pow2(n - 1); + const BASES: [u128; 24] = [ + 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, + ]; + + 'outer: for a in BASES { + if a >= n { + continue; + } + let mut x = pow_mod(a, d, n); + if x == 1 || x == n - 1 { + continue; + } + for _ in 1..s { + x = mul_mod(x, x, n); + if x == n - 1 { + continue 'outer; + } + } + return false; + } + + true +} + +fn decompose_pow2(mut d: u128) -> (u128, u32) { + let mut s = 0u32; + while d % 2 == 0 { + d >>= 1; + s += 1; + } + (d, s) +} + +fn pow_mod(mut base: u128, mut exp: u128, modulus: u128) -> u128 { + let mut result = 1u128; + base %= modulus; + while exp > 0 { + if (exp & 1) == 1 { + result = mul_mod(result, base, modulus); + } + base = mul_mod(base, base, modulus); + exp >>= 1; + } + result +} + +fn mul_mod(mut a: u128, mut b: u128, modulus: u128) -> u128 { + let mut result = 0u128; + a %= modulus; + b %= modulus; + while b > 0 { + if (b & 1) == 1 { + result = add_mod(result, a, modulus); + } + a = add_mod(a, a, modulus); + b >>= 1; + } + result +} + +fn add_mod(a: u128, b: u128, modulus: u128) -> u128 { + if a >= modulus - b { + a - (modulus - b) + } else { + a + b + } +} + +#[test] +fn pow2_offset_profiles_are_probable_primes() { + for Pow2OffsetPrimeSpec { + bits, + offset, + modulus, + } in POW2_OFFSET_PRIMES + { + assert_eq!( + Some(modulus), + pseudo_mersenne_modulus(bits, offset as u128), + "profile formula mismatch for bits={bits}, offset={offset}" + ); + assert!( + is_probable_prime_miller_rabin(modulus), + "Miller-Rabin rejected bits={bits}, offset={offset}, q={modulus}" + ); + } +} + +#[test] +fn miller_rabin_rejects_known_composites() { + let composites: [u128; 9] = [4, 9, 15, 21, 341, 561, 645, 1105, 1729]; + for n in composites { + assert!( + !is_probable_prime_miller_rabin(n), + "composite unexpectedly accepted: {n}" + ); + } +} diff --git a/tests/ring_commitment_core.rs b/tests/ring_commitment_core.rs new file mode 100644 index 00000000..e7da6bca --- /dev/null +++ b/tests/ring_commitment_core.rs @@ -0,0 +1,159 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::CyclotomicRing; +use hachi_pcs::protocol::commitment::{ + utils::linear::decompose_block, CommitmentConfig, DecompositionParams, HachiCommitmentCore, + HachiCommitmentLayout, RingCommitmentScheme, SmallTestCommitmentConfig, +}; +use hachi_pcs::test_utils::*; +use hachi_pcs::{FromSmallInt, HachiError}; +use std::array::from_fn; + +#[derive(Clone)] +struct BadDegreeConfig; + +impl CommitmentConfig for BadDegreeConfig { + const D: usize = 32; + const N_A: usize = 8; + const N_B: usize = 4; + const N_D: usize = 4; + const CHALLENGE_WEIGHT: usize = 3; + + fn decomposition() -> DecompositionParams { + DecompositionParams { + log_basis: 3, + log_commit_bound: 32, + log_open_bound: None, + } + } + + fn commitment_layout(_max_num_vars: usize) -> Result { + HachiCommitmentLayout::new::(4, 2, &Self::decomposition()) + } +} + +#[test] +fn setup_shape_is_consistent() { + let (p1, v1) = + >::setup(16).unwrap(); + let (p2, v2) = + >::setup(16).unwrap(); + + assert_eq!(p1.expanded.seed.max_num_vars, 16); + assert_eq!(v1.expanded.seed.max_num_vars, 16); + assert_eq!(p2.expanded.seed.max_num_vars, 16); + assert_eq!(v2.expanded.seed.max_num_vars, 16); + assert_eq!(p1.expanded.A.num_rows(), TinyConfig::N_A); + assert!(p1.expanded.A.num_cols_at::() >= BLOCK_LEN * num_digits_commit()); + assert_eq!(p1.expanded.B.num_rows(), TinyConfig::N_B); + assert!(p1.expanded.B.num_cols_at::() >= TinyConfig::N_A * num_digits_open() * NUM_BLOCKS); +} + +#[test] +fn commit_is_deterministic_and_shape_consistent() { + let (psetup, _) = + >::setup(16).unwrap(); + let blocks = sample_blocks(); + + let w1 = >::commit_ring_blocks( + &blocks, &psetup, + ) + .unwrap(); + let w2 = >::commit_ring_blocks( + &blocks, &psetup, + ) + .unwrap(); + + assert_eq!(w1.commitment, w2.commitment); + assert_eq!(w1.t_hat, w2.t_hat); + + let num_blocks = NUM_BLOCKS; + assert_eq!(w1.commitment.u.len(), TinyConfig::N_B); + assert_eq!(w1.t_hat.len(), num_blocks); + let depth = num_digits_commit(); + assert!(w1.t_hat.iter().all(|t| t.len() == TinyConfig::N_A * depth)); +} + +#[test] +fn commit_ring_coeffs_matches_block_commitment() { + let (psetup, _) = + >::setup(16).unwrap(); + let blocks = sample_blocks(); + + let wb = >::commit_ring_blocks( + &blocks, &psetup, + ) + .unwrap(); + + // Sequential layout: block 0 elements, then block 1 elements, etc. + let f_coeffs: Vec<_> = blocks + .iter() + .flat_map(|block| block.iter().copied()) + .collect(); + + let wc = >::commit_coeffs( + &f_coeffs, &psetup, + ) + .unwrap(); + + assert_eq!(wb.commitment, wc.commitment); + assert_eq!(wb.t_hat, wc.t_hat); +} + +#[test] +fn opening_satisfies_inner_and_outer_equations() { + let (psetup, _) = + >::setup(16).unwrap(); + let blocks = sample_blocks(); + let w = >::commit_ring_blocks( + &blocks, &psetup, + ) + .unwrap(); + + let depth = num_digits_commit(); + let log_basis = log_basis(); + for (i, block) in blocks.iter().enumerate() { + let s_i = decompose_block(block, depth, log_basis); + let lhs = mat_vec_mul(&psetup.expanded.A, &s_i); + let rhs: Vec> = (0..TinyConfig::N_A) + .map(|j| { + let start = j * depth; + let end = start + depth; + CyclotomicRing::gadget_recompose_pow2_i8(&w.t_hat[i][start..end], log_basis) + }) + .collect(); + assert_eq!(lhs, rhs); + } + + let t_hat_flat_ring: Vec> = w + .t_hat + .iter() + .flat_map(|x| x.iter()) + .map(|plane| { + let coeffs: [F; D] = from_fn(|k| F::from_i64(plane[k] as i64)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let outer = mat_vec_mul(&psetup.expanded.B, &t_hat_flat_ring); + assert_eq!(outer, w.commitment.u); +} + +#[test] +fn small_test_config_has_expected_shape() { + assert_eq!(SmallTestCommitmentConfig::D, 16); + let layout = SmallTestCommitmentConfig::commitment_layout(8).unwrap(); + assert_eq!(layout.block_len, 16); + assert_eq!(layout.num_blocks, 4); + let depth = layout.num_digits_commit; + assert!(depth > 0); +} + +#[test] +fn setup_rejects_mismatched_degree() { + let err = >::setup(16) + .unwrap_err(); + match err { + HachiError::InvalidSetup(msg) => assert!(msg.contains("mismatches")), + other => panic!("unexpected error: {other:?}"), + } +} diff --git a/tests/sparse_challenge.rs b/tests/sparse_challenge.rs new file mode 100644 index 00000000..fd64ac3c --- /dev/null +++ b/tests/sparse_challenge.rs @@ -0,0 +1,98 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::fields::LiftBase; +use hachi_pcs::algebra::ring::{CyclotomicRing, SparseChallenge, SparseChallengeConfig}; +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::challenges::sparse::sparse_challenge_from_transcript; +use hachi_pcs::protocol::transcript::labels::DOMAIN_HACHI_PROTOCOL; +use hachi_pcs::protocol::transcript::{Blake2bTranscript, Transcript}; +use hachi_pcs::{FieldCore, FromSmallInt}; + +type F = Fp64<4294967197>; + +const D: usize = 16; + +fn dense_eval>(alpha: E, x: &CyclotomicRing) -> E { + let mut acc = E::zero(); + let mut pow = E::one(); + for c in x.coefficients().iter().copied() { + acc += E::lift_base(c) * pow; + pow = pow * alpha; + } + acc +} + +#[test] +fn sparse_challenge_validate_and_to_dense() { + let cfg = SparseChallengeConfig { + weight: 3, + nonzero_coeffs: vec![-1, 1], + }; + cfg.validate::().unwrap(); + + let s = SparseChallenge { + positions: vec![0, 7, 12], + coeffs: vec![1, -1, 1], + }; + s.validate::().unwrap(); + assert_eq!(s.hamming_weight(), 3); + assert_eq!(s.l1_norm(), 3); + + let dense = s.to_dense::().unwrap(); + assert_eq!(dense.hamming_weight(), 3); + assert_eq!(dense.coefficients()[0], F::one()); + assert_eq!(dense.coefficients()[7], -F::one()); + assert_eq!(dense.coefficients()[12], F::one()); +} + +#[test] +fn sparse_eval_at_alpha_matches_dense_eval() { + let alpha = F::from_u64(5); + let alpha_pows = { + let mut out = Vec::with_capacity(D); + let mut acc = F::one(); + for _ in 0..D { + out.push(acc); + acc *= alpha; + } + out + }; + + let s = SparseChallenge { + positions: vec![1, 3, 9], + coeffs: vec![2, -1, 1], + }; + let dense = s.to_dense::().unwrap(); + + let sparse_eval = s.eval_at_alpha::(&alpha_pows).unwrap(); + let dense_eval = dense_eval::(alpha, &dense); + assert_eq!(sparse_eval, dense_eval); +} + +#[test] +fn sparse_challenge_sampling_is_deterministic_and_exact_weight() { + let cfg = SparseChallengeConfig { + weight: 8, + nonzero_coeffs: vec![-1, 1], + }; + + let mut t1 = Blake2bTranscript::::new(DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(DOMAIN_HACHI_PROTOCOL); + + // Make transcript state non-empty to avoid degenerate behavior. + t1.append_field(b"seed", &F::from_u64(123)); + t2.append_field(b"seed", &F::from_u64(123)); + + let c1 = sparse_challenge_from_transcript::(&mut t1, b"c", 0, &cfg).unwrap(); + let c2 = sparse_challenge_from_transcript::(&mut t2, b"c", 0, &cfg).unwrap(); + assert_eq!(c1, c2); + c1.validate::().unwrap(); + assert_eq!(c1.hamming_weight(), cfg.weight); + assert_eq!(c1.l1_norm(), cfg.weight as u64); + + // Different instance_idx should change the sample. + let mut t3 = Blake2bTranscript::::new(DOMAIN_HACHI_PROTOCOL); + t3.append_field(b"seed", &F::from_u64(123)); + let c3 = sparse_challenge_from_transcript::(&mut t3, b"c", 1, &cfg).unwrap(); + assert_ne!(c1, c3); +} diff --git a/tests/sumcheck_core.rs b/tests/sumcheck_core.rs new file mode 100644 index 00000000..b75f0586 --- /dev/null +++ b/tests/sumcheck_core.rs @@ -0,0 +1,290 @@ +#![allow(missing_docs)] + +use std::time::Instant; + +use hachi_pcs::algebra::poly::multilinear_eval; +use hachi_pcs::algebra::Fp64; +use hachi_pcs::error::HachiError; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{ + prove_sumcheck, verify_sumcheck, Blake2bTranscript, CompressedUniPoly, SumcheckInstanceProver, + SumcheckInstanceVerifier, SumcheckProof, Transcript, UniPoly, +}; +use hachi_pcs::{FieldCore, FieldSampling, FromSmallInt}; +use rand::rngs::StdRng; +use rand::RngCore; +use rand::SeedableRng; + +type F = Fp64<4294967197>; + +#[test] +fn compressed_unipoly_round_trip_and_eval() { + let mut rng = StdRng::seed_from_u64(123); + + for degree in 0..8usize { + let coeffs: Vec = (0..=degree).map(|_| F::sample(&mut rng)).collect(); + let poly = UniPoly::from_coeffs(coeffs); + + // Hint is g(0) + g(1). + let hint = poly.evaluate(&F::zero()) + poly.evaluate(&F::one()); + + let compressed = poly.compress(); + let decompressed = compressed.decompress(&hint); + + // Decompression should be functionally equivalent (it may materialize + // a trailing zero linear term for constant polynomials). + for x_u64 in [0u64, 1, 2, 3, 17] { + let x = F::from_u64(x_u64); + let direct = poly.evaluate(&x); + let decompressed_direct = decompressed.evaluate(&x); + let via_hint = compressed.eval_from_hint(&hint, &x); + assert_eq!(direct, decompressed_direct); + assert_eq!(direct, via_hint); + } + } +} + +#[test] +fn sumcheck_proof_verifier_driver_is_transcript_deterministic() { + // This test checks that the verifier driver absorbs messages and samples challenges + // consistently, and that the returned (final_claim, r_vec) matches a manual replay. + let mut rng = StdRng::seed_from_u64(999); + + let num_rounds = 5usize; + let degree_bound = 7usize; + + // Build random per-round univariates (degree <= degree_bound), compress them. + let round_polys: Vec> = (0..num_rounds) + .map(|_| { + let deg = (rng.next_u32() as usize) % (degree_bound + 1); + let coeffs: Vec = (0..=deg).map(|_| F::sample(&mut rng)).collect(); + UniPoly::from_coeffs(coeffs).compress() + }) + .collect(); + + let proof = SumcheckProof { round_polys }; + let claim0 = F::sample(&mut rng); + + // Verifier run. + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (final_claim_1, r_1) = proof + .verify::(claim0, num_rounds, degree_bound, &mut t1, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + // Manual replay with a fresh transcript (must match). + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut claim = claim0; + let mut r_manual = Vec::with_capacity(num_rounds); + for poly in &proof.round_polys { + t2.append_serde(labels::ABSORB_SUMCHECK_ROUND, poly); + let r_i = t2.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + r_manual.push(r_i); + claim = poly.eval_from_hint(&claim, &r_i); + } + + assert_eq!(r_1, r_manual); + assert_eq!(final_claim_1, claim); +} + +struct DenseSumcheckProver { + evals: Vec, + num_vars: usize, +} + +impl SumcheckInstanceProver for DenseSumcheckProver { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 1 + } + + fn input_claim(&self) -> E { + self.evals.iter().copied().fold(E::zero(), |a, b| a + b) + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + let half = self.evals.len() / 2; + let mut eval_0 = E::zero(); + let mut eval_1 = E::zero(); + for i in 0..half { + eval_0 += self.evals[2 * i]; + eval_1 += self.evals[2 * i + 1]; + } + UniPoly::from_coeffs(vec![eval_0, eval_1 - eval_0]) + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + let half = self.evals.len() / 2; + let mut new_evals = Vec::with_capacity(half); + for i in 0..half { + new_evals.push(self.evals[2 * i] + r * (self.evals[2 * i + 1] - self.evals[2 * i])); + } + self.evals = new_evals; + } +} + +struct DenseSumcheckVerifier { + evals: Vec, + num_vars: usize, + claim: E, +} + +impl SumcheckInstanceVerifier for DenseSumcheckVerifier { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 1 + } + + fn input_claim(&self) -> E { + self.claim + } + + fn expected_output_claim(&self, challenges: &[E]) -> Result { + multilinear_eval(&self.evals, challenges) + } +} + +#[test] +fn prove_and_verify_single_sumcheck() { + let num_vars = 4; + let n = 1 << num_vars; + + let evals: Vec = (1..=n).map(|i| F::from_u64(i as u64)).collect(); + let claim: F = evals.iter().copied().fold(F::zero(), |a, b| a + b); + + let mut prover = DenseSumcheckProver { + evals: evals.clone(), + num_vars, + }; + + let mut prover_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let (proof, prover_challenges, _final_claim) = + prove_sumcheck::(&mut prover, &mut prover_transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let verifier = DenseSumcheckVerifier { + evals, + num_vars, + claim, + }; + + let mut verifier_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let verifier_challenges = + verify_sumcheck::(&proof, &verifier, &mut verifier_transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + assert_eq!(prover_challenges, verifier_challenges); +} + +#[test] +fn verify_rejects_wrong_claim() { + let num_vars = 3; + let n = 1 << num_vars; + + let evals: Vec = (1..=n).map(|i| F::from_u64(i as u64)).collect(); + let correct_claim: F = evals.iter().copied().fold(F::zero(), |a, b| a + b); + let wrong_claim = correct_claim + F::one(); + + // Prove with correct claim. + let mut prover = DenseSumcheckProver { + evals: evals.clone(), + num_vars, + }; + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let (proof, _, _) = prove_sumcheck::(&mut prover, &mut pt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + // Verify with *wrong* claim — should fail. + let verifier = DenseSumcheckVerifier { + evals, + num_vars, + claim: wrong_claim, + }; + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let result = verify_sumcheck::(&proof, &verifier, &mut vt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }); + + assert!(result.is_err()); +} + +/// End-to-end sumcheck over 2^20 random field elements. +/// +/// The prover holds a multilinear polynomial f with 2^20 evaluations and +/// proves that Σ_{b ∈ {0,1}^20} f(b) = claimed_sum. The verifier checks the +/// proof using only the proof transcript and the oracle evaluation f(r). +#[test] +fn e2e_sumcheck_2_pow_20() { + let num_vars = 20; + let n: usize = 1 << num_vars; // 1,048,576 + + let mut rng = StdRng::seed_from_u64(42); + let evals: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let claim: F = evals.iter().copied().fold(F::zero(), |a, b| a + b); + + let t0 = Instant::now(); + + let mut prover = DenseSumcheckProver { + evals: evals.clone(), + num_vars, + }; + let mut prover_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let (proof, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut prover_transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let prove_time = t0.elapsed(); + + // Proof is just 20 compressed univariate polynomials (degree 1 each). + assert_eq!(proof.round_polys.len(), num_vars); + + // Sanity: final claim must equal f evaluated at the challenge point. + let oracle_eval = multilinear_eval(&evals, &prover_challenges).unwrap(); + assert_eq!(final_claim, oracle_eval); + + let t1 = Instant::now(); + + let verifier = DenseSumcheckVerifier { + evals, + num_vars, + claim, + }; + let mut verifier_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let verifier_challenges = + verify_sumcheck::(&proof, &verifier, &mut verifier_transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let verify_time = t1.elapsed(); + + assert_eq!(prover_challenges, verifier_challenges); + + eprintln!( + "[e2e_sumcheck_2_pow_20] n=2^{num_vars}={n} \ + prove={prove_time:.2?} verify={verify_time:.2?} \ + rounds={} degree=1", + proof.round_polys.len() + ); +} diff --git a/tests/sumcheck_prover_driver.rs b/tests/sumcheck_prover_driver.rs new file mode 100644 index 00000000..3abcaf53 --- /dev/null +++ b/tests/sumcheck_prover_driver.rs @@ -0,0 +1,97 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{ + prove_sumcheck, Blake2bTranscript, SumcheckInstanceProver, Transcript, UniPoly, +}; +use hachi_pcs::{FieldCore, FieldSampling}; +use rand::rngs::StdRng; +use rand::SeedableRng; + +type F = Fp64<4294967197>; + +/// A tiny prover-side sumcheck instance for a multilinear function in evaluation-table form. +/// +/// Variable order convention: the current round binds the least-significant index bit first, +/// i.e. pairs are `(i<<1)|0` and `(i<<1)|1` (matches the common LSB-first sumcheck table fold). +struct DenseTableSumcheck { + table: Vec, +} + +impl DenseTableSumcheck { + fn new(table: Vec) -> Self { + assert!(table.len().is_power_of_two()); + Self { table } + } +} + +impl SumcheckInstanceProver for DenseTableSumcheck { + fn num_rounds(&self) -> usize { + self.table.len().trailing_zeros() as usize + } + + fn degree_bound(&self) -> usize { + 1 + } + + fn input_claim(&self) -> F { + self.table.iter().copied().fold(F::zero(), |a, b| a + b) + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: F) -> UniPoly { + let half = self.table.len() / 2; + let mut s0 = F::zero(); + let mut s1 = F::zero(); + for i in 0..half { + s0 += self.table[i << 1]; + s1 += self.table[(i << 1) | 1]; + } + UniPoly::from_coeffs(vec![s0, s1 - s0]) + } + + fn ingest_challenge(&mut self, _round: usize, r_round: F) { + let half = self.table.len() / 2; + let mut next = Vec::with_capacity(half); + let one_minus = F::one() - r_round; + for i in 0..half { + let v0 = self.table[i << 1]; + let v1 = self.table[(i << 1) | 1]; + next.push(one_minus * v0 + r_round * v1); + } + self.table = next; + } +} + +#[test] +fn prover_driver_produces_proof_that_verifier_replays() { + let mut rng = StdRng::seed_from_u64(2026); + let num_rounds = 8usize; + let n = 1usize << num_rounds; + + let table: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let mut prover_inst = DenseTableSumcheck::new(table.clone()); + let mut prover_t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (proof, r_vec, final_claim) = + prove_sumcheck::(&mut prover_inst, &mut prover_t, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + // After folding all variables, the table should be a single value equal to f(r*). + assert_eq!(prover_inst.table.len(), 1); + assert_eq!(final_claim, prover_inst.table[0]); + + // Verifier replay must derive the same (final_claim, r_vec). + let initial_claim = table.iter().copied().fold(F::zero(), |acc, x| acc + x); + let mut verifier_t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + verifier_t.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &initial_claim); + let (final_claim_v, r_vec_v) = proof + .verify::(initial_claim, num_rounds, 1, &mut verifier_t, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + assert_eq!(r_vec_v, r_vec); + assert_eq!(final_claim_v, final_claim); +} diff --git a/tests/transcript.rs b/tests/transcript.rs new file mode 100644 index 00000000..50335bec --- /dev/null +++ b/tests/transcript.rs @@ -0,0 +1,136 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{Blake2bTranscript, KeccakTranscript, Transcript}; + +type F = Fp64<4294967197>; + +fn sample_schedule>(transcript: &mut T) -> F { + transcript.append_bytes(labels::ABSORB_COMMITMENT, b"commitment-a"); + transcript.append_bytes(labels::ABSORB_COMMITMENT, b"commitment-b"); + transcript.append_serde(labels::ABSORB_EVALUATION_CLAIMS, &42u64); + let rho = transcript.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + + transcript.append_bytes(labels::ABSORB_RING_SWITCH_MESSAGE, b"ring-switch"); + let zeta = transcript.challenge_scalar(labels::CHALLENGE_RING_SWITCH); + + transcript.append_field(labels::ABSORB_SUMCHECK_ROUND, &(rho + zeta)); + let r = transcript.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + + transcript.append_field(labels::ABSORB_STOP_CONDITION, &r); + transcript.challenge_scalar(labels::CHALLENGE_STOP_CONDITION) +} + +#[test] +fn transcript_is_deterministic_for_identical_schedule() { + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let c1 = sample_schedule(&mut t1); + let c2 = sample_schedule(&mut t2); + assert_eq!(c1, c2); +} + +#[test] +fn transcript_differs_when_label_changes() { + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"same-bytes"); + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"same-bytes"); + let c1 = t1.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + let c2 = t2.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + assert_ne!(c1, c2); +} + +#[test] +fn transcript_differs_when_absorb_order_changes() { + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"A"); + t1.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"B"); + + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"B"); + t2.append_bytes(labels::ABSORB_COMMITMENT, b"A"); + + let c1 = t1.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + let c2 = t2.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + assert_ne!(c1, c2); +} + +#[test] +fn transcript_reset_restores_domain_state() { + let mut t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + t.append_bytes(labels::ABSORB_COMMITMENT, b"before-reset"); + let _ = t.challenge_scalar(labels::CHALLENGE_STOP_CONDITION); + + t.reset(labels::DOMAIN_HACHI_PROTOCOL); + let after_reset = sample_schedule(&mut t); + + let mut fresh = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let fresh_challenge = sample_schedule(&mut fresh); + assert_eq!(after_reset, fresh_challenge); +} + +#[test] +fn keccak_transcript_is_deterministic_for_identical_schedule() { + let mut t1 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let c1 = sample_schedule(&mut t1); + let c2 = sample_schedule(&mut t2); + assert_eq!(c1, c2); +} + +#[test] +fn keccak_transcript_differs_when_label_changes() { + let mut t1 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"same-bytes"); + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"same-bytes"); + let c1 = t1.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + let c2 = t2.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + assert_ne!(c1, c2); +} + +#[test] +fn keccak_transcript_differs_when_absorb_order_changes() { + let mut t1 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"A"); + t1.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"B"); + + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"B"); + t2.append_bytes(labels::ABSORB_COMMITMENT, b"A"); + + let c1 = t1.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + let c2 = t2.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + assert_ne!(c1, c2); +} + +#[test] +fn keccak_transcript_reset_restores_domain_state() { + let mut t = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + t.append_bytes(labels::ABSORB_COMMITMENT, b"before-reset"); + let _ = t.challenge_scalar(labels::CHALLENGE_STOP_CONDITION); + + t.reset(labels::DOMAIN_HACHI_PROTOCOL); + let after_reset = sample_schedule(&mut t); + + let mut fresh = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let fresh_challenge = sample_schedule(&mut fresh); + assert_eq!(after_reset, fresh_challenge); +} + +#[test] +fn blake2b_and_keccak_diverge_on_same_schedule() { + let mut blake = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut keccak = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let b = sample_schedule(&mut blake); + let k = sample_schedule(&mut keccak); + assert_ne!(b, k); +}