diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7aa166de..8d38457b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,6 +6,9 @@ on: pull_request: branches: ["**", main] +permissions: + contents: read + env: RUSTFLAGS: -D warnings CARGO_TERM_COLOR: always @@ -19,8 +22,8 @@ jobs: name: Format runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: actions-rust-lang/setup-rust-toolchain@a0b538fa0b742a6aa35d6e2c169b4bd06d225a98 # v1 with: components: rustfmt - name: Check formatting @@ -30,21 +33,21 @@ jobs: name: Clippy runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: actions-rust-lang/setup-rust-toolchain@a0b538fa0b742a6aa35d6e2c169b4bd06d225a98 # v1 with: components: clippy - name: Clippy (all features) - run: cargo clippy -q --message-format=short --all-features --all-targets -- -D warnings + run: cargo clippy --all --all-targets --all-features -- -D warnings - name: Clippy (no default features) - run: cargo clippy -q --message-format=short --no-default-features --lib -- -D warnings + run: cargo clippy --all --all-targets --no-default-features -- -D warnings doc: name: Documentation runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: actions-rust-lang/setup-rust-toolchain@a0b538fa0b742a6aa35d6e2c169b4bd06d225a98 # v1 - name: Build documentation run: cargo doc -q --no-deps --all-features env: @@ -54,9 +57,9 @@ jobs: name: Test runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: actions-rust-lang/setup-rust-toolchain@a0b538fa0b742a6aa35d6e2c169b4bd06d225a98 # v1 - name: Install cargo-nextest - uses: taiki-e/install-action@nextest + uses: taiki-e/install-action@f092c064826410a38929a5791d2c0225b94432fe # nextest - name: Run tests - run: cargo nextest run -q --all-features + run: cargo nextest run --all-features diff --git a/.github/workflows/onehot-bench.yml b/.github/workflows/onehot-bench.yml new file mode 100644 index 00000000..f3765f0f --- /dev/null +++ b/.github/workflows/onehot-bench.yml @@ -0,0 +1,277 @@ +name: Onehot 32 Variables Benchmark + +on: + push: + branches: [main] + pull_request: + branches: ["**", main] + workflow_dispatch: + +permissions: + actions: read + contents: read + issues: write + pull-requests: write + +env: + CARGO_TERM_COLOR: always + HACHI_BENCH_ARTIFACT_NAME: onehot-bench-32-variables-data + HACHI_BENCH_MODE: onehot + HACHI_BENCH_NUM_VARS: 32 + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + bench: + name: Onehot 32 Variables (1-of-256) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + with: + fetch-depth: 0 + + - name: Initialize benchmark paths + run: | + echo "HACHI_BENCH_ARTIFACT_DIR=$RUNNER_TEMP/onehot-bench-artifact" >> "$GITHUB_ENV" + echo "HACHI_BENCH_MAIN_BASELINE_DIR=$RUNNER_TEMP/onehot-bench-main-baseline" >> "$GITHUB_ENV" + echo "HACHI_BENCH_PREVIOUS_RUN_DIR=$RUNNER_TEMP/onehot-bench-previous-run" >> "$GITHUB_ENV" + echo "HACHI_BENCH_REPORT=$RUNNER_TEMP/onehot-bench.md" >> "$GITHUB_ENV" + echo "HACHI_BENCH_COMMENT=$RUNNER_TEMP/onehot-bench-comment.md" >> "$GITHUB_ENV" + echo "HACHI_BENCH_SOURCE_SHA=$GITHUB_SHA" >> "$GITHUB_ENV" + echo "HACHI_BENCH_SOURCE_BRANCH=${{ github.head_ref || github.ref_name }}" >> "$GITHUB_ENV" + echo "HACHI_BENCH_BASE_REF=" >> "$GITHUB_ENV" + echo "HACHI_BENCH_MERGE_BASE_SHA=" >> "$GITHUB_ENV" + { + echo "HACHI_BENCH_SOURCE_SUBJECT<> "$GITHUB_ENV" + + - name: Determine PR benchmark merge base + if: github.event_name == 'pull_request' + run: | + echo "HACHI_BENCH_BASE_REF=${{ github.event.pull_request.base.ref }}" >> "$GITHUB_ENV" + merge_base_sha="$(git merge-base "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" || true)" + if [ -n "$merge_base_sha" ]; then + echo "HACHI_BENCH_MERGE_BASE_SHA=$merge_base_sha" >> "$GITHUB_ENV" + fi + + - uses: actions-rust-lang/setup-rust-toolchain@a0b538fa0b742a6aa35d6e2c169b4bd06d225a98 # v1 + + - name: Build profile example + run: cargo build --release --quiet --example profile + + - name: Run onehot 32 variables benchmark (1-of-256) + run: | + python3 scripts/onehot_bench_report.py run \ + --binary ./target/release/examples/profile \ + --output-dir "$HACHI_BENCH_ARTIFACT_DIR" \ + --mode "$HACHI_BENCH_MODE" \ + --num-vars "$HACHI_BENCH_NUM_VARS" + + - name: Upload benchmark artifact + if: always() + continue-on-error: true + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: ${{ env.HACHI_BENCH_ARTIFACT_NAME }} + path: ${{ env.HACHI_BENCH_ARTIFACT_DIR }} + if-no-files-found: warn + retention-days: 30 + + - name: Determine comparison baseline artifacts + if: always() && github.event_name == 'pull_request' + continue-on-error: true + id: bench-baselines + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + with: + script: | + const { owner, repo } = context.repo; + const artifactName = process.env.HACHI_BENCH_ARTIFACT_NAME; + const workflowName = process.env.GITHUB_WORKFLOW; + const currentRunId = Number(process.env.GITHUB_RUN_ID); + const currentSha = process.env.GITHUB_SHA; + const pullRequest = context.payload.pull_request; + const headRef = pullRequest.head.ref; + const baseRef = process.env.HACHI_BENCH_BASE_REF; + const mergeBaseSha = process.env.HACHI_BENCH_MERGE_BASE_SHA; + + function setBaselineOutput(prefix, run, label) { + core.setOutput(`${prefix}-run-id`, run ? String(run.id) : ''); + core.setOutput(`${prefix}-sha`, run ? run.head_sha : ''); + core.setOutput(`${prefix}-label`, run ? label : ''); + } + + async function firstRunWithArtifact(runs) { + for (const run of runs) { + if (run.id === currentRunId) { + continue; + } + if (run.name !== workflowName || run.conclusion !== 'success') { + continue; + } + const artifactsResponse = await github.rest.actions.listWorkflowRunArtifacts({ + owner, + repo, + run_id: run.id, + per_page: 100, + }); + const artifact = artifactsResponse.data.artifacts.find(candidate => + candidate.name === artifactName && !candidate.expired + ); + if (artifact) { + return run; + } + } + return null; + } + + const prRunsResponse = await github.rest.actions.listWorkflowRunsForRepo({ + owner, + repo, + event: 'pull_request', + branch: headRef, + status: 'completed', + per_page: 100, + }); + const previousPrCandidates = prRunsResponse.data.workflow_runs.filter(run => { + if (run.head_sha === currentSha) { + return false; + } + const prs = run.pull_requests || []; + return prs.length === 0 || prs.some(pr => pr.number === pullRequest.number); + }); + const previousPrRun = await firstRunWithArtifact(previousPrCandidates); + if (!previousPrRun) { + core.info('No previous PR baseline artifact found.'); + } + setBaselineOutput('previous', previousPrRun, 'the previous successful PR update'); + + const baseRunsResponse = await github.rest.actions.listWorkflowRunsForRepo({ + owner, + repo, + event: 'push', + branch: baseRef, + status: 'completed', + per_page: 100, + }); + const baseCandidates = baseRunsResponse.data.workflow_runs.filter(run => + run.name === workflowName && run.conclusion === 'success' + ); + const exactMergeBaseCandidates = mergeBaseSha + ? baseCandidates.filter(run => run.head_sha === mergeBaseSha) + : []; + const exactMergeBaseRun = + exactMergeBaseCandidates.length > 0 + ? await firstRunWithArtifact(exactMergeBaseCandidates) + : null; + const fallbackBaseCandidates = mergeBaseSha + ? baseCandidates.filter(run => run.head_sha !== mergeBaseSha) + : baseCandidates; + const mainRun = + exactMergeBaseRun ?? await firstRunWithArtifact(fallbackBaseCandidates); + if (!mainRun) { + core.info('No main baseline artifact found.'); + } + const mainLabel = + mainRun && Boolean(mergeBaseSha) && mainRun.head_sha === mergeBaseSha + ? `merge-base on \`${baseRef}\`` + : `the latest successful \`${baseRef}\` run`; + setBaselineOutput('main', mainRun, mainLabel); + + - name: Download main baseline artifact + if: always() && steps.bench-baselines.outputs.main-run-id != '' + continue-on-error: true + uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4 + with: + name: ${{ env.HACHI_BENCH_ARTIFACT_NAME }} + path: ${{ env.HACHI_BENCH_MAIN_BASELINE_DIR }} + run-id: ${{ steps.bench-baselines.outputs.main-run-id }} + github-token: ${{ github.token }} + + - name: Download previous run artifact + if: always() && steps.bench-baselines.outputs.previous-run-id != '' + continue-on-error: true + uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4 + with: + name: ${{ env.HACHI_BENCH_ARTIFACT_NAME }} + path: ${{ env.HACHI_BENCH_PREVIOUS_RUN_DIR }} + run-id: ${{ steps.bench-baselines.outputs.previous-run-id }} + github-token: ${{ github.token }} + + - name: Render benchmark report + if: always() + continue-on-error: true + env: + HACHI_BENCH_MAIN_BASELINE_LABEL: ${{ steps.bench-baselines.outputs.main-label }} + HACHI_BENCH_MAIN_BASELINE_SHA: ${{ steps.bench-baselines.outputs.main-sha }} + HACHI_BENCH_PREVIOUS_BASELINE_LABEL: ${{ steps.bench-baselines.outputs.previous-label }} + HACHI_BENCH_PREVIOUS_BASELINE_SHA: ${{ steps.bench-baselines.outputs.previous-sha }} + run: | + if [ ! -f "$HACHI_BENCH_ARTIFACT_DIR/summary.json" ]; then + echo "Benchmark summary not found; benchmark step likely failed." > "$HACHI_BENCH_REPORT" + else + python3 scripts/onehot_bench_report.py render \ + "$HACHI_BENCH_ARTIFACT_DIR/summary.json" \ + --main-baseline-dir "$HACHI_BENCH_MAIN_BASELINE_DIR" \ + --previous-baseline-dir "$HACHI_BENCH_PREVIOUS_RUN_DIR" > "$HACHI_BENCH_REPORT" + fi + cp "$HACHI_BENCH_REPORT" "$HACHI_BENCH_ARTIFACT_DIR/report.md" + cat "$HACHI_BENCH_REPORT" >> "$GITHUB_STEP_SUMMARY" + { + echo '' + echo + cat "$HACHI_BENCH_REPORT" + echo + echo '> Posted by Cursor assistant (model: GPT-5.4) on behalf of the user (Quang Dao) with approval.' + } > "$HACHI_BENCH_COMMENT" + + - name: Upsert benchmark PR comment + if: >- + always() && + github.event_name == 'pull_request' && + github.event.pull_request.head.repo.full_name == github.repository + continue-on-error: true + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + with: + script: | + const fs = require('fs'); + const commentPath = process.env.HACHI_BENCH_COMMENT; + if (!fs.existsSync(commentPath)) { + core.info('No benchmark comment body found.'); + return; + } + + const marker = ''; + const body = fs.readFileSync(commentPath, 'utf8'); + const { owner, repo } = context.repo; + const issue_number = context.issue.number; + try { + const comments = await github.paginate( + github.rest.issues.listComments, + { owner, repo, issue_number, per_page: 100 } + ); + const existing = comments.find(comment => + comment.user?.login === 'github-actions[bot]' && comment.body?.includes(marker) + ); + + if (existing) { + await github.rest.issues.updateComment({ + owner, + repo, + comment_id: existing.id, + body, + }); + } else { + await github.rest.issues.createComment({ + owner, + repo, + issue_number, + body, + }); + } + } catch (error) { + core.warning(`Skipping benchmark PR comment upsert: ${error.message}`); + } diff --git a/.gitignore b/.gitignore index bd04f281..98e812bf 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,7 @@ .urs PUBLISH_CHECKLIST.md + +profile_traces/ + +.cursor/ diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..68f667fa --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,35 @@ +# AGENTS.md + +**Compatibility notice (explicit): This repo makes NO backward-compatibility guarantees. Breaking changes are allowed and expected.** + +## Project Overview + +Hachi is a lattice-based polynomial commitment scheme (PCS) with transparent setup and post-quantum security. Built in Rust. Intended to replace Dory in Jolt. + +## Essential Commands + +```bash +cargo clippy --all --message-format=short -q -- -D warnings +cargo fmt -q +cargo test # no nextest yet +``` + +## Crate Structure + +Two workspace members: `hachi-pcs` (root) and `derive` (proc macros). + +- `src/primitives/` — Core traits: `FieldCore`, `Module`, `MultilinearLagrange`, `Transcript`, serialization +- `src/algebra/` — Concrete backends: prime fields, extension fields, cyclotomic rings, NTT, domains +- `src/protocol/` — Protocol layer: commitment, prover, verifier, opening (ring-switch), challenges, transcript +- `src/error.rs` — Error types + +## Key Abstractions + +- `CommitmentScheme` / `StreamingCommitmentScheme` — top-level PCS traits +- `FieldCore` + `PseudoMersenneField` + `Module` — arithmetic over lattice-friendly fields and rings +- `MultilinearLagrange` — multilinear polynomial in Lagrange basis +- `Transcript` — Fiat-Shamir + +## Feature Flags + +- `parallel` — Rayon parallelization diff --git a/CLAUDE.md b/CLAUDE.md new file mode 120000 index 00000000..47dc3e3d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/CONSTANT_TIME_NOTES.md b/CONSTANT_TIME_NOTES.md new file mode 100644 index 00000000..dac24600 --- /dev/null +++ b/CONSTANT_TIME_NOTES.md @@ -0,0 +1,42 @@ +# Constant-Time Review Notes (Phase 0/1 Algebra) + +This note tracks timing-sensitive implementation decisions for the current +algebra and ring stack. + +## Reviewed Components + +- `src/algebra/fields/fp32.rs` +- `src/algebra/fields/fp64.rs` +- `src/algebra/fields/fp128.rs` +- `src/algebra/ntt/prime.rs` +- `src/algebra/ntt/butterfly.rs` +- `src/algebra/ring/cyclotomic.rs` +- `src/algebra/ring/crt_ntt_repr.rs` + +## Current State + +- Branchless primitives are in place for: + - `Fp32/Fp64/Fp128` add/sub/neg raw helpers. + - `Fp128` multiplication reduction (`reduce_u256`) with branchless conditional subtract. + - `Fp32/Fp64` multiplication reduction (division-free fixed-iteration paths). + - NTT helper operations `csubp`, `caddp`, and `center`. +- NTT butterfly arithmetic runs in fixed loop structure independent of data. +- Ring multiplication (`CyclotomicRing`) is fixed-structure schoolbook over `D`. +- CRT reconstruction inner accumulation now uses fixed-trip, branchless + modular add/mul-by-small-factor helpers. +- Prime fields now expose `Invertible::inv_or_zero()` for secret-bearing + inversion use-cases without input-dependent branching on zero. +- CRT reconstruction final projection now uses a division-free fixed-iteration + reducer (`reduce_u128_divfree`) instead of `% q`. + +## Known Timing Risks / Follow-ups + +- `FieldCore::inv()` still returns `Option` and therefore branches on zero; + treat that API as public-value oriented. Use `Invertible::inv_or_zero()` + in secret-dependent paths. + +## Action Items Before Production-Critical Use + +1. Wire secret-bearing call sites to `Invertible::inv_or_zero()` as + protocol code matures. +2. Add dedicated CT review tests/checklists for any arithmetic subsystem changes. diff --git a/Cargo.lock b/Cargo.lock index a505ce5c..0a1f995c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,29 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -11,6 +34,33 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocative" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fac2ce611db8b8cee9b2aa886ca03c924e9da5e5295d0dbd0526e5d0b0710f7" +dependencies = [ + "allocative_derive", + "ctor", +] + +[[package]] +name = "allocative_derive" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe233a377643e0fc1a56421d7c90acdec45c291b30345eb9f08e8d0ddce5a4ab" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "anes" version = "0.1.6" @@ -23,12 +73,154 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +[[package]] +name = "ark-bn254" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-serialize", + "ark-std", +] + +[[package]] +name = "ark-ec" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "ahash", + "ark-ff", + "ark-poly", + "ark-serialize", + "ark-std", + "educe", + "fnv", + "hashbrown", + "itertools 0.13.0", + "num-bigint", + "num-integer", + "num-traits", + "zeroize", +] + +[[package]] +name = "ark-ff" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "allocative", + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std", + "arrayvec", + "digest", + "educe", + "itertools 0.13.0", + "num-bigint", + "num-traits", + "paste", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "quote", + "syn 2.0.110", +] + +[[package]] +name = "ark-ff-macros" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "ark-poly" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "ahash", + "ark-ff", + "ark-serialize", + "ark-std", + "educe", + "fnv", + "hashbrown", +] + +[[package]] +name = "ark-serialize" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "ark-serialize-derive", + "ark-std", + "arrayvec", + "digest", + "num-bigint", +] + +[[package]] +name = "ark-serialize-derive" +version = "0.5.0" +source = "git+https://github.com/a16z/arkworks-algebra?branch=dev/twist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "ark-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -74,6 +266,16 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "4.5.51" @@ -99,6 +301,15 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "criterion" version = "0.5.1" @@ -111,7 +322,7 @@ dependencies = [ "clap", "criterion-plot", "is-terminal", - "itertools", + "itertools 0.10.5", "num-traits", "once_cell", "oorandom", @@ -132,7 +343,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" dependencies = [ "cast", - "itertools", + "itertools 0.10.5", ] [[package]] @@ -166,12 +377,100 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "ctor" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d2301688392eb071b0bf1a37be05c469d3cc4dbbd95df672fe28ab021e6a096" +dependencies = [ + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.110", +] + [[package]] name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "enum-ordinalize" +version = "4.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a1091a7bb1f8f2c4b28f1fe2cef4980ca2d410a3d727d67ecc3178c9b0800f0" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ca9601fb2d62598ee17836250842873a413586e5d7ed88b356e38ddbb0ec631" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -189,20 +488,29 @@ version = "0.1.0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", ] [[package]] name = "hachi-pcs" version = "0.1.0" dependencies = [ + "aes", + "ark-bn254", + "ark-ff", + "blake2", "criterion", + "ctr", "hachi-derive", + "num-bigint", "rand", "rand_core", "rayon", + "sha3", "thiserror", "tracing", + "tracing-chrome", + "tracing-subscriber", ] [[package]] @@ -216,12 +524,30 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "allocator-api2", +] + [[package]] name = "hermit-abi" version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "generic-array", +] + [[package]] name = "is-terminal" version = "0.4.17" @@ -242,6 +568,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.15" @@ -258,18 +593,76 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "keccak" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" +dependencies = [ + "cpufeatures", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "libc" version = "0.2.177" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + [[package]] name = "memchr" version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -291,6 +684,12 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -479,7 +878,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", ] [[package]] @@ -495,6 +894,48 @@ dependencies = [ "serde_core", ] +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.110" @@ -523,7 +964,16 @@ checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", ] [[package]] @@ -538,9 +988,9 @@ dependencies = [ [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -549,30 +999,89 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", +] + +[[package]] +name = "tracing-chrome" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf0a738ed5d6450a9fb96e86a23ad808de2b727fd1394585da5cdd6788ffe724" +dependencies = [ + "serde_json", + "tracing-core", + "tracing-subscriber", ] [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", + "valuable", ] +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicode-ident" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "walkdir" version = "2.5.0" @@ -621,7 +1130,7 @@ dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn", + "syn 2.0.110", "wasm-bindgen-shared", ] @@ -685,5 +1194,25 @@ checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.110", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", ] diff --git a/Cargo.toml b/Cargo.toml index 47abb4d8..c35b80bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,9 +6,11 @@ resolver = "2" name = "hachi-pcs" version = "0.1.0" edition = "2021" -rust-version = "1.75" +rust-version = "1.88" authors = [ "Markos Georghiades ", + "Quang Dao ", + "Omid Bodaghi ", ] license = "Apache-2.0 OR MIT" description = "A high performance and modular implementation of the Hachi polynomial commitment scheme." @@ -32,19 +34,52 @@ include = [ all-features = true [features] -default = [] +default = ["parallel"] parallel = ["dep:rayon"] +disk-persistence = [] [dependencies] thiserror = "2.0" -rand_core = "0.6" +rand_core = { version = "0.6", features = ["getrandom"] } hachi-derive = { version = "0.1.0", path = "derive" } tracing = "0.1" rayon = { version = "1.10", optional = true } +blake2 = "0.10.6" +sha3 = "0.10.8" +aes = "0.8.4" +ctr = "0.9.2" [dev-dependencies] rand = "0.8" criterion = { version = "0.5", features = ["html_reports"] } +num-bigint = "0.4.6" +ark-bn254 = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout", features = ["scalar_field"] } +ark-ff = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" } +tracing-chrome = "0.7" +tracing-subscriber = { version = "0.3", features = ["env-filter", "registry"] } + +[[example]] +name = "profile" + +[[bench]] +name = "ring_ntt" +harness = false + +[[bench]] +name = "field_arith" +harness = false + +[[bench]] +name = "fp64_reduce_probe" +harness = false + +[[bench]] +name = "hachi_e2e" +harness = false + +[[bench]] +name = "labrador_jl_aggregation" +harness = false [lints.rust] missing_docs = "warn" diff --git a/HACHI_PROGRESS.md b/HACHI_PROGRESS.md new file mode 100644 index 00000000..afca63e6 --- /dev/null +++ b/HACHI_PROGRESS.md @@ -0,0 +1,200 @@ +## Hachi PCS implementation progress + +This file is the **single source of truth** for implementation status and near-term priorities. + +### Goals (project-level) + +- **Production-ready implementation**: correctness, security, maintainability, and performance are first-class goals. +- **Standalone codebase**: implementation and comments should stand on their own; external acknowledgements live in `README.md`. +- **Constant-time cryptographic core**: arithmetic and protocol-critical paths must be constant-time with respect to secret data. +- **No shortcuts / no fallback design**: avoid temporary or degraded code paths in the core implementation. + +### Non-negotiable requirements + +- **Constant-time discipline** + - No secret-dependent branches or memory access patterns in cryptographic hot paths. + - No secret-indexed table lookups; table access patterns must be independent of secret data. + - Keep data representations and reductions explicit and auditable for timing behavior. + - Add targeted tests/reviews for constant-time-sensitive code as features land. +- **Code quality bar** + - Clear naming, explicit invariants, small cohesive modules, and API docs for public interfaces. + - No placeholder crypto logic in mainline code (no "temporary" arithmetic shortcuts). + - Tests are required for correctness-critical arithmetic before dependent protocol code is built. + - No section-banner comments (e.g., `// ---- Section ----`, `// === ... ===`). Let the code and doc-comments speak for themselves. +- **Standalone implementation policy** + - Do not mention external inspirations/ports in core code comments. + - Keep terminology and structure internally coherent and project-native. + - Keep external attribution limited to dedicated docs (for now: `README.md` acknowledgements). +- **Git discipline** + - Do not commit or push without explicit user approval. + +### Implementation workflow (cautious + approval-driven) + +- Before each major subsystem, present implementation options with trade-offs. +- Seek explicit approval before proceeding with a selected option. +- Pause at milestone boundaries for review and feedback before continuing. +- Prefer slow, verifiable progress over rapid, high-risk changes. +- Ask for user input frequently when requirements are ambiguous or involve design trade-offs. + +### Definition of Done (all crypto-critical work) + +- **Security / constant-time** + - Secret-independent control flow and memory access in cryptographic paths. + - Constant-time review notes included for non-trivial arithmetic/ring changes. +- **Correctness** + - Unit tests for edge cases and algebraic identities. + - Cross-check vectors/reference checks added where practical. +- **Code quality** + - Clear naming, explicit invariants, and no placeholder logic in core paths. + - Public interfaces documented sufficiently for safe usage. +- **Performance** + - Hot-path performance impact evaluated (benchmark or measured rationale). +- **Tooling + CI** + - `cargo fmt --all --check` passes. + - `cargo clippy --all --all-targets --all-features` passes. + - `cargo test` (or targeted suite for touched modules) passes. +- **Process** + - Implementation options reviewed with user before major subsystem changes. + - Milestone update recorded in this file. + +### Scope (current) + +- **Implemented so far (Phase 0 + Phase 1 functional core)**: prime fields (32/64/128-bit representations), extension fields, cyclotomic `R_q = Z_q[X]/(X^d + 1)`, CRT+NTT representation, backend/domain layering, ring automorphisms, and functional gadget decomposition. +- **Phase 2+ protocol status**: interface scaffold plus ring-native §4.1 commitment core are present (`Transcript`, Blake2b/Keccak backends, phase-grounded labels, `RingCommitmentScheme`, config layer, and setup/commit implementation). Sumcheck core building blocks (univariate messages + transcript-driving prover/verifier driver) are now implemented, with tests. Open-check prover/verifier paths remain stubbed. +- **Deferred future phase**: integration into Jolt (replacement of Dory with Hachi) is intentionally out of current execution scope; cross-repo analysis is design input only. + +### Critical review snapshot (2026-02-13) + +- **Phase 1 functional milestone appears complete** + - Ring/gadget components listed in Phase 1 are implemented and currently checked off. + - Conversion and arithmetic paths in coefficient and CRT+NTT domains are exercised by passing tests. +- **Not yet "production-ready" despite functional completion** + - Constant-time hardening follow-ups narrowed: secret-bearing call-sites still need to migrate from `FieldCore::inv()` to `Invertible::inv_or_zero()` as protocol code lands (see `CONSTANT_TIME_NOTES.md`). + - Current ring multiplication in coefficient form remains `O(D^2)` schoolbook (`src/algebra/ring/cyclotomic.rs`), with CRT+NTT available as the faster domain path. +- **Tooling/quality gate status (current branch snapshot)** + - `cargo test` passes, including protocol transcript/label/commitment contract tests and new ring-commitment core/config/stub tests. + - `cargo fmt --all --check` passes. + - `cargo clippy --all --all-targets --all-features` passes. +- **Phase 2 scaffold + commitment core landed; proof-system work still pending** + - `src/protocol/*` now provides transcript + commitment abstraction boundaries with `Transcript` naming. + - Two transcript backends are wired (`Blake2bTranscript`, `KeccakTranscript`) with deterministic replay/order/reset tests. + - Hachi-native labels are now calibrated to paper-stage phases (§4.1, §4.2, §4.3, §4.5). + - Commitment absorption is label-directed at call sites (`AppendToTranscript` no longer hardcodes commitment labels). + - Ring-native commitment setup/commit flow for §4.1 is implemented in `src/protocol/commitment/commit.rs` behind `RingCommitmentScheme`. + - Sumcheck core module landed (`src/protocol/sumcheck.rs`) with unit/integration tests (`tests/sumcheck_core.rs`, `tests/sumcheck_prover_driver.rs`). + - Prover/verifier split folders are wired with explicit stubs (`src/protocol/prover/stub.rs`, `src/protocol/verifier/stub.rs`) for future open-check implementation. +- **Conclusion** + - Treat **Phase 1 as functionally complete**. + - Treat **Phase 2 as active/in-progress** (commitment core implemented; prove/verify and later reductions still open). + - Remaining strict CT follow-ups stay tracked in `CONSTANT_TIME_NOTES.md`. + +### Status board + +#### Phase 0 — Algebra + +- [x] Prime field `Fp32` (u32 storage; u64 mul) implementing `FieldCore + CanonicalField` (`src/algebra/fields/fp32.rs`) +- [x] Prime field `Fp64` (u64 storage; u128 mul) implementing `FieldCore + CanonicalField` (`src/algebra/fields/fp64.rs`) +- [x] Prime field `Fp128` (u128 storage; 256-bit intermediate) implementing `FieldCore + CanonicalField` (`src/algebra/fields/fp128.rs`, `src/algebra/fields/u256.rs`) +- [x] Branchless constant-time `add_raw`, `sub_raw`, `neg` for all field types +- [x] Constant-time inversion helper for prime fields: `Invertible::inv_or_zero()` (`src/primitives/arithmetic.rs`, `src/algebra/fields/fp*.rs`) +- [x] Division-free fixed-iteration reduction for `Fp32/Fp64` multiplication paths +- [x] Division-free fixed-iteration CRT final projection (replaced `% q` in scalar reconstruction path) +- [x] Rejection-sampled `FieldSampling::sample()` for all field types (no modular bias) +- [x] Pow2Offset pseudo-Mersenne registry + aliases (`q = 2^k - offset`, bounded `k <= 128`, `q % 8 == 5`) (`src/algebra/fields/pseudo_mersenne.rs`) +- [x] Constant-time review notes for current algebra/ring paths (`CONSTANT_TIME_NOTES.md`) +- [x] Deterministic parameter presets + - [x] `q = 2^32 - 99` constants scaffold (`src/algebra/ntt/tables.rs`) + - [x] `Pow2Offset` presets selected for 64/128-bit path: + - `q = 2^64 - 59` (`POW2_OFFSET_MODULUS_64`) + - `q = 2^128 - 275` (`POW2_OFFSET_MODULUS_128`) + - source: `src/algebra/fields/pseudo_mersenne.rs` +- [x] `Module` implementations: + - [x] `VectorModule` (fixed-length vectors; `Module` via scalar*vector mul) (`src/algebra/module.rs`) + - [x] `PolyModule` removed from current scope (not needed for near-term Hachi milestones) +- [ ] Extension fields: + - [x] `Fp2` quadratic extension (`src/algebra/fields/ext.rs`) + - [x] `Fp4` tower extension (`src/algebra/fields/ext.rs`) +- [x] Serialization for algebra types (`HachiSerialize` / `HachiDeserialize`) (+ `u128/i128` primitives in `src/primitives/serialization.rs`) +- [x] NTT small-prime arithmetic: Montgomery-like `fpmul`, Barrett-like `fpred`, branchless `csubq`/`caddq`/`center` (`src/algebra/ntt/prime.rs`) +- [x] CRT limb arithmetic: `LimbQ`, `QData` (`src/algebra/ntt/crt.rs`) +- [x] Tests (49 total in `tests/algebra.rs`): + - [x] field arithmetic, identities, distributivity (Fp32/Fp64/Fp128) + - [x] zero inversion returns None + - [x] serialization round-trips (all field types, extensions, Poly, VectorModule) + - [x] Fp2 conjugate, norm, distributivity + - [x] U256 wide multiply and bit access + - [x] LimbQ round-trip, add/sub inverse, QData consistency + - [x] NTT normalize range, fpmul commutativity + - [x] Poly add/sub/neg + - [x] Cyclotomic ring identities and serialization (D=4, D=64) + - [x] NTT forward/inverse round-trips (single prime and all Q32 primes) + - [x] Cyclotomic CRT+NTT full round-trip (`from_ring` -> `to_ring`) + - [x] Scalar backend path equivalence (`*_with_backend` vs default path) + - [x] Pow2Offset profile invariants (`q = 2^k - offset`, `q % 8 == 5`) + - [x] `FieldSampling::sample()` output bound checks + - [x] Checked deserialization rejects non-canonical field encodings + - [x] Galois automorphism checks (`sigma` composition + multiplicativity) + - [x] Functional gadget decompose/recompose round-trip checks + - [x] Sparse `+/-1` challenge support checks (`hamming_weight = omega`) +- [x] Dedicated Pow2Offset primality regression tests (`tests/primality.rs`) + - [x] Miller-Rabin probable-prime checks for all registered Pow2Offset moduli + - [x] Composite sanity rejection checks + +#### Phase 1 — Ring + gadgets (functional core) + +- [x] Cyclotomic ring `Rq` with `X^D = -1` (`src/algebra/ring/cyclotomic.rs`) +- [x] CRT+NTT-domain ring representation + CRT conversion (`src/algebra/ring/crt_ntt_repr.rs`) +- [x] Backend/domain layering for ring execution (`src/algebra/backend/*`, `src/algebra/domains/*`) +- [x] Galois automorphisms `sigma_i: X ↦ X^i` (odd `i`) +- [x] Functional gadget decomposition/recomposition (`G^{-1}` / `G` behavior) for base-`2^d` digits, without materializing dense gadget matrices +- [x] sparse short challenges (paper: `||c||_1 ≤ ω`, sparse ±1) + +#### Phase 2+ — Protocol (later) + +- [x] Protocol module scaffold (`src/protocol/*`) and top-level re-exports +- [x] Transcript interface (`Transcript`) plus Blake2b/Keccak implementations +- [x] Hachi-native transcript label schedule aligned to paper phases (§4.1/§4.2/§4.3/§4.5) +- [x] Commitment trait surface + streaming trait surface + contract tests +- [x] Label-directed transcript absorption for commitments (`AppendToTranscript` takes label at call site) +- [x] ring-native commitment core (`RingCommitmentScheme`, `commit.rs`, config wiring) for §4.1 setup/commit +- [x] protocol prover/verifier folder split with explicit stubs (`prover/stub.rs`, `verifier/stub.rs`) +- [x] ring-commitment tests (`ring_commitment_core`, `ring_commitment_config`, `prover_verifier_stub_contract`) +- [x] sumcheck core building blocks (univariate messages + transcript-driving prover/verifier driver) (`src/protocol/sumcheck.rs`) +- [x] sumcheck core tests (`tests/sumcheck_core.rs`, `tests/sumcheck_prover_driver.rs`) +- [ ] commitment open-check prove/verify implementation (currently stubs) +- [ ] evaluation → linear relation (paper §4.2) +- [ ] ring-switching constraints as sumcheck instances (paper §4.3, Fig. 4–7) +- [ ] recursion / “stop condition” + optional Greyhound composition (§4.5) + +#### Phase 3 — Integration into Jolt (deferred; not active now) + +- [ ] Define compatibility boundary document (what must match Jolt/Dory behavior vs what can remain Hachi-native) +- [ ] Provide Jolt-facing transcript adapter design (`Jolt` transcript pattern ↔ Hachi transcript object) +- [ ] Provide Jolt-facing PCS shim design (`CommitmentScheme`/`StreamingCommitmentScheme` mapping) +- [ ] Add transcript/commitment compatibility tests for integration-readiness (without wiring into Jolt yet) + +### Conventions + +- **Correctness first**: lock arithmetic with tests before touching protocol code. +- **Security first**: enforce constant-time behavior for secret-dependent operations. +- **Lean deps**: avoid heavyweight crypto crates until there is a clear need. +- **Explicit parameter sets**: each field/ring preset lives in code with a clear name and rationale. + +### Module layout + +``` +src/algebra/ +├── backend/ Backend execution traits + scalar backend +├── domains/ Domain-level aliases (coefficient / CRT+NTT) +├── fields/ Prime fields, pseudo-mersenne registry, u256, and extensions +├── ntt/ NTT kernels (butterfly), prime kernels (prime), CRT helpers (crt), presets (tables) +├── module.rs VectorModule +├── poly.rs Poly container +└── ring/ Cyclotomic ring and CRT+NTT representation +``` + +### References + +- Hachi paper: `paper/hachi.pdf` +- Core traits: `src/primitives/arithmetic.rs`, `src/primitives/serialization.rs` + diff --git a/NTT_PRIME_ANALYSIS.md b/NTT_PRIME_ANALYSIS.md new file mode 100644 index 00000000..e6b9c85c --- /dev/null +++ b/NTT_PRIME_ANALYSIS.md @@ -0,0 +1,146 @@ +# NTT Prime Analysis (Pow2Offset / Solinas Context) + +This note records the current analysis for small NTT primes and CRT coverage targets. + +## References + +- NIST ML-KEM: `paper/standards/NIST.FIPS.203.pdf` +- NIST ML-DSA: `paper/standards/NIST.FIPS.204.pdf` +- Current small-prime table: `src/algebra/ntt/tables.rs` +- Labrador generator heuristic: `../labrador/data.py` + +## Why does `2D` divide `p - 1`? + +For negacyclic NTT on `Z_p[X]/(X^D + 1)`, we need a primitive `2D`-th root `psi` such that: + +- `psi^D = -1 (mod p)` +- `psi^(2D) = 1 (mod p)` + +Over prime fields, `F_p^*` is cyclic of size `p - 1`, so an element of order `2D` exists iff: + +- `2D | (p - 1)` + +So yes, the `128 | (p - 1)` condition is directly tied to `D = 64`. + +## What if `D = 1024`? + +Then requirement becomes: + +- `2D = 2048`, so `2048 | (p - 1)`. + +Under the current "small prime" cap (`p < 2^14`), this is extremely restrictive. + +## Why `< 2^14` in current code? + +This is a backend implementation constraint, not a hard NTT math requirement: + +- Current small-prime NTT backend stores modulus/coefficients in signed 16-bit lanes (`i16`). +- It relies on centered signed arithmetic and butterfly add/sub before full normalization. +- Keeping `p < 2^14` leaves practical headroom in those 16-bit operations. +- Current CRT limb code is also radix-`2^14`, matching this design style. + +So the `2^14` cap is about the present `i16` scalar kernel design. If we introduce an `i32` backend, this cap can be raised substantially. + +## Exhaustive counts (for `p < 2^14`) + +We classify exact Solinas as: + +- `p = 2^x - 2^y + 1`. + +Results: + +- `D = 64` (`128 | p-1`) + - all small NTT primes: **31** + - exact Solinas NTT primes: **6** + - all-prime set: + - `257, 641, 769, 1153, 1409, 2689, 3329, 3457, 4481, 4993, 6529, 7297, 7681, 7937, 9473, 9601, 9857, 10369, 10753, 11393, 11777, 12161, 12289, 13313, 13441, 13697, 14081, 14593, 15233, 15361, 16001` + - Solinas set: `257, 769, 7681, 7937, 12289, 15361` +- `D = 256` (`512 | p-1`) + - all small NTT primes: **6** + - exact Solinas NTT primes: **3** + - Solinas set: `7681, 12289, 15361` +- `D = 1024` (`2048 | p-1`) + - all small NTT primes: **1** + - exact Solinas NTT primes: **1** + - Solinas set: `12289` + +Conclusion: for higher `D`, the small-prime pool shrinks rapidly. + +## 30-bit exploration (`p < 2^30`) with NTT constraints + +To assess a larger-prime backend direction, we scanned for primes under `2^30` with: + +- `p ≡ 1 (mod 2D)`. + +Below are the **full outputs of the bounded search run** (top 30 largest primes found by descending scan): + +### `D = 64` (`2D = 128`) + +- Top-30 list: + - `1073741441, 1073739649, 1073738753, 1073736449, 1073735297, 1073734913, 1073732993, 1073732609, 1073731201, 1073731073, 1073730817, 1073728897, 1073727617, 1073726977, 1073722753, 1073719681, 1073717377, 1073716993, 1073713409, 1073712769, 1073712257, 1073710721, 1073708929, 1073707009, 1073703809, 1073702657, 1073702401, 1073698817, 1073696257, 1073693441` +- Coverage for `q = 2^128 - 275`: + - `P > q`: **5** limbs + - `P > 128*q^2`: **9** limbs + +### `D = 1024` (`2D = 2048`) + +- Top-30 list: + - `1073707009, 1073698817, 1073692673, 1073682433, 1073668097, 1073655809, 1073651713, 1073643521, 1073620993, 1073600513, 1073569793, 1073563649, 1073551361, 1073539073, 1073522689, 1073510401, 1073508353, 1073479681, 1073453057, 1073442817, 1073440769, 1073430529, 1073412097, 1073391617, 1073385473, 1073354753, 1073350657, 1073330177, 1073299457, 1073268737` +- Coverage for `q = 2^128 - 275`: + - `P > q`: **5** limbs + - `P > 128*q^2`: **9** limbs + +### Bit-estimate sanity check + +- `ceil(128 / 30) = 5` +- `ceil(263 / 30) = 9` (for `128*q^2 ~ 2^263`) + +This matches the concrete product counts above. + +## CRT size targets for `q = 2^128 - 275` + +Two common thresholds: + +1. Minimal uniqueness target: + - `P = prod(p_i) > q` +2. Labrador conservative heuristic: + - `P > 128 * q^2` (from `data.py`, with `FIXME` comment) + +### Limb counts at `D = 64` with current small-prime pool + +- Using all small NTT primes (`31` available): + - `P > q` achievable with **10** limbs + - `P > 128*q^2` achievable with **20** limbs +- Using only exact Solinas NTT primes (`6` available): + - `P > q`: **not achievable** + - `P > 128*q^2`: **not achievable** + - total product is only about `2^70` + +### Limb counts at `D = 1024` with current small-prime pool + +- Only one qualifying prime (`12289`) under `p < 2^14`, so neither threshold is achievable. + +## What is Labrador's safety margin doing? + +In Labrador code, prime selection stops at: + +- `P > 128 * q^2` + +Interpretation: + +- `q^2` tracks product-scale growth, +- extra factor `128` gives additional headroom (for `N=64`, this is `2N`), +- but their own `FIXME` comment indicates this is a conservative engineering bound, not a tight proof. + +So treat this as a robust heuristic rather than a formal minimum. + +## Practical implication for Hachi + +- If we stay with `D=64` and small i16-ish primes, we need non-Solinas primes in the CRT set. +- If we push to `D=1024`, we must either: + - lift prime size beyond `<2^14`, or + - change CRT strategy (fewer larger limbs / different backend), or + - avoid strict small-prime CRT-NTT at that degree. +- A mixed backend model is sensible: + - keep the current `i16` backend for small-prime kernels, + - add an `i32`/wider backend for larger-prime kernels (e.g., up to ~30-bit). diff --git a/README.md b/README.md index 492b4d12..8036344f 100644 --- a/README.md +++ b/README.md @@ -3,3 +3,7 @@ A high performance and modular implementation of the Hachi polynomial commitment scheme. Hachi is a lattice-based polynomial commitment scheme with transparent setup and post-quantum security. + +## Acknowledgements + +The CRT/NTT and small-prime arithmetic design in this repository is informed by the Labrador/Greyhound C implementation family. In particular, the current pseudo-Mersenne profile uses moduli of the form `q = 2^k - offset` (smallest prime below `2^k` with `q % 8 == 5`). Hachi provides a Rust-native architecture and APIs, while drawing algorithmic inspiration from those implementations. diff --git a/benches/field_arith.rs b/benches/field_arith.rs new file mode 100644 index 00000000..e5f07f25 --- /dev/null +++ b/benches/field_arith.rs @@ -0,0 +1,1448 @@ +#![allow(missing_docs)] + +use ark_bn254::Fr as BN254Fr; +use ark_ff::{AdditiveGroup, Field}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use hachi_pcs::algebra::fields::fp128::{Prime128M18M0, Prime128M54P0}; +use hachi_pcs::algebra::fields::fp32::Fp32; +use hachi_pcs::algebra::{HasPacking, PackedField, PackedValue, Prime128M13M4P0, Prime128M8M4M1M0}; +use hachi_pcs::algebra::{ + Pow2Offset24Field, Pow2Offset30Field, Pow2Offset31Field, Pow2Offset32Field, Pow2Offset40Field, + Pow2Offset48Field, Pow2Offset56Field, Pow2Offset64Field, +}; +use hachi_pcs::{CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible}; +use rand::{rngs::StdRng, RngCore, SeedableRng}; +use std::env; +#[cfg(feature = "parallel")] +use std::thread; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; +#[cfg(feature = "parallel")] +use rayon::ThreadPoolBuilder; + +fn rand_u128(rng: &mut R) -> u128 { + let lo = rng.next_u64() as u128; + let hi = rng.next_u64() as u128; + lo | (hi << 64) +} + +fn env_usize(name: &str, default: usize) -> usize { + env::var(name) + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(default) +} + +fn bench_mul(c: &mut Criterion) { + type F13 = Prime128M13M4P0; + type F275 = Prime128M8M4M1M0; + type F2p18p1 = Prime128M18M0; + type F2p54m1 = Prime128M54P0; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let inputs_u128: Vec = (0..2048).map(|_| rand_u128(&mut rng)).collect(); + + let inputs_f13: Vec = inputs_u128 + .iter() + .copied() + .map(F13::from_canonical_u128_reduced) + .collect(); + + let inputs_f275: Vec = inputs_u128 + .iter() + .copied() + .map(F275::from_canonical_u128_reduced) + .collect(); + let inputs_f2p18p1: Vec = inputs_u128 + .iter() + .copied() + .map(F2p18p1::from_canonical_u128_reduced) + .collect(); + let inputs_f2p54m1: Vec = inputs_u128 + .iter() + .copied() + .map(F2p54m1::from_canonical_u128_reduced) + .collect(); + + let mut group = c.benchmark_group("field_mul"); + + group.bench_function("fp128_prime128m13m4p0", |b| { + b.iter(|| { + let mut acc = F13::one(); + for x in inputs_f13.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + + group.bench_function("fp128_prime128m8m4m1m0", |b| { + b.iter(|| { + let mut acc = F275::one(); + for x in inputs_f275.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + + group.bench_function("fp128_prime128m18m0_shift_special", |b| { + b.iter(|| { + let mut acc = F2p18p1::one(); + for x in inputs_f2p18p1.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + + group.bench_function("fp128_prime128m54p0_shift_special", |b| { + b.iter(|| { + let mut acc = F2p54m1::one(); + for x in inputs_f2p54m1.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + + group.finish(); +} + +fn bench_mul_only(c: &mut Criterion) { + type F13 = Prime128M13M4P0; + type F2p18p1 = Prime128M18M0; + type F2p54m1 = Prime128M54P0; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let inputs_f13: Vec = (0..2048) + .map(|_| F13::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let inputs_f2p18p1: Vec = (0..2048) + .map(|_| F2p18p1::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let inputs_f2p54m1: Vec = (0..2048) + .map(|_| F2p54m1::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let mut group = c.benchmark_group("field_mul_only"); + + group.bench_function("mul_chain_2048", |b| { + b.iter(|| { + let mut acc = F13::one(); + for x in inputs_f13.iter() { + acc *= *x; + } + black_box(acc) + }) + }); + + group.bench_function("mul_chain_16384", |b| { + b.iter(|| { + let mut acc = F13::one(); + for _ in 0..8 { + for x in inputs_f13.iter() { + acc *= *x; + } + } + black_box(acc) + }) + }); + + group.bench_function("mul_parallel_1024", |b| { + b.iter(|| { + let mut sum = F13::zero(); + for pair in inputs_f13.chunks_exact(2) { + sum += pair[0] * pair[1]; + } + black_box(sum) + }) + }); + + group.bench_function("mul_chain_2048_special_m18m0", |b| { + b.iter(|| { + let mut acc = F2p18p1::one(); + for x in inputs_f2p18p1.iter() { + acc *= *x; + } + black_box(acc) + }) + }); + + group.bench_function("mul_chain_2048_special_m54p0", |b| { + b.iter(|| { + let mut acc = F2p54m1::one(); + for x in inputs_f2p54m1.iter() { + acc *= *x; + } + black_box(acc) + }) + }); + + group.finish(); +} + +fn bench_mul_isolated(c: &mut Criterion) { + use ark_ff::UniformRand; + + type F13 = Prime128M13M4P0; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let a_fp128 = F13::from_canonical_u128_reduced(rand_u128(&mut rng)); + let b_fp128 = F13::from_canonical_u128_reduced(rand_u128(&mut rng)); + let a_bn254 = BN254Fr::rand(&mut rng); + let b_bn254 = BN254Fr::rand(&mut rng); + + let mut group = c.benchmark_group("field_mul_isolated"); + + group.bench_function("fp128_black_box_only", |b| b.iter(|| black_box(a_fp128))); + + group.bench_function("bn254_black_box_only", |b| b.iter(|| black_box(a_bn254))); + + group.bench_function("fp128_pair_passthrough", |b| { + b.iter(|| { + let x = black_box(a_fp128); + let y = black_box(b_fp128); + black_box((x, y)) + }) + }); + + group.bench_function("bn254_pair_passthrough", |b| { + b.iter(|| { + let x = black_box(a_bn254); + let y = black_box(b_bn254); + black_box((x, y)) + }) + }); + + group.bench_function("fp128_mul_single", |b| { + b.iter(|| { + let x = black_box(a_fp128); + let y = black_box(b_fp128); + black_box(x * y) + }) + }); + + group.bench_function("bn254_mul_single", |b| { + b.iter(|| { + let x = black_box(a_bn254); + let y = black_box(b_bn254); + black_box(x * y) + }) + }); + + let lanes_fp128: [(F13, F13); 8] = std::array::from_fn(|_| { + ( + F13::from_canonical_u128_reduced(rand_u128(&mut rng)), + F13::from_canonical_u128_reduced(rand_u128(&mut rng)), + ) + }); + let lanes_bn254: [(BN254Fr, BN254Fr); 8] = + std::array::from_fn(|_| (BN254Fr::rand(&mut rng), BN254Fr::rand(&mut rng))); + + group.bench_function("fp128_mul_8way_independent", |b| { + b.iter(|| { + let lanes = black_box(&lanes_fp128); + let p0 = lanes[0].0 * lanes[0].1; + let p1 = lanes[1].0 * lanes[1].1; + let p2 = lanes[2].0 * lanes[2].1; + let p3 = lanes[3].0 * lanes[3].1; + let p4 = lanes[4].0 * lanes[4].1; + let p5 = lanes[5].0 * lanes[5].1; + let p6 = lanes[6].0 * lanes[6].1; + let p7 = lanes[7].0 * lanes[7].1; + black_box([p0, p1, p2, p3, p4, p5, p6, p7]) + }) + }); + + group.bench_function("fp128_8way_passthrough", |b| { + b.iter(|| { + let lanes = black_box(&lanes_fp128); + let p0 = lanes[0].0; + let p1 = lanes[1].0; + let p2 = lanes[2].0; + let p3 = lanes[3].0; + let p4 = lanes[4].0; + let p5 = lanes[5].0; + let p6 = lanes[6].0; + let p7 = lanes[7].0; + black_box([p0, p1, p2, p3, p4, p5, p6, p7]) + }) + }); + + group.bench_function("bn254_mul_8way_independent", |b| { + b.iter(|| { + let lanes = black_box(&lanes_bn254); + let p0 = lanes[0].0 * lanes[0].1; + let p1 = lanes[1].0 * lanes[1].1; + let p2 = lanes[2].0 * lanes[2].1; + let p3 = lanes[3].0 * lanes[3].1; + let p4 = lanes[4].0 * lanes[4].1; + let p5 = lanes[5].0 * lanes[5].1; + let p6 = lanes[6].0 * lanes[6].1; + let p7 = lanes[7].0 * lanes[7].1; + black_box([p0, p1, p2, p3, p4, p5, p6, p7]) + }) + }); + + group.bench_function("bn254_8way_passthrough", |b| { + b.iter(|| { + let lanes = black_box(&lanes_bn254); + let p0 = lanes[0].0; + let p1 = lanes[1].0; + let p2 = lanes[2].0; + let p3 = lanes[3].0; + let p4 = lanes[4].0; + let p5 = lanes[5].0; + let p6 = lanes[6].0; + let p7 = lanes[7].0; + black_box([p0, p1, p2, p3, p4, p5, p6, p7]) + }) + }); + + group.finish(); +} + +fn bench_sqr(c: &mut Criterion) { + type F13 = Prime128M13M4P0; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let start = F13::from_canonical_u128_reduced(rand_u128(&mut rng)); + + let mut group = c.benchmark_group("field_sqr"); + + group.bench_function("sqr_chain_2048", |b| { + b.iter(|| { + let mut acc = start; + for _ in 0..2048 { + acc = acc.square(); + } + black_box(acc) + }) + }); + + group.bench_function("mul_self_chain_2048", |b| { + b.iter(|| { + let mut acc = start; + for _ in 0..2048 { + acc = acc * acc; + } + black_box(acc) + }) + }); + + group.finish(); +} + +fn bench_inv(c: &mut Criterion) { + type F13 = Prime128M13M4P0; + + let mut rng = StdRng::seed_from_u64(0x1a2b_3c4d_5e6f_7788); + let inputs: Vec = (0..256) + .map(|_| F13::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + c.bench_function("fp128_inv_or_zero_prime128m13m4p0", |b| { + b.iter(|| { + let mut acc = F13::one(); + for x in inputs.iter() { + acc *= x.inv_or_zero(); + } + black_box(acc) + }) + }); +} + +fn bench_bn254(c: &mut Criterion) { + use ark_ff::UniformRand; + + let mut rng = StdRng::seed_from_u64(0x5eed); + let inputs: Vec = (0..2048).map(|_| BN254Fr::rand(&mut rng)).collect(); + + let mut group = c.benchmark_group("bn254_fr"); + + group.bench_function("mul_add_chain_2048", |b| { + b.iter(|| { + let mut acc = BN254Fr::ONE; + for x in inputs.iter() { + acc = acc * x + acc; + } + black_box(acc) + }) + }); + + group.bench_function("mul_chain_2048", |b| { + b.iter(|| { + let mut acc = BN254Fr::ONE; + for x in inputs.iter() { + acc *= x; + } + black_box(acc) + }) + }); + + group.bench_function("mul_chain_16384", |b| { + b.iter(|| { + let mut acc = BN254Fr::ONE; + for _ in 0..8 { + for x in inputs.iter() { + acc *= x; + } + } + black_box(acc) + }) + }); + + group.bench_function("mul_parallel_1024", |b| { + b.iter(|| { + let mut sum = BN254Fr::ZERO; + for pair in inputs.chunks_exact(2) { + sum += pair[0] * pair[1]; + } + black_box(sum) + }) + }); + + group.bench_function("sqr_chain_2048", |b| { + b.iter(|| { + let mut acc = inputs[0]; + for _ in 0..2048 { + acc.square_in_place(); + } + black_box(acc) + }) + }); + + group.bench_function("inv_256", |b| { + b.iter(|| { + let mut acc = BN254Fr::ONE; + for x in inputs[..256].iter() { + acc *= x.inverse().unwrap_or(BN254Fr::ZERO); + } + black_box(acc) + }) + }); + + group.finish(); +} + +fn bench_packed_fp128_backend(c: &mut Criterion) { + type F = Prime128M13M4P0; + type PF = ::Packing; + let packed_streams = env_usize("HACHI_BENCH_PACKED_STREAMS", 8); + let latency_iters = env_usize("HACHI_BENCH_LATENCY_ITERS", 4096); + let throughput_iters = env_usize("HACHI_BENCH_THROUGHPUT_ITERS", 256); + let stream_iters = env_usize("HACHI_BENCH_STREAM_ITERS", 2048); + let mix_iters = env_usize("HACHI_BENCH_MIX_ITERS", 256); + let mix_muls = env_usize("HACHI_BENCH_MIX_MULS", 3); + let mix_adds = env_usize("HACHI_BENCH_MIX_ADDS", 1); + let mix_subs = env_usize("HACHI_BENCH_MIX_SUBS", 1); + + assert!(packed_streams > 0, "HACHI_BENCH_PACKED_STREAMS must be > 0"); + assert!(latency_iters > 0, "HACHI_BENCH_LATENCY_ITERS must be > 0"); + assert!( + throughput_iters > 0, + "HACHI_BENCH_THROUGHPUT_ITERS must be > 0" + ); + assert!(stream_iters > 0, "HACHI_BENCH_STREAM_ITERS must be > 0"); + assert!(mix_iters > 0, "HACHI_BENCH_MIX_ITERS must be > 0"); + + let muls_per_stream = throughput_iters + 1; + let mix_ops = mix_muls + mix_adds + mix_subs; + assert!(mix_ops > 0, "at least one mix operation must be enabled"); + + let backend = if cfg!(all(target_arch = "aarch64", target_feature = "neon")) { + "aarch64_neon" + } else { + "scalar_fallback" + }; + let mut group = c.benchmark_group(format!("field_packed_backend/{backend}/w{}", PF::WIDTH)); + + let mut rng = StdRng::seed_from_u64(0xd00d_f00d_1122_3344); + let scalar_stream_len = PF::WIDTH * stream_iters; + let lhs: Vec = (0..scalar_stream_len) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let rhs: Vec = (0..scalar_stream_len) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let packed_lhs: Vec = PF::pack_slice(&lhs); + let packed_rhs: Vec = PF::pack_slice(&rhs); + let scalar_latency_inputs: Vec = (0..latency_iters) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let packed_latency_inputs: Vec = (0..latency_iters) + .map(|_| PF::from_fn(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng)))) + .collect(); + + let scalar_streams = packed_streams * PF::WIDTH; + let scalar_lanes: Vec<(F, F)> = (0..scalar_streams) + .map(|_| { + ( + F::from_canonical_u128_reduced(rand_u128(&mut rng)), + F::from_canonical_u128_reduced(rand_u128(&mut rng)), + ) + }) + .collect(); + let packed_lanes: Vec<(PF, PF)> = (0..packed_streams) + .map(|_| { + ( + PF::from_fn(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))), + PF::from_fn(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))), + ) + }) + .collect(); + + group.throughput(Throughput::Elements(scalar_stream_len as u64)); + group.bench_function("scalar_add_stream", |b| { + let mut out = lhs.clone(); + b.iter(|| { + for (dst, src) in out.iter_mut().zip(rhs.iter()) { + *dst += *src; + } + black_box(out[0]) + }) + }); + + group.throughput(Throughput::Elements(scalar_stream_len as u64)); + group.bench_function("packed_add_stream", |b| { + let mut out = packed_lhs.clone(); + b.iter(|| { + for (dst, src) in out.iter_mut().zip(packed_rhs.iter()) { + *dst += *src; + } + black_box(out[0].extract(0)) + }) + }); + + group.throughput(Throughput::Elements(latency_iters as u64)); + group.bench_function("scalar_mul_latency_chain", |b| { + b.iter(|| { + let mut acc = F::one(); + for x in scalar_latency_inputs.iter() { + acc *= *x; + } + black_box(acc) + }) + }); + + group.throughput(Throughput::Elements((latency_iters * PF::WIDTH) as u64)); + group.bench_function("packed_mul_latency_chain", |b| { + b.iter(|| { + let mut acc = PF::broadcast(F::one()); + for x in packed_latency_inputs.iter() { + acc *= *x; + } + black_box(acc.extract(0)) + }) + }); + + group.throughput(Throughput::Elements( + (scalar_streams * muls_per_stream) as u64, + )); + group.bench_function("scalar_mul_throughput_8way", |b| { + b.iter(|| { + let lanes = black_box(&scalar_lanes); + let mut acc: Vec = lanes.iter().map(|(a, b)| *a * *b).collect(); + for _ in 0..throughput_iters { + for (acc_i, lane) in acc.iter_mut().zip(lanes.iter()) { + *acc_i *= lane.0; + } + } + black_box(acc[0]) + }) + }); + + group.throughput(Throughput::Elements( + (packed_streams * muls_per_stream * PF::WIDTH) as u64, + )); + group.bench_function("packed_mul_throughput_8way", |b| { + b.iter(|| { + let lanes = black_box(&packed_lanes); + let mut acc: Vec = lanes.iter().map(|(a, b)| *a * *b).collect(); + for _ in 0..throughput_iters { + for (acc_i, lane) in acc.iter_mut().zip(lanes.iter()) { + *acc_i *= lane.0; + } + } + black_box(acc[0].extract(0)) + }) + }); + + group.throughput(Throughput::Elements( + (scalar_streams * mix_iters * mix_ops) as u64, + )); + group.bench_function("scalar_mix_sumcheck_like", |b| { + b.iter(|| { + let lanes = black_box(&scalar_lanes); + let mut acc: Vec = lanes.iter().map(|(a, b)| *a + *b).collect(); + for _ in 0..mix_iters { + for (acc_i, lane) in acc.iter_mut().zip(lanes.iter()) { + let (x, y) = *lane; + for _ in 0..mix_muls { + *acc_i *= x; + } + for _ in 0..mix_adds { + *acc_i += y; + } + for _ in 0..mix_subs { + *acc_i -= x; + } + } + } + black_box(acc[0]) + }) + }); + + group.throughput(Throughput::Elements( + (packed_streams * PF::WIDTH * mix_iters * mix_ops) as u64, + )); + group.bench_function("packed_mix_sumcheck_like", |b| { + b.iter(|| { + let lanes = black_box(&packed_lanes); + let mut acc: Vec = lanes.iter().map(|(a, b)| *a + *b).collect(); + for _ in 0..mix_iters { + for (acc_i, lane) in acc.iter_mut().zip(lanes.iter()) { + let (x, y) = *lane; + for _ in 0..mix_muls { + *acc_i *= x; + } + for _ in 0..mix_adds { + *acc_i += y; + } + for _ in 0..mix_subs { + *acc_i -= x; + } + } + } + black_box(acc[0].extract(0)) + }) + }); + + group.finish(); +} + +fn bench_fp32_fp64_mul(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(0x3264_3264); + let n = 2048; + + let inputs_24: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_30: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_31: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_32: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_40: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let inputs_64: Vec = + (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let mut group = c.benchmark_group("fp32_fp64_mul"); + + macro_rules! chain_bench { + ($name:expr, $ty:ty, $inputs:expr) => { + group.bench_function(concat!($name, "_mul_chain_2048"), |b| { + b.iter(|| { + let mut acc = <$ty>::one(); + for x in $inputs.iter() { + acc *= *x; + } + black_box(acc) + }) + }); + group.bench_function(concat!($name, "_mul_add_chain_2048"), |b| { + b.iter(|| { + let mut acc = <$ty>::one(); + for x in $inputs.iter() { + acc = acc * *x + acc; + } + black_box(acc) + }) + }); + }; + } + + chain_bench!("fp32_2pow24m3", Pow2Offset24Field, inputs_24); + chain_bench!("fp32_2pow30m35", Pow2Offset30Field, inputs_30); + chain_bench!("fp32_2pow31m19", Pow2Offset31Field, inputs_31); + chain_bench!("fp32_2pow32m99", Pow2Offset32Field, inputs_32); + chain_bench!("fp64_2pow40m195", Pow2Offset40Field, inputs_40); + chain_bench!("fp64_2pow64m59", Pow2Offset64Field, inputs_64); + + group.finish(); +} + +fn bench_widening_ops(c: &mut Criterion) { + type F = Prime128M8M4M1M0; + + let mut rng = StdRng::seed_from_u64(0x01de_be0c_0001); + let a = F::from_canonical_u128_reduced(rand_u128(&mut rng)); + let b = F::from_canonical_u128_reduced(rand_u128(&mut rng)); + let b_u64 = rng.next_u64(); + + let mut group = c.benchmark_group("widening_ops"); + + group.bench_function("mul_wide_u64_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_u64(black_box(b_u64)))) + }); + + group.bench_function("mul_wide_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide(black_box(b)))) + }); + + let limbs3 = [rng.next_u64(), rng.next_u64(), rng.next_u64()]; + let limbs4 = [ + rng.next_u64(), + rng.next_u64(), + rng.next_u64(), + rng.next_u64(), + ]; + + group.bench_function("mul_wide_limbs_3_to_5_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_limbs::<3, 5>(black_box(limbs3)))) + }); + group.bench_function("mul_wide_limbs_3_to_4_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_limbs::<3, 4>(black_box(limbs3)))) + }); + group.bench_function("mul_wide_limbs_4_to_5_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_limbs::<4, 5>(black_box(limbs4)))) + }); + group.bench_function("mul_wide_limbs_4_to_4_only", |bench| { + bench.iter(|| black_box(black_box(a).mul_wide_limbs::<4, 4>(black_box(limbs4)))) + }); + + group.bench_function("full_mul_u64_reduce", |bench| { + bench.iter(|| black_box(black_box(a) * F::from_u64(black_box(b_u64)))) + }); + + group.bench_function("full_mul_reduce", |bench| { + bench.iter(|| black_box(black_box(a) * black_box(b))) + }); + + let wide3 = a.mul_wide_u64(b_u64); + let wide4 = a.mul_wide(b); + let wide5 = { + let mut l = [0u64; 5]; + l[..3].copy_from_slice(&wide3); + l[4] = rng.next_u64() & 0xFF; + l + }; + + group.bench_function("solinas_reduce_3_limbs", |bench| { + bench.iter(|| black_box(F::solinas_reduce(black_box(&wide3)))) + }); + + group.bench_function("solinas_reduce_4_limbs", |bench| { + bench.iter(|| black_box(F::solinas_reduce(black_box(&wide4)))) + }); + + group.bench_function("solinas_reduce_5_limbs", |bench| { + bench.iter(|| black_box(F::solinas_reduce(black_box(&wide5)))) + }); + + group.bench_function("mul_wide_u64_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let y = black_box(b_u64); + black_box(F::solinas_reduce(&x.mul_wide_u64(y))) + }) + }); + + group.bench_function("mul_wide_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let y = black_box(b); + black_box(F::solinas_reduce(&x.mul_wide(y))) + }) + }); + + group.bench_function("mul_wide_limbs_3_to_5_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let m = black_box(limbs3); + black_box(F::solinas_reduce(&x.mul_wide_limbs::<3, 5>(m))) + }) + }); + group.bench_function("mul_wide_limbs_3_to_4_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let m = black_box(limbs3); + black_box(F::solinas_reduce(&x.mul_wide_limbs::<3, 4>(m))) + }) + }); + group.bench_function("mul_wide_limbs_4_to_5_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let m = black_box(limbs4); + black_box(F::solinas_reduce(&x.mul_wide_limbs::<4, 5>(m))) + }) + }); + group.bench_function("mul_wide_limbs_4_to_4_roundtrip", |bench| { + bench.iter(|| { + let x = black_box(a); + let m = black_box(limbs4); + black_box(F::solinas_reduce(&x.mul_wide_limbs::<4, 4>(m))) + }) + }); + + group.finish(); +} + +fn bench_accumulator_pattern(c: &mut Criterion) { + type F = Prime128M8M4M1M0; + + let mut rng = StdRng::seed_from_u64(0xacc0_1a70_0002); + let inputs_a: Vec = (0..256) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let inputs_b_u64: Vec = (0..256).map(|_| rng.next_u64()).collect(); + let inputs_b_f: Vec = (0..256) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let mut group = c.benchmark_group("accumulator_pattern"); + + for &n in &[16, 64, 256] { + group.bench_function(format!("eager_mul_u64_{n}"), |bench| { + bench.iter(|| { + let a_s = black_box(&inputs_a[..n]); + let b_s = black_box(&inputs_b_u64[..n]); + let mut acc = F::zero(); + for i in 0..n { + acc += a_s[i] * F::from_u64(b_s[i]); + } + black_box(acc) + }) + }); + + group.bench_function(format!("widening_accum_u64_{n}"), |bench| { + bench.iter(|| { + let a_s = black_box(&inputs_a[..n]); + let b_s = black_box(&inputs_b_u64[..n]); + let mut acc = [0u64; 5]; + for i in 0..n { + let wide = a_s[i].mul_wide_u64(b_s[i]); + let mut carry: u64 = 0; + for j in 0..3 { + let sum = acc[j] as u128 + wide[j] as u128 + carry as u128; + acc[j] = sum as u64; + carry = (sum >> 64) as u64; + } + for item in &mut acc[3..5] { + let sum = *item as u128 + carry as u128; + *item = sum as u64; + carry = (sum >> 64) as u64; + } + } + black_box(F::solinas_reduce(&acc)) + }) + }); + + group.bench_function(format!("eager_mul_full_{n}"), |bench| { + bench.iter(|| { + let a_s = black_box(&inputs_a[..n]); + let b_s = black_box(&inputs_b_f[..n]); + let mut acc = F::zero(); + for i in 0..n { + acc += a_s[i] * b_s[i]; + } + black_box(acc) + }) + }); + + group.bench_function(format!("widening_accum_full_{n}"), |bench| { + bench.iter(|| { + let a_s = black_box(&inputs_a[..n]); + let b_s = black_box(&inputs_b_f[..n]); + let mut acc = [0u64; 6]; + for i in 0..n { + let wide = a_s[i].mul_wide(b_s[i]); + let mut carry: u64 = 0; + for j in 0..4 { + let sum = acc[j] as u128 + wide[j] as u128 + carry as u128; + acc[j] = sum as u64; + carry = (sum >> 64) as u64; + } + for item in &mut acc[4..6] { + let sum = *item as u128 + carry as u128; + *item = sum as u64; + carry = (sum >> 64) as u64; + } + } + black_box(F::solinas_reduce(&acc)) + }) + }); + } + + group.finish(); +} + +fn bench_throughput(c: &mut Criterion) { + let n = 4096u64; + let mut rng = StdRng::seed_from_u64(0xdead_cafe); + + type M31 = Fp32<{ (1u32 << 31) - 1 }>; + + let a24: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b24: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a30: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b30: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let am31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let bm31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a32: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b32: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a40: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b40: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a48: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b48: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a56: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b56: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a64: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b64: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let a128: Vec = (0..n) + .map(|_| Prime128M8M4M1M0::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let b128: Vec = (0..n) + .map(|_| Prime128M8M4M1M0::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let mut out24 = vec![Pow2Offset24Field::zero(); n as usize]; + let mut out30 = vec![Pow2Offset30Field::zero(); n as usize]; + let mut out31 = vec![Pow2Offset31Field::zero(); n as usize]; + let mut outm31 = vec![M31::zero(); n as usize]; + let mut out32 = vec![Pow2Offset32Field::zero(); n as usize]; + let mut out40 = vec![Pow2Offset40Field::zero(); n as usize]; + let mut out48 = vec![Pow2Offset48Field::zero(); n as usize]; + let mut out56 = vec![Pow2Offset56Field::zero(); n as usize]; + let mut out64 = vec![Pow2Offset64Field::zero(); n as usize]; + let mut out128 = vec![Prime128M8M4M1M0::zero(); n as usize]; + + let mut group = c.benchmark_group("throughput"); + group.throughput(Throughput::Elements(n)); + + macro_rules! bench_op { + ($name:expr, $a:expr, $b:expr, $out:expr, $op:tt) => { + group.bench_function($name, |bench| { + bench.iter(|| { + let a = black_box(&$a); + let b = black_box(&$b); + let out = &mut $out; + for i in 0..n as usize { + out[i] = a[i] $op b[i]; + } + }) + }); + }; + } + + bench_op!("fp32_24b_mul", a24, b24, out24, *); + bench_op!("fp32_24b_add", a24, b24, out24, +); + bench_op!("fp32_30b_mul", a30, b30, out30, *); + bench_op!("fp32_30b_add", a30, b30, out30, +); + bench_op!("fp32_31b_mul", a31, b31, out31, *); + bench_op!("fp32_31b_add", a31, b31, out31, +); + bench_op!("fp32_m31_mul", am31, bm31, outm31, *); + bench_op!("fp32_m31_add", am31, bm31, outm31, +); + bench_op!("fp32_32b_mul", a32, b32, out32, *); + bench_op!("fp32_32b_add", a32, b32, out32, +); + bench_op!("fp64_40b_mul", a40, b40, out40, *); + bench_op!("fp64_40b_add", a40, b40, out40, +); + bench_op!("fp64_48b_mul", a48, b48, out48, *); + bench_op!("fp64_48b_add", a48, b48, out48, +); + bench_op!("fp64_56b_mul", a56, b56, out56, *); + bench_op!("fp64_56b_add", a56, b56, out56, +); + bench_op!("fp64_64b_mul", a64, b64, out64, *); + bench_op!("fp64_64b_add", a64, b64, out64, +); + bench_op!("fp128_mul", a128, b128, out128, *); + bench_op!("fp128_add", a128, b128, out128, +); + + group.finish(); +} + +fn bench_packed_throughput(c: &mut Criterion) { + use hachi_pcs::algebra::{Fp128Packing, Fp32Packing, Fp64Packing}; + + let n = 4096u64; + let mut rng = StdRng::seed_from_u64(0xbeef_cafe); + + macro_rules! packed_bench { + ($group:expr, $label:expr, $field:ty, $packing:ty, $rng:expr, $n:expr) => {{ + let lhs: Vec<$field> = (0..$n).map(|_| FieldSampling::sample($rng)).collect(); + let rhs: Vec<$field> = (0..$n).map(|_| FieldSampling::sample($rng)).collect(); + let lhs_p = <$packing>::pack_slice(&lhs); + let rhs_p = <$packing>::pack_slice(&rhs); + let mut out_p = vec![<$packing>::broadcast(<$field>::zero()); lhs_p.len()]; + + $group.bench_function(concat!($label, "_packed_mul"), |b| { + b.iter(|| { + let a = black_box(&lhs_p); + let b_v = black_box(&rhs_p); + let out = &mut out_p; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + }) + }); + $group.bench_function(concat!($label, "_packed_add"), |b| { + b.iter(|| { + let a = black_box(&lhs_p); + let b_v = black_box(&rhs_p); + let out = &mut out_p; + for i in 0..out.len() { + out[i] = a[i] + b_v[i]; + } + }) + }); + $group.bench_function(concat!($label, "_packed_sub"), |b| { + b.iter(|| { + let a = black_box(&lhs_p); + let b_v = black_box(&rhs_p); + let out = &mut out_p; + for i in 0..out.len() { + out[i] = a[i] - b_v[i]; + } + }) + }); + }}; + } + + let mut group = c.benchmark_group("packed_throughput"); + group.throughput(Throughput::Elements(n)); + + use hachi_pcs::algebra::fields::pseudo_mersenne::*; + type M31 = Fp32<{ (1u32 << 31) - 1 }>; + + type P24 = Fp32Packing<{ POW2_OFFSET_MODULUS_24 }>; + type P30 = Fp32Packing<{ POW2_OFFSET_MODULUS_30 }>; + type P31 = Fp32Packing<{ POW2_OFFSET_MODULUS_31 }>; + type PM31 = Fp32Packing<{ (1u32 << 31) - 1 }>; + type P32 = Fp32Packing<{ POW2_OFFSET_MODULUS_32 }>; + type P40 = Fp64Packing<{ POW2_OFFSET_MODULUS_40 }>; + type P48 = Fp64Packing<{ POW2_OFFSET_MODULUS_48 }>; + type P56 = Fp64Packing<{ POW2_OFFSET_MODULUS_56 }>; + type P64 = Fp64Packing<{ POW2_OFFSET_MODULUS_64 }>; + type P128 = Fp128Packing<{ POW2_OFFSET_MODULUS_128 }>; + + packed_bench!(group, "fp32_24b", Pow2Offset24Field, P24, &mut rng, n); + packed_bench!(group, "fp32_30b", Pow2Offset30Field, P30, &mut rng, n); + packed_bench!(group, "fp32_31b", Pow2Offset31Field, P31, &mut rng, n); + packed_bench!(group, "fp32_m31", M31, PM31, &mut rng, n); + packed_bench!(group, "fp32_32b", Pow2Offset32Field, P32, &mut rng, n); + packed_bench!(group, "fp64_40b", Pow2Offset40Field, P40, &mut rng, n); + packed_bench!(group, "fp64_48b", Pow2Offset48Field, P48, &mut rng, n); + packed_bench!(group, "fp64_56b", Pow2Offset56Field, P56, &mut rng, n); + packed_bench!(group, "fp64_64b", Pow2Offset64Field, P64, &mut rng, n); + packed_bench!(group, "fp128", Prime128M8M4M1M0, P128, &mut rng, n); + + group.finish(); +} + +#[cfg(feature = "parallel")] +fn bench_parallel_throughput(c: &mut Criterion) { + use hachi_pcs::algebra::{Fp32Packing, Fp64Packing}; + + let profile = env::var("HACHI_BENCH_PAR_PROFILE").unwrap_or_else(|_| "dev".to_string()); + let default_n = match profile.as_str() { + "scale" | "large" => 1 << 20, + "xlarge" => 1 << 22, + _ => 1 << 15, + }; + let n = env_usize("HACHI_BENCH_PAR_N", default_n); + let default_chunk = match profile.as_str() { + "scale" | "large" => 1 << 14, + "xlarge" => 1 << 15, + _ => 1 << 12, + }; + let chunk = env_usize("HACHI_BENCH_PAR_CHUNK", default_chunk); + let threads = env_usize( + "HACHI_BENCH_PAR_THREADS", + thread::available_parallelism() + .map(|v| v.get()) + .unwrap_or(1), + ); + + assert!(threads > 0, "HACHI_BENCH_PAR_THREADS must be > 0"); + assert!(n > 0, "HACHI_BENCH_PAR_N must be > 0"); + assert!(chunk > 0, "HACHI_BENCH_PAR_CHUNK must be > 0"); + assert!(n % 4 == 0, "HACHI_BENCH_PAR_N must be divisible by 4"); + + let pool = ThreadPoolBuilder::new() + .num_threads(threads) + .build() + .expect("failed to build rayon pool"); + + let mut rng = StdRng::seed_from_u64(0xfeed_face); + + let lhs31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let rhs31: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let lhs64: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let rhs64: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let lhs128: Vec = (0..n) + .map(|_| Prime128M13M4P0::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let rhs128: Vec = (0..n) + .map(|_| Prime128M13M4P0::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + type P31 = Fp32Packing<{ hachi_pcs::algebra::fields::pseudo_mersenne::POW2_OFFSET_MODULUS_31 }>; + type P64 = Fp64Packing<{ hachi_pcs::algebra::fields::pseudo_mersenne::POW2_OFFSET_MODULUS_64 }>; + type F128 = Prime128M13M4P0; + type P128 = ::Packing; + let chunk31_p = (chunk / P31::WIDTH).max(1); + let chunk64_p = (chunk / P64::WIDTH).max(1); + let chunk128_p = (chunk / P128::WIDTH).max(1); + + let lhs31_p = P31::pack_slice(&lhs31); + let rhs31_p = P31::pack_slice(&rhs31); + let lhs64_p = P64::pack_slice(&lhs64); + let rhs64_p = P64::pack_slice(&rhs64); + let lhs128_p = P128::pack_slice(&lhs128); + let rhs128_p = P128::pack_slice(&rhs128); + + let mut out31 = vec![Pow2Offset31Field::zero(); n]; + let mut out64 = vec![Pow2Offset64Field::zero(); n]; + let mut out128 = vec![F128::zero(); n]; + let mut out31_p = vec![P31::broadcast(Pow2Offset31Field::zero()); lhs31_p.len()]; + let mut out64_p = vec![P64::broadcast(Pow2Offset64Field::zero()); lhs64_p.len()]; + let mut out128_p = vec![P128::broadcast(F128::zero()); lhs128_p.len()]; + + let mut group = c.benchmark_group(format!( + "parallel_throughput/{profile}/t{threads}/n{n}/c{chunk}" + )); + group.throughput(Throughput::Elements(n as u64)); + + group.bench_function("fp32_31b_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs31); + let b_v = black_box(&rhs31); + let out = &mut out31; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0]) + }) + }); + + group.bench_function("fp32_31b_mul_par_zip", |b| { + b.iter(|| { + let a = black_box(&lhs31); + let b_v = black_box(&rhs31); + let out = &mut out31; + pool.install(|| { + out.par_iter_mut() + .zip(a.par_iter()) + .zip(b_v.par_iter()) + .for_each(|((dst, lhs), rhs)| *dst = *lhs * *rhs); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp32_31b_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs31); + let b_v = black_box(&rhs31); + let out = &mut out31; + pool.install(|| { + out.par_chunks_mut(chunk) + .zip(a.par_chunks(chunk)) + .zip(b_v.par_chunks(chunk)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp32_31b_packed_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs31_p); + let b_v = black_box(&rhs31_p); + let out = &mut out31_p; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp32_31b_packed_mul_par_zip", |b| { + b.iter(|| { + let a = black_box(&lhs31_p); + let b_v = black_box(&rhs31_p); + let out = &mut out31_p; + pool.install(|| { + out.par_iter_mut() + .zip(a.par_iter()) + .zip(b_v.par_iter()) + .for_each(|((dst, lhs), rhs)| *dst = *lhs * *rhs); + }); + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp32_31b_packed_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs31_p); + let b_v = black_box(&rhs31_p); + let out = &mut out31_p; + pool.install(|| { + out.par_chunks_mut(chunk31_p) + .zip(a.par_chunks(chunk31_p)) + .zip(b_v.par_chunks(chunk31_p)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp64_64b_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs64); + let b_v = black_box(&rhs64); + let out = &mut out64; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0]) + }) + }); + + group.bench_function("fp64_64b_mul_par_zip", |b| { + b.iter(|| { + let a = black_box(&lhs64); + let b_v = black_box(&rhs64); + let out = &mut out64; + pool.install(|| { + out.par_iter_mut() + .zip(a.par_iter()) + .zip(b_v.par_iter()) + .for_each(|((dst, lhs), rhs)| *dst = *lhs * *rhs); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp64_64b_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs64); + let b_v = black_box(&rhs64); + let out = &mut out64; + pool.install(|| { + out.par_chunks_mut(chunk) + .zip(a.par_chunks(chunk)) + .zip(b_v.par_chunks(chunk)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp64_64b_packed_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs64_p); + let b_v = black_box(&rhs64_p); + let out = &mut out64_p; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp64_64b_packed_mul_par_zip", |b| { + b.iter(|| { + let a = black_box(&lhs64_p); + let b_v = black_box(&rhs64_p); + let out = &mut out64_p; + pool.install(|| { + out.par_iter_mut() + .zip(a.par_iter()) + .zip(b_v.par_iter()) + .for_each(|((dst, lhs), rhs)| *dst = *lhs * *rhs); + }); + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp64_64b_packed_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs64_p); + let b_v = black_box(&rhs64_p); + let out = &mut out64_p; + pool.install(|| { + out.par_chunks_mut(chunk64_p) + .zip(a.par_chunks(chunk64_p)) + .zip(b_v.par_chunks(chunk64_p)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp128_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs128); + let b_v = black_box(&rhs128); + let out = &mut out128; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0]) + }) + }); + + group.bench_function("fp128_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs128); + let b_v = black_box(&rhs128); + let out = &mut out128; + pool.install(|| { + out.par_chunks_mut(chunk) + .zip(a.par_chunks(chunk)) + .zip(b_v.par_chunks(chunk)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0]) + }) + }); + + group.bench_function("fp128_packed_mul_seq", |b| { + b.iter(|| { + let a = black_box(&lhs128_p); + let b_v = black_box(&rhs128_p); + let out = &mut out128_p; + for i in 0..out.len() { + out[i] = a[i] * b_v[i]; + } + black_box(out[0].extract(0)) + }) + }); + + group.bench_function("fp128_packed_mul_par_chunked", |b| { + b.iter(|| { + let a = black_box(&lhs128_p); + let b_v = black_box(&rhs128_p); + let out = &mut out128_p; + pool.install(|| { + out.par_chunks_mut(chunk128_p) + .zip(a.par_chunks(chunk128_p)) + .zip(b_v.par_chunks(chunk128_p)) + .for_each(|((dst, lhs), rhs)| { + for i in 0..dst.len() { + dst[i] = lhs[i] * rhs[i]; + } + }); + }); + black_box(out[0].extract(0)) + }) + }); + + group.finish(); +} + +#[cfg(not(feature = "parallel"))] +fn bench_parallel_throughput(_: &mut Criterion) {} + +fn bench_packed_sumcheck_mix(c: &mut Criterion) { + use hachi_pcs::algebra::{Fp128Packing, Fp32Packing, Fp64Packing}; + + let n = 4096u64; + let mut rng = StdRng::seed_from_u64(0xface_bead); + + macro_rules! sumcheck_bench { + ($group:expr, $label:expr, $field:ty, $packing:ty, $rng:expr, $n:expr) => {{ + let eq: Vec<$field> = (0..$n).map(|_| FieldSampling::sample($rng)).collect(); + let poly: Vec<$field> = (0..$n).map(|_| FieldSampling::sample($rng)).collect(); + let eq_p = <$packing>::pack_slice(&eq); + let poly_p = <$packing>::pack_slice(&poly); + let mut acc = <$packing>::broadcast(<$field>::zero()); + + $group.bench_function(concat!($label, "_packed_macc"), |b| { + b.iter(|| { + let e = black_box(&eq_p); + let p_v = black_box(&poly_p); + acc = <$packing>::broadcast(<$field>::zero()); + for i in 0..e.len() { + acc += e[i] * p_v[i]; + } + black_box(acc) + }) + }); + }}; + } + + let mut group = c.benchmark_group("packed_sumcheck_mix"); + group.throughput(Throughput::Elements(n)); + + use hachi_pcs::algebra::fields::pseudo_mersenne::*; + type M31 = Fp32<{ (1u32 << 31) - 1 }>; + + type P24 = Fp32Packing<{ POW2_OFFSET_MODULUS_24 }>; + type P30 = Fp32Packing<{ POW2_OFFSET_MODULUS_30 }>; + type P31 = Fp32Packing<{ POW2_OFFSET_MODULUS_31 }>; + type PM31 = Fp32Packing<{ (1u32 << 31) - 1 }>; + type P32 = Fp32Packing<{ POW2_OFFSET_MODULUS_32 }>; + type P40 = Fp64Packing<{ POW2_OFFSET_MODULUS_40 }>; + type P48 = Fp64Packing<{ POW2_OFFSET_MODULUS_48 }>; + type P56 = Fp64Packing<{ POW2_OFFSET_MODULUS_56 }>; + type P64 = Fp64Packing<{ POW2_OFFSET_MODULUS_64 }>; + type P128 = Fp128Packing<{ POW2_OFFSET_MODULUS_128 }>; + + sumcheck_bench!(group, "fp32_24b", Pow2Offset24Field, P24, &mut rng, n); + sumcheck_bench!(group, "fp32_30b", Pow2Offset30Field, P30, &mut rng, n); + sumcheck_bench!(group, "fp32_31b", Pow2Offset31Field, P31, &mut rng, n); + sumcheck_bench!(group, "fp32_m31", M31, PM31, &mut rng, n); + sumcheck_bench!(group, "fp32_32b", Pow2Offset32Field, P32, &mut rng, n); + sumcheck_bench!(group, "fp64_40b", Pow2Offset40Field, P40, &mut rng, n); + sumcheck_bench!(group, "fp64_48b", Pow2Offset48Field, P48, &mut rng, n); + sumcheck_bench!(group, "fp64_56b", Pow2Offset56Field, P56, &mut rng, n); + sumcheck_bench!(group, "fp64_64b", Pow2Offset64Field, P64, &mut rng, n); + sumcheck_bench!(group, "fp128", Prime128M8M4M1M0, P128, &mut rng, n); + + group.finish(); +} + +criterion_group!( + field_arith, + bench_mul, + bench_mul_only, + bench_mul_isolated, + bench_sqr, + bench_inv, + bench_packed_fp128_backend, + bench_bn254, + bench_fp32_fp64_mul, + bench_widening_ops, + bench_accumulator_pattern, + bench_throughput, + bench_packed_throughput, + bench_packed_sumcheck_mix, + bench_parallel_throughput +); +criterion_main!(field_arith); diff --git a/benches/fp64_reduce_probe.rs b/benches/fp64_reduce_probe.rs new file mode 100644 index 00000000..073544a9 --- /dev/null +++ b/benches/fp64_reduce_probe.rs @@ -0,0 +1,118 @@ +#![allow(missing_docs)] + +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; + +const P40: u64 = hachi_pcs::algebra::fields::pseudo_mersenne::POW2_OFFSET_MODULUS_40; +const P64: u64 = hachi_pcs::algebra::fields::pseudo_mersenne::POW2_OFFSET_MODULUS_64; +const C40: u64 = (1u64 << 40) - P40; // 195 +const C64: u64 = 0u64.wrapping_sub(P64); // 59 +const MASK40: u64 = (1u64 << 40) - 1; +const MASK64_U128: u128 = u64::MAX as u128; + +#[inline(always)] +fn mul_c40_split(x: u64) -> u64 { + let c = C40 as u32; + let x_lo = x as u32; + let x_hi = (x >> 32) as u32; + (c as u64 * x_lo as u64).wrapping_add((c as u64 * x_hi as u64) << 32) +} + +#[inline(always)] +fn reduce40_split(lo: u64, hi: u64) -> u64 { + let high = (lo >> 40) | (hi << 24); + let f1 = (lo & MASK40).wrapping_add(mul_c40_split(high)); + let f2 = (f1 & MASK40).wrapping_add(mul_c40_split(f1 >> 40)); + let reduced = f2.wrapping_sub(P40); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P40) +} + +#[inline(always)] +fn reduce40_direct(lo: u64, hi: u64) -> u64 { + let high = (lo >> 40) | (hi << 24); + let f1 = (lo & MASK40).wrapping_add(C40.wrapping_mul(high)); + let f2 = (f1 & MASK40).wrapping_add(C40.wrapping_mul(f1 >> 40)); + let reduced = f2.wrapping_sub(P40); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P40) +} + +#[inline(always)] +fn reduce64(lo: u64, hi: u64) -> u64 { + let f1 = (lo as u128) + (C64 as u128) * (hi as u128); + let f2 = (f1 & MASK64_U128) + (C64 as u128) * ((f1 >> 64) as u64 as u128); + let reduced = f2.wrapping_sub(P64 as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P64 as u128)) as u64 +} + +#[inline(always)] +fn next_u64(state: &mut u64) -> u64 { + let mut x = *state; + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + *state = x; + x +} + +fn bench_fp64_reduce_probe(c: &mut Criterion) { + let n = 1 << 13; + let mut seed = 0x9e37_79b9_7f4a_7c15u64; + + let mut pairs40 = Vec::with_capacity(n); + let mut pairs64 = Vec::with_capacity(n); + for _ in 0..n { + let a40 = next_u64(&mut seed) % P40; + let b40 = next_u64(&mut seed) % P40; + let x40 = (a40 as u128) * (b40 as u128); + pairs40.push((x40 as u64, (x40 >> 64) as u64)); + + let a64 = next_u64(&mut seed); + let b64 = next_u64(&mut seed); + let x64 = (a64 as u128) * (b64 as u128); + pairs64.push((x64 as u64, (x64 >> 64) as u64)); + } + + for &(lo, hi) in &pairs40 { + assert_eq!(reduce40_split(lo, hi), reduce40_direct(lo, hi)); + } + + let mut group = c.benchmark_group("fp64_reduce_probe"); + group.throughput(Throughput::Elements(n as u64)); + + group.bench_function("reduce40_split", |b| { + b.iter(|| { + let mut acc = 0u64; + for &(lo, hi) in black_box(&pairs40) { + acc ^= reduce40_split(lo, hi); + } + black_box(acc) + }) + }); + + group.bench_function("reduce40_direct", |b| { + b.iter(|| { + let mut acc = 0u64; + for &(lo, hi) in black_box(&pairs40) { + acc ^= reduce40_direct(lo, hi); + } + black_box(acc) + }) + }); + + group.bench_function("reduce64", |b| { + b.iter(|| { + let mut acc = 0u64; + for &(lo, hi) in black_box(&pairs64) { + acc ^= reduce64(lo, hi); + } + black_box(acc) + }) + }); + + group.finish(); +} + +criterion_group!(fp64_reduce_probe, bench_fp64_reduce_probe); +criterion_main!(fp64_reduce_probe); diff --git a/benches/hachi_e2e.rs b/benches/hachi_e2e.rs new file mode 100644 index 00000000..601d21c3 --- /dev/null +++ b/benches/hachi_e2e.rs @@ -0,0 +1,419 @@ +#![allow(missing_docs)] + +use criterion::measurement::WallTime; +use criterion::{black_box, criterion_group, BatchSize, BenchmarkGroup, Criterion}; +use hachi_pcs::algebra::poly::multilinear_eval; +use hachi_pcs::algebra::Fp128; +use hachi_pcs::protocol::commitment::{ + Fp128FullCommitmentConfig, Fp128LogBasisCommitmentConfig, Fp128OneHotCommitmentConfig, +}; +use hachi_pcs::protocol::commitment_scheme::HachiCommitmentScheme; +use hachi_pcs::protocol::hachi_poly_ops::{DensePoly, OneHotPoly}; +use hachi_pcs::protocol::transcript::Blake2bTranscript; +use hachi_pcs::protocol::CommitmentConfig; +use hachi_pcs::{BasisMode, CanonicalField, CommitmentScheme, FromSmallInt, Transcript}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::time::Duration; + +type F = Fp128<0xfffffffffffffffffffffffffffffeed>; + +fn make_dense_evals(nv: usize) -> Vec { + let mut rng = StdRng::seed_from_u64(0xdead_beef); + let len = 1usize << nv; + let decomp = Cfg::decomposition(); + if decomp.log_commit_bound >= 128 { + (0..len) + .map(|_| F::from_canonical_u128_reduced(rng.gen::())) + .collect() + } else { + let half_bound = 1i64 << (decomp.log_commit_bound.min(62) - 1); + (0..len) + .map(|_| F::from_i64(rng.gen_range(-half_bound..half_bound))) + .collect() + } +} + +fn random_point(nv: usize) -> Vec { + let mut rng = StdRng::seed_from_u64(0xcafe_babe); + (0..nv) + .map(|_| F::from_canonical_u128_reduced(rng.gen::())) + .collect() +} + +fn configure_group(group: &mut BenchmarkGroup<'_, WallTime>, nv: usize) { + if nv >= 20 { + group.sample_size(10); + group.measurement_time(Duration::from_secs(30)); + } +} + +fn bench_dense_phases( + c: &mut Criterion, + label: &str, + nv: usize, +) { + let layout = Cfg::commitment_layout(nv).expect("benchmark layout"); + let evals = make_dense_evals::(nv); + let poly = DensePoly::::from_field_evals(nv, &evals).unwrap(); + let pt = random_point(nv); + let opening = multilinear_eval(&evals, &pt).unwrap(); + + let mut group = c.benchmark_group(format!("hachi/{label}/nv{nv}")); + configure_group(&mut group, nv); + + group.bench_function("setup", |b| { + b.iter(|| { + black_box( + as CommitmentScheme>::setup_prover(black_box( + nv, + )), + ) + }) + }); + + let setup = as CommitmentScheme>::setup_prover(nv); + + group.bench_function("commit", |b| { + b.iter(|| { + black_box( + as CommitmentScheme>::commit( + black_box(&poly), + black_box(&setup), + black_box(&layout), + ) + .unwrap(), + ) + }) + }); + + let (commitment, hint) = + as CommitmentScheme>::commit(&poly, &setup, &layout) + .unwrap(); + + group.bench_function("prove", |b| { + b.iter_batched( + || hint.clone(), + |h| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + black_box( + as CommitmentScheme>::prove( + &setup, + &poly, + &pt, + h, + &mut transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(), + ) + }, + BatchSize::LargeInput, + ) + }); + + let verifier_setup = + as CommitmentScheme>::setup_verifier(&setup); + let mut prover_transcript = Blake2bTranscript::::new(b"bench"); + let proof = as CommitmentScheme>::prove( + &setup, + &poly, + &pt, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + group.bench_function("verify", |b| { + b.iter(|| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + black_box(&proof), + black_box(&verifier_setup), + &mut transcript, + black_box(&pt), + black_box(&opening), + black_box(&commitment), + BasisMode::Lagrange, + black_box(&layout), + ) + .unwrap(); + }) + }); + + group.bench_function("e2e", |b| { + b.iter(|| { + let (cm, h) = as CommitmentScheme>::commit( + &poly, &setup, &layout, + ) + .unwrap(); + let mut pt_tr = Blake2bTranscript::::new(b"bench"); + let pf = as CommitmentScheme>::prove( + &setup, + &poly, + &pt, + h, + &mut pt_tr, + &cm, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + let mut vt_tr = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + &pf, + &verifier_setup, + &mut vt_tr, + &pt, + &opening, + &cm, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + black_box(()) + }) + }); + + group.finish(); +} + +fn bench_onehot_phases( + c: &mut Criterion, + label: &str, + nv: usize, +) { + let layout = Cfg::commitment_layout(nv).expect("benchmark layout"); + let total_ring = layout.num_blocks * layout.block_len; + let onehot_k = D; + + let mut rng = StdRng::seed_from_u64(0xbeef_cafe); + let indices: Vec> = (0..total_ring) + .map(|_| Some(rng.gen_range(0..onehot_k))) + .collect(); + + let onehot_poly = + OneHotPoly::::new(onehot_k, indices.clone(), layout.r_vars, layout.m_vars).unwrap(); + + let dense_evals: Vec = { + let mut evals = vec![F::from_u64(0); total_ring * onehot_k]; + for (ci, opt_idx) in indices.iter().enumerate() { + if let Some(idx) = opt_idx { + evals[ci * onehot_k + idx] = F::from_u64(1); + } + } + evals + }; + let pt = random_point(nv); + let opening = multilinear_eval(&dense_evals, &pt).unwrap(); + + let setup = as CommitmentScheme>::setup_prover(nv); + + let mut group = c.benchmark_group(format!("hachi/{label}/nv{nv}")); + configure_group(&mut group, nv); + + group.bench_function("commit_onehot", |b| { + b.iter(|| { + black_box( + as CommitmentScheme>::commit( + black_box(&onehot_poly), + black_box(&setup), + black_box(&layout), + ) + .unwrap(), + ) + }) + }); + + let (commitment, hint) = as CommitmentScheme>::commit( + &onehot_poly, + &setup, + &layout, + ) + .unwrap(); + + group.bench_function("prove", |b| { + b.iter_batched( + || hint.clone(), + |h| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + black_box( + as CommitmentScheme>::prove( + &setup, + &onehot_poly, + &pt, + h, + &mut transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(), + ) + }, + BatchSize::LargeInput, + ) + }); + + let verifier_setup = + as CommitmentScheme>::setup_verifier(&setup); + let mut prover_transcript = Blake2bTranscript::::new(b"bench"); + let proof = as CommitmentScheme>::prove( + &setup, + &onehot_poly, + &pt, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + group.bench_function("verify", |b| { + b.iter(|| { + let mut transcript = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + black_box(&proof), + black_box(&verifier_setup), + &mut transcript, + black_box(&pt), + black_box(&opening), + black_box(&commitment), + BasisMode::Lagrange, + black_box(&layout), + ) + .unwrap(); + }) + }); + + group.bench_function("e2e", |b| { + b.iter(|| { + let (cm, h) = as CommitmentScheme>::commit( + &onehot_poly, + &setup, + &layout, + ) + .unwrap(); + let mut pt_tr = Blake2bTranscript::::new(b"bench"); + let pf = as CommitmentScheme>::prove( + &setup, + &onehot_poly, + &pt, + h, + &mut pt_tr, + &cm, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + let mut vt_tr = Blake2bTranscript::::new(b"bench"); + as CommitmentScheme>::verify( + &pf, + &verifier_setup, + &mut vt_tr, + &pt, + &opening, + &cm, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + black_box(()) + }) + }); + + group.finish(); +} + +fn bench_full_nv15(c: &mut Criterion) { + bench_dense_phases::<{ Fp128FullCommitmentConfig::D }, Fp128FullCommitmentConfig>( + c, "full", 15, + ); +} +fn bench_full_nv20(c: &mut Criterion) { + bench_dense_phases::<{ Fp128FullCommitmentConfig::D }, Fp128FullCommitmentConfig>( + c, "full", 20, + ); +} +fn bench_full_nv25(c: &mut Criterion) { + bench_dense_phases::<{ Fp128FullCommitmentConfig::D }, Fp128FullCommitmentConfig>( + c, "full", 25, + ); +} + +fn bench_onehot_nv15(c: &mut Criterion) { + bench_onehot_phases::<{ Fp128OneHotCommitmentConfig::D }, Fp128OneHotCommitmentConfig>( + c, "onehot", 15, + ); +} +fn bench_onehot_nv20(c: &mut Criterion) { + bench_onehot_phases::<{ Fp128OneHotCommitmentConfig::D }, Fp128OneHotCommitmentConfig>( + c, "onehot", 20, + ); +} +fn bench_onehot_nv25(c: &mut Criterion) { + bench_onehot_phases::<{ Fp128OneHotCommitmentConfig::D }, Fp128OneHotCommitmentConfig>( + c, "onehot", 25, + ); +} + +fn bench_logbasis_nv15(c: &mut Criterion) { + bench_dense_phases::<{ Fp128LogBasisCommitmentConfig::D }, Fp128LogBasisCommitmentConfig>( + c, "logbasis", 15, + ); +} +fn bench_logbasis_nv20(c: &mut Criterion) { + bench_dense_phases::<{ Fp128LogBasisCommitmentConfig::D }, Fp128LogBasisCommitmentConfig>( + c, "logbasis", 20, + ); +} +fn bench_logbasis_nv25(c: &mut Criterion) { + bench_dense_phases::<{ Fp128LogBasisCommitmentConfig::D }, Fp128LogBasisCommitmentConfig>( + c, "logbasis", 25, + ); +} + +criterion_group!( + hachi_benches, + bench_full_nv15, + bench_full_nv20, + bench_full_nv25, + bench_onehot_nv15, + bench_onehot_nv20, + bench_onehot_nv25, + bench_logbasis_nv15, + bench_logbasis_nv20, + bench_logbasis_nv25, +); + +/// Set `HACHI_PARALLEL=0` to run benchmarks single-threaded. +fn main() { + #[cfg(feature = "parallel")] + { + let num_threads = if std::env::var("HACHI_PARALLEL") + .map(|v| v == "0") + .unwrap_or(false) + { + tracing::info!("HACHI_PARALLEL=0: running single-threaded"); + 1 + } else { + 0 + }; + rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .stack_size(64 * 1024 * 1024) + .build_global() + .ok(); + } + + hachi_benches(); + criterion::Criterion::default() + .configure_from_args() + .final_summary(); +} diff --git a/benches/labrador_jl_aggregation.rs b/benches/labrador_jl_aggregation.rs new file mode 100644 index 00000000..fb041e88 --- /dev/null +++ b/benches/labrador_jl_aggregation.rs @@ -0,0 +1,64 @@ +#![allow(missing_docs)] + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use hachi_pcs::algebra::fields::Prime128M13M4P0; +use hachi_pcs::algebra::{Pow2Offset32Field, Pow2Offset64Field}; +use hachi_pcs::protocol::labrador::aggregation::aggregate_jl_contraints_one_lift; +use hachi_pcs::protocol::labrador::LabradorJlMatrix; +use hachi_pcs::protocol::transcript::{labels, Blake2bTranscript}; +use hachi_pcs::{CanonicalField, FieldCore, Transcript}; + +const D: usize = 64; +// Observed in realistic profile runs for NV=25 full mode. +const BENCH_COLS: usize = 4_128_768; + +fn sample_omega_from_transcript(transcript: &mut T) -> [F; 256] +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + std::array::from_fn(|_| transcript.challenge_scalar(labels::CHALLENGE_LABRADOR_JL_COLLAPSE)) +} + +fn bench_aggregate_jl_contraints_one_lift_for_field( + c: &mut Criterion, + field_name: &str, +) { + let cols = BENCH_COLS; + let mut transcript = Blake2bTranscript::::new(b"bench/labrador-jl-aggregation"); + let matrix = LabradorJlMatrix::generate::(&mut transcript, cols).unwrap(); + let omega = sample_omega_from_transcript::(&mut transcript); + c.bench_function( + &format!("labrador/aggregate_jl_contraints_one_lift/{field_name}"), + |b| { + b.iter(|| { + let got = + aggregate_jl_contraints_one_lift::(black_box(&matrix), black_box(&omega)) + .unwrap(); + black_box(got); + }) + }, + ); +} + +fn bench_aggregate_jl_contraints_one_lift_fp32(c: &mut Criterion) { + bench_aggregate_jl_contraints_one_lift_for_field::(c, "fp32"); +} + +fn bench_aggregate_jl_contraints_one_lift_fp64(c: &mut Criterion) { + type F64 = Pow2Offset64Field; + bench_aggregate_jl_contraints_one_lift_for_field::(c, "fp64"); +} + +fn bench_aggregate_jl_contraints_one_lift_fp128(c: &mut Criterion) { + type F128 = Prime128M13M4P0; + bench_aggregate_jl_contraints_one_lift_for_field::(c, "fp128"); +} + +criterion_group!( + labrador_jl_aggregation, + bench_aggregate_jl_contraints_one_lift_fp32, + bench_aggregate_jl_contraints_one_lift_fp64, + bench_aggregate_jl_contraints_one_lift_fp128 +); +criterion_main!(labrador_jl_aggregation); diff --git a/benches/ring_ntt.rs b/benches/ring_ntt.rs new file mode 100644 index 00000000..891b3f05 --- /dev/null +++ b/benches/ring_ntt.rs @@ -0,0 +1,699 @@ +#![allow(missing_docs)] + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use hachi_pcs::algebra::ntt::butterfly::{forward_ntt, inverse_ntt, NttTwiddles}; +use hachi_pcs::algebra::tables::{ + q128_primes, q32_garner, Q128_NUM_PRIMES, Q32_MODULUS, Q32_NUM_PRIMES, Q32_PRIMES, +}; +use hachi_pcs::algebra::{ + CrtNttParamSet, CyclotomicCrtNtt, CyclotomicRing, Fp64, HasPacking, MontCoeff, + PackedPartialSplitEval32, PartialSplitEval32, PartialSplitNtt32, Prime128Offset5823, +}; +use hachi_pcs::{FieldCore, FromSmallInt}; + +type F = Fp64<{ Q32_MODULUS }>; +type R = CyclotomicRing; +type N = CyclotomicCrtNtt; +type F128 = Prime128Offset5823; +type R128 = CyclotomicRing; +type N128 = CyclotomicCrtNtt; +type PF128 = ::Packing; +const CACHE_MAT_ROWS: usize = 8; +const CACHE_MAT_COLS: usize = 16; +const MUL_BATCH_FACTORS: [usize; 3] = [1, 4, 16]; + +fn sample_ring(seed: u64) -> R { + let coeffs = std::array::from_fn(|i| { + let x = seed + .wrapping_mul(31) + .wrapping_add((i as u64).wrapping_mul(17)); + F::from_u64(x % Q32_MODULUS) + }); + R::from_coefficients(coeffs) +} + +fn sample_ring_q128m5823(seed: u64) -> R128 { + let coeffs = std::array::from_fn(|i| { + let x = seed + .wrapping_mul(29) + .wrapping_add((i as u64).wrapping_mul(13)); + let centered = (x % 257) as i64 - 128; + F128::from_i64(centered) + }); + R128::from_coefficients(coeffs) +} + +fn sample_centered_i8(seed: u64) -> [i8; 64] { + std::array::from_fn(|i| { + let x = seed + .wrapping_mul(43) + .wrapping_add((i as u64).wrapping_mul(17)); + ((x % 256) as i16 - 128) as i8 + }) +} + +fn sample_ring_q128m5823_tag(seed: u64, tag: u64) -> R128 { + sample_ring_q128m5823(seed.wrapping_mul(131).wrapping_add(tag)) +} + +fn pack_split_batch(batch: &[PartialSplitEval32]) -> Vec> { + let width = PackedPartialSplitEval32::::WIDTH; + debug_assert_eq!(batch.len() % width, 0); + batch + .chunks_exact(width) + .map(|chunk| PackedPartialSplitEval32::::from_fn(|lane| chunk[lane])) + .collect() +} + +fn bench_ring_schoolbook_mul(c: &mut Criterion) { + let lhs = sample_ring(3); + let rhs = sample_ring(11); + c.bench_function("ring_schoolbook_mul_d64", |b| { + b.iter(|| black_box(lhs) * black_box(rhs)) + }); +} + +fn bench_ntt_single_prime_round_trip(c: &mut Criterion) { + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + let base: [MontCoeff; 64] = + std::array::from_fn(|i| prime.from_canonical(((i * 5 + 7) as i16) % prime.p)); + + c.bench_function("ntt_single_prime_forward_inverse_d64", |b| { + b.iter(|| { + let mut a = base; + forward_ntt(&mut a, prime, &tw); + inverse_ntt(&mut a, prime, &tw); + black_box(a) + }) + }); +} + +fn bench_crt_round_trip(c: &mut Criterion) { + let ring = sample_ring(19); + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + let garner = q32_garner(); + + c.bench_function("ring_ntt_crt_round_trip_d64_k6", |b| { + b.iter(|| { + let ntt = N::from_ring(black_box(&ring), &Q32_PRIMES, &twiddles); + let back: R = ntt.to_ring(&Q32_PRIMES, &twiddles, &garner); + black_box(back) + }) + }); +} + +fn bench_ring_schoolbook_mul_q128m5823(c: &mut Criterion) { + let lhs = sample_ring_q128m5823(23); + let rhs = sample_ring_q128m5823(41); + c.bench_function("ring_schoolbook_mul_d64_q128m5823", |b| { + b.iter(|| black_box(lhs) * black_box(rhs)) + }); +} + +fn bench_partial_split_mul_q128m5823(c: &mut Criterion) { + let lhs = sample_ring_q128m5823(23); + let rhs = sample_ring_q128m5823(41); + let split = PartialSplitNtt32::::compute(); + c.bench_function("ring_partial_split_mul_d64_q128m5823", |b| { + b.iter(|| split.multiply_d64(black_box(&lhs), black_box(&rhs))) + }); +} + +fn bench_crt_mul_q128m5823(c: &mut Criterion) { + let lhs = sample_ring_q128m5823(23); + let rhs = sample_ring_q128m5823(41); + let params = CrtNttParamSet::new(q128_primes()); + + c.bench_function("ring_crt_ntt_mul_d64_q128m5823_k5", |b| { + b.iter(|| { + let lhs_ntt = N128::from_ring_with_params(black_box(&lhs), ¶ms); + let rhs_ntt = N128::from_ring_with_params(black_box(&rhs), ¶ms); + let prod = lhs_ntt.pointwise_mul_with_params(&rhs_ntt, ¶ms); + let out: R128 = prod.to_ring_with_params(¶ms); + black_box(out) + }) + }); +} + +fn bench_partial_split_mul_i8_rhs_q128m5823(c: &mut Criterion) { + let lhs = sample_ring_q128m5823(23); + let rhs = sample_centered_i8(41); + let split = PartialSplitNtt32::::compute(); + c.bench_function("ring_partial_split_mul_i8_rhs_d64_q128m5823", |b| { + b.iter(|| split.multiply_d64_rhs_i8(black_box(&lhs), black_box(&rhs))) + }); +} + +fn bench_crt_mul_i8_rhs_q128m5823(c: &mut Criterion) { + let lhs = sample_ring_q128m5823(23); + let rhs = sample_centered_i8(41); + let params = CrtNttParamSet::new(q128_primes()); + + c.bench_function("ring_crt_ntt_mul_i8_rhs_d64_q128m5823_k5", |b| { + b.iter(|| { + let lhs_ntt = N128::from_ring_with_params(black_box(&lhs), ¶ms); + let rhs_ntt = N128::from_i8_with_params(black_box(&rhs), ¶ms); + let prod = lhs_ntt.pointwise_mul_with_params(&rhs_ntt, ¶ms); + let out: R128 = prod.to_ring_with_params(¶ms); + black_box(out) + }) + }); +} + +fn bench_cached_mul_batch_scaling_q128m5823(c: &mut Criterion) { + let width = PackedPartialSplitEval32::::WIDTH; + let split = PartialSplitNtt32::::compute(); + let packed = split.packed::(); + let params = CrtNttParamSet::new(q128_primes()); + let mut group = c.benchmark_group("ring_cached_mul_batch_scaling_d64_q128m5823"); + + for factor in MUL_BATCH_FACTORS { + let count = factor * width; + let lhs_split: Vec> = (0..count) + .map(|idx| { + PartialSplitEval32::from_ring(&split, &sample_ring_q128m5823_tag(23, idx as u64)) + }) + .collect(); + let rhs_split: Vec> = (0..count) + .map(|idx| { + PartialSplitEval32::from_ring(&split, &sample_ring_q128m5823_tag(41, idx as u64)) + }) + .collect(); + let lhs_packed = pack_split_batch(&lhs_split); + let rhs_packed = pack_split_batch(&rhs_split); + let lhs_crt: Vec = (0..count) + .map(|idx| { + N128::from_ring_with_params(&sample_ring_q128m5823_tag(23, idx as u64), ¶ms) + }) + .collect(); + let rhs_crt: Vec = (0..count) + .map(|idx| { + N128::from_ring_with_params(&sample_ring_q128m5823_tag(41, idx as u64), ¶ms) + }) + .collect(); + + group.bench_with_input( + BenchmarkId::new("split_scalar", count), + &count, + |b, &count| { + b.iter(|| { + let out: Vec = (0..count) + .map(|idx| { + lhs_split[idx] + .pointwise_mul(black_box(&rhs_split[idx]), &split) + .to_ring(&split) + }) + .collect(); + black_box(out) + }) + }, + ); + + group.bench_with_input( + BenchmarkId::new("split_packed", count), + &count, + |b, &count| { + b.iter(|| { + let mut out = Vec::with_capacity(count); + for idx in 0..(count / width) { + let acc = + packed.pointwise_mul(&lhs_packed[idx], black_box(&rhs_packed[idx])); + packed.append_rings(&acc, &mut out); + } + black_box(out) + }) + }, + ); + + group.bench_with_input(BenchmarkId::new("crt_simd", count), &count, |b, &count| { + b.iter(|| { + let out: Vec = (0..count) + .map(|idx| { + let mut acc = N128::zero(); + acc.add_assign_pointwise_mul_with_params( + &lhs_crt[idx], + black_box(&rhs_crt[idx]), + ¶ms, + ); + acc.to_ring_with_params(¶ms) + }) + .collect(); + black_box(out) + }) + }); + } + + group.finish(); +} + +fn bench_cached_mul_batch_scaling_i8_rhs_q128m5823(c: &mut Criterion) { + let width = PackedPartialSplitEval32::::WIDTH; + let split = PartialSplitNtt32::::compute(); + let packed = split.packed::(); + let params = CrtNttParamSet::new(q128_primes()); + let mut group = c.benchmark_group("ring_cached_mul_batch_scaling_i8_rhs_d64_q128m5823"); + + for factor in MUL_BATCH_FACTORS { + let count = factor * width; + let lhs_split: Vec> = (0..count) + .map(|idx| { + PartialSplitEval32::from_ring(&split, &sample_ring_q128m5823_tag(23, idx as u64)) + }) + .collect(); + let rhs_split: Vec> = (0..count) + .map(|idx| PartialSplitEval32::from_i8(&split, &sample_centered_i8(41 + idx as u64))) + .collect(); + let lhs_packed = pack_split_batch(&lhs_split); + let rhs_packed = pack_split_batch(&rhs_split); + let lhs_crt: Vec = (0..count) + .map(|idx| { + N128::from_ring_with_params(&sample_ring_q128m5823_tag(23, idx as u64), ¶ms) + }) + .collect(); + let rhs_crt: Vec = (0..count) + .map(|idx| N128::from_i8_with_params(&sample_centered_i8(41 + idx as u64), ¶ms)) + .collect(); + + group.bench_with_input( + BenchmarkId::new("split_scalar", count), + &count, + |b, &count| { + b.iter(|| { + let out: Vec = (0..count) + .map(|idx| { + lhs_split[idx] + .pointwise_mul(black_box(&rhs_split[idx]), &split) + .to_ring(&split) + }) + .collect(); + black_box(out) + }) + }, + ); + + group.bench_with_input( + BenchmarkId::new("split_packed", count), + &count, + |b, &count| { + b.iter(|| { + let mut out = Vec::with_capacity(count); + for idx in 0..(count / width) { + let acc = + packed.pointwise_mul(&lhs_packed[idx], black_box(&rhs_packed[idx])); + packed.append_rings(&acc, &mut out); + } + black_box(out) + }) + }, + ); + + group.bench_with_input(BenchmarkId::new("crt_simd", count), &count, |b, &count| { + b.iter(|| { + let out: Vec = (0..count) + .map(|idx| { + let mut acc = N128::zero(); + acc.add_assign_pointwise_mul_with_params( + &lhs_crt[idx], + black_box(&rhs_crt[idx]), + ¶ms, + ); + acc.to_ring_with_params(¶ms) + }) + .collect(); + black_box(out) + }) + }); + } + + group.finish(); +} + +fn bench_partial_split_cyclic_mul_q128m5823(c: &mut Criterion) { + let split = PartialSplitNtt32::::compute(); + let lhs = sample_ring_q128m5823(23); + let rhs = sample_ring_q128m5823(41); + + c.bench_function("ring_partial_split_cyclic_mul_d64_q128m5823", |b| { + b.iter(|| { + let out = split.multiply_cyclic_d64(black_box(&lhs), black_box(&rhs)); + black_box(out) + }) + }); +} + +fn bench_crt_cyclic_mul_q128m5823(c: &mut Criterion) { + let params = CrtNttParamSet::new(q128_primes()); + let lhs = sample_ring_q128m5823(23); + let rhs = sample_ring_q128m5823(41); + + c.bench_function("ring_crt_ntt_cyclic_mul_d64_q128m5823_k5", |b| { + b.iter(|| { + let lhs_ntt = N128::from_ring_cyclic(black_box(&lhs), ¶ms); + let rhs_ntt = N128::from_ring_cyclic(black_box(&rhs), ¶ms); + let prod = lhs_ntt.pointwise_mul_with_params(&rhs_ntt, ¶ms); + let out: R128 = prod.to_ring_cyclic(¶ms); + black_box(out) + }) + }); +} + +fn bench_partial_split_quotient_q128m5823(c: &mut Criterion) { + let split = PartialSplitNtt32::::compute(); + let lhs = sample_ring_q128m5823(23); + let rhs = sample_ring_q128m5823(41); + + c.bench_function("ring_partial_split_quotient_d64_q128m5823", |b| { + b.iter(|| { + let out = split.unreduced_quotient_d64(black_box(&lhs), black_box(&rhs)); + black_box(out) + }) + }); +} + +fn bench_crt_quotient_q128m5823(c: &mut Criterion) { + let params = CrtNttParamSet::new(q128_primes()); + let lhs = sample_ring_q128m5823(23); + let rhs = sample_ring_q128m5823(41); + + c.bench_function("ring_crt_ntt_quotient_d64_q128m5823_k5", |b| { + b.iter(|| { + let lhs_neg = N128::from_ring_with_params(black_box(&lhs), ¶ms); + let rhs_neg = N128::from_ring_with_params(black_box(&rhs), ¶ms); + let neg: R128 = lhs_neg + .pointwise_mul_with_params(&rhs_neg, ¶ms) + .to_ring_with_params(¶ms); + + let lhs_cyc = N128::from_ring_cyclic(black_box(&lhs), ¶ms); + let rhs_cyc = N128::from_ring_cyclic(black_box(&rhs), ¶ms); + let cyc: R128 = lhs_cyc + .pointwise_mul_with_params(&rhs_cyc, ¶ms) + .to_ring_cyclic(¶ms); + + let out = R128::from_coefficients(std::array::from_fn(|i| { + (cyc.coefficients()[i] - neg.coefficients()[i]) * F128::TWO_INV + })); + black_box(out) + }) + }); +} + +fn bench_partial_split_cached_matvec_q128m5823(c: &mut Criterion) { + let split = PartialSplitNtt32::::compute(); + let matrix: Vec>> = (0..CACHE_MAT_ROWS) + .map(|r| { + (0..CACHE_MAT_COLS) + .map(|col| { + PartialSplitEval32::from_ring( + &split, + &sample_ring_q128m5823_tag(23, (r * CACHE_MAT_COLS + col) as u64), + ) + }) + .collect() + }) + .collect(); + let vector: Vec> = (0..CACHE_MAT_COLS) + .map(|col| { + PartialSplitEval32::from_ring(&split, &sample_ring_q128m5823_tag(41, col as u64)) + }) + .collect(); + + c.bench_function("ring_partial_split_cached_matvec_d64_q128m5823", |b| { + b.iter(|| { + let out: Vec = matrix + .iter() + .map(|row| { + let mut acc = PartialSplitEval32::zero(); + for (mat_entry, vec_entry) in row.iter().zip(vector.iter()) { + acc.add_mul_assign(mat_entry, black_box(vec_entry), &split); + } + acc.to_ring(&split) + }) + .collect(); + black_box(out) + }) + }); +} + +fn bench_partial_split_cached_matvec_i8_rhs_q128m5823(c: &mut Criterion) { + let split = PartialSplitNtt32::::compute(); + let matrix: Vec>> = (0..CACHE_MAT_ROWS) + .map(|r| { + (0..CACHE_MAT_COLS) + .map(|col| { + PartialSplitEval32::from_ring( + &split, + &sample_ring_q128m5823_tag(23, (r * CACHE_MAT_COLS + col) as u64), + ) + }) + .collect() + }) + .collect(); + let vector: Vec> = (0..CACHE_MAT_COLS) + .map(|col| PartialSplitEval32::from_i8(&split, &sample_centered_i8(41 + col as u64))) + .collect(); + + c.bench_function( + "ring_partial_split_cached_matvec_i8_rhs_d64_q128m5823", + |b| { + b.iter(|| { + let out: Vec = matrix + .iter() + .map(|row| { + let mut acc = PartialSplitEval32::zero(); + for (mat_entry, vec_entry) in row.iter().zip(vector.iter()) { + acc.add_mul_assign(mat_entry, black_box(vec_entry), &split); + } + acc.to_ring(&split) + }) + .collect(); + black_box(out) + }) + }, + ); +} + +fn bench_partial_split_packed_cached_matvec_q128m5823(c: &mut Criterion) { + let split = PartialSplitNtt32::::compute(); + let packed = split.packed::(); + let matrix_scalar: Vec>> = (0..CACHE_MAT_ROWS) + .map(|r| { + (0..CACHE_MAT_COLS) + .map(|col| { + PartialSplitEval32::from_ring( + &split, + &sample_ring_q128m5823_tag(23, (r * CACHE_MAT_COLS + col) as u64), + ) + }) + .collect() + }) + .collect(); + let vector_scalar: Vec> = (0..CACHE_MAT_COLS) + .map(|col| { + PartialSplitEval32::from_ring(&split, &sample_ring_q128m5823_tag(41, col as u64)) + }) + .collect(); + let mut matrix_chunks = matrix_scalar.chunks_exact(PackedPartialSplitEval32::::WIDTH); + let matrix_packed: Vec>> = matrix_chunks + .by_ref() + .map(|row_chunk| { + (0..CACHE_MAT_COLS) + .map(|col| PackedPartialSplitEval32::::from_fn(|lane| row_chunk[lane][col])) + .collect() + }) + .collect(); + let matrix_scalar_tail = matrix_chunks.remainder(); + let vector_packed: Vec> = vector_scalar + .iter() + .map(PackedPartialSplitEval32::::broadcast) + .collect(); + + c.bench_function( + "ring_partial_split_packed_cached_matvec_d64_q128m5823", + |b| { + b.iter(|| { + let mut out = Vec::with_capacity(CACHE_MAT_ROWS); + for packed_row in &matrix_packed { + let mut acc = PackedPartialSplitEval32::::zero(); + for (mat_entry, vec_entry) in packed_row.iter().zip(vector_packed.iter()) { + packed.add_mul_assign(&mut acc, mat_entry, black_box(vec_entry)); + } + packed.append_rings(&acc, &mut out); + } + for row in matrix_scalar_tail { + let mut acc = PartialSplitEval32::zero(); + for (mat_entry, vec_entry) in row.iter().zip(vector_scalar.iter()) { + acc.add_mul_assign(mat_entry, black_box(vec_entry), &split); + } + out.push(acc.to_ring(&split)); + } + black_box(out) + }) + }, + ); +} + +fn bench_crt_simd_cached_matvec_q128m5823(c: &mut Criterion) { + let params = CrtNttParamSet::new(q128_primes()); + let matrix: Vec> = (0..CACHE_MAT_ROWS) + .map(|r| { + (0..CACHE_MAT_COLS) + .map(|col| { + N128::from_ring_with_params( + &sample_ring_q128m5823_tag(23, (r * CACHE_MAT_COLS + col) as u64), + ¶ms, + ) + }) + .collect() + }) + .collect(); + let vector: Vec = (0..CACHE_MAT_COLS) + .map(|col| N128::from_ring_with_params(&sample_ring_q128m5823_tag(41, col as u64), ¶ms)) + .collect(); + + c.bench_function("ring_crt_ntt_simd_cached_matvec_d64_q128m5823_k5", |b| { + b.iter(|| { + let out: Vec = matrix + .iter() + .map(|row| { + let mut acc = N128::zero(); + for (mat_entry, vec_entry) in row.iter().zip(vector.iter()) { + acc.add_assign_pointwise_mul_with_params( + mat_entry, + black_box(vec_entry), + ¶ms, + ); + } + acc.to_ring_with_params(¶ms) + }) + .collect(); + black_box(out) + }) + }); +} + +fn bench_partial_split_packed_cached_matvec_i8_rhs_q128m5823(c: &mut Criterion) { + let split = PartialSplitNtt32::::compute(); + let packed = split.packed::(); + let matrix_scalar: Vec>> = (0..CACHE_MAT_ROWS) + .map(|r| { + (0..CACHE_MAT_COLS) + .map(|col| { + PartialSplitEval32::from_ring( + &split, + &sample_ring_q128m5823_tag(23, (r * CACHE_MAT_COLS + col) as u64), + ) + }) + .collect() + }) + .collect(); + let vector_scalar: Vec> = (0..CACHE_MAT_COLS) + .map(|col| PartialSplitEval32::from_i8(&split, &sample_centered_i8(41 + col as u64))) + .collect(); + let mut matrix_chunks = matrix_scalar.chunks_exact(PackedPartialSplitEval32::::WIDTH); + let matrix_packed: Vec>> = matrix_chunks + .by_ref() + .map(|row_chunk| { + (0..CACHE_MAT_COLS) + .map(|col| PackedPartialSplitEval32::::from_fn(|lane| row_chunk[lane][col])) + .collect() + }) + .collect(); + let matrix_scalar_tail = matrix_chunks.remainder(); + let vector_packed: Vec> = vector_scalar + .iter() + .map(PackedPartialSplitEval32::::broadcast) + .collect(); + + c.bench_function( + "ring_partial_split_packed_cached_matvec_i8_rhs_d64_q128m5823", + |b| { + b.iter(|| { + let mut out = Vec::with_capacity(CACHE_MAT_ROWS); + for packed_row in &matrix_packed { + let mut acc = PackedPartialSplitEval32::::zero(); + for (mat_entry, vec_entry) in packed_row.iter().zip(vector_packed.iter()) { + packed.add_mul_assign(&mut acc, mat_entry, black_box(vec_entry)); + } + packed.append_rings(&acc, &mut out); + } + for row in matrix_scalar_tail { + let mut acc = PartialSplitEval32::zero(); + for (mat_entry, vec_entry) in row.iter().zip(vector_scalar.iter()) { + acc.add_mul_assign(mat_entry, black_box(vec_entry), &split); + } + out.push(acc.to_ring(&split)); + } + black_box(out) + }) + }, + ); +} + +fn bench_crt_simd_cached_matvec_i8_rhs_q128m5823(c: &mut Criterion) { + let params = CrtNttParamSet::new(q128_primes()); + let matrix: Vec> = (0..CACHE_MAT_ROWS) + .map(|r| { + (0..CACHE_MAT_COLS) + .map(|col| { + N128::from_ring_with_params( + &sample_ring_q128m5823_tag(23, (r * CACHE_MAT_COLS + col) as u64), + ¶ms, + ) + }) + .collect() + }) + .collect(); + let vector: Vec = (0..CACHE_MAT_COLS) + .map(|col| N128::from_i8_with_params(&sample_centered_i8(41 + col as u64), ¶ms)) + .collect(); + + c.bench_function( + "ring_crt_ntt_simd_cached_matvec_i8_rhs_d64_q128m5823_k5", + |b| { + b.iter(|| { + let out: Vec = matrix + .iter() + .map(|row| { + let mut acc = N128::zero(); + for (mat_entry, vec_entry) in row.iter().zip(vector.iter()) { + acc.add_assign_pointwise_mul_with_params( + mat_entry, + black_box(vec_entry), + ¶ms, + ); + } + acc.to_ring_with_params(¶ms) + }) + .collect(); + black_box(out) + }) + }, + ); +} + +criterion_group!( + ring_ntt, + bench_ring_schoolbook_mul, + bench_ntt_single_prime_round_trip, + bench_crt_round_trip, + bench_ring_schoolbook_mul_q128m5823, + bench_partial_split_mul_q128m5823, + bench_crt_mul_q128m5823, + bench_partial_split_mul_i8_rhs_q128m5823, + bench_crt_mul_i8_rhs_q128m5823, + bench_cached_mul_batch_scaling_q128m5823, + bench_cached_mul_batch_scaling_i8_rhs_q128m5823, + bench_partial_split_cyclic_mul_q128m5823, + bench_crt_cyclic_mul_q128m5823, + bench_partial_split_quotient_q128m5823, + bench_crt_quotient_q128m5823, + bench_partial_split_cached_matvec_q128m5823, + bench_partial_split_packed_cached_matvec_q128m5823, + bench_crt_simd_cached_matvec_q128m5823, + bench_partial_split_cached_matvec_i8_rhs_q128m5823, + bench_partial_split_packed_cached_matvec_i8_rhs_q128m5823, + bench_crt_simd_cached_matvec_i8_rhs_q128m5823 +); +criterion_main!(ring_ntt); diff --git a/examples/codegen_probe_special.rs b/examples/codegen_probe_special.rs new file mode 100644 index 00000000..bc6d5b50 --- /dev/null +++ b/examples/codegen_probe_special.rs @@ -0,0 +1,143 @@ +#![allow(missing_docs)] + +//! Codegen probe for packed/scalar Fp64 multiply kernels. +//! +//! Build with: +//! `cargo rustc --example codegen_probe_special --release -- --emit=asm` + +use hachi_pcs::algebra::fields::pseudo_mersenne::{POW2_OFFSET_MODULUS_40, POW2_OFFSET_MODULUS_64}; +use hachi_pcs::algebra::{Fp64, Fp64Packing, PackedValue}; +use hachi_pcs::CanonicalField; + +const MASK40: u64 = (1u64 << 40) - 1; +const P40: u64 = POW2_OFFSET_MODULUS_40; +const C40: u64 = (1u64 << 40) - P40; // 195 +const P64: u64 = POW2_OFFSET_MODULUS_64; +const C64: u64 = 0u64.wrapping_sub(P64); // 59 + +#[inline(always)] +fn mul_c40_split(x: u64) -> u64 { + let c = C40 as u32; + let x_lo = x as u32; + let x_hi = (x >> 32) as u32; + (c as u64 * x_lo as u64).wrapping_add((c as u64 * x_hi as u64) << 32) +} + +#[inline(always)] +fn mul_c40_shiftadd(x: u64) -> u64 { + // 195x = (128 + 64 + 2 + 1) * x + (x << 7) + .wrapping_add(x << 6) + .wrapping_add(x << 1) + .wrapping_add(x) +} + +#[inline(always)] +fn reduce40_with_mulc(lo: u64, hi: u64, mulc: fn(u64) -> u64) -> u64 { + let high = (lo >> 40) | (hi << 24); + let f1 = (lo & MASK40).wrapping_add(mulc(high)); + let f2 = (f1 & MASK40).wrapping_add(mulc(f1 >> 40)); + let reduced = f2.wrapping_sub(P40); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P40) +} + +#[inline(always)] +fn reduce64(lo: u64, hi: u64) -> u64 { + let f1 = (lo as u128) + (C64 as u128) * (hi as u128); + let f2 = (f1 as u64 as u128) + (C64 as u128) * ((f1 >> 64) as u64 as u128); + let reduced = f2.wrapping_sub(P64 as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P64 as u128)) as u64 +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_reduce40_split(lo: u64, hi: u64) -> u64 { + reduce40_with_mulc(lo, hi, mul_c40_split) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_reduce40_shiftadd(lo: u64, hi: u64) -> u64 { + reduce40_with_mulc(lo, hi, mul_c40_shiftadd) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_reduce64(lo: u64, hi: u64) -> u64 { + reduce64(lo, hi) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_packed_fp64_40_mul(a0: u64, a1: u64, b0: u64, b1: u64) -> u64 { + type F = Fp64<{ POW2_OFFSET_MODULUS_40 }>; + type PF = Fp64Packing<{ POW2_OFFSET_MODULUS_40 }>; + + let a = PF::from_fn(|i| { + if i == 0 { + F::from_canonical_u64(a0) + } else { + F::from_canonical_u64(a1) + } + }); + let b = PF::from_fn(|i| { + if i == 0 { + F::from_canonical_u64(b0) + } else { + F::from_canonical_u64(b1) + } + }); + let c = a * b; + (c.extract(0).to_canonical_u128() as u64) ^ (c.extract(1).to_canonical_u128() as u64) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_packed_fp64_64_mul(a0: u64, a1: u64, b0: u64, b1: u64) -> u64 { + type F = Fp64<{ POW2_OFFSET_MODULUS_64 }>; + type PF = Fp64Packing<{ POW2_OFFSET_MODULUS_64 }>; + + let a = PF::from_fn(|i| { + if i == 0 { + F::from_canonical_u64(a0) + } else { + F::from_canonical_u64(a1) + } + }); + let b = PF::from_fn(|i| { + if i == 0 { + F::from_canonical_u64(b0) + } else { + F::from_canonical_u64(b1) + } + }); + let c = a * b; + (c.extract(0).to_canonical_u128() as u64) ^ (c.extract(1).to_canonical_u128() as u64) +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_scalar_fp64_40_mul(a: u64, b: u64) -> u64 { + type F = Fp64<{ POW2_OFFSET_MODULUS_40 }>; + (F::from_canonical_u64(a) * F::from_canonical_u64(b)).to_canonical_u128() as u64 +} + +#[inline(never)] +#[no_mangle] +pub extern "C" fn probe_scalar_fp64_64_mul(a: u64, b: u64) -> u64 { + type F = Fp64<{ POW2_OFFSET_MODULUS_64 }>; + (F::from_canonical_u64(a) * F::from_canonical_u64(b)).to_canonical_u128() as u64 +} + +fn main() { + let x = probe_packed_fp64_40_mul(1, 2, 3, 4) + ^ probe_packed_fp64_64_mul(5, 6, 7, 8) + ^ probe_scalar_fp64_40_mul(9, 10) + ^ probe_scalar_fp64_64_mul(11, 12) + ^ probe_reduce40_split(13, 14) + ^ probe_reduce40_shiftadd(15, 16) + ^ probe_reduce64(17, 18); + std::hint::black_box(x); +} diff --git a/examples/profile.rs b/examples/profile.rs new file mode 100644 index 00000000..7c45815b --- /dev/null +++ b/examples/profile.rs @@ -0,0 +1,685 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::Fp128; +use hachi_pcs::primitives::serialization::Compress; +use hachi_pcs::protocol::commitment::{ + Fp128BoundedCommitmentConfig, Fp128D64BoundedCommitmentConfig, Fp128FullCommitmentConfig, + Fp128LogBasisCommitmentConfig, Fp128OneHotCommitmentConfig, HachiCommitmentLayout, +}; +use hachi_pcs::protocol::commitment_scheme::HachiCommitmentScheme; +use hachi_pcs::protocol::hachi_poly_ops::{DensePoly, OneHotPoly}; +use hachi_pcs::protocol::opening_point::{ + reduce_inner_opening_to_ring_element, ring_opening_point_from_field, +}; +use hachi_pcs::protocol::proof::{ + FlatLabradorLevelProof, FlatLabradorWitness, HachiLevelProof, HachiProof, HachiProofTail, + LabradorTail, +}; +use hachi_pcs::protocol::transcript::Blake2bTranscript; +use hachi_pcs::protocol::CommitmentConfig; +use hachi_pcs::{ + BasisMode, CanonicalField, CommitmentScheme, FromSmallInt, HachiPolyOps, HachiSerialize, + Transcript, +}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::env; +use std::fs; +use std::time::{Instant, SystemTime, UNIX_EPOCH}; +use tracing_chrome::ChromeLayerBuilder; +use tracing_subscriber::fmt::format::FmtSpan; +use tracing_subscriber::prelude::*; +use tracing_subscriber::EnvFilter; + +type F = Fp128<0xfffffffffffffffffffffffffffffeed>; +const ONEHOT_K: usize = 256; + +fn env_flag(name: &str, default: bool) -> bool { + env::var(name) + .ok() + .map(|value| value != "0") + .unwrap_or(default) +} + +fn opening_from_poly>( + poly: &P, + point: &[F], + layout: &HachiCommitmentLayout, + basis: BasisMode, +) -> F { + let alpha_bits = D.trailing_zeros() as usize; + assert_eq!(point.len(), alpha_bits + layout.m_vars + layout.r_vars); + + let inner_point = &point[..alpha_bits]; + let reduced_point = &point[alpha_bits..]; + let ring_opening_point = + ring_opening_point_from_field(reduced_point, layout.r_vars, layout.m_vars, basis) + .expect("opening point shape should match layout"); + + let (y_ring, _) = poly.evaluate_and_fold( + &ring_opening_point.b, + &ring_opening_point.a, + layout.block_len, + ); + let v = reduce_inner_opening_to_ring_element::(inner_point, basis) + .expect("inner opening point should match ring dimension"); + (y_ring * v.sigma_m1()).coefficients()[0] +} + +fn run_prove>( + label: &str, + setup: & as CommitmentScheme>::ProverSetup, + poly: &P, + pt: &[F], + opening: F, + layout: &HachiCommitmentLayout, +) { + type Scheme = HachiCommitmentScheme; + + let t0 = Instant::now(); + let (commitment, hint) = + as CommitmentScheme>::commit(poly, setup, layout).unwrap(); + tracing::info!(label, elapsed_s = t0.elapsed().as_secs_f64(), "commit"); + + let t0 = Instant::now(); + let mut prover_transcript = Blake2bTranscript::::new(b"profile"); + let proof = as CommitmentScheme>::prove( + setup, + poly, + pt, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + layout, + ) + .unwrap(); + tracing::info!(label, elapsed_s = t0.elapsed().as_secs_f64(), "prove"); + print_proof_summary(label, &proof); + + let t0 = Instant::now(); + let verifier_setup = as CommitmentScheme>::setup_verifier(setup); + let mut verifier_transcript = Blake2bTranscript::::new(b"profile"); + match as CommitmentScheme>::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + pt, + &opening, + &commitment, + BasisMode::Lagrange, + layout, + ) { + Ok(()) => tracing::info!(label, elapsed_s = t0.elapsed().as_secs_f64(), "verify OK"), + Err(e) => { + tracing::error!(label, elapsed_s = t0.elapsed().as_secs_f64(), error = %e, "verify FAILED") + } + } +} + +fn print_proof_summary(label: &str, proof: &HachiProof) { + let top_levels_len_size = std::mem::size_of::(); + let top_tail_tag_size = std::mem::size_of::(); + let hachi_levels_total: usize = proof + .levels + .iter() + .map(|level| level.serialized_size(Compress::No)) + .sum(); + let tail_total = match &proof.tail { + HachiProofTail::Direct(final_w) => final_w.serialized_size(Compress::No), + HachiProofTail::Labrador(tail) => tail.serialized_size(Compress::No), + }; + let accounted_total = top_levels_len_size + top_tail_tag_size + hachi_levels_total + tail_total; + + tracing::info!( + label, + levels = proof.levels.len(), + proof_size_bytes = proof.size(), + accounted_bytes = accounted_total, + hachi_fold_bytes = hachi_levels_total, + tail_bytes = tail_total, + "proof summary" + ); + debug_assert_eq!(accounted_total, proof.size()); + eprintln!( + "[{label}] proof framing: levels_len={top_levels_len_size} bytes, tail_tag={top_tail_tag_size} byte" + ); + + for (i, lp) in proof.levels.iter().enumerate() { + print_hachi_level_breakdown(label, i, lp); + } + match &proof.tail { + HachiProofTail::Direct(final_w) => { + eprintln!( + "[{label}] tail_choice: kind=direct, bytes={}", + final_w.serialized_size(Compress::No) + ); + eprintln!( + "[{label}] final_w: total={} bytes, elems={}, bits/elem={}", + final_w.serialized_size(Compress::No), + final_w.num_elems, + final_w.bits_per_elem, + ); + } + HachiProofTail::Labrador(tail) => { + eprintln!( + "[{label}] tail_choice: kind=labrador, bytes={}, labrador_levels={}", + tail.serialized_size(Compress::No), + tail.labrador_proof.levels.len() + ); + print_labrador_tail_breakdown(label, tail); + } + } +} + +fn print_hachi_level_breakdown(label: &str, level_idx: usize, level: &HachiLevelProof) -> usize { + let y_ring_size = level.y_ring.serialized_size(Compress::No); + let v_size = level.v.serialized_size(Compress::No); + let stage1_sumcheck_size = level.stage1.sumcheck.serialized_size(Compress::No); + let stage1_s_claim_size = level.stage1.s_claim.serialized_size(Compress::No); + let stage2_sumcheck_size = level.stage2.sumcheck.serialized_size(Compress::No); + let next_w_commitment_size = level.stage2.next_w_commitment.serialized_size(Compress::No); + let next_w_eval_size = level.stage2.next_w_eval.serialized_size(Compress::No); + let total = level.serialized_size(Compress::No); + + eprintln!("[{label}] hachi_fold L{level_idx}: total={total} bytes"); + eprintln!( + "[{label}] y_ring={} bytes ({} ring elems, D={})", + y_ring_size, + level.y_ring.count(), + level.y_ring.ring_dim(), + ); + eprintln!( + "[{label}] v={} bytes ({} ring elems, D={})", + v_size, + level.v.count(), + level.v.ring_dim(), + ); + eprintln!("[{label}] stage1_sumcheck={stage1_sumcheck_size} bytes"); + eprintln!("[{label}] stage1_s_claim={stage1_s_claim_size} bytes"); + eprintln!("[{label}] stage2_sumcheck={stage2_sumcheck_size} bytes"); + eprintln!( + "[{label}] next_w_commitment={next_w_commitment_size} bytes ({} ring elems, D={})", + level.stage2.next_w_commitment.count(), + level.w_commit_d(), + ); + eprintln!("[{label}] next_w_eval={next_w_eval_size} bytes"); + + debug_assert_eq!( + total, + y_ring_size + + v_size + + stage1_sumcheck_size + + stage1_s_claim_size + + stage2_sumcheck_size + + next_w_commitment_size + + next_w_eval_size + ); + total +} + +fn print_labrador_tail_breakdown(label: &str, tail: &LabradorTail) -> usize { + let labrador_proof_size = tail.labrador_proof.serialized_size(Compress::No); + let v_size = tail.v.serialized_size(Compress::No); + let y_ring_size = tail.y_ring.serialized_size(Compress::No); + let witness_norm_bound_sq_size = tail.witness_norm_bound_sq.serialized_size(Compress::No); + let total = tail.serialized_size(Compress::No); + + eprintln!("[{label}] final_w: Labrador tail"); + eprintln!("[{label}] labrador_tail: total={total} bytes"); + eprintln!("[{label}] labrador_proof={labrador_proof_size} bytes"); + eprintln!( + "[{label}] v={} bytes ({} ring elems, D={})", + v_size, + tail.v.count(), + tail.v.ring_dim(), + ); + eprintln!( + "[{label}] y_ring={} bytes ({} ring elems, D={})", + y_ring_size, + tail.y_ring.count(), + tail.y_ring.ring_dim(), + ); + eprintln!("[{label}] witness_norm_bound_sq={witness_norm_bound_sq_size} bytes"); + debug_assert_eq!( + total, + labrador_proof_size + v_size + y_ring_size + witness_norm_bound_sq_size + ); + + let labrador_levels_len_size = std::mem::size_of::(); + let labrador_levels_total: usize = tail + .labrador_proof + .levels + .iter() + .map(|level| level.serialized_size(Compress::No)) + .sum(); + let final_opening_witness_size = tail + .labrador_proof + .final_opening_witness + .serialized_size(Compress::No); + let labrador_accounted = + labrador_levels_len_size + labrador_levels_total + final_opening_witness_size; + eprintln!( + "[{label}] labrador_fold: levels={}, total={} bytes, levels_len={} bytes, final_opening_witness={} bytes", + tail.labrador_proof.levels.len(), + labrador_proof_size, + labrador_levels_len_size, + final_opening_witness_size, + ); + debug_assert_eq!(labrador_proof_size, labrador_accounted); + + for (i, level) in tail.labrador_proof.levels.iter().enumerate() { + print_labrador_level_breakdown(label, i, level); + } + print_labrador_final_witness_breakdown(label, &tail.labrador_proof.final_opening_witness); + + total +} + +fn print_labrador_level_breakdown( + label: &str, + level_idx: usize, + level: &FlatLabradorLevelProof, +) -> usize { + let tail_flag_size = std::mem::size_of::(); + let input_row_lengths_size = level.input_row_lengths.serialized_size(Compress::No); + let config_size = level.config.serialized_size(Compress::No); + let virtual_row_len_size = level.virtual_row_len.serialized_size(Compress::No); + let row_split_counts_size = level.row_split_counts.serialized_size(Compress::No); + let inner_opening_payload_size = level.inner_opening_payload.serialized_size(Compress::No); + let linear_garbage_payload_size = level.linear_garbage_payload.serialized_size(Compress::No); + let jl_projection_size = level.jl_projection.len() * std::mem::size_of::(); + let jl_nonce_size = level.jl_nonce.serialized_size(Compress::No); + let jl_lift_residuals_size = level.jl_lift_residuals.serialized_size(Compress::No); + let next_witness_norm_sq_size = level.next_witness_norm_sq.serialized_size(Compress::No); + let total = level.serialized_size(Compress::No); + + eprintln!( + "[{label}] labrador_fold L{level_idx}: total={total} bytes, tail={}", + level.tail + ); + eprintln!( + "[{label}] params: input_row_lengths={:?}, virtual_row_len={}, virtual_row_count={}, row_split_counts={:?}, witness_digit_parts={}, witness_digit_bits={}, aux_digit_parts={}, aux_digit_bits={}, inner_commit_rank={}, outer_commit_rank={}", + level.input_row_lengths, + level.virtual_row_len, + level.row_split_counts.iter().sum::(), + level.row_split_counts, + level.config.witness_digit_parts, + level.config.witness_digit_bits, + level.config.aux_digit_parts, + level.config.aux_digit_bits, + level.config.inner_commit_rank, + level.config.outer_commit_rank, + ); + eprintln!( + "[{label}] framing: tail_flag={tail_flag_size}, input_row_lengths={input_row_lengths_size}, config={config_size}, virtual_row_len={virtual_row_len_size}, row_split_counts={row_split_counts_size}, next_witness_norm_sq={next_witness_norm_sq_size}" + ); + eprintln!( + "[{label}] msg inner_opening_payload={} bytes ({} ring elems, D={})", + inner_opening_payload_size, + level.inner_opening_payload.count(), + level.inner_opening_payload.ring_dim(), + ); + eprintln!( + "[{label}] msg linear_garbage_payload={} bytes ({} ring elems, D={})", + linear_garbage_payload_size, + level.linear_garbage_payload.count(), + level.linear_garbage_payload.ring_dim(), + ); + eprintln!( + "[{label}] msg jl_projection={jl_projection_size} bytes, jl_nonce={jl_nonce_size} bytes" + ); + eprintln!( + "[{label}] msg jl_lift_residuals={} bytes ({} ring elems, D={})", + jl_lift_residuals_size, + level.jl_lift_residuals.count(), + level.jl_lift_residuals.ring_dim(), + ); + + debug_assert_eq!( + total, + tail_flag_size + + input_row_lengths_size + + config_size + + virtual_row_len_size + + row_split_counts_size + + inner_opening_payload_size + + linear_garbage_payload_size + + jl_projection_size + + jl_nonce_size + + jl_lift_residuals_size + + next_witness_norm_sq_size + ); + total +} + +fn print_labrador_final_witness_breakdown(label: &str, witness: &FlatLabradorWitness) -> usize { + let rows_len_size = std::mem::size_of::(); + let rows_total: usize = witness + .rows + .iter() + .map(|row| row.serialized_size(Compress::No)) + .sum(); + let total = witness.serialized_size(Compress::No); + + eprintln!( + "[{label}] final_opening_witness: total={total} bytes, rows_len={rows_len_size} bytes" + ); + for (row_idx, row) in witness.rows.iter().enumerate() { + eprintln!( + "[{label}] row{row_idx}={} bytes ({} ring elems, D={})", + row.serialized_size(Compress::No), + row.count(), + row.ring_dim(), + ); + } + debug_assert_eq!(total, rows_len_size + rows_total); + total +} + +fn print_layout(layout: &HachiCommitmentLayout) { + tracing::debug!( + m_vars = layout.m_vars, + r_vars = layout.r_vars, + num_blocks = layout.num_blocks, + block_len = layout.block_len, + delta_commit = layout.num_digits_commit, + delta_open = layout.num_digits_open, + delta_fold = layout.num_digits_fold, + log_basis = layout.log_basis, + "layout" + ); +} + +fn run_dense(nv: usize, layout: &HachiCommitmentLayout) { + let mut rng = StdRng::seed_from_u64(0xbeef_cafe); + let pt: Vec = (0..nv) + .map(|_| F::from_canonical_u128_reduced(rng.gen::())) + .collect(); + let (poly, opening) = { + let len = 1usize << nv; + let decomp = Cfg::decomposition(); + let half_bound = 1i64 << (decomp.log_commit_bound.min(62) - 1); + let evals: Vec = if decomp.log_commit_bound >= 128 { + (0..len) + .map(|_| F::from_canonical_u128_reduced(rng.gen::())) + .collect() + } else { + (0..len) + .map(|_| F::from_i64(rng.gen_range(-half_bound..half_bound))) + .collect() + }; + let poly = DensePoly::::from_field_evals(nv, &evals).unwrap(); + let opening = opening_from_poly(&poly, &pt, layout, BasisMode::Lagrange); + (poly, opening) + }; + + let t0 = Instant::now(); + let setup = as CommitmentScheme>::setup_prover(nv); + tracing::info!( + label = "dense", + elapsed_s = t0.elapsed().as_secs_f64(), + "setup" + ); + + run_prove::("dense", &setup, &poly, &pt, opening, layout); +} + +fn run_onehot(nv: usize, layout: &HachiCommitmentLayout) { + let mut rng = StdRng::seed_from_u64(0xbeef_cafe); + let total_field = (layout.num_blocks * layout.block_len) + .checked_mul(D) + .expect("total field size overflow"); + let onehot_k = ONEHOT_K; + let total_chunks = total_field / onehot_k; + assert_eq!( + total_chunks * onehot_k, + total_field, + "onehot K must divide total field size" + ); + + let indices: Vec> = (0..total_chunks) + .map(|_| Some(rng.gen_range(0..onehot_k))) + .collect(); + let onehot_poly = + OneHotPoly::::new(onehot_k, indices, layout.r_vars, layout.m_vars).unwrap(); + let pt: Vec = (0..nv) + .map(|_| F::from_canonical_u128_reduced(rng.gen::())) + .collect(); + let opening = opening_from_poly(&onehot_poly, &pt, layout, BasisMode::Lagrange); + + let t0 = Instant::now(); + let setup = as CommitmentScheme>::setup_prover(nv); + tracing::info!( + label = "onehot", + elapsed_s = t0.elapsed().as_secs_f64(), + "setup" + ); + + run_prove::("onehot", &setup, &onehot_poly, &pt, opening, layout); +} + +fn run_dense_mode(title: &str, nv: usize) { + let layout = resolve_layout::(nv); + tracing::info!("{}", title); + print_layout(&layout); + run_dense::(nv, &layout); +} + +fn run_onehot_mode(title: &str, nv: usize) { + let layout = resolve_layout::(nv); + tracing::info!("{}", title); + print_layout(&layout); + run_onehot::(nv, &layout); +} + +fn main() { + #[cfg(feature = "parallel")] + rayon::ThreadPoolBuilder::new() + .stack_size(64 * 1024 * 1024) + .build_global() + .ok(); + + if cfg!(debug_assertions) && env::var("HACHI_ALLOW_DEBUG_PROFILE").as_deref() != Ok("1") { + eprintln!("examples/profile must be run with --release for meaningful timings."); + eprintln!("Re-run with: cargo run --release --example profile"); + eprintln!("Set HACHI_ALLOW_DEBUG_PROFILE=1 to override this guard."); + std::process::exit(2); + } + + let nv: usize = env::var("HACHI_NUM_VARS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(25); + + let mode = env::var("HACHI_MODE").unwrap_or_else(|_| "full".to_string()); + let enable_trace = env_flag("HACHI_PROFILE_TRACE", true); + let enable_ansi = env_flag("HACHI_PROFILE_ANSI", true); + let span_events = if env_flag("HACHI_PROFILE_SPAN_CLOSES", true) { + FmtSpan::CLOSE + } else { + FmtSpan::NONE + }; + let log_filter = + EnvFilter::try_new(env::var("HACHI_PROFILE_LOG").unwrap_or_else(|_| "trace".to_string())) + .unwrap_or_else(|_| EnvFilter::new("trace")); + + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let trace_file = format!("profile_traces/hachi_nv{nv}_{mode}_{timestamp}.json"); + + let fmt_layer = tracing_subscriber::fmt::layer() + .with_ansi(enable_ansi) + .with_span_events(span_events) + .compact() + .with_target(false); + let _chrome_guard = if enable_trace { + fs::create_dir_all("profile_traces").ok(); + let (chrome_layer, guard) = ChromeLayerBuilder::new() + .include_args(true) + .file(&trace_file) + .build(); + tracing_subscriber::registry() + .with(log_filter) + .with(fmt_layer) + .with(chrome_layer) + .init(); + tracing::info!(trace_file = %trace_file, "Perfetto trace"); + Some(guard) + } else { + tracing_subscriber::registry() + .with(log_filter) + .with(fmt_layer) + .init(); + tracing::info!("Perfetto trace disabled"); + None + }; + tracing::info!(num_vars = nv, mode = %mode, "profile config"); + + match mode.as_str() { + "full" => { + type Cfg = Fp128FullCommitmentConfig; + run_dense_mode::<{ Fp128FullCommitmentConfig::D }, Cfg>( + "=== full (dense, log_commit_bound=128) ===", + nv, + ); + } + "onehot" => { + type Cfg = Fp128OneHotCommitmentConfig; + run_onehot_mode::<{ Fp128OneHotCommitmentConfig::D }, Cfg>( + "=== onehot (1-of-256, log_commit_bound=1) ===", + nv, + ); + } + "logbasis" => { + type Cfg = Fp128LogBasisCommitmentConfig; + run_dense_mode::<{ Fp128LogBasisCommitmentConfig::D }, Cfg>( + "=== logbasis (dense, log_commit_bound=3) ===", + nv, + ); + } + "all" => { + { + type Cfg = Fp128FullCommitmentConfig; + run_dense_mode::<{ Fp128FullCommitmentConfig::D }, Cfg>( + "=== full (dense, log_commit_bound=128) ===", + nv, + ); + } + { + type Cfg = Fp128OneHotCommitmentConfig; + run_onehot_mode::<{ Fp128OneHotCommitmentConfig::D }, Cfg>( + "=== onehot (1-of-256, log_commit_bound=1) ===", + nv, + ); + } + { + type Cfg = Fp128LogBasisCommitmentConfig; + run_dense_mode::<{ Fp128LogBasisCommitmentConfig::D }, Cfg>( + "=== logbasis (dense, log_commit_bound=3) ===", + nv, + ); + } + } + "compare_onehot" => { + { + type Cfg = Fp128D64BoundedCommitmentConfig<1, 3, 3>; + run_onehot_mode::<{ Cfg::D }, Cfg>( + "=== [A] onehot (1-of-256), basis=3 everywhere ===", + nv, + ); + } + { + type Cfg = Fp128D64BoundedCommitmentConfig<1, 2, 2>; + run_onehot_mode::<{ Cfg::D }, Cfg>( + "=== [B] onehot (1-of-256), basis=2 everywhere ===", + nv, + ); + } + { + type Cfg = Fp128D64BoundedCommitmentConfig<1, 2, 3>; + run_onehot_mode::<{ Cfg::D }, Cfg>( + "=== [C] onehot (1-of-256), L0 basis=2, w-levels basis=3 ===", + nv, + ); + } + { + type Cfg = Fp128D64BoundedCommitmentConfig<1, 2, 4>; + run_onehot_mode::<{ Cfg::D }, Cfg>( + "=== [D] onehot (1-of-256), L0 basis=2, w-levels basis=4 ===", + nv, + ); + } + } + "compare_logbasis" => { + { + type Cfg = Fp128BoundedCommitmentConfig<3, 3, 3>; + run_dense_mode::<{ Cfg::D }, Cfg>( + "=== [A] logbasis coeffs, basis=3 everywhere ===", + nv, + ); + } + { + type Cfg = Fp128BoundedCommitmentConfig<3, 2, 2>; + run_dense_mode::<{ Cfg::D }, Cfg>( + "=== [B] logbasis coeffs, basis=2 everywhere ===", + nv, + ); + } + { + type Cfg = Fp128BoundedCommitmentConfig<3, 2, 3>; + run_dense_mode::<{ Cfg::D }, Cfg>( + "=== [C] logbasis coeffs, L0 basis=2, w-levels basis=3 ===", + nv, + ); + } + { + type Cfg = Fp128BoundedCommitmentConfig<3, 2, 4>; + run_dense_mode::<{ Cfg::D }, Cfg>( + "=== [D] logbasis coeffs, L0 basis=2, w-levels basis=4 ===", + nv, + ); + } + } + "compare_basis" => { + { + type Cfg = Fp128BoundedCommitmentConfig<128, 3, 3>; + run_dense_mode::<{ Cfg::D }, Cfg>( + "=== [A] baseline: log_basis=3 everywhere ===", + nv, + ); + } + { + type Cfg = Fp128BoundedCommitmentConfig<128, 2, 2>; + run_dense_mode::<{ Cfg::D }, Cfg>("=== [B] log_basis=2 everywhere ===", nv); + } + { + type Cfg = Fp128BoundedCommitmentConfig<128, 2, 3>; + run_dense_mode::<{ Cfg::D }, Cfg>("=== [C] L0 basis=2, w-levels basis=3 ===", nv); + } + { + type Cfg = Fp128BoundedCommitmentConfig<128, 2, 4>; + run_dense_mode::<{ Cfg::D }, Cfg>("=== [D] L0 basis=2, w-levels basis=4 ===", nv); + } + } + other => { + tracing::error!( + mode = other, + "Unknown HACHI_MODE. Use: full, onehot, logbasis, all, compare_onehot, compare_logbasis, compare_basis" + ); + std::process::exit(1); + } + } + + if enable_trace { + tracing::info!(trace_file = %trace_file, "Done. Trace saved"); + } else { + tracing::info!("Done"); + } +} + +fn resolve_layout(nv: usize) -> HachiCommitmentLayout { + Cfg::commitment_layout(nv).expect("layout") +} diff --git a/paper/hachi.pdf b/paper/hachi.pdf deleted file mode 100644 index 33354fc2..00000000 Binary files a/paper/hachi.pdf and /dev/null differ diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 00000000..c5b9f7f3 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,4 @@ +[toolchain] +channel = "1.88" +profile = "minimal" +components = ["cargo", "rustc", "clippy", "rustfmt"] diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 00000000..541c50b7 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,5 @@ +edition = "2021" +tab_spaces = 4 +newline_style = "Unix" +use_try_shorthand = true +use_field_init_shorthand = true diff --git a/scripts/estimate_hachi_d64_k256_onehot_sis.py b/scripts/estimate_hachi_d64_k256_onehot_sis.py new file mode 100644 index 00000000..27ac36fd --- /dev/null +++ b/scripts/estimate_hachi_d64_k256_onehot_sis.py @@ -0,0 +1,670 @@ +#!/usr/bin/env python3 +""" +Estimate SIS security for a prospective Hachi D=64 one-hot family. + +This script is meant to be readable on its own. It fixes one concrete +prospective `D = 64` parameter regime, explains the challenge family in plain +language, and prints the extracted SIS instances that determine the security +floor. + +The setup studied here is: + +- field modulus `q = 2^128 - 5823` +- ring degree `D = 64`, i.e. ring `F_q[X] / (X^64 + 1)` +- one-hot chunk size `K = 256` +- sparse challenge coefficients in `{-2, -1, 1, 2}` + +The default challenge family is the rigorous split `k=32` family +`C_{21,<=6}`. Intuitively: + +- split the `64` ring coefficients into the even and odd positions +- in each half, choose exactly `21` active slots out of `32` +- each active slot gets sign `+/-` +- at most `6` of those active slots may use magnitude `2` +- the two halves are then interleaved back into one ring element + +That is what the name `C_{21,<=6}` means. Its conservative challenge mass is +`54 = 2 * (21 + 6)`. The same-budget larger variant `C_{22,<=6}` has mass +`56`. + +For comparison, the script also supports two older direct full-ring shells: + +- `(28,11)`: exactly `28` coefficients in `+/-1` and `11` in `+/-2` +- `(31,10)`: exactly `31` coefficients in `+/-1` and `10` in `+/-2` + +Those legacy shells are useful comparison points, but they are not proven. + +The script prints two reports: + +1. one main estimate for a chosen parameter point +2. one sweep over `nv` to find the largest point that still clears 128 bits + +Glossary for the printed quantities: + +- `nv`: total multilinear variables before ring packing +- `alpha = log2(D)`: variables absorbed by the ring structure; here `alpha = 6` +- `m_vars`, `r_vars`: inner-block and outer/fold variables, with + `nv - alpha = m_vars + r_vars` +- `num_blocks = 2^r_vars`, `block_len = 2^m_vars` +- `LOG_BASIS = 3`: base-8 digit decomposition +- `N_A`, `N_B`, `N_D`: module ranks of the three main commitment layers +- `inner_width`, `outer_width`, `d_matrix_width`: extracted SIS widths in ring + elements +- `width_ring`: SIS width measured in ring elements; the estimator sees + `m = width_ring * D` field coordinates +- `collision_inf`: `l_inf` collision bound passed to the estimator +- `A_fullwidth`, `B`, `D`, `M_code`, `M_tight`: the extracted SIS instances +- `overall floor`: the minimum security estimate across those instances + +Modeling choices: + +- the script calls the Euclidean `SIS.lattice(...)` path from + `lattice-estimator` +- the reduction model is pinned to `BDGL16 + lgsa` +- for the folded one-hot witness `z_pre`, it uses the tighter onehot-aware + bound `||z_pre||_inf <= 2^r_vars * 2` + instead of the older dense proxy + `2^r_vars * challenge_mass * 2^(LOG_BASIS - 1)` + +Run from the `hachi/` repo root with either: + + sage -python scripts/estimate_hachi_d64_k256_onehot_sis.py + +or, if the estimator is not in the default sibling location: + + LATTICE_ESTIMATOR_PATH="../lattice-estimator" \ + sage -python scripts/estimate_hachi_d64_k256_onehot_sis.py +""" + +from __future__ import annotations + +import argparse +import os +import sys +from dataclasses import dataclass +from pathlib import Path +from textwrap import dedent + + +# Fixed family parameters for the experiment. +Q = (1 << 128) - 5823 +Q_LABEL = "2^128 - 5823" +D = 64 +K = 256 +LOG_BASIS = 3 +DELTA_COMMIT = 1 +DELTA_OPEN = 43 +MAX_ABS_CHALLENGE_COEFF = 2 +DEFAULT_CHALLENGE_MASS = 54 + +ALPHA = D.bit_length() - 1 + + +class HelpFormatter( + argparse.ArgumentDefaultsHelpFormatter, + argparse.RawDescriptionHelpFormatter, +): + """Combine preserved formatting with automatic default display.""" + + +HELP_EPILOG = dedent( + """\ + Terminology: + nv total multilinear variables before ring packing + alpha log2(D), i.e. packing variables absorbed by the ring + m_vars inner-block variables; block_len = 2^m_vars + r_vars outer/fold variables; num_blocks = 2^r_vars + N_A/N_B/N_D module ranks of the A, B, and D commitment layers + challenge_mass + conservative L1 bound for the sparse challenge family + delta_* base-8 digit counts used in the extracted bounds + A/B/D/M SIS instances reported by the script + """ +) + + +TERMINOLOGY_LINES = [ + ("q", "field modulus used in the SIS instance"), + ("D", "ring degree; ring is F_q[X] / (X^D + 1)"), + ("K", "one-hot chunk size before ring packing"), + ("alpha", "log2(D), the number of variables absorbed by ring packing"), + ("nv", "total multilinear variables before subtracting alpha"), + ("m_vars", "inner-block variables; block_len = 2^m_vars"), + ("r_vars", "outer variables; num_blocks = 2^r_vars"), + ("LOG_BASIS", "digit decomposition base exponent, so base = 2^LOG_BASIS = 8"), + ("challenge_mass", "conservative L1 bound for the sparse challenge family"), + ("delta_commit/open/fold", "numbers of base-8 digits used in the extracted bounds"), + ("inner/outer/D widths", "ring-element counts of the A, B, and D SIS instances"), + ("width_ring", "SIS width measured in ring elements; field-coordinate width is width_ring * D"), + ("collision_inf", "l_inf collision bound passed to the SIS estimator"), + ("A/B/D/M", "the extracted SIS instances whose minimum gives the overall floor"), +] + + +RIGOROUS_SPLIT_BY_MASS = { + 54: (21, 6), + 56: (22, 6), +} + +LEGACY_RAW_SHELL_BY_MASS = { + 50: (28, 11), + 51: (31, 10), +} + + +@dataclass(frozen=True) +class Layout: + """Split `reduced_vars = nv - alpha` into inner-block and outer variables.""" + + nv: int + m_vars: int + r_vars: int + delta_fold_tight: int + delta_fold_code: int + + @property + def num_blocks(self) -> int: + return 1 << self.r_vars + + @property + def block_len(self) -> int: + return 1 << self.m_vars + + @property + def inner_width(self) -> int: + return self.block_len * DELTA_COMMIT + + def outer_width(self, n_a: int) -> int: + return n_a * DELTA_OPEN * self.num_blocks + + @property + def d_matrix_width(self) -> int: + return DELTA_OPEN * self.num_blocks + + +@dataclass(frozen=True) +class LayerEstimate: + """Security estimate for one extracted SIS instance.""" + + name: str + sec_bits: float + rank: int + width_ring_elems: int + collision_inf: int + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Estimate SIS security for the Hachi D=64, K=256 one-hot family " + "and print both a main configuration report and a rank-1 nv sweep." + ), + formatter_class=HelpFormatter, + epilog=HELP_EPILOG, + ) + parser.add_argument( + "--estimator-path", + help=( + "Path to the lattice-estimator repo. Defaults to " + "LATTICE_ESTIMATOR_PATH or a sibling ../lattice-estimator checkout." + ), + ) + parser.add_argument( + "--candidate-nv", + type=int, + default=44, + help="Total multilinear variables nv for the main configuration report.", + ) + parser.add_argument( + "--candidate-na", + type=int, + default=1, + help="Module rank N_A of the inner A commitment layer.", + ) + parser.add_argument( + "--candidate-nb", + type=int, + default=2, + help="Module rank N_B of the outer B commitment layer.", + ) + parser.add_argument( + "--candidate-nd", + type=int, + default=2, + help="Module rank N_D of the outer D commitment layer.", + ) + parser.add_argument( + "--sweep-min-nv", + type=int, + default=28, + help="Smallest total multilinear variable count nv in the sweep.", + ) + parser.add_argument( + "--sweep-max-nv", + type=int, + default=44, + help="Largest total multilinear variable count nv in the sweep.", + ) + parser.add_argument( + "--challenge-mass", + type=int, + default=DEFAULT_CHALLENGE_MASS, + help=( + "Conservative L1 mass of the D=64 {+/-1, +/-2} family. " + "Use 54 for rigorous split C_{21,<=6} (default), 56 for " + "rigorous split C_{22,<=6}, or 50/51 for the older raw " + "direct full-ring shells (28,11)/(31,10)." + ), + ) + return parser.parse_args() + + +def repo_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def locate_estimator_repo(explicit: str | None) -> Path: + candidates: list[Path] = [] + if explicit: + candidates.append(Path(explicit).expanduser()) + + env_path = os.environ.get("LATTICE_ESTIMATOR_PATH") + if env_path: + candidates.append(Path(env_path).expanduser()) + + root = repo_root() + candidates.extend( + [ + root / "lattice-estimator", + root / "third_party" / "lattice-estimator", + root / "vendor" / "lattice-estimator", + root.parent / "lattice-estimator", + ] + ) + + for candidate in candidates: + if (candidate / "estimator" / "__init__.py").exists(): + return candidate.resolve() + + raise SystemExit( + "Could not locate lattice-estimator. " + "Set LATTICE_ESTIMATOR_PATH or pass --estimator-path." + ) + + +def load_estimator(estimator_repo: Path): + sys.path.insert(0, str(estimator_repo)) + from estimator import SIS # type: ignore + from estimator.reduction import RC # type: ignore + from sage.all import log # type: ignore + + return SIS, RC, log + + +def compute_num_digits(log_bound: int, log_basis: int) -> int: + """Return the number of signed base-2^log_basis digits needed for a bound.""" + + if log_basis <= 0 or log_basis >= 128: + raise ValueError("invalid log_basis") + if log_bound == 0: + return 1 + + levels = (log_bound + log_basis - 1) // log_basis + total_bits = levels * log_basis + + if total_bits <= log_bound: + b = 1 << log_basis + half_b_minus_1 = b // 2 - 1 + b_pow = b**levels + max_positive = half_b_minus_1 * ((b_pow - 1) // (b - 1)) + required = (1 << (log_bound - 1)) - 1 + if max_positive < required: + levels += 1 + + return max(levels, 1) + + +def compute_num_digits_fold_code(r_vars: int, challenge_mass: int) -> int: + """Fold-digit count from the generic dense proxy used by the code path.""" + + beta = challenge_mass * (1 << (r_vars + LOG_BASIS - 1)) + return compute_num_digits(beta.bit_length(), LOG_BASIS) + + +def compute_num_digits_fold_tight(r_vars: int) -> int: + """Fold-digit count from the tighter onehot-aware bound on z_pre.""" + + # A monomial times a sparse challenge only sees the maximum absolute + # challenge coefficient, not the full challenge L1 mass. + beta = (1 << r_vars) * MAX_ABS_CHALLENGE_COEFF + return compute_num_digits(beta.bit_length(), LOG_BASIS) + + +def describe_challenge_family(challenge_mass: int) -> str: + if challenge_mass == 54: + return "rigorous split k=32 family C_{21,<=6}" + if challenge_mass == 56: + return "rigorous split k=32 family C_{22,<=6}" + if challenge_mass == 50: + return "legacy raw direct full-ring shell (28,11) [not yet proven]" + if challenge_mass == 51: + return "legacy raw direct full-ring shell (31,10) [not yet proven]" + return f"custom D=64 {{+/-1, +/-2}} family with conservative L1 mass {challenge_mass}" + + +def challenge_family_definition_lines(challenge_mass: int) -> list[str]: + """Return a self-contained description of the selected challenge family.""" + + if challenge_mass in RIGOROUS_SPLIT_BY_MASS: + w, m = RIGOROUS_SPLIT_BY_MASS[challenge_mass] + return [ + "This is the rigorous split family.", + "Think of a challenge as an even half plus an odd half.", + f"In each half: choose exactly {w} active slots out of 32, give them signs +/- ,", + f"and allow magnitude 2 on at most {m} of those active slots.", + "Then interleave the two halves back into one ring element as a0(X^2) + X a1(X^2).", + f"Selected parameters: w = {w}, m = {m}.", + f"Exact support = 2w = {2 * w}.", + f"Conservative challenge_mass = 2(w + m) = {2 * (w + m)}.", + ] + + if challenge_mass in LEGACY_RAW_SHELL_BY_MASS: + n1, n2 = LEGACY_RAW_SHELL_BY_MASS[challenge_mass] + return [ + "This is the older direct full-ring comparison shell.", + f"Choose exactly {n1} positions with coefficients +/-1 and exactly {n2} positions with coefficients +/-2", + "across all 64 slots of the ring element.", + f"Selected parameters: n1 = {n1}, n2 = {n2}.", + f"Exact support = n1 + n2 = {n1 + n2}.", + f"Conservative challenge_mass = n1 + 2*n2 = {n1 + 2 * n2}.", + "This shell is kept only for comparison; it is not proven.", + ] + + return [ + "Custom conservative model with no named exact family attached.", + "The script will treat challenge_mass only as an L1 proxy in the folded bound.", + f"Selected challenge_mass = {challenge_mass}.", + ] + + +def best_layout_onehot(nv: int, n_a: int, challenge_mass: int) -> Layout: + """Choose the cheapest (m_vars, r_vars) split for a given nv and N_A.""" + + alpha = D.bit_length() - 1 + reduced_vars = nv - alpha + if reduced_vars <= 1: + raise ValueError(f"nv={nv} is too small for D={D}") + + best: tuple[int, int, int, int] | None = None + for r_vars in range(1, reduced_vars): + m_vars = reduced_vars - r_vars + delta_fold_tight = compute_num_digits_fold_tight(r_vars) + cost = ( + (DELTA_OPEN + n_a * DELTA_COMMIT) * (1 << r_vars) + + DELTA_COMMIT * delta_fold_tight * (1 << m_vars) + ) + candidate = (cost, m_vars, r_vars, delta_fold_tight) + if best is None or candidate < best: + best = candidate + + assert best is not None + _, m_vars, r_vars, delta_fold_tight = best + delta_fold_code = compute_num_digits_fold_code(r_vars, challenge_mass) + return Layout( + nv=nv, + m_vars=m_vars, + r_vars=r_vars, + delta_fold_tight=delta_fold_tight, + delta_fold_code=delta_fold_code, + ) + + +def estimate_sec_bits(SIS, RC, log, rank: int, width_ring_elems: int, collision_inf: int) -> float: + """Estimate security in bits for one extracted SIS instance.""" + + n = rank * D + m = width_ring_elems * D + length_bound = (m**0.5) * collision_inf + out = SIS.lattice( + SIS.Parameters(n=n, q=Q, m=m, length_bound=length_bound, norm=2, tag="repro"), + red_cost_model=RC.BDGL16, + red_shape_model="lgsa", + log_level=0, + ) + return float(log(out["rop"], 2)) + + +def main_configuration_estimates( + SIS, RC, log, layout: Layout, n_a: int, n_b: int, n_d: int +) -> list[LayerEstimate]: + """Estimate all reported SIS layers for one chosen parameter point.""" + + inner_width = layout.inner_width + outer_width = layout.outer_width(n_a) + d_matrix_width = layout.d_matrix_width + + a_bits = estimate_sec_bits(SIS, RC, log, n_a, inner_width, 2) + b_bits = estimate_sec_bits(SIS, RC, log, n_b, outer_width, 7) + d_bits = estimate_sec_bits(SIS, RC, log, n_d, d_matrix_width, 7) + + m_collision_inf = 2 * ((1 << layout.r_vars) * MAX_ABS_CHALLENGE_COEFF) + m_rank = n_a + n_b + n_d + 2 + m_code_width = d_matrix_width + outer_width + inner_width * layout.delta_fold_code + m_tight_width = d_matrix_width + outer_width + inner_width * layout.delta_fold_tight + m_code_bits = estimate_sec_bits(SIS, RC, log, m_rank, m_code_width, m_collision_inf) + m_tight_bits = estimate_sec_bits(SIS, RC, log, m_rank, m_tight_width, m_collision_inf) + + return [ + LayerEstimate("A_fullwidth", a_bits, n_a, inner_width, 2), + LayerEstimate("B", b_bits, n_b, outer_width, 7), + LayerEstimate("D", d_bits, n_d, d_matrix_width, 7), + LayerEstimate("M_code", m_code_bits, m_rank, m_code_width, m_collision_inf), + LayerEstimate("M_tight", m_tight_bits, m_rank, m_tight_width, m_collision_inf), + ] + + +def sweep_rank1_cutoff(SIS, RC, log, min_nv: int, max_nv: int, n_a: int, challenge_mass: int): + """Sweep nv with N_B = N_D = 1 and report the 128-bit cutoff.""" + + rows = [] + for nv in range(min_nv, max_nv + 1): + layout = best_layout_onehot(nv, n_a=n_a, challenge_mass=challenge_mass) + outer_width = layout.outer_width(n_a) + d_matrix_width = layout.d_matrix_width + m_rank = n_a + 1 + 1 + 2 + m_collision_inf = 2 * ((1 << layout.r_vars) * MAX_ABS_CHALLENGE_COEFF) + m_tight_width = d_matrix_width + outer_width + layout.inner_width * layout.delta_fold_tight + + a_bits = estimate_sec_bits(SIS, RC, log, n_a, layout.inner_width, 2) + b_bits = estimate_sec_bits(SIS, RC, log, 1, outer_width, 7) + d_bits = estimate_sec_bits(SIS, RC, log, 1, d_matrix_width, 7) + m_bits = estimate_sec_bits(SIS, RC, log, m_rank, m_tight_width, m_collision_inf) + overall = min(a_bits, b_bits, d_bits, m_bits) + rows.append( + { + "nv": nv, + "m_vars": layout.m_vars, + "r_vars": layout.r_vars, + "delta_fold": layout.delta_fold_tight, + "A_bits": a_bits, + "B_bits": b_bits, + "D_bits": d_bits, + "BD_floor_bits": min(b_bits, d_bits), + "M_bits": m_bits, + "overall_bits": overall, + } + ) + + at_least_128 = [row["nv"] for row in rows if row["overall_bits"] >= 128.0] + cutoff = max(at_least_128) if at_least_128 else None + return rows, cutoff + + +def fmt(bits: float) -> str: + return f"{bits:.2f}" + + +def print_header(title: str) -> None: + print() + print(title) + print("=" * len(title)) + + +def print_terminology() -> None: + """Print the glossary used by the report.""" + + print_header("Terminology") + for term, definition in TERMINOLOGY_LINES: + print(f"- {term}: {definition}") + + +def print_challenge_family_definition(challenge_mass: int) -> None: + """Print the exact family definition associated with the chosen mass.""" + + print_header("Challenge Family Definition") + for line in challenge_family_definition_lines(challenge_mass): + print(line) + + +def print_intro(estimator_repo: Path, challenge_mass: int) -> None: + print_header("Hachi D=64, K=256 one-hot SIS estimator") + print(f"repo_root = {repo_root()}") + print(f"estimator_repo = {estimator_repo}") + print(f"field modulus = {Q_LABEL}") + print(f"ring degree D = {D}") + print(f"one-hot chunk size K = {K}") + print( + f"one-hot sparsity = 1-of-{K} " + f"(equiv. 1-sparse over {K} slots, density = {100.0 / K:.2f}%)" + ) + print(f"packing alpha = log2(D) = {ALPHA}") + print(f"digit basis = 2^{LOG_BASIS} = {1 << LOG_BASIS}") + print(f"challenge family = {describe_challenge_family(challenge_mass)}") + print(f"challenge mass (L1) = {challenge_mass}") + print(f"folded z bound = onehot-aware: ||z_pre||_inf <= 2^r_vars * {MAX_ABS_CHALLENGE_COEFF}") + print(f"estimator model = BDGL16 + lgsa") + + +def print_main_configuration( + layout: Layout, + estimates: list[LayerEstimate], + n_a: int, + n_b: int, + n_d: int, + challenge_mass: int, +) -> None: + inner_width = layout.inner_width + outer_width = layout.outer_width(n_a) + d_matrix_width = layout.d_matrix_width + m_code_width = d_matrix_width + outer_width + inner_width * layout.delta_fold_code + m_tight_width = d_matrix_width + outer_width + inner_width * layout.delta_fold_tight + m_collision_inf = 2 * ((1 << layout.r_vars) * MAX_ABS_CHALLENGE_COEFF) + + print_header("Main configuration estimate") + print("This section estimates a single parameter point for the family above.") + print() + print(f"main nv = {layout.nv}") + print(f"N_A, N_B, N_D = {n_a}, {n_b}, {n_d}") + print(f"challenge family = {describe_challenge_family(challenge_mass)}") + print(f"challenge mass (L1) = {challenge_mass}") + print(f"reduced vars = nv - alpha = {layout.nv - ALPHA}") + print(f"layout = (m_vars={layout.m_vars}, r_vars={layout.r_vars})") + print(f"num_blocks = 2^r_vars = {layout.num_blocks}") + print(f"block_len = 2^m_vars = {layout.block_len}") + print(f"delta_fold_tight = {layout.delta_fold_tight}") + print(f"delta_fold_code = {layout.delta_fold_code}") + print(f"inner_width = {inner_width}") + print(f"outer_width = {outer_width}") + print(f"d_matrix_width = {d_matrix_width}") + print(f"M_code width = {m_code_width}") + print(f"M_tight width = {m_tight_width}") + print(f"collision_inf(M) = {m_collision_inf}") + print() + print(f"{'layer':<12} {'sec_bits':>10} {'rank':>8} {'width_ring':>14} {'collision_inf':>14}") + for estimate in estimates: + print( + f"{estimate.name:<12} {fmt(estimate.sec_bits):>10} " + f"{estimate.rank:>8} {estimate.width_ring_elems:>14} {estimate.collision_inf:>14}" + ) + overall = min(estimate.sec_bits for estimate in estimates) + overall_layer = min(estimates, key=lambda estimate: estimate.sec_bits).name + print() + print("Layer legend:") + print("- A_fullwidth: conservative full-support proxy for the inner A layer") + print("- B / D: outer commitment layers with digit-collision bound 7") + print("- M_code: folded witness width using the generic code-style delta_fold proxy") + print("- M_tight: folded witness width using the tighter onehot-aware delta_fold") + print() + print(f"overall floor = {fmt(overall)} bits ({overall_layer})") + + +def print_sweep(rows: list[dict], cutoff: int | None) -> None: + print_header("Rank-1 cutoff sweep") + print("This sweep fixes N_B = N_D = 1 and searches for the largest nv with overall >= 128 bits.") + print( + "Columns: nv = total variables, m_vars/r_vars = layout split, " + "d_fold = delta_fold_tight, A/B/D/M = security bits by layer." + ) + print() + print( + f"{'nv':>4} {'m_vars':>7} {'r_vars':>7} {'d_fold':>7} " + f"{'A':>8} {'B/D':>8} {'M':>8} {'overall':>8}" + ) + for row in rows: + print( + f"{row['nv']:>4} {row['m_vars']:>7} {row['r_vars']:>7} {row['delta_fold']:>7} " + f"{fmt(row['A_bits']):>8} {fmt(row['BD_floor_bits']):>8} " + f"{fmt(row['M_bits']):>8} {fmt(row['overall_bits']):>8}" + ) + print() + if cutoff is None: + print("largest nv with overall >= 128 bits: none in sweep") + else: + print(f"largest nv with overall >= 128 bits: {cutoff}") + + +def main() -> None: + args = parse_args() + estimator_repo = locate_estimator_repo(args.estimator_path) + SIS, RC, log = load_estimator(estimator_repo) + + main_layout = best_layout_onehot( + args.candidate_nv, + n_a=args.candidate_na, + challenge_mass=args.challenge_mass, + ) + main_estimates = main_configuration_estimates( + SIS, + RC, + log, + main_layout, + n_a=args.candidate_na, + n_b=args.candidate_nb, + n_d=args.candidate_nd, + ) + rows, cutoff = sweep_rank1_cutoff( + SIS, + RC, + log, + min_nv=args.sweep_min_nv, + max_nv=args.sweep_max_nv, + n_a=args.candidate_na, + challenge_mass=args.challenge_mass, + ) + + print_intro(estimator_repo, args.challenge_mass) + print_challenge_family_definition(args.challenge_mass) + print_terminology() + print_main_configuration( + main_layout, + main_estimates, + n_a=args.candidate_na, + n_b=args.candidate_nb, + n_d=args.candidate_nd, + challenge_mass=args.challenge_mass, + ) + print_sweep(rows, cutoff) + + +if __name__ == "__main__": + main() diff --git a/scripts/onehot_bench_report.py b/scripts/onehot_bench_report.py new file mode 100644 index 00000000..bcfcbce1 --- /dev/null +++ b/scripts/onehot_bench_report.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import pathlib +import re +import shlex +import subprocess +import sys +from dataclasses import dataclass +from datetime import datetime, timezone + + +ANSI_RE = re.compile(r"\x1b\[[0-9;]*m") +KV_RE = re.compile(r'([A-Za-z_]+)=(".*?"|\S+)') +RSS_PATTERNS = [ + re.compile(r"Maximum resident set size \(kbytes\):\s+(\d+)"), + re.compile(r"^\s*(\d+)\s+maximum resident set size$", re.MULTILINE), +] +ONEHOT_ARITY = 256 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run and render the Hachi onehot benchmark report." + ) + subparsers = parser.add_subparsers(dest="command", required=True) + + run_parser = subparsers.add_parser("run", help="Run the benchmark and write summary files.") + run_parser.add_argument("--binary", required=True, help="Path to the benchmark binary.") + run_parser.add_argument( + "--output-dir", required=True, help="Directory where logs and summary.json are written." + ) + run_parser.add_argument("--mode", default="onehot", help="Benchmark mode.") + run_parser.add_argument("--num-vars", type=int, default=32, help="Number of variables.") + + render_parser = subparsers.add_parser( + "render", help="Render a markdown report from summary.json files." + ) + render_parser.add_argument("summary", help="Path to the current summary.json file.") + render_parser.add_argument( + "--main-baseline-dir", + default="", + help="Optional artifact directory containing the main-baseline summary.json.", + ) + render_parser.add_argument( + "--previous-baseline-dir", + default="", + help="Optional artifact directory containing the previous-run summary.json.", + ) + + return parser.parse_args() + + +def parse_kvs(line: str) -> dict[str, str]: + line = ANSI_RE.sub("", line) + out: dict[str, str] = {} + for key, raw_value in KV_RE.findall(line): + value = raw_value.rstrip(",") + if value.startswith('"') and value.endswith('"'): + value = value[1:-1] + out[key] = value + return out + + +def write_text(path: pathlib.Path, text: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(text, encoding="utf-8") + + +def time_command(binary: str) -> list[str]: + if sys.platform == "darwin": + return ["/usr/bin/time", "-l", binary] + return ["/usr/bin/time", "-v", binary] + + +def require_float(summary: dict[str, object], key: str) -> float: + value = summary.get(key) + if value is None: + raise ValueError(f"missing required metric: {key}") + return float(value) + + +def require_int(summary: dict[str, object], key: str) -> int: + value = summary.get(key) + if value is None: + raise ValueError(f"missing required metric: {key}") + return int(value) + + +def derive_hachi_labrador_split(total: float, hachi: float, labrador: float) -> tuple[float, float]: + if hachi == 0.0 and labrador == 0.0: + return total, 0.0 + if hachi == 0.0: + return max(total - labrador, 0.0), labrador + if labrador == 0.0: + return hachi, max(total - hachi, 0.0) + return hachi, labrador + + +def benchmark_name(mode: str, num_vars: int) -> str: + if mode == "onehot": + return f"1-of-{ONEHOT_ARITY} one-hot with {num_vars} variables" + return f"{mode} with {num_vars} variables" + + +def extract_summary(log_text: str, mode: str, num_vars: int) -> dict[str, object]: + summary: dict[str, object] = { + "schema_version": 1, + "benchmark": benchmark_name(mode, num_vars), + "mode": mode, + "num_vars": num_vars, + "collected_at": datetime.now(timezone.utc).isoformat(), + } + + for line in log_text.splitlines(): + line = ANSI_RE.sub("", line) + kvs = parse_kvs(line) + if " INFO setup" in line and kvs.get("label") == mode: + summary["setup_s"] = float(kvs["elapsed_s"]) + elif " INFO commit" in line and kvs.get("label") == mode: + summary["commit_s"] = float(kvs["elapsed_s"]) + elif "hachi prove complete" in line: + summary["prove_hachi_s"] = float(kvs["elapsed_s"]) + if "levels" in kvs: + summary["hachi_levels"] = int(kvs["levels"]) + elif "labrador prove complete" in line: + summary["prove_labrador_s"] = float(kvs["elapsed_s"]) + if "levels" in kvs: + summary["labrador_levels"] = int(kvs["levels"]) + elif " INFO prove" in line and kvs.get("label") == mode: + summary["prove_total_s"] = float(kvs["elapsed_s"]) + elif "hachi verify complete" in line: + summary["verify_hachi_s"] = float(kvs["elapsed_s"]) + elif "labrador verify complete" in line: + summary["verify_labrador_s"] = float(kvs["elapsed_s"]) + if "levels" in kvs and "labrador_levels" not in summary: + summary["labrador_levels"] = int(kvs["levels"]) + elif "verify OK" in line and kvs.get("label") == mode: + summary["verify_total_s"] = float(kvs["elapsed_s"]) + elif "proof summary" in line and kvs.get("label") == mode: + summary["proof_size_bytes"] = int(kvs["proof_size_bytes"]) + summary["hachi_fold_bytes"] = int(kvs["hachi_fold_bytes"]) + summary["tail_bytes"] = int(kvs["tail_bytes"]) + if "levels" in kvs and "hachi_levels" not in summary: + summary["hachi_levels"] = int(kvs["levels"]) + elif "estimated tail comparison" in line: + if "selected_tail" in kvs: + summary["selected_tail"] = kvs["selected_tail"] + if "packed_direct_bytes" in kvs: + summary["packed_direct_bytes"] = int(kvs["packed_direct_bytes"]) + if "estimated_labrador_tail_bytes" in kvs: + summary["estimated_labrador_tail_bytes"] = int( + kvs["estimated_labrador_tail_bytes"] + ) + + for index, pattern in enumerate(RSS_PATTERNS): + rss_match = pattern.search(log_text) + if rss_match: + rss_value = int(rss_match.group(1)) + if index == 1 and sys.platform == "darwin": + rss_value //= 1024 + summary["max_rss_kib"] = rss_value + break + + prove_total = require_float(summary, "prove_total_s") + prove_hachi = float(summary.get("prove_hachi_s", 0.0)) + prove_labrador = float(summary.get("prove_labrador_s", 0.0)) + summary["prove_hachi_s"], summary["prove_labrador_s"] = derive_hachi_labrador_split( + prove_total, + prove_hachi, + prove_labrador, + ) + + verify_total = require_float(summary, "verify_total_s") + verify_hachi = float(summary.get("verify_hachi_s", 0.0)) + verify_labrador = float(summary.get("verify_labrador_s", 0.0)) + summary["verify_hachi_s"], summary["verify_labrador_s"] = derive_hachi_labrador_split( + verify_total, + verify_hachi, + verify_labrador, + ) + + summary.setdefault("selected_tail", "unknown") + return summary + + +def run_benchmark(args: argparse.Namespace) -> int: + output_dir = pathlib.Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + env = os.environ.copy() + env["HACHI_MODE"] = args.mode + env["HACHI_NUM_VARS"] = str(args.num_vars) + env.setdefault("HACHI_PROFILE_TRACE", "0") + env.setdefault("HACHI_PROFILE_SPAN_CLOSES", "0") + env.setdefault("HACHI_PROFILE_LOG", "info") + env.setdefault("HACHI_PROFILE_ANSI", "0") + + command = time_command(args.binary) + completed = subprocess.run(command, capture_output=True, text=True, env=env) + combined_log = completed.stdout + completed.stderr + + write_text(output_dir / "stdout.log", completed.stdout) + write_text(output_dir / "stderr.log", completed.stderr) + write_text(output_dir / "benchmark.log", combined_log) + write_text(output_dir / "command.txt", " ".join(shlex.quote(part) for part in command) + "\n") + + if completed.returncode != 0: + return completed.returncode + + summary = extract_summary(combined_log, mode=args.mode, num_vars=args.num_vars) + summary["command"] = command + summary["binary"] = args.binary + summary["exit_code"] = completed.returncode + summary["env"] = { + "HACHI_MODE": env["HACHI_MODE"], + "HACHI_NUM_VARS": env["HACHI_NUM_VARS"], + "HACHI_PROFILE_TRACE": env["HACHI_PROFILE_TRACE"], + "HACHI_PROFILE_SPAN_CLOSES": env["HACHI_PROFILE_SPAN_CLOSES"], + "HACHI_PROFILE_LOG": env["HACHI_PROFILE_LOG"], + "HACHI_PROFILE_ANSI": env["HACHI_PROFILE_ANSI"], + } + + write_text(output_dir / "summary.json", json.dumps(summary, indent=2, sort_keys=True) + "\n") + return 0 + + +def load_summary(path: pathlib.Path) -> dict[str, object]: + return json.loads(path.read_text(encoding="utf-8")) + + +def load_optional_summary(dir_path: str) -> dict[str, object] | None: + if not dir_path: + return None + summary_path = pathlib.Path(dir_path) / "summary.json" + if not summary_path.exists(): + return None + return load_summary(summary_path) + + +def commit_ref(sha: str | None) -> str | None: + if not sha: + return None + short = sha[:7] + repo = os.environ.get("GITHUB_REPOSITORY") + if repo: + return f"[`{short}`](https://github.com/{repo}/commit/{sha})" + return f"`{short}`" + + +def fmt_seconds(value: float) -> str: + return f"{value:.3f}" + + +def fmt_mib(value_kib: float) -> str: + return f"{value_kib / 1024.0:.1f}" + + +def fmt_bytes(value: float) -> str: + return f"{int(round(value)):,}" + + +@dataclass(frozen=True) +class Metric: + key: str + name: str + unit: str + value_formatter: callable + + +TIME_METRICS = [ + Metric("setup_s", "Setup", "s", fmt_seconds), + Metric("commit_s", "Commit", "s", fmt_seconds), + Metric("prove_hachi_s", "Prove (Hachi)", "s", fmt_seconds), + Metric("prove_labrador_s", "Prove (Labrador)", "s", fmt_seconds), + Metric("prove_total_s", "Prove (Total)", "s", fmt_seconds), + Metric("verify_hachi_s", "Verify (Hachi)", "s", fmt_seconds), + Metric("verify_labrador_s", "Verify (Labrador)", "s", fmt_seconds), + Metric("verify_total_s", "Verify (Total)", "s", fmt_seconds), + Metric("max_rss_kib", "Max RSS", "MiB", fmt_mib), +] + + +def render_metric_row( + metric: Metric, + current: dict[str, object], + baselines: list[tuple[str, dict[str, object] | None]], +) -> str: + current_value = current.get(metric.key) + if current_value is None: + return "" + + columns: list[str] = [] + for _, summary in baselines: + if summary is None or summary.get(metric.key) is None: + columns.append("n/a") + else: + columns.append(metric.value_formatter(float(summary[metric.key]))) + + columns.append(metric.value_formatter(float(current_value))) + return f"| {metric.name} | " + " | ".join(columns) + f" | {metric.unit} |" + + +def render_report(args: argparse.Namespace) -> int: + summary_path = pathlib.Path(args.summary) + current = load_summary(summary_path) + + baselines: list[tuple[str, dict[str, object] | None]] = [ + ("Main baseline", load_optional_summary(args.main_baseline_dir)), + ("Previous run", load_optional_summary(args.previous_baseline_dir)), + ] + visible_baselines = [(label, summary) for label, summary in baselines if summary is not None] + + source_sha = os.environ.get("HACHI_BENCH_SOURCE_SHA") + source_subject = os.environ.get("HACHI_BENCH_SOURCE_SUBJECT") + source_branch = os.environ.get("HACHI_BENCH_SOURCE_BRANCH") or os.environ.get("GITHUB_REF_NAME") + main_baseline_sha = os.environ.get("HACHI_BENCH_MAIN_BASELINE_SHA") + main_baseline_label = os.environ.get("HACHI_BENCH_MAIN_BASELINE_LABEL") + previous_baseline_sha = os.environ.get("HACHI_BENCH_PREVIOUS_BASELINE_SHA") + previous_baseline_label = os.environ.get("HACHI_BENCH_PREVIOUS_BASELINE_LABEL") + + print("## One-hot 32 Variables Benchmark Report") + print() + print(f"- Benchmark: `{benchmark_name(current['mode'], int(current['num_vars']))}`") + if current["mode"] == "onehot": + print( + f"- Sparsity: `1-of-{ONEHOT_ARITY}` one-hot " + f"(equivalently, `1`-sparse over `{ONEHOT_ARITY}` slots, density `{100.0 / ONEHOT_ARITY:.2f}%`)." + ) + ref = commit_ref(source_sha) + if ref: + print(f"- Latest run: {ref}") + if source_subject: + print(f"- Message: {source_subject}") + if source_branch: + print(f"- Ref: `{source_branch}`") + if visible_baselines: + main_ref = commit_ref(main_baseline_sha) + if baselines[0][1] is not None: + if main_ref and main_baseline_label: + print(f"- Main baseline: {main_ref} from {main_baseline_label}.") + elif main_ref: + print(f"- Main baseline: {main_ref}.") + elif main_baseline_label: + print(f"- Main baseline: {main_baseline_label}.") + + previous_ref = commit_ref(previous_baseline_sha) + if baselines[1][1] is not None: + if previous_ref and previous_baseline_label: + print(f"- Previous run: {previous_ref} from {previous_baseline_label}.") + elif previous_ref: + print(f"- Previous run: {previous_ref}.") + elif previous_baseline_label: + print(f"- Previous run: {previous_baseline_label}.") + print( + "- Command: `target/release/examples/profile` with " + f"`HACHI_MODE={current['mode']}` `HACHI_NUM_VARS={current['num_vars']}` " + "`HACHI_PROFILE_TRACE=0` `HACHI_PROFILE_SPAN_CLOSES=0` " + "`HACHI_PROFILE_LOG=info` `HACHI_PROFILE_ANSI=0`." + ) + print("- Memory: maximum resident set size from `/usr/bin/time` on the benchmark process.") + print() + + column_labels = [label for label, _ in visible_baselines] + ["Latest run"] + print("| Metric | " + " | ".join(column_labels) + " | Unit |") + print("| --- | " + " | ".join("---:" for _ in column_labels) + " | --- |") + + for metric in TIME_METRICS: + row = render_metric_row(metric, current, visible_baselines) + if row: + print(row) + + print() + print(f"- Tail: `{current.get('selected_tail', 'unknown')}`") + if current.get("proof_size_bytes") is not None: + print(f"- Proof size: `{fmt_bytes(float(current['proof_size_bytes']))} B`") + if current.get("hachi_levels") is not None: + print(f"- Hachi levels: `{current['hachi_levels']}`") + if current.get("labrador_levels") is not None: + print(f"- Labrador levels: `{current['labrador_levels']}`") + + return 0 + + +def main() -> int: + args = parse_args() + if args.command == "run": + return run_benchmark(args) + if args.command == "render": + return render_report(args) + raise ValueError(f"unsupported command: {args.command}") + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/algebra/backend/mod.rs b/src/algebra/backend/mod.rs new file mode 100644 index 00000000..3b885ce1 --- /dev/null +++ b/src/algebra/backend/mod.rs @@ -0,0 +1,7 @@ +//! Backend contracts and concrete backend implementations. + +pub mod scalar; +pub mod traits; + +pub use scalar::ScalarBackend; +pub use traits::{CrtReconstruct, NttPrimeOps, NttTransform, RingBackend}; diff --git a/src/algebra/backend/scalar.rs b/src/algebra/backend/scalar.rs new file mode 100644 index 00000000..a1386d72 --- /dev/null +++ b/src/algebra/backend/scalar.rs @@ -0,0 +1,94 @@ +//! Default scalar backend: delegates to NTT kernels and uses Garner's +//! algorithm for CRT reconstruction. + +use super::traits::{CrtReconstruct, NttPrimeOps, NttTransform}; +use crate::algebra::ntt::butterfly::{forward_ntt, inverse_ntt, NttTwiddles}; +use crate::algebra::ntt::crt::GarnerData; +use crate::algebra::ntt::prime::{MontCoeff, NttPrime, PrimeWidth}; +use crate::algebra::ring::CrtNttConvertibleField; + +/// Default scalar backend implementation. +#[derive(Debug, Clone, Copy, Default)] +pub struct ScalarBackend; + +impl NttPrimeOps for ScalarBackend { + #[inline] + fn from_canonical(prime: NttPrime, value: W) -> MontCoeff { + prime.from_canonical(value) + } + + #[inline] + fn to_canonical(prime: NttPrime, value: MontCoeff) -> W { + prime.to_canonical(value) + } + + #[inline] + fn reduce_range(prime: NttPrime, value: MontCoeff) -> MontCoeff { + prime.reduce_range(value) + } + + #[inline] + fn pointwise_mul( + prime: NttPrime, + out: &mut [MontCoeff; D], + lhs: &[MontCoeff; D], + rhs: &[MontCoeff; D], + ) { + prime.pointwise_mul(out, lhs, rhs); + } +} + +impl NttTransform for ScalarBackend { + #[inline] + fn forward_ntt(limb: &mut [MontCoeff; D], prime: NttPrime, twiddles: &NttTwiddles) { + forward_ntt(limb, prime, twiddles); + } + + #[inline] + fn inverse_ntt(limb: &mut [MontCoeff; D], prime: NttPrime, twiddles: &NttTwiddles) { + inverse_ntt(limb, prime, twiddles); + } +} + +impl CrtReconstruct for ScalarBackend { + fn reconstruct( + primes: &[NttPrime; K], + canonical: &[[W; D]; K], + garner: &GarnerData, + ) -> [F; D] { + let mut coeffs = [F::zero(); D]; + for (d, coeff) in coeffs.iter_mut().enumerate() { + // Garner mixed-radix decomposition (all arithmetic in i64, mod p_i). + let mut v = [0i64; K]; + v[0] = canonical[0][d].to_i64(); + for i in 1..K { + let pi = primes[i].p.to_i64(); + let mut temp = canonical[i][d].to_i64(); + #[allow(clippy::needless_range_loop)] + for j in 0..i { + temp -= v[j]; + temp = ((temp % pi) + pi) % pi; + temp = (temp * garner.gamma[i][j].to_i64()) % pi; + } + // Center the mixed-radix digit to keep the final reconstruction + // in a small signed range when inputs are centered. + if temp > pi / 2 { + temp -= pi; + } + v[i] = temp; + } + + // Horner accumulation in the target field F. + let mut result = F::from_i64(v[0]); + let mut partial_prod = F::from_i64(primes[0].p.to_i64()); + for i in 1..K { + result += F::from_i64(v[i]) * partial_prod; + if i + 1 < K { + partial_prod = partial_prod * F::from_i64(primes[i].p.to_i64()); + } + } + *coeff = result; + } + coeffs + } +} diff --git a/src/algebra/backend/traits.rs b/src/algebra/backend/traits.rs new file mode 100644 index 00000000..118a5f1f --- /dev/null +++ b/src/algebra/backend/traits.rs @@ -0,0 +1,59 @@ +//! Backend traits for CRT+NTT execution semantics. +//! +//! All traits are generic over `W: PrimeWidth` to support both +//! `i16` (primes < 2^14) and `i32` (primes < 2^30) NTT backends. + +use crate::algebra::ntt::butterfly::NttTwiddles; +use crate::algebra::ntt::crt::GarnerData; +use crate::algebra::ntt::prime::{MontCoeff, NttPrime, PrimeWidth}; +use crate::algebra::ring::CrtNttConvertibleField; + +/// Per-prime arithmetic primitives used by CRT+NTT domains. +pub trait NttPrimeOps { + /// Convert canonical coefficient to backend prime representation. + fn from_canonical(prime: NttPrime, value: W) -> MontCoeff; + + /// Convert backend prime representation back to canonical coefficient. + fn to_canonical(prime: NttPrime, value: MontCoeff) -> W; + + /// Range-reduce one coefficient from `(-2p, 2p)` to `(-p, p)`. + fn reduce_range(prime: NttPrime, value: MontCoeff) -> MontCoeff; + + /// Pointwise multiplication in backend prime representation. + fn pointwise_mul( + prime: NttPrime, + out: &mut [MontCoeff; D], + lhs: &[MontCoeff; D], + rhs: &[MontCoeff; D], + ); +} + +/// Forward/inverse transform kernels for one NTT limb. +pub trait NttTransform { + /// Forward transform from coefficient limb to NTT limb. + fn forward_ntt(limb: &mut [MontCoeff; D], prime: NttPrime, twiddles: &NttTwiddles); + + /// Inverse transform from NTT limb to coefficient limb. + fn inverse_ntt(limb: &mut [MontCoeff; D], prime: NttPrime, twiddles: &NttTwiddles); +} + +/// CRT reconstruction from per-prime canonical coefficients via Garner's algorithm. +pub trait CrtReconstruct { + /// Reconstruct coefficient-domain values from canonical CRT residues. + fn reconstruct( + primes: &[NttPrime; K], + canonical_limbs: &[[W; D]; K], + garner: &GarnerData, + ) -> [F; D]; +} + +/// Convenience composition trait for full ring backend capability. +pub trait RingBackend: + NttPrimeOps + NttTransform + CrtReconstruct +{ +} + +impl RingBackend for T where + T: NttPrimeOps + NttTransform + CrtReconstruct +{ +} diff --git a/src/algebra/fields/ext.rs b/src/algebra/fields/ext.rs new file mode 100644 index 00000000..b46cbbf9 --- /dev/null +++ b/src/algebra/fields/ext.rs @@ -0,0 +1,789 @@ +//! Quadratic and quartic extension fields. + +use super::wide::{AccumPair, HasUnreducedOps}; +use crate::algebra::module::VectorModule; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{AdditiveGroup, FieldCore, FieldSampling, FromSmallInt}; + +/// `Fp2Config` with non-residue = -1. +/// +/// Valid when `p ≡ 3 (mod 4)`, i.e. -1 is a quadratic non-residue. +pub struct NegOneNr; + +impl Fp2Config for NegOneNr { + const IS_NEG_ONE: bool = true; + + fn non_residue() -> F { + -F::one() + } +} + +/// `Fp2Config` with non-residue = 2. +/// +/// Valid when `p ≡ 5 (mod 8)`, i.e. 2 is a quadratic non-residue. +/// All Hachi pseudo-Mersenne primes (`2^k - c` with `c ≡ 3 mod 8`) +/// satisfy this. +pub struct TwoNr; + +impl Fp2Config for TwoNr { + fn non_residue() -> F { + F::from_u64(2) + } +} +use rand_core::RngCore; +use std::io::{Read, Write}; +use std::marker::PhantomData; +use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; + +/// Parameters for an `Fp2` quadratic extension over base field `F`. +pub trait Fp2Config { + /// Whether the non-residue is -1. + /// + /// When `true`, multiplication by the non-residue is a free negation and + /// the Karatsuba/squaring routines can avoid a base-field multiply. + const IS_NEG_ONE: bool = false; + + /// Non-residue `NR` such that `u^2 = NR`. + fn non_residue() -> F; +} + +/// Quadratic extension element `c0 + c1 * u` with `u^2 = NR`. +pub struct Fp2> { + /// Constant term. + pub c0: F, + /// Coefficient of `u`. + pub c1: F, + _cfg: PhantomData C>, +} + +impl> Fp2 { + /// Construct `c0 + c1 * u`. + #[inline] + pub fn new(c0: F, c1: F) -> Self { + Self { + c0, + c1, + _cfg: PhantomData, + } + } + + /// Multiply a base-field element by the non-residue. + /// + /// When `IS_NEG_ONE` is true this is just a negation (no multiply). + #[inline(always)] + fn mul_nr(x: F) -> F { + if C::IS_NEG_ONE { + -x + } else { + C::non_residue() * x + } + } + + /// Return the conjugate `c0 - c1 * u`. + #[inline] + pub fn conjugate(self) -> Self { + Self::new(self.c0, -self.c1) + } + + /// Return the norm in the base field: `c0^2 - NR * c1^2`. + #[inline] + pub fn norm(self) -> F { + (self.c0 * self.c0) - Self::mul_nr(self.c1 * self.c1) + } +} + +impl> std::fmt::Debug for Fp2 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Fp2") + .field("c0", &self.c0) + .field("c1", &self.c1) + .finish() + } +} + +impl> Clone for Fp2 { + fn clone(&self) -> Self { + *self + } +} + +impl> Copy for Fp2 {} + +impl> Default for Fp2 { + fn default() -> Self { + Self::new(F::zero(), F::zero()) + } +} + +impl> PartialEq for Fp2 { + fn eq(&self, other: &Self) -> bool { + self.c0 == other.c0 && self.c1 == other.c1 + } +} + +impl> Eq for Fp2 {} + +impl> Add for Fp2 { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + Self::new(self.c0 + rhs.c0, self.c1 + rhs.c1) + } +} +impl> Sub for Fp2 { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + Self::new(self.c0 - rhs.c0, self.c1 - rhs.c1) + } +} +impl> Neg for Fp2 { + type Output = Self; + fn neg(self) -> Self::Output { + Self::new(-self.c0, -self.c1) + } +} +impl> AddAssign for Fp2 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} +impl> SubAssign for Fp2 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} +impl> Mul for Fp2 { + type Output = Self; + fn mul(self, rhs: Self) -> Self::Output { + let v0 = self.c0 * rhs.c0; + let v1 = self.c1 * rhs.c1; + Self::new( + v0 + Self::mul_nr(v1), + (self.c0 + self.c1) * (rhs.c0 + rhs.c1) - v0 - v1, + ) + } +} + +impl<'a, F: FieldCore, C: Fp2Config> Add<&'a Self> for Fp2 { + type Output = Self; + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} +impl<'a, F: FieldCore, C: Fp2Config> Sub<&'a Self> for Fp2 { + type Output = Self; + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} +impl<'a, F: FieldCore, C: Fp2Config> Mul<&'a Self> for Fp2 { + type Output = Self; + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl> Valid for Fp2 { + fn check(&self) -> Result<(), SerializationError> { + self.c0.check()?; + self.c1.check()?; + Ok(()) + } +} + +impl> HachiSerialize for Fp2 { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.c0.serialize_with_mode(&mut writer, compress)?; + self.c1.serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.c0.serialized_size(compress) + self.c1.serialized_size(compress) + } +} + +impl> HachiDeserialize for Fp2 { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let c0 = F::deserialize_with_mode(&mut reader, compress, validate)?; + let c1 = F::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self::new(c0, c1); + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl> AdditiveGroup for Fp2 { + const ZERO: Self = Self { + c0: F::ZERO, + c1: F::ZERO, + _cfg: PhantomData, + }; +} + +impl> FieldCore for Fp2 { + fn one() -> Self { + Self::new(F::one(), F::zero()) + } + + fn is_zero(&self) -> bool { + self.c0.is_zero() && self.c1.is_zero() + } + + /// Specialized squaring: 2 base-field multiplications instead of 3. + /// + /// `(c0 + c1·u)^2 = (c0^2 + NR·c1^2) + (2·c0·c1)·u` + fn square(&self) -> Self { + let v0 = self.c0 * self.c0; + let v1 = self.c1 * self.c1; + Self::new(v0 + Self::mul_nr(v1), (self.c0 + self.c0) * self.c1) + } + + fn inv(self) -> Option { + if self.is_zero() { + return None; + } + let inv_n = self.norm().inv()?; + Some(Self::new(self.c0 * inv_n, (-self.c1) * inv_n)) + } + + const TWO_INV: Self = Self { + c0: F::TWO_INV, + c1: F::ZERO, + _cfg: PhantomData, + }; +} + +impl> FieldSampling for Fp2 { + fn sample(rng: &mut R) -> Self { + Self::new(F::sample(rng), F::sample(rng)) + } +} + +impl> FromSmallInt for Fp2 { + fn from_u64(val: u64) -> Self { + Self::new(F::from_u64(val), F::zero()) + } + + fn from_i64(val: i64) -> Self { + Self::new(F::from_i64(val), F::zero()) + } +} + +impl> HasUnreducedOps for Fp2 { + type MulU64Accum = AccumPair; + type ProductAccum = AccumPair; + + #[inline] + fn mul_u64_unreduced(self, small: u64) -> AccumPair { + AccumPair( + self.c0.mul_u64_unreduced(small), + self.c1.mul_u64_unreduced(small), + ) + } + + #[inline] + fn mul_to_product_accum(self, other: Self) -> AccumPair { + // Karatsuba: (c0 + c1·u)(d0 + d1·u) = (c0·d0 + NR·c1·d1) + (c0·d1 + c1·d0)·u + let v0 = self.c0.mul_to_product_accum(other.c0); + let v1 = self.c1.mul_to_product_accum(other.c1); + let cross = (self.c0 + self.c1).mul_to_product_accum(other.c0 + other.c1); + + let nr_v1 = if C::IS_NEG_ONE { -v1 } else { v1 + v1 }; + AccumPair(v0 + nr_v1, cross - v0 - v1) + } + + #[inline] + fn reduce_mul_u64_accum(accum: AccumPair) -> Self { + Self::new( + F::reduce_mul_u64_accum(accum.0), + F::reduce_mul_u64_accum(accum.1), + ) + } + + #[inline] + fn reduce_product_accum(accum: AccumPair) -> Self { + Self::new( + F::reduce_product_accum(accum.0), + F::reduce_product_accum(accum.1), + ) + } +} + +/// Parameters for an `Fp4` quadratic extension over `Fp2`. +pub trait Fp4Config> { + /// Non-residue `NR2` in `Fp2` such that `v^2 = NR2`. + fn non_residue() -> Fp2; +} + +/// `Fp4Config` with non-residue `u ∈ Fp2` (the element `(0, 1)`). +/// +/// This is the standard tower choice: `Fp4 = Fp2[v] / (v^2 - u)`. +pub struct UnitNr; + +impl> Fp4Config for UnitNr { + fn non_residue() -> Fp2 { + Fp2::new(F::zero(), F::one()) + } +} + +/// Quartic extension element `c0 + c1 * v` over `Fp2`, where `v^2 = NR2`. +pub struct Fp4, C4: Fp4Config> { + /// Constant term. + pub c0: Fp2, + /// Coefficient of `v`. + pub c1: Fp2, + _cfg: PhantomData C4>, +} + +impl, C4: Fp4Config> Fp4 { + /// Construct `c0 + c1 * v`. + #[inline] + pub fn new(c0: Fp2, c1: Fp2) -> Self { + Self { + c0, + c1, + _cfg: PhantomData, + } + } + + /// Return the norm in `Fp2`: `c0^2 - NR2 * c1^2`. + #[inline] + pub fn norm(self) -> Fp2 { + let nr2 = C4::non_residue(); + (self.c0 * self.c0) - (nr2 * (self.c1 * self.c1)) + } +} + +impl, C4: Fp4Config> std::fmt::Debug + for Fp4 +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Fp4") + .field("c0", &self.c0) + .field("c1", &self.c1) + .finish() + } +} + +impl, C4: Fp4Config> Clone for Fp4 { + fn clone(&self) -> Self { + *self + } +} + +impl, C4: Fp4Config> Copy for Fp4 {} + +impl, C4: Fp4Config> Default for Fp4 { + fn default() -> Self { + Self::new( + Fp2::new(F::zero(), F::zero()), + Fp2::new(F::zero(), F::zero()), + ) + } +} + +impl, C4: Fp4Config> PartialEq for Fp4 { + fn eq(&self, other: &Self) -> bool { + self.c0 == other.c0 && self.c1 == other.c1 + } +} + +impl, C4: Fp4Config> Eq for Fp4 {} + +impl, C4: Fp4Config> Add for Fp4 { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + Self::new(self.c0 + rhs.c0, self.c1 + rhs.c1) + } +} +impl, C4: Fp4Config> Sub for Fp4 { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + Self::new(self.c0 - rhs.c0, self.c1 - rhs.c1) + } +} +impl, C4: Fp4Config> Neg for Fp4 { + type Output = Self; + fn neg(self) -> Self::Output { + Self::new(-self.c0, -self.c1) + } +} +impl, C4: Fp4Config> AddAssign for Fp4 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} +impl, C4: Fp4Config> SubAssign for Fp4 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} +impl, C4: Fp4Config> Mul for Fp4 { + type Output = Self; + fn mul(self, rhs: Self) -> Self::Output { + let nr2 = C4::non_residue(); + let v0 = self.c0 * rhs.c0; + let v1 = self.c1 * rhs.c1; + Self::new( + v0 + (nr2 * v1), + (self.c0 + self.c1) * (rhs.c0 + rhs.c1) - v0 - v1, + ) + } +} + +impl<'a, F: FieldCore, C2: Fp2Config, C4: Fp4Config> Add<&'a Self> for Fp4 { + type Output = Self; + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} +impl<'a, F: FieldCore, C2: Fp2Config, C4: Fp4Config> Sub<&'a Self> for Fp4 { + type Output = Self; + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} +impl<'a, F: FieldCore, C2: Fp2Config, C4: Fp4Config> Mul<&'a Self> for Fp4 { + type Output = Self; + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl, C4: Fp4Config> Valid for Fp4 { + fn check(&self) -> Result<(), SerializationError> { + self.c0.check()?; + self.c1.check()?; + Ok(()) + } +} + +impl, C4: Fp4Config> HachiSerialize for Fp4 { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.c0.serialize_with_mode(&mut writer, compress)?; + self.c1.serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.c0.serialized_size(compress) + self.c1.serialized_size(compress) + } +} + +impl, C4: Fp4Config> HachiDeserialize + for Fp4 +{ + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let c0 = Fp2::::deserialize_with_mode(&mut reader, compress, validate)?; + let c1 = Fp2::::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self::new(c0, c1); + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl, C4: Fp4Config> AdditiveGroup for Fp4 { + const ZERO: Self = Self { + c0: Fp2::ZERO, + c1: Fp2::ZERO, + _cfg: PhantomData, + }; +} + +impl, C4: Fp4Config> FieldCore for Fp4 { + fn one() -> Self { + Self::new(Fp2::one(), Fp2::zero()) + } + + fn is_zero(&self) -> bool { + self.c0.is_zero() && self.c1.is_zero() + } + + fn square(&self) -> Self { + let nr2 = C4::non_residue(); + let v0 = self.c0.square(); + let v1 = self.c1.square(); + Self::new(v0 + nr2 * v1, (self.c0 + self.c0) * self.c1) + } + + fn inv(self) -> Option { + if self.is_zero() { + return None; + } + let inv_n = self.norm().inv()?; + Some(Self::new(self.c0 * inv_n, (-self.c1) * inv_n)) + } + + const TWO_INV: Self = Self { + c0: Fp2::TWO_INV, + c1: Fp2::ZERO, + _cfg: PhantomData, + }; +} + +impl, C4: Fp4Config> FieldSampling + for Fp4 +{ + fn sample(rng: &mut R) -> Self { + Self::new(Fp2::sample(rng), Fp2::sample(rng)) + } +} + +impl, C4: Fp4Config> FromSmallInt + for Fp4 +{ + fn from_u64(val: u64) -> Self { + Self::new(Fp2::from_u64(val), Fp2::zero()) + } + + fn from_i64(val: i64) -> Self { + Self::new(Fp2::from_i64(val), Fp2::zero()) + } +} + +// Scalar * VectorModule impls for extension scalars. + +impl Mul, N>> for Fp2 +where + F: FieldCore + Valid, + C: Fp2Config, +{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, F, C, const N: usize> Mul<&'a VectorModule, N>> for Fp2 +where + F: FieldCore + Valid, + C: Fp2Config, +{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} + +impl Mul, N>> for Fp4 +where + F: FieldCore + Valid, + C2: Fp2Config, + C4: Fp4Config, +{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, F, C2, C4, const N: usize> Mul<&'a VectorModule, N>> for Fp4 +where + F: FieldCore + Valid, + C2: Fp2Config, + C4: Fp4Config, +{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} + +// Convenience aliases for extension fields with NR = 2 (valid for all Hachi +// pseudo-Mersenne primes where p ≡ 5 mod 8). + +/// Quadratic extension over any `F` with non-residue 2. +pub type Ext2 = Fp2; + +/// Quartic extension as tower `Ext2[v]/(v^2 - u)`. +pub type Ext4 = Fp4; + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::lift::ExtField; + use crate::algebra::Fp64; + use crate::{FieldCore, FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp64<4294967197>; + type E2 = Ext2; + type E4 = Ext4; + + #[test] + fn fp2_add_sub_identity() { + let a = E2::new(F::from_u64(3), F::from_u64(5)); + let b = E2::new(F::from_u64(7), F::from_u64(11)); + let c = a + b; + assert_eq!(c - b, a); + assert_eq!(c - a, b); + } + + #[test] + fn fp2_mul_one() { + let a = E2::new(F::from_u64(42), F::from_u64(13)); + assert_eq!(a * E2::one(), a); + assert_eq!(E2::one() * a, a); + } + + #[test] + fn fp2_mul_commutativity() { + let mut rng = StdRng::seed_from_u64(1234); + let a = E2::sample(&mut rng); + let b = E2::sample(&mut rng); + assert_eq!(a * b, b * a); + } + + #[test] + fn fp2_karatsuba_matches_schoolbook() { + let mut rng = StdRng::seed_from_u64(5678); + for _ in 0..100 { + let a = E2::sample(&mut rng); + let b = E2::sample(&mut rng); + let nr = >::non_residue(); + let expected = E2::new( + (a.c0 * b.c0) + (nr * (a.c1 * b.c1)), + (a.c0 * b.c1) + (a.c1 * b.c0), + ); + assert_eq!(a * b, expected); + } + } + + #[test] + fn fp2_square_matches_mul() { + let mut rng = StdRng::seed_from_u64(9012); + for _ in 0..100 { + let a = E2::sample(&mut rng); + assert_eq!(a.square(), a * a, "square mismatch for {a:?}"); + } + } + + #[test] + fn fp2_inv() { + let mut rng = StdRng::seed_from_u64(3456); + for _ in 0..50 { + let a = E2::sample(&mut rng); + if !a.is_zero() { + let inv = a.inv().unwrap(); + assert_eq!(a * inv, E2::one()); + } + } + } + + #[test] + fn fp4_mul_commutativity() { + let mut rng = StdRng::seed_from_u64(7890); + let a = E4::sample(&mut rng); + let b = E4::sample(&mut rng); + assert_eq!(a * b, b * a); + } + + #[test] + fn fp4_square_matches_mul() { + let mut rng = StdRng::seed_from_u64(1111); + for _ in 0..50 { + let a = E4::sample(&mut rng); + assert_eq!(a.square(), a * a); + } + } + + #[test] + fn fp4_inv() { + let mut rng = StdRng::seed_from_u64(2222); + for _ in 0..50 { + let a = E4::sample(&mut rng); + if !a.is_zero() { + let inv = a.inv().unwrap(); + assert_eq!(a * inv, E4::one()); + } + } + } + + #[test] + fn from_small_int_fp2() { + let a = E2::from_u64(42); + assert_eq!(a, E2::new(F::from_u64(42), F::zero())); + + let b = E2::from_i64(-3); + assert_eq!(b, E2::new(F::from_i64(-3), F::zero())); + + let c = E2::from_u8(7); + assert_eq!(c, E2::from_u64(7)); + + let d = E2::from_u32(100_000); + assert_eq!(d, E2::from_u64(100_000)); + } + + #[test] + fn from_small_int_fp4() { + let a = E4::from_u64(42); + assert_eq!(a, E4::new(E2::from_u64(42), E2::zero())); + + let b = E4::from_i64(-7); + assert_eq!(b, E4::new(E2::from_i64(-7), E2::zero())); + } + + #[test] + fn ext_field_degree() { + assert_eq!(>::EXT_DEGREE, 1); + assert_eq!(>::EXT_DEGREE, 2); + assert_eq!(>::EXT_DEGREE, 4); + } + + #[test] + fn ext_field_from_base_slice() { + let c0 = F::from_u64(3); + let c1 = F::from_u64(5); + let e2 = E2::from_base_slice(&[c0, c1]); + assert_eq!(e2, E2::new(c0, c1)); + + let c2 = F::from_u64(7); + let c3 = F::from_u64(11); + let e4 = E4::from_base_slice(&[c0, c1, c2, c3]); + assert_eq!(e4, E4::new(E2::new(c0, c1), E2::new(c2, c3))); + } + + #[test] + fn eq_impl() { + let a = E2::new(F::from_u64(1), F::from_u64(2)); + let b = E2::new(F::from_u64(1), F::from_u64(2)); + let c = E2::new(F::from_u64(1), F::from_u64(3)); + assert_eq!(a, b); + assert_ne!(a, c); + } +} diff --git a/src/algebra/fields/fp128.rs b/src/algebra/fields/fp128.rs new file mode 100644 index 00000000..5b5a8ca5 --- /dev/null +++ b/src/algebra/fields/fp128.rs @@ -0,0 +1,1067 @@ +//! 128-bit prime field for primes of the form `p = 2^128 − c` with `c < 2^64`. +//! +//! Uses Solinas-style two-fold reduction: no Montgomery form, ~23 cycles/mul +//! on both AArch64 and x86-64. The offset `c` is computed at compile time +//! from the const-generic modulus `P`. +//! +//! ## Naming convention for built-in primes +//! +//! The built-in type names encode the **signed terms as they appear in the +//! modulus `p`** (excluding the leading `+2^128` term). For example, +//! `Prime128M13M4P0` denotes `p = 2^128 − 2^13 − 2^4 + 2^0`. + +use std::io::{Read, Write}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use rand_core::RngCore; + +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{ + AdditiveGroup, CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible, + PseudoMersenneField, +}; + +/// Pack two u64 limbs into `[lo, hi]`. +#[inline(always)] +const fn pack(lo: u64, hi: u64) -> [u64; 2] { + [lo, hi] +} + +/// Convert `u128` → `[u64; 2]`. +#[inline(always)] +const fn from_u128(x: u128) -> [u64; 2] { + [x as u64, (x >> 64) as u64] +} + +/// Convert `[u64; 2]` → `u128`. +#[inline(always)] +const fn to_u128(x: [u64; 2]) -> u128 { + x[0] as u128 | (x[1] as u128) << 64 +} + +use super::util::{is_pow2_u64, log2_pow2_u64, mul64_wide}; + +/// 128-bit prime field element for primes `p = 2^128 − c` with `c < 2^64`. +/// +/// Stored as `[u64; 2]` (lo, hi) for 8-byte alignment and direct limb access. +/// +/// The offset `c = 2^128 − p` and all derived constants are computed at +/// compile time from the const-generic `P`. Instantiating `Fp128` with a +/// modulus that is not of this form is a compile-time error. +#[derive(Debug, Clone, Copy, Default)] +pub struct Fp128(pub(crate) [u64; 2]); + +impl PartialEq for Fp128

{ + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for Fp128

{} + +impl Fp128

{ + /// Offset `c = 2^128 − p`. Validated at compile time. + pub const C: u128 = { + let c = 0u128.wrapping_sub(P); + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!(c < (1u128 << 64), "P must be 2^128 - c with c < 2^64"); + // Fused overflow+canonicalize requires C(C+1) < P. + assert!( + c * (c + 1) < P, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + /// Low 64 bits of `C` (always equals `C` since `C < 2^64`). + pub const C_LO: u64 = Self::C as u64; + /// +1 means `C = 2^a + 1`, -1 means `C = 2^a - 1`, 0 means generic. + const C_SHIFT_KIND: i8 = { + let c = Self::C_LO; + if c > 1 && is_pow2_u64(c - 1) { + 1 + } else if c == u64::MAX || is_pow2_u64(c + 1) { + -1 + } else { + 0 + } + }; + const C_SHIFT: u32 = { + let c = Self::C_LO; + if Self::C_SHIFT_KIND == 1 { + log2_pow2_u64(c - 1) + } else if Self::C_SHIFT_KIND == -1 { + if c == u64::MAX { + 64 + } else { + log2_pow2_u64(c + 1) + } + } else { + 0 + } + }; + + /// Multiply by `C = 2^128 - P`. For `C = 2^a ± 1`, this is shift/add or + /// shift/sub only; otherwise it falls back to generic widening multiply. + #[inline(always)] + fn mul_c_wide(x: u64) -> (u64, u64) { + if Self::C_SHIFT_KIND == 1 { + let v = ((x as u128) << Self::C_SHIFT) + x as u128; + (v as u64, (v >> 64) as u64) + } else if Self::C_SHIFT_KIND == -1 { + let v = ((x as u128) << Self::C_SHIFT) - x as u128; + (v as u64, (v >> 64) as u64) + } else { + mul64_wide(Self::C_LO, x) + } + } + + /// Create from a canonical representative in `[0, p)`. + #[inline] + pub fn from_canonical_u128(x: u128) -> Self { + debug_assert!(x < P); + Self(from_u128(x)) + } + + /// Return the canonical representative in `[0, p)`. + #[inline] + pub fn to_canonical_u128(self) -> u128 { + to_u128(self.0) + } + + /// Const-evaluable `from_i64`. Embeds a small signed integer into `Fp`. + pub const fn from_i64_const(val: i64) -> Self { + if val >= 0 { + Self(from_u128(val as u128)) + } else { + Self(Self::sub_raw( + pack(0, 0), + from_u128(val.unsigned_abs() as u128), + )) + } + } + + /// Const-evaluable lookup table for balanced digits in `[-b/2, b/2)` + /// where `b = 2^log_basis`. Requires `log_basis <= 4`. + pub const fn digit_lut(log_basis: u32) -> [Self; 16] { + assert!(log_basis > 0 && log_basis <= 4); + let b = 1u32 << log_basis; + let half_b = (b / 2) as i64; + let mut lut = [Self(pack(0, 0)); 16]; + let mut i = 0u32; + while i < b { + lut[i as usize] = Self::from_i64_const(i as i64 - half_b); + i += 1; + } + lut + } + + #[inline(always)] + fn add_raw(a: [u64; 2], b: [u64; 2]) -> [u64; 2] { + let (s, carry) = to_u128(a).overflowing_add(to_u128(b)); + let (reduced, borrow) = s.overflowing_sub(P); + from_u128(if carry | !borrow { + reduced + } else { + reduced.wrapping_add(P) + }) + } + + #[inline(always)] + const fn sub_raw(a: [u64; 2], b: [u64; 2]) -> [u64; 2] { + let (diff, borrow) = to_u128(a).overflowing_sub(to_u128(b)); + from_u128(if borrow { diff.wrapping_add(P) } else { diff }) + } + + /// Fold 2 + canonicalize: reduce `[t0, t1] + t2·2^128` into `[0, p)`. + /// + /// Correctness argument for the fused overflow+canonicalize: + /// + /// Let `v = base + C·t2` (mathematical, not mod 2^128). + /// From the fold-1 mac chain, `t2 ≤ C`, so `C·t2 ≤ C²`. + /// + /// - **No overflow** (`v < 2^128`): `s = v`, and the standard + /// canonicalize applies — `s + C` carries iff `s ≥ P`. + /// - **Overflow** (`v ≥ 2^128`): `s = v − 2^128`, so `s < C·t2 ≤ C²`. + /// The correct reduced value is `s + C` (since `2^128 ≡ C mod P`). + /// Because `s + C < C² + C = C(C+1)` and `C(C+1) < P` for all + /// `C < 2^64`, the value `s + C` is already in `[0, P)` — no + /// further canonicalization is needed, and `s + C < 2^128` so the + /// add does NOT carry. + /// + /// Therefore `if (overflow | carry) { s + C } else { s }` is correct + /// in both cases, fusing the overflow correction with canonicalization. + #[inline(always)] + fn fold2_canonicalize(t0: u64, t1: u64, t2: u64) -> [u64; 2] { + let (ct2_lo, ct2_hi) = Self::mul_c_wide(t2); + + let (s0, carry0) = t0.overflowing_add(ct2_lo); + let (s1a, carry1a) = t1.overflowing_add(ct2_hi); + let (s1, carry1b) = s1a.overflowing_add(carry0 as u64); + let overflow = carry1a | carry1b; + + let (r0, carry2) = s0.overflowing_add(Self::C_LO); + let (r1, carry3) = s1.overflowing_add(carry2 as u64); + + pack( + if overflow | carry3 { r0 } else { s0 }, + if overflow | carry3 { r1 } else { s1 }, + ) + } + + /// Solinas fold for exactly 4 limbs: `[r0,r1] + C·[r2,r3]` → 3 limbs, + /// then `fold2_canonicalize`. + #[inline(always)] + fn reduce_4(r0: u64, r1: u64, r2: u64, r3: u64) -> [u64; 2] { + let (cr2_lo, cr2_hi) = Self::mul_c_wide(r2); + let (cr3_lo, cr3_hi) = Self::mul_c_wide(r3); + + let t0_sum = r0 as u128 + cr2_lo as u128; + let t0 = t0_sum as u64; + let carryf = (t0_sum >> 64) as u64; + + let t1_sum = r1 as u128 + cr2_hi as u128 + cr3_lo as u128 + carryf as u128; + let t1 = t1_sum as u64; + + let t2_sum = cr3_hi as u128 + (t1_sum >> 64); + let t2 = t2_sum as u64; + debug_assert_eq!(t2_sum >> 64, 0); + + Self::fold2_canonicalize(t0, t1, t2) + } + + #[inline(always)] + fn mul_raw(a: [u64; 2], b: [u64; 2]) -> [u64; 2] { + let [r0, r1, r2, r3] = Self(a).mul_wide(Self(b)); + Self::reduce_4(r0, r1, r2, r3) + } + + #[inline(always)] + fn sqr_wide(self) -> [u64; 4] { + let (a0, a1) = (self.0[0], self.0[1]); + let (p00_lo, p00_hi) = mul64_wide(a0, a0); + let (p01_lo, p01_hi) = mul64_wide(a0, a1); + let (p11_lo, p11_hi) = mul64_wide(a1, a1); + + let row1 = p00_hi as u128 + (p01_lo as u128) * 2; + let r0 = p00_lo; + let r1 = row1 as u64; + let carry1 = (row1 >> 64) as u64; + + let row2 = (p01_hi as u128) * 2 + p11_lo as u128 + carry1 as u128; + let r2 = row2 as u64; + let carry2 = (row2 >> 64) as u64; + + let row3 = p11_hi as u128 + carry2 as u128; + let r3 = row3 as u64; + debug_assert_eq!(row3 >> 64, 0); + + [r0, r1, r2, r3] + } + + #[inline(always)] + fn sqr_raw(a: [u64; 2]) -> [u64; 2] { + let [r0, r1, r2, r3] = Self(a).sqr_wide(); + Self::reduce_4(r0, r1, r2, r3) + } + + /// Squaring, equivalent to `self * self`. + #[inline(always)] + pub fn square(self) -> Self { + Self(Self::sqr_raw(self.0)) + } + + fn pow_u128(self, mut exp: u128) -> Self { + let mut base = self; + let mut acc = Self::one(); + while exp > 0 { + if (exp & 1) == 1 { + acc *= base; + } + base = Self(Self::sqr_raw(base.0)); + exp >>= 1; + } + acc + } + + /// Extract the canonical `[lo, hi]` limb representation. + #[inline(always)] + pub fn to_limbs(self) -> [u64; 2] { + self.0 + } + + /// 128×64 → 192-bit widening multiply, **no reduction**. + /// + /// Returns `[lo, mid, hi]` representing `self · other` as a 192-bit + /// integer. Cost: 2 widening `mul64`. + #[inline(always)] + pub fn mul_wide_u64(self, other: u64) -> [u64; 3] { + let (a0, a1) = (self.0[0], self.0[1]); + let (p0_lo, p0_hi) = mul64_wide(a0, other); + let (p1_lo, p1_hi) = mul64_wide(a1, other); + let mid = p0_hi as u128 + p1_lo as u128; + let hi = p1_hi + (mid >> 64) as u64; + [p0_lo, mid as u64, hi] + } + + /// 128×128 → 256-bit widening multiply, **no reduction**. + /// + /// Returns `[r0, r1, r2, r3]` representing `self · other` as a 256-bit + /// integer. This is the schoolbook 2×2 portion of the Solinas multiply, + /// without the reduction fold. Cost: 4 widening `mul64`. + #[inline(always)] + pub fn mul_wide(self, other: Self) -> [u64; 4] { + let (a0, a1) = (self.0[0], self.0[1]); + let (b0, b1) = (other.0[0], other.0[1]); + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r0 = p00_lo; + let r1 = row1 as u64; + let carry1 = (row1 >> 64) as u64; + + let row2 = p01_hi as u128 + p10_hi as u128 + p11_lo as u128 + carry1 as u128; + let r2 = row2 as u64; + let carry2 = (row2 >> 64) as u64; + + let row3 = p11_hi as u128 + carry2 as u128; + let r3 = row3 as u64; + debug_assert_eq!(row3 >> 64, 0); + + [r0, r1, r2, r3] + } + + /// 128×128 → 256-bit widening multiply with a raw `u128` operand, + /// **no reduction**. + #[inline(always)] + pub fn mul_wide_u128(self, other: u128) -> [u64; 4] { + self.mul_wide(Self(from_u128(other))) + } + + /// 128×(64*M) → (64*OUT) widening multiply, **no reduction**. + /// + /// Multiplies a canonical Fp128 value (`[u64; 2]`) by an arbitrary + /// little-endian limb array and returns the little-endian product + /// truncated/extended to `OUT` limbs. + #[inline(always)] + pub fn mul_wide_limbs(self, other: [u64; M]) -> [u64; OUT] { + let (a0, a1) = (self.0[0], self.0[1]); + + // Hot-path specializations used by Jolt (M in {3,4}, OUT in {4,5}). + // These avoid loop/control-flow overhead in tight sumcheck FMAs. + if M == 3 && OUT == 5 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let (p12_lo, p12_hi) = mul64_wide(a1, b2); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + let carry3 = row3 >> 64; + + let row4 = p12_hi as u128 + carry3; + let r4 = row4 as u64; + debug_assert_eq!(row4 >> 64, 0); + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + out[4] = r4; + return out; + } + if M == 3 && OUT == 4 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let p12_lo = a1.wrapping_mul(b2); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + return out; + } + if M == 4 && OUT == 6 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + let b3 = other[3]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let (p03_lo, p03_hi) = mul64_wide(a0, b3); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let (p12_lo, p12_hi) = mul64_wide(a1, b2); + let (p13_lo, p13_hi) = mul64_wide(a1, b3); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p03_lo as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + let carry3 = row3 >> 64; + + let row4 = p03_hi as u128 + p12_hi as u128 + p13_lo as u128 + carry3; + let r4 = row4 as u64; + let carry4 = row4 >> 64; + + let row5 = p13_hi as u128 + carry4; + let r5 = row5 as u64; + debug_assert_eq!(row5 >> 64, 0); + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + out[4] = r4; + out[5] = r5; + return out; + } + if M == 4 && OUT == 5 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + let b3 = other[3]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let (p03_lo, p03_hi) = mul64_wide(a0, b3); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let (p12_lo, p12_hi) = mul64_wide(a1, b2); + let p13_lo = a1.wrapping_mul(b3); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p03_lo as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + let carry3 = row3 >> 64; + + let row4 = p03_hi as u128 + p12_hi as u128 + p13_lo as u128 + carry3; + let r4 = row4 as u64; + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + out[4] = r4; + return out; + } + if M == 4 && OUT == 4 { + let b0 = other[0]; + let b1 = other[1]; + let b2 = other[2]; + let b3 = other[3]; + + let (p00_lo, p00_hi) = mul64_wide(a0, b0); + let (p01_lo, p01_hi) = mul64_wide(a0, b1); + let (p02_lo, p02_hi) = mul64_wide(a0, b2); + let p03_lo = a0.wrapping_mul(b3); + let (p10_lo, p10_hi) = mul64_wide(a1, b0); + let (p11_lo, p11_hi) = mul64_wide(a1, b1); + let p12_lo = a1.wrapping_mul(b2); + + let r0 = p00_lo; + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r1 = row1 as u64; + let carry1 = row1 >> 64; + + let row2 = p01_hi as u128 + p02_lo as u128 + p10_hi as u128 + p11_lo as u128 + carry1; + let r2 = row2 as u64; + let carry2 = row2 >> 64; + + let row3 = p02_hi as u128 + p03_lo as u128 + p11_hi as u128 + p12_lo as u128 + carry2; + let r3 = row3 as u64; + + let mut out = [0u64; OUT]; + out[0] = r0; + out[1] = r1; + out[2] = r2; + out[3] = r3; + return out; + } + + let mut out = [0u64; OUT]; + + for (i, &b) in other.iter().enumerate() { + if i >= OUT { + break; + } + + let (p0_lo, p0_hi) = mul64_wide(a0, b); + let (p1_lo, p1_hi) = mul64_wide(a1, b); + + let s0 = out[i] as u128 + p0_lo as u128; + out[i] = s0 as u64; + let mut carry = s0 >> 64; + + if i + 1 >= OUT { + continue; + } + let s1 = out[i + 1] as u128 + p0_hi as u128 + p1_lo as u128 + carry; + out[i + 1] = s1 as u64; + carry = s1 >> 64; + + if i + 2 >= OUT { + continue; + } + let s2 = out[i + 2] as u128 + p1_hi as u128 + carry; + out[i + 2] = s2 as u64; + + let mut carry_hi = s2 >> 64; + let mut j = i + 3; + while carry_hi != 0 && j < OUT { + let sj = out[j] as u128 + carry_hi; + out[j] = sj as u64; + carry_hi = sj >> 64; + j += 1; + } + } + + out + } + + /// Reduce an arbitrary-width little-endian limb array to a canonical + /// field element via iterated Solinas folding. + /// + /// Each fold splits at the 128-bit boundary and replaces + /// `hi · 2^128` with `hi · C`, reducing width by one limb per + /// iteration. Supports 0–10 input limbs (up to 640 bits). + /// + /// # Panics + /// + /// Panics if `limbs.len() > 10`. + #[inline(always)] + pub fn solinas_reduce(limbs: &[u64]) -> Self { + match limbs.len() { + 0 => Self::zero(), + 1 => Self(pack(limbs[0], 0)), + 2 => Self::from_canonical_u128_reduced(to_u128([limbs[0], limbs[1]])), + 3 => Self(Self::fold2_canonicalize(limbs[0], limbs[1], limbs[2])), + 4 => Self(Self::reduce_4(limbs[0], limbs[1], limbs[2], limbs[3])), + 5 => { + let (l0, l1, l2, l3, l4) = (limbs[0], limbs[1], limbs[2], limbs[3], limbs[4]); + let (c2_lo, c2_hi) = Self::mul_c_wide(l2); + let (c3_lo, c3_hi) = Self::mul_c_wide(l3); + let (c4_lo, c4_hi) = Self::mul_c_wide(l4); + + let s0 = l0 as u128 + c2_lo as u128; + let s1 = l1 as u128 + c2_hi as u128 + c3_lo as u128 + (s0 >> 64); + let s2 = c3_hi as u128 + c4_lo as u128 + (s1 >> 64); + let s3 = c4_hi as u128 + (s2 >> 64); + debug_assert_eq!(s3 >> 64, 0); + + Self(Self::reduce_4(s0 as u64, s1 as u64, s2 as u64, s3 as u64)) + } + n => { + assert!(n <= 10, "solinas_reduce supports at most 10 limbs"); + let mut buf = [0u64; 11]; + buf[..n].copy_from_slice(limbs); + let mut len = n; + let c = Self::C_LO; + + while len > 5 { + let high_len = len - 2; + let mut next = [0u64; 11]; + + let mut carry: u64 = 0; + for i in 0..high_len { + let wide = c as u128 * buf[i + 2] as u128 + carry as u128; + next[i] = wide as u64; + carry = (wide >> 64) as u64; + } + next[high_len] = carry; + + let s0 = next[0] as u128 + buf[0] as u128; + next[0] = s0 as u64; + let s1 = next[1] as u128 + buf[1] as u128 + (s0 >> 64); + next[1] = s1 as u64; + let mut c_out = (s1 >> 64) as u64; + for limb in &mut next[2..=high_len] { + if c_out == 0 { + break; + } + let s = *limb as u128 + c_out as u128; + *limb = s as u64; + c_out = (s >> 64) as u64; + } + debug_assert_eq!(c_out, 0); + + buf = next; + len -= 1; + while len > 5 && buf[len - 1] == 0 { + len -= 1; + } + } + + Self::solinas_reduce(&buf[..len]) + } + } + } +} + +impl Add for Fp128

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Self(Self::add_raw(self.0, rhs.0)) + } +} + +impl Sub for Fp128

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Self(Self::sub_raw(self.0, rhs.0)) + } +} + +impl Mul for Fp128

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Self(Self::mul_raw(self.0, rhs.0)) + } +} + +impl Neg for Fp128

{ + type Output = Self; + #[inline] + fn neg(self) -> Self::Output { + Self(Self::sub_raw(pack(0, 0), self.0)) + } +} + +impl AddAssign for Fp128

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for Fp128

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for Fp128

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl<'a, const P: u128> Add<&'a Self> for Fp128

{ + type Output = Self; + #[inline] + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} + +impl<'a, const P: u128> Sub<&'a Self> for Fp128

{ + type Output = Self; + #[inline] + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} + +impl<'a, const P: u128> Mul<&'a Self> for Fp128

{ + type Output = Self; + #[inline] + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl Valid for Fp128

{ + fn check(&self) -> Result<(), SerializationError> { + if to_u128(self.0) < P { + Ok(()) + } else { + Err(SerializationError::InvalidData("Fp128 out of range".into())) + } + } +} + +impl HachiSerialize for Fp128

{ + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + to_u128(self.0).serialize_with_mode(&mut writer, Compress::No)?; + Ok(()) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 16 + } +} + +impl HachiDeserialize for Fp128

{ + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let x = u128::deserialize_with_mode(&mut reader, Compress::No, validate)?; + if matches!(validate, Validate::Yes) && x >= P { + return Err(SerializationError::InvalidData( + "Fp128 out of range".to_string(), + )); + } + + // Without validation, reduce without division. + // For `p = 2^128 − c` with `c < 2^64` we have `p > 2^127`, + // so any `u128` is in `[0, 2p)` and one conditional subtract suffices. + let out = if matches!(validate, Validate::Yes) { + x + } else { + let (sub, borrow) = x.overflowing_sub(P); + if borrow { + x + } else { + sub + } + }; + Ok(Self(from_u128(out))) + } +} + +impl AdditiveGroup for Fp128

{ + const ZERO: Self = Self(pack(0, 0)); +} + +impl FieldCore for Fp128

{ + fn one() -> Self { + Self(pack(1, 0)) + } + + fn is_zero(&self) -> bool { + self.0 == [0, 0] + } + + fn inv(self) -> Option { + let inv = self.inv_or_zero(); + if self.is_zero() { + None + } else { + Some(inv) + } + } + + const TWO_INV: Self = { + let v = (P >> 1) + 1; + Self(pack(v as u64, (v >> 64) as u64)) + }; +} + +impl Invertible for Fp128

{ + fn inv_or_zero(self) -> Self { + let candidate = self.pow_u128(P.wrapping_sub(2)); + let v = to_u128(self.0); + let nz = ((v | v.wrapping_neg()) >> 127) & 1; + let mask = 0u128.wrapping_sub(nz); + let masked = to_u128(candidate.0) & mask; + Self(from_u128(masked)) + } +} + +impl FieldSampling for Fp128

{ + fn sample(rng: &mut R) -> Self { + loop { + let lo = rng.next_u64(); + let hi = rng.next_u64(); + let x = lo as u128 | (hi as u128) << 64; + if x < P { + return Self(pack(lo, hi)); + } + } + } +} + +impl FromSmallInt for Fp128

{ + fn from_u64(val: u64) -> Self { + // For Fp128 pseudo-Mersenne primes, p = 2^128 - c with c < 2^64. + // Therefore any u64 is always canonical (< p), so this can be a + // direct limb construction with no reduction path. + Self(from_u128(val as u128)) + } + + fn from_i64(val: i64) -> Self { + Self::from_i64_const(val) + } + + fn digit_lut(log_basis: u32) -> [Self; 16] { + Self::digit_lut(log_basis) + } +} + +impl CanonicalField for Fp128

{ + fn to_canonical_u128(self) -> u128 { + to_u128(self.0) + } + + fn from_canonical_u128_checked(val: u128) -> Option { + if val < P { + Some(Self(from_u128(val))) + } else { + None + } + } + + fn from_canonical_u128_reduced(val: u128) -> Self { + let (sub, borrow) = val.overflowing_sub(P); + Self(from_u128(if borrow { val } else { sub })) + } +} + +impl PseudoMersenneField for Fp128

{ + const MODULUS_BITS: u32 = 128; + const MODULUS_OFFSET: u128 = Self::C; +} + +/// `p = 2^128 − 2^13 − 2^4 + 2^0` (C = 8207). +pub type Prime128M13M4P0 = Fp128<0xffffffffffffffffffffffffffffdff1>; +/// `p = 2^128 − 2^37 + 2^3 + 2^0` (C = 137438953463). +pub type Prime128M37P3P0 = Fp128<0xffffffffffffffffffffffe000000009>; +/// `p = 2^128 − 2^52 − 2^3 + 2^0` (C = 4503599627370487). +pub type Prime128M52M3P0 = Fp128<0xffffffffffffffffffeffffffffffff9>; +/// `p = 2^128 − 2^54 + 2^4 + 2^0` (C = 18014398509481967). +pub type Prime128M54P4P0 = Fp128<0xffffffffffffffffffc0000000000011>; +/// `p = 2^128 − 2^8 − 2^4 − 2^1 − 2^0` (C = 275). +pub type Prime128M8M4M1M0 = Fp128<0xfffffffffffffffffffffffffffffeed>; +/// `p = 2^128 − 5823` (C = 5823). +pub type Prime128Offset5823 = Fp128<0xffffffffffffffffffffffffffffe941>; +/// `p = 2^128 − 2^18 − 2^0` (C = 2^18 + 1). +pub type Prime128M18M0 = Fp128<0xfffffffffffffffffffffffffffbffff>; +/// `p = 2^128 − 2^54 + 2^0` (C = 2^54 − 1). +pub type Prime128M54P0 = Fp128<0xffffffffffffffffffc0000000000001>; + +#[cfg(test)] +mod tests { + use super::*; + use crate::{FieldSampling, PseudoMersenneField}; + use rand::rngs::StdRng; + use rand::SeedableRng; + use rand_core::RngCore; + + type F = Prime128M8M4M1M0; + + #[test] + fn to_limbs_roundtrip() { + let mut rng = StdRng::seed_from_u64(0xdead_beef_cafe_1234); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + assert_eq!(Fp128(a.to_limbs()), a); + } + } + + #[test] + fn mul_wide_u64_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0x1122_3344_5566_7788); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b = rng.next_u64(); + let expected = a * F::from_u64(b); + let reduced = F::solinas_reduce(&a.mul_wide_u64(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0xaabb_ccdd_eeff_0011); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b: F = FieldSampling::sample(&mut rng); + let expected = a * b; + let reduced = F::solinas_reduce(&a.mul_wide(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_u128_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0x9988_7766_5544_3322); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b = rng.next_u64() as u128 | ((rng.next_u64() as u128) << 64); + let expected = a * F::from_canonical_u128_reduced(b); + let reduced = F::solinas_reduce(&a.mul_wide_u128(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_limbs_roundtrips_through_reduction() { + let mut rng = StdRng::seed_from_u64(0x1bad_f00d_0ddc_afe1); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b3 = [rng.next_u64(), rng.next_u64(), rng.next_u64()]; + let b4 = [ + rng.next_u64(), + rng.next_u64(), + rng.next_u64(), + rng.next_u64(), + ]; + + let got3_full = a.mul_wide_limbs::<3, 5>(b3); + let got3_trunc = a.mul_wide_limbs::<3, 4>(b3); + assert_eq!( + got3_trunc, + [got3_full[0], got3_full[1], got3_full[2], got3_full[3]] + ); + let exp3 = a * F::solinas_reduce(&b3); + assert_eq!(F::solinas_reduce(&got3_full), exp3); + + let got4_full = a.mul_wide_limbs::<4, 6>(b4); + let got4_trunc = a.mul_wide_limbs::<4, 4>(b4); + assert_eq!( + got4_trunc, + [got4_full[0], got4_full[1], got4_full[2], got4_full[3]] + ); + let exp4 = a * F::solinas_reduce(&b4); + assert_eq!(F::solinas_reduce(&got4_full), exp4); + } + } + + #[test] + fn solinas_reduce_small_inputs() { + assert_eq!(F::solinas_reduce(&[]), F::zero()); + assert_eq!(F::solinas_reduce(&[42]), F::from_u64(42)); + let one_shifted = F::from_canonical_u128_reduced(1u128 << 64); + assert_eq!(F::solinas_reduce(&[0, 1]), one_shifted); + } + + #[test] + fn solinas_reduce_4_limbs_max() { + // 2^256 - 1 ≡ C² - 1 (mod P), since 2^128 ≡ C + let c = F::from_canonical_u128_reduced(::MODULUS_OFFSET); + let expected = c * c - F::one(); + assert_eq!(F::solinas_reduce(&[u64::MAX; 4]), expected); + } + + #[test] + fn solinas_reduce_9_limbs() { + // 1 + 2^512 = 1 + (2^128)^4 ≡ 1 + C^4 + let c = F::from_canonical_u128_reduced(::MODULUS_OFFSET); + let expected = F::one() + c * c * c * c; + assert_eq!(F::solinas_reduce(&[1, 0, 0, 0, 0, 0, 0, 0, 1]), expected); + } + + #[test] + fn solinas_reduce_accumulated_products() { + let mut rng = StdRng::seed_from_u64(0xfeed_face_0bad_c0de); + let mut acc = [0u64; 5]; + let mut expected = F::zero(); + + for _ in 0..200 { + let a: F = FieldSampling::sample(&mut rng); + let b = rng.next_u64(); + let wide = a.mul_wide_u64(b); + + let mut carry: u64 = 0; + for j in 0..5 { + let addend = if j < 3 { wide[j] } else { 0 }; + let sum = acc[j] as u128 + addend as u128 + carry as u128; + acc[j] = sum as u64; + carry = (sum >> 64) as u64; + } + assert_eq!(carry, 0); + expected += a * F::from_u64(b); + } + + assert_eq!(F::solinas_reduce(&acc), expected); + } + + #[test] + fn solinas_reduce_cross_prime() { + // Verify with Prime128M18M0 (C = 2^18 + 1, shift+add path) + type G = Prime128M18M0; + let c = G::from_canonical_u128_reduced(::MODULUS_OFFSET); + let expected = c * c - G::one(); + assert_eq!(G::solinas_reduce(&[u64::MAX; 4]), expected); + + // Verify with Prime128M54P0 (C = 2^54 - 1, shift-sub path) + type H = Prime128M54P0; + let c = H::from_canonical_u128_reduced(::MODULUS_OFFSET); + let expected = c * c - H::one(); + assert_eq!(H::solinas_reduce(&[u64::MAX; 4]), expected); + } + + #[test] + fn from_i64_handles_min_without_overflow() { + let x = F::from_i64(i64::MIN); + let y = F::from_u64(i64::MIN.unsigned_abs()); + assert_eq!(x + y, F::zero()); + } +} diff --git a/src/algebra/fields/fp32.rs b/src/algebra/fields/fp32.rs new file mode 100644 index 00000000..77419476 --- /dev/null +++ b/src/algebra/fields/fp32.rs @@ -0,0 +1,471 @@ +//! Prime field for primes of the form `p = 2^k − c` with `c` small, backed +//! by `u32` storage. +//! +//! Uses Solinas-style two-fold reduction: the offset `c` and fold point `k` +//! are computed at compile time from the const-generic modulus `P`. + +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use rand_core::RngCore; + +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{ + AdditiveGroup, CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible, + PseudoMersenneField, +}; +use std::io::{Read, Write}; + +/// Prime field element for primes `p = 2^k − c` stored as `u32`. +/// +/// The fold point `k` and offset `c = 2^k − p` are computed at compile time +/// from the const-generic `P`. Instantiating with a modulus that does not +/// satisfy the Solinas conditions is a compile-time error. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct Fp32(pub(crate) u32); + +impl Fp32

{ + /// Fold point: smallest `k` such that `P ≤ 2^k`. + const BITS: u32 = 32 - P.leading_zeros(); + + /// Offset `c = 2^k − P`. + pub const C: u32 = { + let c = if Self::BITS == 32 { + 0u32.wrapping_sub(P) + } else { + (1u32 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u64) * (c as u64 + 1) < P as u64, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + /// Mask for extracting the low `BITS` bits from a u64. + const MASK: u64 = if Self::BITS == 32 { + u32::MAX as u64 + } else { + (1u64 << Self::BITS) - 1 + }; + + /// Create from a canonical representative in `[0, P)`. + #[inline] + pub fn from_canonical_u32(x: u32) -> Self { + debug_assert!(x < P); + Self(x) + } + + /// Return the canonical representative in `[0, P)`. + #[inline] + pub fn to_canonical_u32(self) -> u32 { + self.0 + } + + /// Solinas reduction: fold a u64 at bit `BITS` until the value fits, + /// then conditionally subtract `P`. + /// + /// For multiplication products (< 2^{2·BITS}) exactly 2 folds suffice; + /// for arbitrary u64 inputs (e.g. `from_u64`) the loop runs at most + /// `ceil(64 / BITS)` iterations. + #[inline(always)] + fn reduce_u64(x: u64) -> u32 { + let c = Self::C as u64; + let mut v = x; + while v >> Self::BITS != 0 { + v = (v & Self::MASK) + c * (v >> Self::BITS); + } + let reduced = v.wrapping_sub(P as u64); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + /// Reduce a `u128` to canonical form (for `from_canonical_u128_reduced`). + #[inline(always)] + fn reduce_u128(x: u128) -> u32 { + let c = Self::C as u128; + let bits = Self::BITS; + let mask = if bits == 32 { + u32::MAX as u128 + } else { + (1u128 << bits) - 1 + }; + let mut v = x; + while v >> bits != 0 { + v = (v & mask) + c * (v >> bits); + } + let f = v as u64; + let reduced = f.wrapping_sub(P as u64); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + /// Two-fold Solinas reduction for multiplication products. + /// + /// Input must be < 2^{2·BITS} (guaranteed for `a*b` where `a,b < P`). + /// Exactly 2 folds + conditional subtract, no loop. + #[inline(always)] + fn reduce_product(x: u64) -> u32 { + let c = Self::C as u64; + let f1 = (x & Self::MASK) + c * (x >> Self::BITS); + let f2 = (f1 & Self::MASK) + c * (f1 >> Self::BITS); + let reduced = f2.wrapping_sub(P as u64); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + #[inline(always)] + fn add_raw(a: u32, b: u32) -> u32 { + let s = (a as u64) + (b as u64); + let reduced = s.wrapping_sub(P as u64); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + #[inline(always)] + fn sub_raw(a: u32, b: u32) -> u32 { + let diff = (a as u64).wrapping_sub(b as u64); + let borrow = diff >> 63; + diff.wrapping_add(borrow.wrapping_neg() & (P as u64)) as u32 + } + + #[inline(always)] + fn mul_raw(a: u32, b: u32) -> u32 { + Self::reduce_product((a as u64) * (b as u64)) + } + + #[inline(always)] + fn sqr_raw(a: u32) -> u32 { + Self::mul_raw(a, a) + } + + /// Squaring, equivalent to `self * self`. + #[inline(always)] + pub fn square(self) -> Self { + Self(Self::sqr_raw(self.0)) + } + + fn pow(self, mut exp: u64) -> Self { + let mut base = self; + let mut acc = Self::one(); + while exp > 0 { + if (exp & 1) == 1 { + acc *= base; + } + base = base.square(); + exp >>= 1; + } + acc + } + + /// Extract the canonical value. + #[inline(always)] + pub fn to_limbs(self) -> u32 { + self.0 + } + + /// 32×32 → 64-bit widening multiply, **no reduction**. + #[inline(always)] + pub fn mul_wide(self, other: Self) -> u64 { + (self.0 as u64) * (other.0 as u64) + } + + /// 32×32 → 64-bit widening multiply with a raw `u32` operand, + /// **no reduction**. + #[inline(always)] + pub fn mul_wide_u32(self, other: u32) -> u64 { + (self.0 as u64) * (other as u64) + } + + /// Reduce a u64 value via Solinas folding to a canonical field element. + #[inline(always)] + pub fn solinas_reduce(x: u64) -> Self { + Self(Self::reduce_u64(x)) + } +} + +impl Add for Fp32

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Self(Self::add_raw(self.0, rhs.0)) + } +} + +impl Sub for Fp32

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Self(Self::sub_raw(self.0, rhs.0)) + } +} + +impl Mul for Fp32

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Self(Self::mul_raw(self.0, rhs.0)) + } +} + +impl Neg for Fp32

{ + type Output = Self; + #[inline] + fn neg(self) -> Self::Output { + Self(Self::sub_raw(0, self.0)) + } +} + +impl AddAssign for Fp32

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for Fp32

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for Fp32

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl<'a, const P: u32> Add<&'a Self> for Fp32

{ + type Output = Self; + #[inline] + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} + +impl<'a, const P: u32> Sub<&'a Self> for Fp32

{ + type Output = Self; + #[inline] + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} + +impl<'a, const P: u32> Mul<&'a Self> for Fp32

{ + type Output = Self; + #[inline] + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl Valid for Fp32

{ + fn check(&self) -> Result<(), SerializationError> { + if self.0 < P { + Ok(()) + } else { + Err(SerializationError::InvalidData("Fp32 out of range".into())) + } + } +} + +impl HachiSerialize for Fp32

{ + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(&mut writer, Compress::No)?; + Ok(()) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 4 + } +} + +impl HachiDeserialize for Fp32

{ + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let x = u32::deserialize_with_mode(&mut reader, Compress::No, validate)?; + if matches!(validate, Validate::Yes) && x >= P { + return Err(SerializationError::InvalidData( + "Fp32 out of range".to_string(), + )); + } + let out = if matches!(validate, Validate::Yes) { + Self(x) + } else { + Self(Self::reduce_u64(x as u64)) + }; + Ok(out) + } +} + +impl AdditiveGroup for Fp32

{ + const ZERO: Self = Self(0); +} + +impl FieldCore for Fp32

{ + fn one() -> Self { + Self(if P > 1 { 1 } else { 0 }) + } + + fn is_zero(&self) -> bool { + self.0 == 0 + } + + fn inv(self) -> Option { + let inv = self.inv_or_zero(); + if self.is_zero() { + None + } else { + Some(inv) + } + } + + const TWO_INV: Self = Self((P as u64).div_ceil(2) as u32); +} + +impl Invertible for Fp32

{ + fn inv_or_zero(self) -> Self { + let candidate = self.pow((P as u64).wrapping_sub(2)); + let nz = ((self.0 | self.0.wrapping_neg()) >> 31) & 1; + let mask = 0u32.wrapping_sub(nz); + Self(candidate.0 & mask) + } +} + +impl FieldSampling for Fp32

{ + fn sample(rng: &mut R) -> Self { + Self(Self::reduce_u64(rng.next_u64())) + } +} + +impl FromSmallInt for Fp32

{ + fn from_u64(val: u64) -> Self { + Self(Self::reduce_u64(val)) + } + + fn from_i64(val: i64) -> Self { + if val >= 0 { + Self::from_u64(val as u64) + } else { + -Self::from_u64(val.unsigned_abs()) + } + } +} + +impl CanonicalField for Fp32

{ + fn to_canonical_u128(self) -> u128 { + self.0 as u128 + } + + fn from_canonical_u128_checked(val: u128) -> Option { + if val < P as u128 { + Some(Self(val as u32)) + } else { + None + } + } + + fn from_canonical_u128_reduced(val: u128) -> Self { + Self(Self::reduce_u128(val)) + } +} + +impl PseudoMersenneField for Fp32

{ + const MODULUS_BITS: u32 = Self::BITS; + const MODULUS_OFFSET: u128 = Self::C as u128; +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp32<251>; // 2^8 - 5 + + #[test] + fn solinas_constants() { + assert_eq!(F::BITS, 8); + assert_eq!(F::C, 5); + assert_eq!(F::MASK, 255); + + type G = Fp32<{ (1u32 << 24) - 3 }>; // 2^24 - 3 + assert_eq!(G::BITS, 24); + assert_eq!(G::C, 3); + } + + #[test] + fn basic_arithmetic() { + let a = F::from_u64(100); + let b = F::from_u64(200); + assert_eq!((a + b).to_canonical_u32(), (100 + 200) % 251); + assert_eq!((a * b).to_canonical_u32(), (100 * 200) % 251); + assert_eq!((b - a).to_canonical_u32(), 100); + assert_eq!((-a).to_canonical_u32(), 251 - 100); + } + + #[test] + fn mul_wide_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0x1234_5678); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b: F = FieldSampling::sample(&mut rng); + let expected = a * b; + let reduced = F::solinas_reduce(a.mul_wide(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_u32_matches() { + let mut rng = StdRng::seed_from_u64(0xabcd_ef01); + for _ in 0..1000 { + let a: F = FieldSampling::sample(&mut rng); + let b = rng.next_u32() % 251; + let expected = a * F::from_canonical_u32(b); + let reduced = F::solinas_reduce(a.mul_wide_u32(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn reduce_large_values() { + assert_eq!( + F::from_u64(u64::MAX).to_canonical_u32(), + (u64::MAX % 251) as u32 + ); + assert_eq!(F::from_u64(0).to_canonical_u32(), 0); + assert_eq!(F::from_u64(251).to_canonical_u32(), 0); + assert_eq!(F::from_u64(252).to_canonical_u32(), 1); + } + + #[test] + fn pseudo_mersenne_trait() { + assert_eq!(::MODULUS_BITS, 8); + assert_eq!(::MODULUS_OFFSET, 5); + } + + #[test] + fn cross_prime_32bit() { + type G = Fp32<{ u32::MAX - 98 }>; // 2^32 - 99 + assert_eq!(G::BITS, 32); + assert_eq!(G::C, 99); + + let a = G::from_u64(1_000_000); + let b = G::from_u64(2_000_000); + let product = (1_000_000u64 * 2_000_000u64) % ((1u64 << 32) - 99); + assert_eq!((a * b).to_canonical_u32(), product as u32); + } +} diff --git a/src/algebra/fields/fp64.rs b/src/algebra/fields/fp64.rs new file mode 100644 index 00000000..5e88ec5b --- /dev/null +++ b/src/algebra/fields/fp64.rs @@ -0,0 +1,585 @@ +//! Prime field for primes of the form `p = 2^k − c` with `c` small, backed +//! by `u64` storage. +//! +//! Uses Solinas-style two-fold reduction. For `c = 2^a ± 1` the fold +//! multiply is replaced by shift+add/sub, saving a u128 widening multiply. + +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use rand_core::RngCore; + +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{ + AdditiveGroup, CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible, + PseudoMersenneField, +}; +use std::io::{Read, Write}; + +use super::util::{is_pow2_u64, log2_pow2_u64, mul64_wide}; + +/// Prime field element for primes `p = 2^k − c` stored as `u64`. +/// +/// The fold point `k` and offset `c = 2^k − p` are computed at compile time +/// from the const-generic `P`. For `c = 2^a ± 1`, the fold multiply is +/// replaced by shift+add/sub. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct Fp64(pub(crate) u64); + +impl Fp64

{ + /// Fold point: smallest `k` such that `P ≤ 2^k`. + const BITS: u32 = 64 - P.leading_zeros(); + + /// Offset `c = 2^k − P`. + pub const C: u64 = { + let c = if Self::BITS == 64 { + 0u64.wrapping_sub(P) + } else { + (1u64 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u128) * (c as u128 + 1) < P as u128, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + /// +1 means `C = 2^a + 1`, -1 means `C = 2^a - 1`, 0 means generic. + const C_SHIFT_KIND: i8 = { + let c = Self::C; + if c > 1 && is_pow2_u64(c - 1) { + 1 + } else if c == u64::MAX || is_pow2_u64(c + 1) { + -1 + } else { + 0 + } + }; + + const C_SHIFT: u32 = { + let c = Self::C; + if Self::C_SHIFT_KIND == 1 { + log2_pow2_u64(c - 1) + } else if Self::C_SHIFT_KIND == -1 { + if c == u64::MAX { + 64 + } else { + log2_pow2_u64(c + 1) + } + } else { + 0 + } + }; + + /// Mask for extracting the low `BITS` bits from a u128. + const MASK: u128 = if Self::BITS == 64 { + u64::MAX as u128 + } else { + (1u128 << Self::BITS) - 1 + }; + + /// u64-width mask (only valid when BITS < 64). + const MASK64: u64 = if Self::BITS < 64 { + (1u64 << Self::BITS) - 1 + } else { + u64::MAX + }; + + /// Whether Solinas folding of a multiplication product can stay + /// entirely in u64. True when BITS < 64 and C·2^BITS < 2^64. + const FOLD_IN_U64: bool = Self::BITS < 64 && (Self::C as u128) < (1u128 << (64 - Self::BITS)); + + /// u64 multiply by C, split into u32-wide halves so LLVM emits + /// `umull` (32×32→64) instead of promoting to u128. + /// Only valid when C fits in u32 (always true: C < sqrt(P) < 2^32). + #[inline(always)] + fn mul_c_narrow(x: u64) -> u64 { + #[cfg(target_arch = "x86_64")] + { + // x86_64 has fast scalar 64-bit multiply; use one multiply instead + // of two widened 32-bit multiplies in the fold hot path. + Self::C.wrapping_mul(x) + } + #[cfg(not(target_arch = "x86_64"))] + { + let c = Self::C as u32; + let x_lo = x as u32; + let x_hi = (x >> 32) as u32; + (c as u64 * x_lo as u64).wrapping_add((c as u64 * x_hi as u64) << 32) + } + } + + /// Multiply `x` by `C`. For `C = 2^a ± 1` uses shift+add/sub. + #[inline(always)] + fn mul_c(x: u64) -> u128 { + if Self::C_SHIFT_KIND == 1 { + ((x as u128) << Self::C_SHIFT) + x as u128 + } else if Self::C_SHIFT_KIND == -1 { + ((x as u128) << Self::C_SHIFT) - x as u128 + } else { + (Self::C as u128) * (x as u128) + } + } + + /// Create from a canonical representative in `[0, P)`. + #[inline] + pub fn from_canonical_u64(x: u64) -> Self { + debug_assert!(x < P); + Self(x) + } + + /// Return the canonical representative in `[0, P)`. + #[inline] + pub fn to_canonical_u64(self) -> u64 { + self.0 + } + + /// Solinas reduction: fold a u128 at bit `BITS` until the value fits, + /// then conditionally subtract `P`. + /// + /// For multiplication products (< 2^{2·BITS}) exactly 2 folds suffice; + /// for arbitrary u128 inputs the loop runs at most `ceil(128 / BITS)` + /// iterations. + #[inline(always)] + fn reduce_u128(x: u128) -> u64 { + let mut v = x; + while v >> Self::BITS != 0 { + v = (v & Self::MASK) + Self::mul_c((v >> Self::BITS) as u64); + } + let reduced = v.wrapping_sub(P as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + + /// Two-fold Solinas reduction for multiplication products. + /// + /// Input must be < 2^{2·BITS} (guaranteed for `a*b` where `a,b < P`). + /// Exactly 2 folds + conditional subtract, no loop. + /// + /// When `FOLD_IN_U64` is true the entire reduction stays in u64, + /// avoiding expensive u128 mask/shift on sub-word primes. + #[inline(always)] + fn reduce_product(x: u128) -> u64 { + if Self::FOLD_IN_U64 { + let lo = x as u64; + let hi = (x >> 64) as u64; + let high = (lo >> Self::BITS) | (hi << (64 - Self::BITS)); + let f1 = (lo & Self::MASK64) + Self::mul_c_narrow(high); + let f2 = (f1 & Self::MASK64) + Self::mul_c_narrow(f1 >> Self::BITS); + let reduced = f2.wrapping_sub(P); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P) + } else { + let f1 = (x & Self::MASK) + Self::mul_c((x >> Self::BITS) as u64); + let f2 = (f1 & Self::MASK) + Self::mul_c((f1 >> Self::BITS) as u64); + let reduced = f2.wrapping_sub(P as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + } + + /// BMI2 fast path: avoid re-materializing `u128` product in the common + /// sub-word configuration where reduction stays in `u64`. + #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] + #[inline(always)] + fn reduce_product_wide(lo: u64, hi: u64) -> u64 { + if Self::FOLD_IN_U64 { + let high = (lo >> Self::BITS) | (hi << (64 - Self::BITS)); + let f1 = (lo & Self::MASK64) + Self::mul_c_narrow(high); + let f2 = (f1 & Self::MASK64) + Self::mul_c_narrow(f1 >> Self::BITS); + let reduced = f2.wrapping_sub(P); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P) + } else { + Self::reduce_product(lo as u128 | ((hi as u128) << 64)) + } + } + + #[inline(always)] + fn add_raw(a: u64, b: u64) -> u64 { + if Self::BITS <= 62 { + let s = a + b; + let reduced = s.wrapping_sub(P); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P) + } else { + let s = (a as u128) + (b as u128); + let reduced = s.wrapping_sub(P as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + } + + #[inline(always)] + fn sub_raw(a: u64, b: u64) -> u64 { + if Self::BITS <= 62 { + let diff = a.wrapping_sub(b); + let borrow = diff >> 63; + diff.wrapping_add(borrow.wrapping_neg() & P) + } else { + let diff = (a as u128).wrapping_sub(b as u128); + let borrow = diff >> 127; + diff.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + } + + #[inline(always)] + fn mul_raw(a: u64, b: u64) -> u64 { + #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] + { + let (lo, hi) = mul64_wide(a, b); + Self::reduce_product_wide(lo, hi) + } + #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))] + { + Self::reduce_product((a as u128) * (b as u128)) + } + } + + #[inline(always)] + fn sqr_raw(a: u64) -> u64 { + Self::mul_raw(a, a) + } + + /// Squaring, equivalent to `self * self`. + #[inline(always)] + pub fn square(self) -> Self { + Self(Self::sqr_raw(self.0)) + } + + fn pow(self, mut exp: u64) -> Self { + let mut base = self; + let mut acc = Self::one(); + while exp > 0 { + if (exp & 1) == 1 { + acc *= base; + } + base = base.square(); + exp >>= 1; + } + acc + } + + /// Extract the canonical value. + #[inline(always)] + pub fn to_limbs(self) -> u64 { + self.0 + } + + /// 64×64 → 128-bit widening multiply, **no reduction**. + #[inline(always)] + pub fn mul_wide(self, other: Self) -> u128 { + let (lo, hi) = mul64_wide(self.0, other.0); + lo as u128 | ((hi as u128) << 64) + } + + /// 64×64 → 128-bit widening multiply with a raw `u64` operand, + /// **no reduction**. + #[inline(always)] + pub fn mul_wide_u64(self, other: u64) -> u128 { + let (lo, hi) = mul64_wide(self.0, other); + lo as u128 | ((hi as u128) << 64) + } + + /// Reduce a u128 value via Solinas folding to a canonical field element. + #[inline(always)] + pub fn solinas_reduce(x: u128) -> Self { + Self(Self::reduce_u128(x)) + } +} + +impl Add for Fp64

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Self(Self::add_raw(self.0, rhs.0)) + } +} + +impl Sub for Fp64

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Self(Self::sub_raw(self.0, rhs.0)) + } +} + +impl Mul for Fp64

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Self(Self::mul_raw(self.0, rhs.0)) + } +} + +impl Neg for Fp64

{ + type Output = Self; + #[inline] + fn neg(self) -> Self::Output { + Self(Self::sub_raw(0, self.0)) + } +} + +impl AddAssign for Fp64

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for Fp64

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for Fp64

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl<'a, const P: u64> Add<&'a Self> for Fp64

{ + type Output = Self; + #[inline] + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} + +impl<'a, const P: u64> Sub<&'a Self> for Fp64

{ + type Output = Self; + #[inline] + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} + +impl<'a, const P: u64> Mul<&'a Self> for Fp64

{ + type Output = Self; + #[inline] + fn mul(self, rhs: &'a Self) -> Self::Output { + self * *rhs + } +} + +impl Valid for Fp64

{ + fn check(&self) -> Result<(), SerializationError> { + if self.0 < P { + Ok(()) + } else { + Err(SerializationError::InvalidData("Fp64 out of range".into())) + } + } +} + +impl HachiSerialize for Fp64

{ + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(&mut writer, Compress::No)?; + Ok(()) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 8 + } +} + +impl HachiDeserialize for Fp64

{ + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let x = u64::deserialize_with_mode(&mut reader, Compress::No, validate)?; + if matches!(validate, Validate::Yes) && x >= P { + return Err(SerializationError::InvalidData( + "Fp64 out of range".to_string(), + )); + } + let out = if matches!(validate, Validate::Yes) { + Self(x) + } else { + Self(Self::reduce_u128(x as u128)) + }; + Ok(out) + } +} + +impl AdditiveGroup for Fp64

{ + const ZERO: Self = Self(0); +} + +impl FieldCore for Fp64

{ + fn one() -> Self { + Self(if P > 1 { 1 } else { 0 }) + } + + fn is_zero(&self) -> bool { + self.0 == 0 + } + + fn inv(self) -> Option { + let inv = self.inv_or_zero(); + if self.is_zero() { + None + } else { + Some(inv) + } + } + + const TWO_INV: Self = Self((P as u128).div_ceil(2) as u64); +} + +impl Invertible for Fp64

{ + fn inv_or_zero(self) -> Self { + let candidate = self.pow(P.wrapping_sub(2)); + let nz = ((self.0 | self.0.wrapping_neg()) >> 63) & 1; + let mask = 0u64.wrapping_sub(nz); + Self(candidate.0 & mask) + } +} + +impl FieldSampling for Fp64

{ + fn sample(rng: &mut R) -> Self { + let lo = rng.next_u64() as u128; + let hi = rng.next_u64() as u128; + Self(Self::reduce_u128(lo | (hi << 64))) + } +} + +impl FromSmallInt for Fp64

{ + fn from_u64(val: u64) -> Self { + Self(Self::reduce_u128(val as u128)) + } + + fn from_i64(val: i64) -> Self { + if val >= 0 { + Self::from_u64(val as u64) + } else { + -Self::from_u64(val.unsigned_abs()) + } + } +} + +impl CanonicalField for Fp64

{ + fn to_canonical_u128(self) -> u128 { + self.0 as u128 + } + + fn from_canonical_u128_checked(val: u128) -> Option { + if val < P as u128 { + Some(Self(val as u64)) + } else { + None + } + } + + fn from_canonical_u128_reduced(val: u128) -> Self { + Self(Self::reduce_u128(val)) + } +} + +impl PseudoMersenneField for Fp64

{ + const MODULUS_BITS: u32 = Self::BITS; + const MODULUS_OFFSET: u128 = Self::C as u128; +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F40 = Fp64<{ (1u64 << 40) - 195 }>; // 2^40 - 195 + type F64 = Fp64<{ u64::MAX - 58 }>; // 2^64 - 59 + + #[test] + fn solinas_constants() { + assert_eq!(F40::BITS, 40); + assert_eq!(F40::C, 195); + + assert_eq!(F64::BITS, 64); + assert_eq!(F64::C, 59); + } + + #[test] + fn basic_arithmetic_sub_word() { + let a = F40::from_u64(1_000_000); + let b = F40::from_u64(2_000_000); + let p = (1u64 << 40) - 195; + assert_eq!((a + b).to_canonical_u64(), 3_000_000); + assert_eq!( + (a * b).to_canonical_u64(), + (1_000_000u128 * 2_000_000u128 % p as u128) as u64 + ); + } + + #[test] + fn basic_arithmetic_full_word() { + let a = F64::from_u64(1_000_000_000); + let b = F64::from_u64(2_000_000_000); + let p = u64::MAX - 58; + assert_eq!( + (a * b).to_canonical_u64(), + (1_000_000_000u128 * 2_000_000_000u128 % p as u128) as u64 + ); + } + + #[test] + fn mul_wide_matches_full_mul() { + let mut rng = StdRng::seed_from_u64(0xdead_beef); + for _ in 0..1000 { + let a: F40 = FieldSampling::sample(&mut rng); + let b: F40 = FieldSampling::sample(&mut rng); + let expected = a * b; + let reduced = F40::solinas_reduce(a.mul_wide(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn mul_wide_u64_matches() { + let mut rng = StdRng::seed_from_u64(0xcafe_d00d); + for _ in 0..1000 { + let a: F40 = FieldSampling::sample(&mut rng); + let b = rng.next_u64() % ((1u64 << 40) - 195); + let expected = a * F40::from_canonical_u64(b); + let reduced = F40::solinas_reduce(a.mul_wide_u64(b)); + assert_eq!(reduced, expected); + } + } + + #[test] + fn pseudo_mersenne_trait() { + assert_eq!(::MODULUS_BITS, 40); + assert_eq!(::MODULUS_OFFSET, 195); + assert_eq!(::MODULUS_BITS, 64); + assert_eq!(::MODULUS_OFFSET, 59); + } + + #[test] + fn shift_optimization_detected() { + type G = Fp64<{ (1u64 << 56) - 27 }>; // C = 27, not 2^a±1 + assert_eq!(G::C_SHIFT_KIND, 0); + + type H = Fp64<{ u64::MAX - 58 }>; // C = 59, not 2^a±1 + assert_eq!(H::C_SHIFT_KIND, 0); + } + + #[test] + fn reduce_u128_large() { + assert_eq!(F64::from_canonical_u128_reduced(u128::MAX), { + let p = u64::MAX as u128 - 58; + F64::from_canonical_u64((u128::MAX % p) as u64) + }); + } +} diff --git a/src/algebra/fields/lift.rs b/src/algebra/fields/lift.rs new file mode 100644 index 00000000..e05a8aa7 --- /dev/null +++ b/src/algebra/fields/lift.rs @@ -0,0 +1,100 @@ +//! Helpers for embedding base fields into extension fields. + +use crate::algebra::fields::ext::{Fp2, Fp2Config, Fp4, Fp4Config}; +use crate::primitives::serialization::Valid; +use crate::{FieldCore, FromSmallInt}; + +/// Lift a base-field element into an extension field. +/// +/// This is intentionally small: for extension towers we embed into the constant term. +pub trait LiftBase: FieldCore { + /// Embed `x ∈ F` as a constant in `Self`. + fn lift_base(x: F) -> Self; +} + +/// An algebraic extension of base field `F`. +/// +/// Provides the extension degree and a constructor from a slice of base-field +/// coefficients (in the canonical basis `{1, u, u^2, ...}`). +pub trait ExtField: FieldCore + LiftBase + FromSmallInt { + /// Extension degree: `[Self : F]`. + const EXT_DEGREE: usize; + + /// Construct from a coefficient slice `[c0, c1, ..., c_{d-1}]`. + /// + /// # Panics + /// Panics if `coeffs.len() != Self::EXT_DEGREE`. + fn from_base_slice(coeffs: &[F]) -> Self; +} + +impl ExtField for F { + const EXT_DEGREE: usize = 1; + + #[inline] + fn from_base_slice(coeffs: &[F]) -> Self { + assert_eq!(coeffs.len(), 1); + coeffs[0] + } +} + +impl ExtField for Fp2 +where + F: FieldCore + FromSmallInt + Valid, + C: Fp2Config, +{ + const EXT_DEGREE: usize = 2; + + #[inline] + fn from_base_slice(coeffs: &[F]) -> Self { + assert_eq!(coeffs.len(), 2); + Self::new(coeffs[0], coeffs[1]) + } +} + +impl ExtField for Fp4 +where + F: FieldCore + FromSmallInt + Valid, + C2: Fp2Config, + C4: Fp4Config, +{ + const EXT_DEGREE: usize = 4; + + #[inline] + fn from_base_slice(coeffs: &[F]) -> Self { + assert_eq!(coeffs.len(), 4); + Self::new( + Fp2::new(coeffs[0], coeffs[1]), + Fp2::new(coeffs[2], coeffs[3]), + ) + } +} + +impl LiftBase for F { + #[inline] + fn lift_base(x: F) -> Self { + x + } +} + +impl LiftBase for Fp2 +where + F: FieldCore + Valid, + C: Fp2Config, +{ + #[inline] + fn lift_base(x: F) -> Self { + Self::new(x, F::zero()) + } +} + +impl LiftBase for Fp4 +where + F: FieldCore + Valid, + C2: Fp2Config, + C4: Fp4Config, +{ + #[inline] + fn lift_base(x: F) -> Self { + Self::new(Fp2::new(x, F::zero()), Fp2::new(F::zero(), F::zero())) + } +} diff --git a/src/algebra/fields/mod.rs b/src/algebra/fields/mod.rs new file mode 100644 index 00000000..1f191b57 --- /dev/null +++ b/src/algebra/fields/mod.rs @@ -0,0 +1,48 @@ +//! Prime fields and extension field towers. + +pub mod ext; +pub mod fp128; +pub mod fp32; +pub mod fp64; +pub mod lift; +pub mod packed; +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(target_feature = "avx512f", target_feature = "avx512dq")) +))] +pub mod packed_avx2; +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx512f", + target_feature = "avx512dq" +))] +pub mod packed_avx512; +pub mod packed_ext; +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +pub mod packed_neon; +pub mod pseudo_mersenne; +pub(crate) mod util; +pub mod wide; + +pub use ext::{Ext2, Ext4, Fp2, Fp2Config, Fp4, Fp4Config, NegOneNr, TwoNr, UnitNr}; +pub use fp128::{ + Fp128, Prime128M13M4P0, Prime128M37P3P0, Prime128M52M3P0, Prime128M54P4P0, Prime128M8M4M1M0, + Prime128Offset5823, +}; +pub use fp32::Fp32; +pub use fp64::Fp64; +pub use lift::{ExtField, LiftBase}; +pub use packed::{ + Fp128Packing, Fp32Packing, Fp64Packing, HasPacking, NoPacking, PackedField, PackedValue, +}; +pub use pseudo_mersenne::{ + is_pow2_offset, pow2_offset, pseudo_mersenne_modulus, Pow2Offset128Field, Pow2Offset24Field, + Pow2Offset30Field, Pow2Offset31Field, Pow2Offset32Field, Pow2Offset40Field, Pow2Offset48Field, + Pow2Offset56Field, Pow2Offset64Field, Pow2OffsetPrimeSpec, POW2_OFFSET_IMPLEMENTED_MAX_BITS, + POW2_OFFSET_MAX, POW2_OFFSET_PRIMES, POW2_OFFSET_TABLE, +}; +pub use wide::{ + AccumPair, Fp128MulU64Accum, Fp128ProductAccum, Fp128x8i32, Fp32x2i32, Fp64ProductAccum, + Fp64x4i32, HasUnreducedOps, HasWide, ReduceTo, +}; diff --git a/src/algebra/fields/packed.rs b/src/algebra/fields/packed.rs new file mode 100644 index 00000000..6da7917b --- /dev/null +++ b/src/algebra/fields/packed.rs @@ -0,0 +1,455 @@ +//! Packed field abstractions and architecture-specific SIMD backends. + +use crate::algebra::fields::{Fp128, Fp32, Fp64}; +use crate::FieldCore; +use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; + +/// Array-like packed values over a scalar type. +pub trait PackedValue: 'static + Copy + Send + Sync { + /// Scalar value type carried by each lane. + type Value: 'static + Copy + Send + Sync; + + /// Number of scalar lanes. + const WIDTH: usize; + + /// Build from a lane generator. + fn from_fn(f: F) -> Self + where + F: FnMut(usize) -> Self::Value; + + /// Extract one lane. + fn extract(&self, lane: usize) -> Self::Value; + + /// Pack a scalar slice into packed values. + /// + /// # Panics + /// + /// Panics if the length is not divisible by `WIDTH`. + #[inline] + fn pack_slice(buf: &[Self::Value]) -> Vec { + assert!( + buf.len() % Self::WIDTH == 0, + "slice length {} must be divisible by WIDTH {}", + buf.len(), + Self::WIDTH + ); + buf.chunks_exact(Self::WIDTH) + .map(|chunk| Self::from_fn(|i| chunk[i])) + .collect() + } + + /// Packed prefix + scalar suffix split. + #[inline] + fn pack_slice_with_suffix(buf: &[Self::Value]) -> (Vec, &[Self::Value]) { + let split = buf.len() - (buf.len() % Self::WIDTH); + let (packed, suffix) = buf.split_at(split); + (Self::pack_slice(packed), suffix) + } + + /// Unpack packed values into a flat scalar vector. + #[inline] + fn unpack_slice(buf: &[Self]) -> Vec { + let mut out = Vec::with_capacity(buf.len() * Self::WIDTH); + for packed in buf { + for lane in 0..Self::WIDTH { + out.push(packed.extract(lane)); + } + } + out + } +} + +/// Packed arithmetic over a scalar field. +pub trait PackedField: + PackedValue + Add + Sub + Mul +{ + /// Scalar field type. + type Scalar: FieldCore; + + /// Broadcast one scalar across all lanes. + fn broadcast(value: Self::Scalar) -> Self; +} + +/// Scalar fallback packed type with one lane. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +#[repr(transparent)] +pub struct NoPacking(pub [T; 1]); + +impl PackedValue for NoPacking +where + T: 'static + Copy + Send + Sync, +{ + type Value = T; + const WIDTH: usize = 1; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([f(0)]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert_eq!(lane, 0); + self.0[0] + } +} + +impl Add for NoPacking { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([self.0[0] + rhs.0[0]]) + } +} + +impl Sub for NoPacking { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([self.0[0] - rhs.0[0]]) + } +} + +impl Mul for NoPacking { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + Self([self.0[0] * rhs.0[0]]) + } +} + +impl AddAssign for NoPacking { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for NoPacking { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for NoPacking { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for NoPacking { + type Scalar = T; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value]) + } +} + +/// Scalar field -> packed field association. +pub trait HasPacking: FieldCore { + /// Packed representation for this scalar field. + type Packing: PackedField; +} + +/// Selected packed backend for `Fp128`. +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +pub type Fp128Packing = super::packed_neon::PackedFp128Neon

; + +/// Selected packed backend for `Fp128`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx512f", + target_feature = "avx512dq" +))] +pub type Fp128Packing = super::packed_avx512::PackedFp128Avx512

; + +/// Selected packed backend for `Fp128`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(target_feature = "avx512f", target_feature = "avx512dq")) +))] +pub type Fp128Packing = super::packed_avx2::PackedFp128Avx2

; + +/// Selected packed backend for `Fp128`. +#[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "x86_64", target_feature = "avx2") +)))] +pub type Fp128Packing = NoPacking>; + +impl HasPacking for Fp128

{ + type Packing = Fp128Packing

; +} + +/// Selected packed backend for `Fp32`. +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +pub type Fp32Packing = super::packed_neon::PackedFp32Neon

; + +/// Selected packed backend for `Fp32`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx512f", + target_feature = "avx512dq" +))] +pub type Fp32Packing = super::packed_avx512::PackedFp32Avx512

; + +/// Selected packed backend for `Fp32`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(target_feature = "avx512f", target_feature = "avx512dq")) +))] +pub type Fp32Packing = super::packed_avx2::PackedFp32Avx2

; + +/// Selected packed backend for `Fp32`. +#[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "x86_64", target_feature = "avx2") +)))] +pub type Fp32Packing = NoPacking>; + +impl HasPacking for Fp32

{ + type Packing = Fp32Packing

; +} + +/// Selected packed backend for `Fp64`. +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +pub type Fp64Packing = super::packed_neon::PackedFp64Neon

; + +/// Selected packed backend for `Fp64`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx512f", + target_feature = "avx512dq" +))] +pub type Fp64Packing = super::packed_avx512::PackedFp64Avx512

; + +/// Selected packed backend for `Fp64`. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(target_feature = "avx512f", target_feature = "avx512dq")) +))] +pub type Fp64Packing = super::packed_avx2::PackedFp64Avx2

; + +/// Selected packed backend for `Fp64`. +#[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "x86_64", target_feature = "avx2") +)))] +pub type Fp64Packing = NoPacking>; + +impl HasPacking for Fp64

{ + type Packing = Fp64Packing

; +} + +#[cfg(test)] +mod tests { + use super::{HasPacking, PackedField, PackedValue}; + use crate::algebra::fields::{ + Pow2Offset24Field, Pow2Offset31Field, Pow2Offset32Field, Pow2Offset40Field, + Pow2Offset64Field, Prime128M13M4P0, + }; + use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt}; + use rand::{rngs::StdRng, RngCore, SeedableRng}; + + fn rand_u128(rng: &mut R) -> u128 { + let lo = rng.next_u64() as u128; + let hi = rng.next_u64() as u128; + lo | (hi << 64) + } + + fn check_packed_add_sub_mul(seed: u64) + where + F: FieldCore + FieldSampling + PartialEq + std::fmt::Debug, + PF: PackedField + PackedValue, + { + let mut rng = StdRng::seed_from_u64(seed); + let len = PF::WIDTH * 17 + 3; + let lhs: Vec = (0..len).map(|_| FieldSampling::sample(&mut rng)).collect(); + let rhs: Vec = (0..len).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let (lhs_p, lhs_s) = PF::pack_slice_with_suffix(&lhs); + let (rhs_p, rhs_s) = PF::pack_slice_with_suffix(&rhs); + + let add_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a + b) + .collect(); + let sub_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a - b) + .collect(); + let mul_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a * b) + .collect(); + + let mut add_out = PF::unpack_slice(&add_p); + let mut sub_out = PF::unpack_slice(&sub_p); + let mut mul_out = PF::unpack_slice(&mul_p); + + for (&a, &b) in lhs_s.iter().zip(rhs_s.iter()) { + add_out.push(a + b); + sub_out.push(a - b); + mul_out.push(a * b); + } + + for i in 0..len { + assert_eq!( + add_out[i], + lhs[i] + rhs[i], + "packed add mismatch at lane {i}" + ); + assert_eq!( + sub_out[i], + lhs[i] - rhs[i], + "packed sub mismatch at lane {i}" + ); + assert_eq!( + mul_out[i], + lhs[i] * rhs[i], + "packed mul mismatch at lane {i}" + ); + } + } + + fn check_broadcast_roundtrip(val: F) + where + F: FieldCore + PartialEq + std::fmt::Debug, + PF: PackedField + PackedValue, + { + let p = PF::broadcast(val); + for lane in 0..PF::WIDTH { + assert_eq!(p.extract(lane), val); + } + } + + #[test] + fn packed_fp128_add_sub_mul_match_scalar() { + type F = Prime128M13M4P0; + type PF = ::Packing; + + let mut rng = StdRng::seed_from_u64(0x55aa_4422_1177_0033); + let len = PF::WIDTH * 17 + 3; + let lhs: Vec = (0..len) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + let rhs: Vec = (0..len) + .map(|_| F::from_canonical_u128_reduced(rand_u128(&mut rng))) + .collect(); + + let (lhs_p, lhs_s) = PF::pack_slice_with_suffix(&lhs); + let (rhs_p, rhs_s) = PF::pack_slice_with_suffix(&rhs); + + let add_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a + b) + .collect(); + let sub_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a - b) + .collect(); + let mul_p: Vec = lhs_p + .iter() + .zip(rhs_p.iter()) + .map(|(&a, &b)| a * b) + .collect(); + + let mut add_out = PF::unpack_slice(&add_p); + let mut sub_out = PF::unpack_slice(&sub_p); + let mut mul_out = PF::unpack_slice(&mul_p); + + for (&a, &b) in lhs_s.iter().zip(rhs_s.iter()) { + add_out.push(a + b); + sub_out.push(a - b); + mul_out.push(a * b); + } + + for i in 0..len { + assert_eq!( + add_out[i], + lhs[i] + rhs[i], + "packed add mismatch at lane {i}" + ); + assert_eq!( + sub_out[i], + lhs[i] - rhs[i], + "packed sub mismatch at lane {i}" + ); + assert_eq!( + mul_out[i], + lhs[i] * rhs[i], + "packed mul mismatch at lane {i}" + ); + } + } + + #[test] + fn fp128_broadcast_and_extract_roundtrip() { + type F = Prime128M13M4P0; + type PF = ::Packing; + check_broadcast_roundtrip::(F::from_u64(42)); + } + + #[test] + fn packed_fp32_24b_add_sub_mul() { + type F = Pow2Offset24Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa24_bb24_cc24_dd24); + } + + #[test] + fn packed_fp32_31b_add_sub_mul() { + type F = Pow2Offset31Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa31_bb31_cc31_dd31); + } + + #[test] + fn packed_fp32_32b_add_sub_mul() { + type F = Pow2Offset32Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa32_bb32_cc32_dd32); + } + + #[test] + fn fp32_broadcast_and_extract_roundtrip() { + type F = Pow2Offset24Field; + type PF = ::Packing; + check_broadcast_roundtrip::(F::from_u64(42)); + } + + #[test] + fn packed_fp64_40b_add_sub_mul() { + type F = Pow2Offset40Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa40_bb40_cc40_dd40); + } + + #[test] + fn packed_fp64_64b_add_sub_mul() { + type F = Pow2Offset64Field; + type PF = ::Packing; + check_packed_add_sub_mul::(0xaa64_bb64_cc64_dd64); + } + + #[test] + fn fp64_broadcast_and_extract_roundtrip() { + type F = Pow2Offset40Field; + type PF = ::Packing; + check_broadcast_roundtrip::(F::from_u64(42)); + } +} diff --git a/src/algebra/fields/packed_avx2.rs b/src/algebra/fields/packed_avx2.rs new file mode 100644 index 00000000..c29685e2 --- /dev/null +++ b/src/algebra/fields/packed_avx2.rs @@ -0,0 +1,767 @@ +//! AVX2 packed backends for Fp32, Fp64, Fp128. +//! +//! Techniques adapted from plonky2 (Goldilocks) and plonky3 (Mersenne-31). + +use super::packed::{PackedField, PackedValue}; +use crate::algebra::fields::{Fp128, Fp32, Fp64}; +use core::arch::x86_64::*; +use core::fmt; +use core::mem::transmute; +use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; + +/// Duplicate high 32 bits of each 64-bit lane into the low 32 bits. +/// Uses the float `movehdup` instruction which runs on port 5 (doesn't compete +/// with multiply on ports 0/1). +#[inline(always)] +unsafe fn movehdup_epi32(x: __m256i) -> __m256i { + _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x))) +} + +#[inline(always)] +unsafe fn moveldup_epi32(x: __m256i) -> __m256i { + _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(x))) +} + +/// 64×64→128 schoolbook multiply using 32×32→64 partial products. +/// Returns (hi, lo) representing the 128-bit product. +#[inline] +unsafe fn mul64_64_256(x: __m256i, y: __m256i) -> (__m256i, __m256i) { + let x_hi = movehdup_epi32(x); + let y_hi = movehdup_epi32(y); + + let mul_ll = _mm256_mul_epu32(x, y); + let mul_lh = _mm256_mul_epu32(x, y_hi); + let mul_hl = _mm256_mul_epu32(x_hi, y); + let mul_hh = _mm256_mul_epu32(x_hi, y_hi); + + let mul_ll_hi = _mm256_srli_epi64::<32>(mul_ll); + let t0 = _mm256_add_epi64(mul_hl, mul_ll_hi); + let mask32 = _mm256_set1_epi64x(0xFFFF_FFFF_i64); + let t0_lo = _mm256_and_si256(t0, mask32); + let t0_hi = _mm256_srli_epi64::<32>(t0); + let t1 = _mm256_add_epi64(mul_lh, t0_lo); + let t2 = _mm256_add_epi64(mul_hh, t0_hi); + let t1_hi = _mm256_srli_epi64::<32>(t1); + let res_hi = _mm256_add_epi64(t2, t1_hi); + + let t1_lo = moveldup_epi32(t1); + let res_lo = _mm256_blend_epi32::<0b10101010>(mul_ll, t1_lo); + + (res_hi, res_lo) +} + +/// Number of `Fp32` lanes in an AVX2 packed vector. +pub const FP32_WIDTH: usize = 8; + +/// AVX2 packed arithmetic for `Fp32

`, processing 8 lanes. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PackedFp32Avx2(pub [Fp32

; FP32_WIDTH]); + +impl PackedFp32Avx2

{ + const BITS: u32 = 32 - P.leading_zeros(); + + const C: u32 = { + let c = if Self::BITS == 32 { + 0u32.wrapping_sub(P) + } else { + (1u32 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u64) * (c as u64 + 1) < P as u64, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + const MASK_U64: u64 = if Self::BITS == 32 { + u32::MAX as u64 + } else { + (1u64 << Self::BITS) - 1 + }; + + #[inline(always)] + fn to_vec(self) -> __m256i { + unsafe { transmute(self) } + } + + #[inline(always)] + unsafe fn from_vec(v: __m256i) -> Self { + unsafe { transmute(v) } + } +} + +impl Default for PackedFp32Avx2

{ + #[inline] + fn default() -> Self { + Self([Fp32(0); FP32_WIDTH]) + } +} + +impl fmt::Debug for PackedFp32Avx2

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp32Avx2").field(&self.0).finish() + } +} + +impl PartialEq for PackedFp32Avx2

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for PackedFp32Avx2

{} + +impl Add for PackedFp32Avx2

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm256_set1_epi32(P as i32); + + let result = if Self::BITS <= 31 { + let t = _mm256_add_epi32(a, b); + let u = _mm256_sub_epi32(t, p); + _mm256_min_epu32(t, u) + } else { + let c = _mm256_set1_epi32(Self::C as i32); + let t = _mm256_add_epi32(a, b); + // Emulate unsigned compare: XOR with 0x80000000 converts u32 compare to i32 + let sign32 = _mm256_set1_epi32(i32::MIN); + let overflow = + _mm256_cmpgt_epi32(_mm256_xor_si256(a, sign32), _mm256_xor_si256(t, sign32)); + // Step 1: correct overflow (2^32 ≡ C mod P) + let t2 = _mm256_add_epi32(t, _mm256_and_si256(overflow, c)); + // Step 2: subtract P if t2 >= P + let r = _mm256_sub_epi32(t2, p); + _mm256_min_epu32(t2, r) + }; + + Self::from_vec(result) + } + } +} + +impl Sub for PackedFp32Avx2

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm256_set1_epi32(P as i32); + + let result = if Self::BITS <= 31 { + let t = _mm256_sub_epi32(a, b); + let u = _mm256_add_epi32(t, p); + _mm256_min_epu32(t, u) + } else { + let t = _mm256_sub_epi32(a, b); + let sign32 = _mm256_set1_epi32(i32::MIN); + let underflow = + _mm256_cmpgt_epi32(_mm256_xor_si256(b, sign32), _mm256_xor_si256(a, sign32)); + let corrected = _mm256_add_epi32(t, p); + _mm256_blendv_epi8(t, corrected, underflow) + }; + + Self::from_vec(result) + } + } +} + +impl Mul for PackedFp32Avx2

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + + let prod_evn = _mm256_mul_epu32(a, b); + let a_odd = movehdup_epi32(a); + let b_odd = movehdup_epi32(b); + let prod_odd = _mm256_mul_epu32(a_odd, b_odd); + + let mask = _mm256_set1_epi64x(Self::MASK_U64 as i64); + let c_vec = _mm256_set1_epi64x(Self::C as i64); + let shift = _mm_set_epi64x(0, Self::BITS as i64); + + // Fold 1 + let evn_lo = _mm256_and_si256(prod_evn, mask); + let evn_hi = _mm256_srl_epi64(prod_evn, shift); + let evn_f1 = _mm256_add_epi64(evn_lo, _mm256_mul_epu32(evn_hi, c_vec)); + + let odd_lo = _mm256_and_si256(prod_odd, mask); + let odd_hi = _mm256_srl_epi64(prod_odd, shift); + let odd_f1 = _mm256_add_epi64(odd_lo, _mm256_mul_epu32(odd_hi, c_vec)); + + // Fold 2 + let evn_f1_lo = _mm256_and_si256(evn_f1, mask); + let evn_f1_hi = _mm256_srl_epi64(evn_f1, shift); + let evn_f2 = _mm256_add_epi64(evn_f1_lo, _mm256_mul_epu32(evn_f1_hi, c_vec)); + + let odd_f1_lo = _mm256_and_si256(odd_f1, mask); + let odd_f1_hi = _mm256_srl_epi64(odd_f1, shift); + let odd_f2 = _mm256_add_epi64(odd_f1_lo, _mm256_mul_epu32(odd_f1_hi, c_vec)); + + // Recombine even/odd: shift odd results into high 32-bit positions, blend. + let odd_shifted = _mm256_slli_epi64::<32>(odd_f2); + let combined = _mm256_blend_epi32::<0b10101010>(evn_f2, odd_shifted); + + // Conditional subtract P + let p_vec = _mm256_set1_epi32(P as i32); + let reduced = _mm256_sub_epi32(combined, p_vec); + Self::from_vec(_mm256_min_epu32(combined, reduced)) + } + } +} + +impl PackedValue for PackedFp32Avx2

{ + type Value = Fp32

; + const WIDTH: usize = FP32_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([f(0), f(1), f(2), f(3), f(4), f(5), f(6), f(7)]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP32_WIDTH); + self.0[lane] + } +} + +impl AddAssign for PackedFp32Avx2

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp32Avx2

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp32Avx2

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp32Avx2

{ + type Scalar = Fp32

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value; FP32_WIDTH]) + } +} + +/// Number of `Fp64` lanes in an AVX2 packed vector. +pub const FP64_WIDTH: usize = 4; + +/// AVX2 packed arithmetic for `Fp64

`, processing 4 lanes. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PackedFp64Avx2(pub [Fp64

; FP64_WIDTH]); + +impl PackedFp64Avx2

{ + const BITS: u32 = 64 - P.leading_zeros(); + + const C_LO: u64 = { + let c = if Self::BITS == 64 { + 0u64.wrapping_sub(P) + } else { + (1u64 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + c + }; + + const MASK64: u64 = if Self::BITS < 64 { + (1u64 << Self::BITS) - 1 + } else { + u64::MAX + }; + + #[inline(always)] + fn to_vec(self) -> __m256i { + unsafe { transmute(self) } + } + + #[inline(always)] + unsafe fn from_vec(v: __m256i) -> Self { + unsafe { transmute(v) } + } + + #[inline] + unsafe fn reduce128_vec(hi: __m256i, lo: __m256i) -> __m256i { + if Self::BITS < 64 { + Self::reduce128_small_k(hi, lo) + } else { + Self::reduce128_full_k(hi, lo) + } + } + + /// Reduction for BITS < 64. All intermediates fit in u64 — no overflow. + #[inline] + unsafe fn reduce128_small_k(hi: __m256i, lo: __m256i) -> __m256i { + let mask_k = _mm256_set1_epi64x(Self::MASK64 as i64); + let c_vec = _mm256_set1_epi64x(Self::C_LO as i64); + let p_vec = _mm256_set1_epi64x(P as i64); + let shift_k = _mm_set_epi64x(0, Self::BITS as i64); + let shift_64mk = _mm_set_epi64x(0, (64 - Self::BITS) as i64); + + let lo_k = _mm256_and_si256(lo, mask_k); + let lo_upper = _mm256_srl_epi64(lo, shift_k); + let hi_shifted = _mm256_sll_epi64(hi, shift_64mk); + let hi_k = _mm256_or_si256(lo_upper, hi_shifted); + + let c_hi_lo = _mm256_mul_epu32(c_vec, hi_k); + let hi_k_top = _mm256_srli_epi64::<32>(hi_k); + let c_hi_top = _mm256_mul_epu32(c_vec, hi_k_top); + let c_hi_top_shifted = _mm256_slli_epi64::<32>(c_hi_top); + let c_hi_full = _mm256_add_epi64(c_hi_lo, c_hi_top_shifted); + + let fold1 = _mm256_add_epi64(lo_k, c_hi_full); + + let fold1_lo_k = _mm256_and_si256(fold1, mask_k); + let fold1_hi = _mm256_srl_epi64(fold1, shift_k); + let c_fold1_hi = _mm256_mul_epu32(c_vec, fold1_hi); + let fold2 = _mm256_add_epi64(fold1_lo_k, c_fold1_hi); + + let reduced = _mm256_sub_epi64(fold2, p_vec); + let sign = _mm256_set1_epi64x(i64::MIN); + let fold2_s = _mm256_xor_si256(fold2, sign); + let reduced_s = _mm256_xor_si256(reduced, sign); + let fold2_lt = _mm256_cmpgt_epi64(reduced_s, fold2_s); + _mm256_blendv_epi8(reduced, fold2, fold2_lt) + } + + /// Reduction for BITS == 64. Uses XOR-with-SIGN_BIT trick for unsigned + /// overflow detection. + #[inline] + unsafe fn reduce128_full_k(hi: __m256i, lo: __m256i) -> __m256i { + let c_vec = _mm256_set1_epi64x(Self::C_LO as i64); + let p_vec = _mm256_set1_epi64x(P as i64); + let sign = _mm256_set1_epi64x(i64::MIN); + let c_hi_lo = _mm256_mul_epu32(c_vec, hi); + let hi_hi = _mm256_srli_epi64::<32>(hi); + let c_hi_hi = _mm256_mul_epu32(c_vec, hi_hi); + + let c_hi_hi_lo32 = _mm256_slli_epi64::<32>(c_hi_hi); + let c_hi_carry = _mm256_srli_epi64::<32>(c_hi_hi); + + let sum_lo = _mm256_add_epi64(c_hi_lo, c_hi_hi_lo32); + let c_hi_lo_s = _mm256_xor_si256(c_hi_lo, sign); + let sum_lo_s = _mm256_xor_si256(sum_lo, sign); + let carry0 = _mm256_cmpgt_epi64(c_hi_lo_s, sum_lo_s); + let overflow = _mm256_sub_epi64(c_hi_carry, carry0); + + let s = _mm256_add_epi64(lo, sum_lo); + let lo_s = _mm256_xor_si256(lo, sign); + let s_s = _mm256_xor_si256(s, sign); + let carry1 = _mm256_cmpgt_epi64(lo_s, s_s); + let total_overflow = _mm256_sub_epi64(overflow, carry1); + + let final_corr = _mm256_mul_epu32(c_vec, total_overflow); + let result = _mm256_add_epi64(s, final_corr); + let s2_s = _mm256_xor_si256(s, sign); + let result_s = _mm256_xor_si256(result, sign); + let carry_f = _mm256_cmpgt_epi64(s2_s, result_s); + let corr_f = _mm256_and_si256(carry_f, c_vec); + let result = _mm256_add_epi64(result, corr_f); + + let result_s2 = _mm256_xor_si256(result, sign); + let p_s = _mm256_xor_si256(p_vec, sign); + let lt_p = _mm256_cmpgt_epi64(p_s, result_s2); + let sub_amt = _mm256_andnot_si256(lt_p, p_vec); + _mm256_sub_epi64(result, sub_amt) + } +} + +impl Default for PackedFp64Avx2

{ + #[inline] + fn default() -> Self { + Self([Fp64(0); FP64_WIDTH]) + } +} + +impl fmt::Debug for PackedFp64Avx2

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp64Avx2").field(&self.0).finish() + } +} + +impl PartialEq for PackedFp64Avx2

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for PackedFp64Avx2

{} + +impl Add for PackedFp64Avx2

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm256_set1_epi64x(P as i64); + + let result = if Self::BITS <= 62 { + // a + b < 2P < 2^63: no overflow. + let s = _mm256_add_epi64(a, b); + let r = _mm256_sub_epi64(s, p); + // s < P? Use signed compare after shift trick. + let sign = _mm256_set1_epi64x(i64::MIN); + let s_s = _mm256_xor_si256(s, sign); + let p_s = _mm256_xor_si256(p, sign); + let borrow = _mm256_cmpgt_epi64(p_s, s_s); + _mm256_blendv_epi8(r, s, borrow) + } else { + // a + b can overflow u64. + let s = _mm256_add_epi64(a, b); + let sign = _mm256_set1_epi64x(i64::MIN); + let a_s = _mm256_xor_si256(a, sign); + let s_s = _mm256_xor_si256(s, sign); + let overflow = _mm256_cmpgt_epi64(a_s, s_s); + let c = _mm256_set1_epi64x(Self::C_LO as i64); + let s_plus_c = _mm256_add_epi64(s, c); + let s_minus_p = _mm256_sub_epi64(s, p); + let p_s = _mm256_xor_si256(p, sign); + let lt_p = _mm256_cmpgt_epi64(p_s, s_s); + let no_of = _mm256_blendv_epi8(s_minus_p, s, lt_p); + _mm256_blendv_epi8(no_of, s_plus_c, overflow) + }; + + Self::from_vec(result) + } + } +} + +impl Sub for PackedFp64Avx2

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm256_set1_epi64x(P as i64); + let d = _mm256_sub_epi64(a, b); + + let sign = _mm256_set1_epi64x(i64::MIN); + let a_s = _mm256_xor_si256(a, sign); + let b_s = _mm256_xor_si256(b, sign); + let underflow = _mm256_cmpgt_epi64(b_s, a_s); + let corrected = _mm256_add_epi64(d, p); + Self::from_vec(_mm256_blendv_epi8(d, corrected, underflow)) + } + } +} + +impl Mul for PackedFp64Avx2

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + unsafe { + let (hi, lo) = mul64_64_256(self.to_vec(), rhs.to_vec()); + Self::from_vec(Self::reduce128_vec(hi, lo)) + } + } +} + +impl PackedValue for PackedFp64Avx2

{ + type Value = Fp64

; + const WIDTH: usize = FP64_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([f(0), f(1), f(2), f(3)]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP64_WIDTH); + self.0[lane] + } +} + +impl AddAssign for PackedFp64Avx2

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp64Avx2

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp64Avx2

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp64Avx2

{ + type Scalar = Fp64

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value; FP64_WIDTH]) + } +} + +/// Number of `Fp128` lanes in an AVX2 packed vector. +pub const FP128_WIDTH: usize = 4; + +/// AVX2 packed arithmetic for `Fp128

`, 4 lanes in SoA layout. +/// +/// Stores 4 elements as separate `lo` and `hi` `u64` arrays, enabling +/// vectorized add/sub via `__m256i`. Mul remains scalar per-lane. +#[derive(Clone, Copy)] +pub struct PackedFp128Avx2 { + lo: [u64; FP128_WIDTH], + hi: [u64; FP128_WIDTH], +} + +impl PackedFp128Avx2

{ + const P_LO: u64 = P as u64; + const P_HI: u64 = (P >> 64) as u64; +} + +impl Default for PackedFp128Avx2

{ + #[inline] + fn default() -> Self { + Self { + lo: [0; FP128_WIDTH], + hi: [0; FP128_WIDTH], + } + } +} + +impl fmt::Debug for PackedFp128Avx2

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let elems: Vec<_> = (0..FP128_WIDTH).map(|i| self.extract(i)).collect(); + f.debug_tuple("PackedFp128Avx2").field(&elems).finish() + } +} + +impl PartialEq for PackedFp128Avx2

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.lo == other.lo && self.hi == other.hi + } +} + +impl Eq for PackedFp128Avx2

{} + +impl Add for PackedFp128Avx2

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a_lo = _mm256_loadu_si256(self.lo.as_ptr().cast()); + let a_hi = _mm256_loadu_si256(self.hi.as_ptr().cast()); + let b_lo = _mm256_loadu_si256(rhs.lo.as_ptr().cast()); + let b_hi = _mm256_loadu_si256(rhs.hi.as_ptr().cast()); + let p_lo = _mm256_set1_epi64x(Self::P_LO as i64); + let p_hi = _mm256_set1_epi64x(Self::P_HI as i64); + let sign = _mm256_set1_epi64x(i64::MIN); + let one = _mm256_set1_epi64x(1); + + // 128-bit add with unsigned compare emulation (XOR sign bit) + let sum_lo = _mm256_add_epi64(a_lo, b_lo); + let carry_lo = + _mm256_cmpgt_epi64(_mm256_xor_si256(a_lo, sign), _mm256_xor_si256(sum_lo, sign)); + let carry_lo_bit = _mm256_and_si256(carry_lo, one); + + let hi_tmp = _mm256_add_epi64(a_hi, b_hi); + let ov1 = + _mm256_cmpgt_epi64(_mm256_xor_si256(a_hi, sign), _mm256_xor_si256(hi_tmp, sign)); + let sum_hi = _mm256_add_epi64(hi_tmp, carry_lo_bit); + let ov2 = _mm256_cmpgt_epi64( + _mm256_xor_si256(hi_tmp, sign), + _mm256_xor_si256(sum_hi, sign), + ); + let carry_128 = _mm256_or_si256(ov1, ov2); + + // 128-bit subtract P + let red_lo = _mm256_sub_epi64(sum_lo, p_lo); + let borrow_lo = + _mm256_cmpgt_epi64(_mm256_xor_si256(p_lo, sign), _mm256_xor_si256(sum_lo, sign)); + let borrow_lo_bit = _mm256_and_si256(borrow_lo, one); + + let red_hi_tmp = _mm256_sub_epi64(sum_hi, p_hi); + let bw1 = + _mm256_cmpgt_epi64(_mm256_xor_si256(p_hi, sign), _mm256_xor_si256(sum_hi, sign)); + let red_hi = _mm256_sub_epi64(red_hi_tmp, borrow_lo_bit); + let bw2 = _mm256_cmpgt_epi64( + _mm256_xor_si256(borrow_lo_bit, sign), + _mm256_xor_si256(red_hi_tmp, sign), + ); + let borrow = _mm256_or_si256(bw1, bw2); + + // use_reduced = carry_128 | !borrow + let not_borrow = _mm256_xor_si256(borrow, _mm256_set1_epi64x(-1)); + let use_reduced = _mm256_or_si256(carry_128, not_borrow); + let out_lo = _mm256_blendv_epi8(sum_lo, red_lo, use_reduced); + let out_hi = _mm256_blendv_epi8(sum_hi, red_hi, use_reduced); + + let mut result = Self::default(); + _mm256_storeu_si256(result.lo.as_mut_ptr().cast(), out_lo); + _mm256_storeu_si256(result.hi.as_mut_ptr().cast(), out_hi); + result + } + } +} + +impl Sub for PackedFp128Avx2

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a_lo = _mm256_loadu_si256(self.lo.as_ptr().cast()); + let a_hi = _mm256_loadu_si256(self.hi.as_ptr().cast()); + let b_lo = _mm256_loadu_si256(rhs.lo.as_ptr().cast()); + let b_hi = _mm256_loadu_si256(rhs.hi.as_ptr().cast()); + let p_lo = _mm256_set1_epi64x(Self::P_LO as i64); + let p_hi = _mm256_set1_epi64x(Self::P_HI as i64); + let sign = _mm256_set1_epi64x(i64::MIN); + let one = _mm256_set1_epi64x(1); + + // 128-bit sub + let diff_lo = _mm256_sub_epi64(a_lo, b_lo); + let borrow_lo = + _mm256_cmpgt_epi64(_mm256_xor_si256(b_lo, sign), _mm256_xor_si256(a_lo, sign)); + let borrow_lo_bit = _mm256_and_si256(borrow_lo, one); + + let hi_tmp = _mm256_sub_epi64(a_hi, b_hi); + let bw1 = + _mm256_cmpgt_epi64(_mm256_xor_si256(b_hi, sign), _mm256_xor_si256(a_hi, sign)); + let diff_hi = _mm256_sub_epi64(hi_tmp, borrow_lo_bit); + let bw2 = _mm256_cmpgt_epi64( + _mm256_xor_si256(borrow_lo_bit, sign), + _mm256_xor_si256(hi_tmp, sign), + ); + let borrow_128 = _mm256_or_si256(bw1, bw2); + + // Correction: add P back where underflow occurred + let corr_lo = _mm256_add_epi64(diff_lo, p_lo); + let carry_lo = _mm256_cmpgt_epi64( + _mm256_xor_si256(diff_lo, sign), + _mm256_xor_si256(corr_lo, sign), + ); + let carry_lo_bit = _mm256_and_si256(carry_lo, one); + let corr_hi = _mm256_add_epi64(diff_hi, p_hi); + let corr_hi = _mm256_add_epi64(corr_hi, carry_lo_bit); + + let out_lo = _mm256_blendv_epi8(diff_lo, corr_lo, borrow_128); + let out_hi = _mm256_blendv_epi8(diff_hi, corr_hi, borrow_128); + + let mut result = Self::default(); + _mm256_storeu_si256(result.lo.as_mut_ptr().cast(), out_lo); + _mm256_storeu_si256(result.hi.as_mut_ptr().cast(), out_hi); + result + } + } +} + +impl Mul for PackedFp128Avx2

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let mut out = Self::default(); + for i in 0..FP128_WIDTH { + let a = Fp128::

([self.lo[i], self.hi[i]]); + let b = Fp128::

([rhs.lo[i], rhs.hi[i]]); + let r = a * b; + out.lo[i] = r.0[0]; + out.hi[i] = r.0[1]; + } + out + } +} + +impl PackedValue for PackedFp128Avx2

{ + type Value = Fp128

; + const WIDTH: usize = FP128_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + let mut lo = [0u64; FP128_WIDTH]; + let mut hi = [0u64; FP128_WIDTH]; + for i in 0..FP128_WIDTH { + let v = f(i); + lo[i] = v.0[0]; + hi[i] = v.0[1]; + } + Self { lo, hi } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP128_WIDTH); + Fp128([self.lo[lane], self.hi[lane]]) + } +} + +impl AddAssign for PackedFp128Avx2

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp128Avx2

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp128Avx2

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp128Avx2

{ + type Scalar = Fp128

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self { + lo: [value.0[0]; FP128_WIDTH], + hi: [value.0[1]; FP128_WIDTH], + } + } +} diff --git a/src/algebra/fields/packed_avx512.rs b/src/algebra/fields/packed_avx512.rs new file mode 100644 index 00000000..343c1dfc --- /dev/null +++ b/src/algebra/fields/packed_avx512.rs @@ -0,0 +1,729 @@ +//! AVX-512 packed backends for Fp32, Fp64, Fp128. +//! +//! Requires AVX-512F + AVX-512DQ. Uses native unsigned comparisons and mask +//! registers for branchless conditionals. + +use super::packed::{PackedField, PackedValue}; +use crate::algebra::fields::{Fp128, Fp32, Fp64}; +use crate::FieldCore; +use core::arch::x86_64::*; +use core::fmt; +use core::mem::transmute; +use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; + +#[inline(always)] +unsafe fn movehdup_epi32_512(x: __m512i) -> __m512i { + _mm512_castps_si512(_mm512_movehdup_ps(_mm512_castsi512_ps(x))) +} + +#[inline(always)] +unsafe fn moveldup_epi32_512(x: __m512i) -> __m512i { + _mm512_castps_si512(_mm512_moveldup_ps(_mm512_castsi512_ps(x))) +} + +/// 64×64→128 schoolbook multiply using 32×32→64 partial products. +/// Returns (hi, lo) representing the 128-bit product. +/// Adapted from plonky3's Goldilocks AVX-512 backend. +#[inline] +unsafe fn mul64_64_512(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + let x_hi = movehdup_epi32_512(x); + let y_hi = movehdup_epi32_512(y); + + let mul_ll = _mm512_mul_epu32(x, y); + let mul_lh = _mm512_mul_epu32(x, y_hi); + let mul_hl = _mm512_mul_epu32(x_hi, y); + let mul_hh = _mm512_mul_epu32(x_hi, y_hi); + + let mul_ll_hi = _mm512_srli_epi64::<32>(mul_ll); + let t0 = _mm512_add_epi64(mul_hl, mul_ll_hi); + let mask32 = _mm512_set1_epi64(0xFFFF_FFFF_i64); + let t0_lo = _mm512_and_si512(t0, mask32); + let t0_hi = _mm512_srli_epi64::<32>(t0); + let t1 = _mm512_add_epi64(mul_lh, t0_lo); + let t2 = _mm512_add_epi64(mul_hh, t0_hi); + let t1_hi = _mm512_srli_epi64::<32>(t1); + let res_hi = _mm512_add_epi64(t2, t1_hi); + + let t1_lo = moveldup_epi32_512(t1); + let res_lo = _mm512_mask_blend_epi32(0b0101_0101_0101_0101, t1_lo, mul_ll); + + (res_hi, res_lo) +} + +/// Number of `Fp32` lanes in an AVX-512 packed vector. +pub const FP32_WIDTH: usize = 16; + +/// AVX-512 packed arithmetic for `Fp32

`, processing 16 lanes. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PackedFp32Avx512(pub [Fp32

; FP32_WIDTH]); + +impl PackedFp32Avx512

{ + const BITS: u32 = 32 - P.leading_zeros(); + + const C: u32 = { + let c = if Self::BITS == 32 { + 0u32.wrapping_sub(P) + } else { + (1u32 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u64) * (c as u64 + 1) < P as u64, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + const MASK_U64: u64 = if Self::BITS == 32 { + u32::MAX as u64 + } else { + (1u64 << Self::BITS) - 1 + }; + + #[inline(always)] + fn to_vec(self) -> __m512i { + unsafe { transmute(self) } + } + + #[inline(always)] + unsafe fn from_vec(v: __m512i) -> Self { + unsafe { transmute(v) } + } +} + +impl Default for PackedFp32Avx512

{ + #[inline] + fn default() -> Self { + Self([Fp32(0); FP32_WIDTH]) + } +} + +impl fmt::Debug for PackedFp32Avx512

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp32Avx512").field(&self.0).finish() + } +} + +impl PartialEq for PackedFp32Avx512

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for PackedFp32Avx512

{} + +impl Add for PackedFp32Avx512

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm512_set1_epi32(P as i32); + + let result = if Self::BITS <= 31 { + let t = _mm512_add_epi32(a, b); + let u = _mm512_sub_epi32(t, p); + _mm512_min_epu32(t, u) + } else { + let c = _mm512_set1_epi32(Self::C as i32); + let t = _mm512_add_epi32(a, b); + // Step 1: correct overflow (2^32 ≡ C mod P) + let overflow = _mm512_cmplt_epu32_mask(t, a); + let t2 = _mm512_mask_add_epi32(t, overflow, t, c); + // Step 2: subtract P if t2 >= P + let geq_p = _mm512_cmpge_epu32_mask(t2, p); + _mm512_mask_sub_epi32(t2, geq_p, t2, p) + }; + + Self::from_vec(result) + } + } +} + +impl Sub for PackedFp32Avx512

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm512_set1_epi32(P as i32); + + let result = if Self::BITS <= 31 { + let t = _mm512_sub_epi32(a, b); + let u = _mm512_add_epi32(t, p); + _mm512_min_epu32(t, u) + } else { + let t = _mm512_sub_epi32(a, b); + let underflow = _mm512_cmplt_epu32_mask(a, b); + _mm512_mask_add_epi32(t, underflow, t, p) + }; + + Self::from_vec(result) + } + } +} + +impl Mul for PackedFp32Avx512

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + + let prod_evn = _mm512_mul_epu32(a, b); + let a_odd = movehdup_epi32_512(a); + let b_odd = movehdup_epi32_512(b); + let prod_odd = _mm512_mul_epu32(a_odd, b_odd); + + let mask = _mm512_set1_epi64(Self::MASK_U64 as i64); + let c_vec = _mm512_set1_epi64(Self::C as i64); + let shift = _mm_set_epi64x(0, Self::BITS as i64); + + // Fold 1 + let evn_lo = _mm512_and_si512(prod_evn, mask); + let evn_hi = _mm512_srl_epi64(prod_evn, shift); + let evn_f1 = _mm512_add_epi64(evn_lo, _mm512_mul_epu32(evn_hi, c_vec)); + + let odd_lo = _mm512_and_si512(prod_odd, mask); + let odd_hi = _mm512_srl_epi64(prod_odd, shift); + let odd_f1 = _mm512_add_epi64(odd_lo, _mm512_mul_epu32(odd_hi, c_vec)); + + // Fold 2 + let evn_f1_lo = _mm512_and_si512(evn_f1, mask); + let evn_f1_hi = _mm512_srl_epi64(evn_f1, shift); + let evn_f2 = _mm512_add_epi64(evn_f1_lo, _mm512_mul_epu32(evn_f1_hi, c_vec)); + + let odd_f1_lo = _mm512_and_si512(odd_f1, mask); + let odd_f1_hi = _mm512_srl_epi64(odd_f1, shift); + let odd_f2 = _mm512_add_epi64(odd_f1_lo, _mm512_mul_epu32(odd_f1_hi, c_vec)); + + // Recombine even/odd + let odd_shifted = _mm512_slli_epi64::<32>(odd_f2); + let combined = _mm512_mask_blend_epi32(0b1010101010101010, evn_f2, odd_shifted); + + // Conditional subtract P + let p_vec = _mm512_set1_epi32(P as i32); + let reduced = _mm512_sub_epi32(combined, p_vec); + Self::from_vec(_mm512_min_epu32(combined, reduced)) + } + } +} + +impl PackedValue for PackedFp32Avx512

{ + type Value = Fp32

; + const WIDTH: usize = FP32_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([ + f(0), + f(1), + f(2), + f(3), + f(4), + f(5), + f(6), + f(7), + f(8), + f(9), + f(10), + f(11), + f(12), + f(13), + f(14), + f(15), + ]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP32_WIDTH); + self.0[lane] + } +} + +impl AddAssign for PackedFp32Avx512

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp32Avx512

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp32Avx512

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp32Avx512

{ + type Scalar = Fp32

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value; FP32_WIDTH]) + } +} + +/// Number of `Fp64` lanes in an AVX-512 packed vector. +pub const FP64_WIDTH: usize = 8; + +/// AVX-512 packed arithmetic for `Fp64

`, processing 8 lanes. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PackedFp64Avx512(pub [Fp64

; FP64_WIDTH]); + +impl PackedFp64Avx512

{ + const BITS: u32 = 64 - P.leading_zeros(); + + const C_LO: u64 = { + let c = if Self::BITS == 64 { + 0u64.wrapping_sub(P) + } else { + (1u64 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + c + }; + + const MASK64: u64 = if Self::BITS < 64 { + (1u64 << Self::BITS) - 1 + } else { + u64::MAX + }; + + #[inline(always)] + fn to_vec(self) -> __m512i { + unsafe { transmute(self) } + } + + #[inline(always)] + unsafe fn from_vec(v: __m512i) -> Self { + unsafe { transmute(v) } + } + + /// Vectorized 128-bit Solinas reduction for p = 2^BITS - C. + /// Given (hi, lo) = 128-bit product, computes result ≡ (hi*2^64 + lo) mod p. + #[inline] + unsafe fn reduce128_vec(hi: __m512i, lo: __m512i) -> __m512i { + if Self::BITS < 64 { + Self::reduce128_small_k(hi, lo) + } else { + Self::reduce128_full_k(hi, lo) + } + } + + /// Reduction for BITS < 64 (e.g. 40-bit prime). No overflow issues: all + /// intermediates fit in u64. + #[inline] + unsafe fn reduce128_small_k(hi: __m512i, lo: __m512i) -> __m512i { + let mask_k = _mm512_set1_epi64(Self::MASK64 as i64); + let c_vec = _mm512_set1_epi64(Self::C_LO as i64); + let p_vec = _mm512_set1_epi64(P as i64); + let shift_k = _mm_set_epi64x(0, Self::BITS as i64); + let shift_64mk = _mm_set_epi64x(0, (64 - Self::BITS) as i64); + + let lo_k = _mm512_and_si512(lo, mask_k); + let lo_upper = _mm512_srl_epi64(lo, shift_k); + let hi_shifted = _mm512_sll_epi64(hi, shift_64mk); + let hi_k = _mm512_or_si512(lo_upper, hi_shifted); + + // c * hi_k: hi_k may exceed 32 bits, split into lo32 and top + let c_hi_lo = _mm512_mul_epu32(c_vec, hi_k); + let hi_k_top = _mm512_srli_epi64::<32>(hi_k); + let c_hi_top = _mm512_mul_epu32(c_vec, hi_k_top); + let c_hi_top_shifted = _mm512_slli_epi64::<32>(c_hi_top); + let c_hi_full = _mm512_add_epi64(c_hi_lo, c_hi_top_shifted); + + let fold1 = _mm512_add_epi64(lo_k, c_hi_full); + + let fold1_lo_k = _mm512_and_si512(fold1, mask_k); + let fold1_hi = _mm512_srl_epi64(fold1, shift_k); + let c_fold1_hi = _mm512_mul_epu32(c_vec, fold1_hi); + let fold2 = _mm512_add_epi64(fold1_lo_k, c_fold1_hi); + + let reduced = _mm512_sub_epi64(fold2, p_vec); + _mm512_min_epu64(fold2, reduced) + } + + /// Reduction for BITS == 64 (e.g. p = 2^64 - 87). Tracks overflow from + /// c*hi exceeding 64 bits, using native unsigned comparisons. + #[inline] + unsafe fn reduce128_full_k(hi: __m512i, lo: __m512i) -> __m512i { + let c_vec = _mm512_set1_epi64(Self::C_LO as i64); + let p_vec = _mm512_set1_epi64(P as i64); + let one = _mm512_set1_epi64(1); + + // c * hi_lo32 + let c_hi_lo = _mm512_mul_epu32(c_vec, hi); + // c * hi_hi32 + let hi_hi = _mm512_srli_epi64::<32>(hi); + let c_hi_hi = _mm512_mul_epu32(c_vec, hi_hi); + + let c_hi_hi_lo32 = _mm512_slli_epi64::<32>(c_hi_hi); + let c_hi_carry = _mm512_srli_epi64::<32>(c_hi_hi); + + // Lower 64 bits of c * hi + let sum_lo = _mm512_add_epi64(c_hi_lo, c_hi_hi_lo32); + let carry0 = _mm512_cmplt_epu64_mask(sum_lo, c_hi_lo); + let overflow = _mm512_mask_add_epi64(c_hi_carry, carry0, c_hi_carry, one); + + // lo + sum_lo + let s = _mm512_add_epi64(lo, sum_lo); + let carry1 = _mm512_cmplt_epu64_mask(s, lo); + let total_overflow = _mm512_mask_add_epi64(overflow, carry1, overflow, one); + + // Fold overflow: total_overflow * c (at most ~2^15) + let final_corr = _mm512_mul_epu32(c_vec, total_overflow); + let result = _mm512_add_epi64(s, final_corr); + let carry_f = _mm512_cmplt_epu64_mask(result, s); + let result = _mm512_mask_add_epi64(result, carry_f, result, c_vec); + + let ge_mask = _mm512_cmpge_epu64_mask(result, p_vec); + _mm512_mask_sub_epi64(result, ge_mask, result, p_vec) + } +} + +impl Default for PackedFp64Avx512

{ + #[inline] + fn default() -> Self { + Self([Fp64(0); FP64_WIDTH]) + } +} + +impl fmt::Debug for PackedFp64Avx512

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp64Avx512").field(&self.0).finish() + } +} + +impl PartialEq for PackedFp64Avx512

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for PackedFp64Avx512

{} + +impl Add for PackedFp64Avx512

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm512_set1_epi64(P as i64); + + let result = if Self::BITS <= 62 { + let s = _mm512_add_epi64(a, b); + let geq_p = _mm512_cmpge_epu64_mask(s, p); + _mm512_mask_sub_epi64(s, geq_p, s, p) + } else { + let s = _mm512_add_epi64(a, b); + let overflow = _mm512_cmplt_epu64_mask(s, a); + let c = _mm512_set1_epi64(Self::C_LO as i64); + let geq_p = _mm512_cmpge_epu64_mask(s, p); + let no_of = _mm512_mask_sub_epi64(s, geq_p, s, p); + let s_plus_c = _mm512_add_epi64(s, c); + _mm512_mask_blend_epi64(overflow, no_of, s_plus_c) + }; + + Self::from_vec(result) + } + } +} + +impl Sub for PackedFp64Avx512

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a = self.to_vec(); + let b = rhs.to_vec(); + let p = _mm512_set1_epi64(P as i64); + let d = _mm512_sub_epi64(a, b); + let underflow = _mm512_cmplt_epu64_mask(a, b); + Self::from_vec(_mm512_mask_add_epi64(d, underflow, d, p)) + } + } +} + +impl Mul for PackedFp64Avx512

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + unsafe { + let (hi, lo) = mul64_64_512(self.to_vec(), rhs.to_vec()); + Self::from_vec(Self::reduce128_vec(hi, lo)) + } + } +} + +impl PackedValue for PackedFp64Avx512

{ + type Value = Fp64

; + const WIDTH: usize = FP64_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self([f(0), f(1), f(2), f(3), f(4), f(5), f(6), f(7)]) + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP64_WIDTH); + self.0[lane] + } +} + +impl AddAssign for PackedFp64Avx512

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp64Avx512

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp64Avx512

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp64Avx512

{ + type Scalar = Fp64

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self([value; FP64_WIDTH]) + } +} + +/// Number of `Fp128` lanes in an AVX-512 packed vector. +pub const FP128_WIDTH: usize = 8; + +/// AVX-512 packed arithmetic for `Fp128

`, 8 lanes in SoA layout. +/// +/// Stores 8 elements as separate `lo` and `hi` `u64` arrays, enabling +/// vectorized add/sub via `__m512i`. Mul remains scalar per-lane. +#[derive(Clone, Copy)] +pub struct PackedFp128Avx512 { + lo: [u64; FP128_WIDTH], + hi: [u64; FP128_WIDTH], +} + +impl PackedFp128Avx512

{ + const P_LO: u64 = P as u64; + const P_HI: u64 = (P >> 64) as u64; +} + +impl Default for PackedFp128Avx512

{ + #[inline] + fn default() -> Self { + Self { + lo: [0; FP128_WIDTH], + hi: [0; FP128_WIDTH], + } + } +} + +impl fmt::Debug for PackedFp128Avx512

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let elems: Vec<_> = (0..FP128_WIDTH).map(|i| self.extract(i)).collect(); + f.debug_tuple("PackedFp128Avx512").field(&elems).finish() + } +} + +impl PartialEq for PackedFp128Avx512

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.lo == other.lo && self.hi == other.hi + } +} + +impl Eq for PackedFp128Avx512

{} + +impl Add for PackedFp128Avx512

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + let a_lo = _mm512_loadu_si512(self.lo.as_ptr().cast()); + let a_hi = _mm512_loadu_si512(self.hi.as_ptr().cast()); + let b_lo = _mm512_loadu_si512(rhs.lo.as_ptr().cast()); + let b_hi = _mm512_loadu_si512(rhs.hi.as_ptr().cast()); + let p_lo = _mm512_set1_epi64(Self::P_LO as i64); + let p_hi = _mm512_set1_epi64(Self::P_HI as i64); + let one = _mm512_set1_epi64(1); + + // 128-bit add: (sum_hi, sum_lo) = (a_hi, a_lo) + (b_hi, b_lo) + let sum_lo = _mm512_add_epi64(a_lo, b_lo); + let carry_lo = _mm512_cmplt_epu64_mask(sum_lo, a_lo); + let hi_tmp = _mm512_add_epi64(a_hi, b_hi); + let ov1 = _mm512_cmplt_epu64_mask(hi_tmp, a_hi); + let sum_hi = _mm512_mask_add_epi64(hi_tmp, carry_lo, hi_tmp, one); + let ov2 = _mm512_cmplt_epu64_mask(sum_hi, hi_tmp); + let carry_128 = ov1 | ov2; + + // 128-bit subtract P: (red_hi, red_lo) = (sum_hi, sum_lo) - P + let red_lo = _mm512_sub_epi64(sum_lo, p_lo); + let borrow_lo = _mm512_cmplt_epu64_mask(sum_lo, p_lo); + let red_hi_tmp = _mm512_sub_epi64(sum_hi, p_hi); + let bw1 = _mm512_cmplt_epu64_mask(sum_hi, p_hi); + let red_hi = _mm512_mask_sub_epi64(red_hi_tmp, borrow_lo, red_hi_tmp, one); + let bw2 = _mm512_cmplt_epu64_mask(red_hi_tmp, _mm512_maskz_mov_epi64(borrow_lo, one)); + let borrow = bw1 | bw2; + + // Use reduced if: overflow happened OR subtraction didn't borrow + let use_reduced = carry_128 | !borrow; + let out_lo = _mm512_mask_blend_epi64(use_reduced, sum_lo, red_lo); + let out_hi = _mm512_mask_blend_epi64(use_reduced, sum_hi, red_hi); + + let mut result = Self::default(); + _mm512_storeu_si512(result.lo.as_mut_ptr().cast(), out_lo); + _mm512_storeu_si512(result.hi.as_mut_ptr().cast(), out_hi); + result + } + } +} + +impl Sub for PackedFp128Avx512

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + let a_lo = _mm512_loadu_si512(self.lo.as_ptr().cast()); + let a_hi = _mm512_loadu_si512(self.hi.as_ptr().cast()); + let b_lo = _mm512_loadu_si512(rhs.lo.as_ptr().cast()); + let b_hi = _mm512_loadu_si512(rhs.hi.as_ptr().cast()); + let p_lo = _mm512_set1_epi64(Self::P_LO as i64); + let p_hi = _mm512_set1_epi64(Self::P_HI as i64); + let one = _mm512_set1_epi64(1); + + // 128-bit sub: (diff_hi, diff_lo) = (a_hi, a_lo) - (b_hi, b_lo) + let diff_lo = _mm512_sub_epi64(a_lo, b_lo); + let borrow_lo = _mm512_cmplt_epu64_mask(a_lo, b_lo); + let hi_tmp = _mm512_sub_epi64(a_hi, b_hi); + let bw1 = _mm512_cmplt_epu64_mask(a_hi, b_hi); + let diff_hi = _mm512_mask_sub_epi64(hi_tmp, borrow_lo, hi_tmp, one); + let bw2 = _mm512_cmplt_epu64_mask(hi_tmp, _mm512_maskz_mov_epi64(borrow_lo, one)); + let borrow_128 = bw1 | bw2; + + // Correction: add P back where underflow occurred + let corr_lo = _mm512_add_epi64(diff_lo, p_lo); + let carry_lo = _mm512_cmplt_epu64_mask(corr_lo, diff_lo); + let corr_hi = _mm512_add_epi64(diff_hi, p_hi); + let corr_hi = _mm512_mask_add_epi64(corr_hi, carry_lo, corr_hi, one); + + let out_lo = _mm512_mask_blend_epi64(borrow_128, diff_lo, corr_lo); + let out_hi = _mm512_mask_blend_epi64(borrow_128, diff_hi, corr_hi); + + let mut result = Self::default(); + _mm512_storeu_si512(result.lo.as_mut_ptr().cast(), out_lo); + _mm512_storeu_si512(result.hi.as_mut_ptr().cast(), out_hi); + result + } + } +} + +impl Mul for PackedFp128Avx512

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let mut out = Self::default(); + for i in 0..FP128_WIDTH { + let a = Fp128::

([self.lo[i], self.hi[i]]); + let b = Fp128::

([rhs.lo[i], rhs.hi[i]]); + let r = a * b; + out.lo[i] = r.0[0]; + out.hi[i] = r.0[1]; + } + out + } +} + +impl PackedValue for PackedFp128Avx512

{ + type Value = Fp128

; + const WIDTH: usize = FP128_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + let mut lo = [0u64; FP128_WIDTH]; + let mut hi = [0u64; FP128_WIDTH]; + for i in 0..FP128_WIDTH { + let v = f(i); + lo[i] = v.0[0]; + hi[i] = v.0[1]; + } + Self { lo, hi } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP128_WIDTH); + Fp128([self.lo[lane], self.hi[lane]]) + } +} + +impl AddAssign for PackedFp128Avx512

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp128Avx512

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp128Avx512

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp128Avx512

{ + type Scalar = Fp128

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self { + lo: [value.0[0]; FP128_WIDTH], + hi: [value.0[1]; FP128_WIDTH], + } + } +} diff --git a/src/algebra/fields/packed_ext.rs b/src/algebra/fields/packed_ext.rs new file mode 100644 index 00000000..9725b598 --- /dev/null +++ b/src/algebra/fields/packed_ext.rs @@ -0,0 +1,421 @@ +//! Packed extension field types using transpose-based packing. +//! +//! A `PackedFp2` stores `[PF; 2]` where `PF` is the packed base field. +//! Each `PF` lane contains the corresponding coefficient of an `Fp2` element. +//! This enables WIDTH-fold parallel arithmetic over `Fp2` using existing SIMD +//! base-field operations. + +use crate::algebra::fields::ext::{Fp2, Fp2Config, Fp4, Fp4Config}; +use crate::algebra::fields::packed::{HasPacking, PackedField, PackedValue}; +use crate::primitives::serialization::Valid; +use crate::FieldCore; +use core::ops::{Add, Mul, Sub}; + +/// Packed `Fp2` elements stored in transpose layout: `[PF; 2]`. +/// +/// If `PF` has width `W`, this represents `W` parallel `Fp2` values. +pub struct PackedFp2, PF: PackedField> { + /// Degree-0 coefficient (packed across SIMD lanes). + pub c0: PF, + /// Degree-1 coefficient (packed across SIMD lanes). + pub c1: PF, + _marker: std::marker::PhantomData (F, C)>, +} + +impl, PF: PackedField> Clone for PackedFp2 { + fn clone(&self) -> Self { + *self + } +} + +impl, PF: PackedField> Copy for PackedFp2 {} + +impl, PF: PackedField> std::fmt::Debug + for PackedFp2 +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PackedFp2").finish_non_exhaustive() + } +} + +impl, PF: PackedField> PackedFp2 { + /// Create a `PackedFp2` from its two packed coefficients. + #[inline] + pub fn new(c0: PF, c1: PF) -> Self { + Self { + c0, + c1, + _marker: std::marker::PhantomData, + } + } + + #[inline] + fn mul_nr(x: PF) -> PF { + if C::IS_NEG_ONE { + let zero = PF::broadcast(F::zero()); + zero - x + } else { + PF::broadcast(C::non_residue()) * x + } + } +} + +impl PackedValue for PackedFp2 +where + F: FieldCore + Valid + 'static, + C: Fp2Config + 'static, + PF: PackedField, +{ + type Value = Fp2; + const WIDTH: usize = PF::WIDTH; + + fn from_fn(mut f: G) -> Self + where + G: FnMut(usize) -> Self::Value, + { + let mut c0s = Vec::with_capacity(PF::WIDTH); + let mut c1s = Vec::with_capacity(PF::WIDTH); + for i in 0..PF::WIDTH { + let val = f(i); + c0s.push(val.c0); + c1s.push(val.c1); + } + Self::new(PF::from_fn(|i| c0s[i]), PF::from_fn(|i| c1s[i])) + } + + fn extract(&self, lane: usize) -> Self::Value { + Fp2::new(self.c0.extract(lane), self.c1.extract(lane)) + } +} + +impl Add for PackedFp2 +where + F: FieldCore, + C: Fp2Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self::new(self.c0 + rhs.c0, self.c1 + rhs.c1) + } +} + +impl Sub for PackedFp2 +where + F: FieldCore, + C: Fp2Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::new(self.c0 - rhs.c0, self.c1 - rhs.c1) + } +} + +impl Mul for PackedFp2 +where + F: FieldCore, + C: Fp2Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let v0 = self.c0 * rhs.c0; + let v1 = self.c1 * rhs.c1; + Self::new( + v0 + Self::mul_nr(v1), + (self.c0 + self.c1) * (rhs.c0 + rhs.c1) - v0 - v1, + ) + } +} + +impl PackedField for PackedFp2 +where + F: FieldCore + Valid + 'static, + C: Fp2Config + 'static, + PF: PackedField, +{ + type Scalar = Fp2; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self::new(PF::broadcast(value.c0), PF::broadcast(value.c1)) + } +} + +impl HasPacking for Fp2 +where + F: FieldCore + Valid + HasPacking + 'static, + C: Fp2Config + 'static, +{ + type Packing = PackedFp2; +} + +/// Packed `Fp4` elements stored in transpose layout: `[PackedFp2; 2]`. +pub struct PackedFp4< + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +> { + /// Low half (`Fp2` coefficient 0). + pub c0: PackedFp2, + /// High half (`Fp2` coefficient 1). + pub c1: PackedFp2, + _marker: std::marker::PhantomData C4>, +} + +impl Clone for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + fn clone(&self) -> Self { + *self + } +} + +impl Copy for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ +} + +impl std::fmt::Debug for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PackedFp4").finish_non_exhaustive() + } +} + +impl PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + /// Create a `PackedFp4` from its two `PackedFp2` halves. + #[inline] + pub fn new(c0: PackedFp2, c1: PackedFp2) -> Self { + Self { + c0, + c1, + _marker: std::marker::PhantomData, + } + } +} + +impl PackedValue for PackedFp4 +where + F: FieldCore + Valid + 'static, + C2: Fp2Config + 'static, + C4: Fp4Config + 'static, + PF: PackedField, +{ + type Value = Fp4; + const WIDTH: usize = PF::WIDTH; + + fn from_fn(mut f: G) -> Self + where + G: FnMut(usize) -> Self::Value, + { + let mut c0s: Vec> = Vec::with_capacity(PF::WIDTH); + let mut c1s: Vec> = Vec::with_capacity(PF::WIDTH); + for i in 0..PF::WIDTH { + let val = f(i); + c0s.push(val.c0); + c1s.push(val.c1); + } + Self::new( + PackedFp2::from_fn(|i| c0s[i]), + PackedFp2::from_fn(|i| c1s[i]), + ) + } + + fn extract(&self, lane: usize) -> Self::Value { + Fp4::new(self.c0.extract(lane), self.c1.extract(lane)) + } +} + +impl Add for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self::new(self.c0 + rhs.c0, self.c1 + rhs.c1) + } +} + +impl Sub for PackedFp4 +where + F: FieldCore, + C2: Fp2Config, + C4: Fp4Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::new(self.c0 - rhs.c0, self.c1 - rhs.c1) + } +} + +impl Mul for PackedFp4 +where + F: FieldCore + Valid + 'static, + C2: Fp2Config + 'static, + C4: Fp4Config, + PF: PackedField, +{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let nr2 = PackedFp2::broadcast(C4::non_residue()); + let v0 = self.c0 * rhs.c0; + let v1 = self.c1 * rhs.c1; + Self::new( + v0 + nr2 * v1, + (self.c0 + self.c1) * (rhs.c0 + rhs.c1) - v0 - v1, + ) + } +} + +impl PackedField for PackedFp4 +where + F: FieldCore + Valid + 'static, + C2: Fp2Config + 'static, + C4: Fp4Config + 'static, + PF: PackedField, +{ + type Scalar = Fp4; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self::new( + PackedFp2::broadcast(value.c0), + PackedFp2::broadcast(value.c1), + ) + } +} + +impl HasPacking for Fp4 +where + F: FieldCore + Valid + HasPacking + 'static, + C2: Fp2Config + 'static, + C4: Fp4Config + 'static, +{ + type Packing = PackedFp4; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::ext::{Ext2, Ext4, TwoNr, UnitNr}; + use crate::algebra::Fp64; + use crate::{FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp64<4294967197>; + type E2 = Ext2; + type E4 = Ext4; + type PE2 = PackedFp2::Packing>; + type PE4 = PackedFp4::Packing>; + + #[test] + fn packed_fp2_add() { + let mut rng = StdRng::seed_from_u64(100); + let width = ::WIDTH; + let a_elems: Vec = (0..width).map(|_| E2::sample(&mut rng)).collect(); + let b_elems: Vec = (0..width).map(|_| E2::sample(&mut rng)).collect(); + + let pa = PE2::from_fn(|i| a_elems[i]); + let pb = PE2::from_fn(|i| b_elems[i]); + let pc = pa + pb; + + for i in 0..width { + assert_eq!(pc.extract(i), a_elems[i] + b_elems[i]); + } + } + + #[test] + fn packed_fp2_mul() { + let mut rng = StdRng::seed_from_u64(200); + let width = ::WIDTH; + let a_elems: Vec = (0..width).map(|_| E2::sample(&mut rng)).collect(); + let b_elems: Vec = (0..width).map(|_| E2::sample(&mut rng)).collect(); + + let pa = PE2::from_fn(|i| a_elems[i]); + let pb = PE2::from_fn(|i| b_elems[i]); + let pc = pa * pb; + + for i in 0..width { + assert_eq!( + pc.extract(i), + a_elems[i] * b_elems[i], + "packed Fp2 mul mismatch at lane {i}" + ); + } + } + + #[test] + fn packed_fp2_broadcast() { + let val = E2::new(F::from_u64(7), F::from_u64(11)); + let packed = PE2::broadcast(val); + let width = ::WIDTH; + for i in 0..width { + assert_eq!(packed.extract(i), val); + } + } + + #[test] + fn packed_fp4_mul() { + let mut rng = StdRng::seed_from_u64(300); + let width = ::WIDTH; + let a_elems: Vec = (0..width).map(|_| E4::sample(&mut rng)).collect(); + let b_elems: Vec = (0..width).map(|_| E4::sample(&mut rng)).collect(); + + let pa = PE4::from_fn(|i| a_elems[i]); + let pb = PE4::from_fn(|i| b_elems[i]); + let pc = pa * pb; + + for i in 0..width { + assert_eq!( + pc.extract(i), + a_elems[i] * b_elems[i], + "packed Fp4 mul mismatch at lane {i}" + ); + } + } + + #[test] + fn pack_unpack_roundtrip_fp2() { + let mut rng = StdRng::seed_from_u64(400); + let width = ::WIDTH; + let elems: Vec = (0..width * 3).map(|_| E2::sample(&mut rng)).collect(); + + let packed = PE2::pack_slice(&elems); + let unpacked = PE2::unpack_slice(&packed); + + assert_eq!(elems, unpacked); + } +} diff --git a/src/algebra/fields/packed_neon.rs b/src/algebra/fields/packed_neon.rs new file mode 100644 index 00000000..e76c1ed2 --- /dev/null +++ b/src/algebra/fields/packed_neon.rs @@ -0,0 +1,781 @@ +//! AArch64 NEON packed backends for Fp32, Fp64, Fp128. + +use super::packed::{PackedField, PackedValue}; +use crate::algebra::fields::{Fp128, Fp32, Fp64}; +use crate::FieldCore; +use core::arch::aarch64::{ + uint32x2_t, uint32x4_t, uint64x2_t, vaddq_u32, vaddq_u64, vandq_u64, vbslq_u32, vbslq_u64, + vcgtq_u64, vcltq_u32, vcltq_u64, vcombine_u32, vdup_n_u32, vdupq_n_s64, vdupq_n_u32, + vdupq_n_u64, veorq_u64, vget_low_u32, vminq_u32, vmovn_u64, vmull_high_u32, vmull_u32, + vorrq_u64, vshlq_u64, vsubq_u32, vsubq_u64, +}; +use core::fmt; +use core::mem::transmute; +use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; + +/// Number of packed `Fp128` lanes in this backend. +pub const WIDTH: usize = 2; + +/// True SoA layout for two packed `Fp128` lanes. +/// +/// `lo = [lane0.lo, lane1.lo]` +/// `hi = [lane0.hi, lane1.hi]` +#[derive(Clone, Copy)] +pub struct PackedFp128Neon { + lo: [u64; 2], + hi: [u64; 2], +} + +#[inline(always)] +fn to_vec(x: [u64; 2]) -> uint64x2_t { + unsafe { transmute::<[u64; 2], uint64x2_t>(x) } +} + +#[inline(always)] +fn from_vec(v: uint64x2_t) -> [u64; 2] { + unsafe { transmute::(v) } +} + +#[inline(always)] +fn mask_to_bit(mask: uint64x2_t) -> uint64x2_t { + // SAFETY: NEON intrinsics are available under this cfg. + unsafe { vandq_u64(mask, vdupq_n_u64(1)) } +} + +#[inline(always)] +const fn modulus_lo() -> u64 { + P as u64 +} + +#[inline(always)] +const fn modulus_hi() -> u64 { + (P >> 64) as u64 +} + +use super::util::{is_pow2_u64, log2_pow2_u64}; + +impl PackedFp128Neon

{ + const C: u128 = { + let c = 0u128.wrapping_sub(P); + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!(c < (1u128 << 64), "P must be 2^128 - c with c < 2^64"); + assert!( + c * (c + 1) < P, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + const C_LO: u64 = Self::C as u64; + const C_SHIFT_KIND: i8 = { + let c = Self::C_LO; + if c > 1 && is_pow2_u64(c - 1) { + 1 + } else if c == u64::MAX || is_pow2_u64(c + 1) { + -1 + } else { + 0 + } + }; + const C_SHIFT: u32 = { + let c = Self::C_LO; + if Self::C_SHIFT_KIND == 1 { + log2_pow2_u64(c - 1) + } else if Self::C_SHIFT_KIND == -1 { + if c == u64::MAX { + 64 + } else { + log2_pow2_u64(c + 1) + } + } else { + 0 + } + }; + + #[inline(always)] + fn mul_wide_u64(a: u64, b: u64) -> (u64, u64) { + let prod = (a as u128) * (b as u128); + (prod as u64, (prod >> 64) as u64) + } + + #[inline(always)] + fn mul_c_wide(x: u64) -> (u64, u64) { + if Self::C_SHIFT_KIND == 1 { + let v = ((x as u128) << Self::C_SHIFT) + x as u128; + (v as u64, (v >> 64) as u64) + } else if Self::C_SHIFT_KIND == -1 { + let v = ((x as u128) << Self::C_SHIFT) - x as u128; + (v as u64, (v >> 64) as u64) + } else { + Self::mul_wide_u64(Self::C_LO, x) + } + } + + #[inline(always)] + fn fold2_canonicalize(t0: u64, t1: u64, t2: u64) -> (u64, u64) { + let (ct2_lo, ct2_hi) = Self::mul_c_wide(t2); + + let (s0, carry0) = t0.overflowing_add(ct2_lo); + let (s1a, carry1a) = t1.overflowing_add(ct2_hi); + let (s1, carry1b) = s1a.overflowing_add(carry0 as u64); + let overflow = carry1a | carry1b; + + let (r0, carry2) = s0.overflowing_add(Self::C_LO); + let (r1, carry3) = s1.overflowing_add(carry2 as u64); + + if overflow | carry3 { + (r0, r1) + } else { + (s0, s1) + } + } + + #[inline(always)] + fn mul_raw_lane(a0: u64, a1: u64, b0: u64, b1: u64) -> (u64, u64) { + let (p00_lo, p00_hi) = Self::mul_wide_u64(a0, b0); + let (p01_lo, p01_hi) = Self::mul_wide_u64(a0, b1); + let (p10_lo, p10_hi) = Self::mul_wide_u64(a1, b0); + let (p11_lo, p11_hi) = Self::mul_wide_u64(a1, b1); + + let row1 = p00_hi as u128 + p01_lo as u128 + p10_lo as u128; + let r0 = p00_lo; + let r1 = row1 as u64; + let carry1 = (row1 >> 64) as u64; + + let row2 = p01_hi as u128 + p10_hi as u128 + p11_lo as u128 + carry1 as u128; + let r2 = row2 as u64; + let carry2 = (row2 >> 64) as u64; + + let row3 = p11_hi as u128 + carry2 as u128; + let r3 = row3 as u64; + debug_assert_eq!(row3 >> 64, 0); + + let (cr2_lo, cr2_hi) = Self::mul_c_wide(r2); + let (cr3_lo, cr3_hi) = Self::mul_c_wide(r3); + + let t0_sum = r0 as u128 + cr2_lo as u128; + let t0 = t0_sum as u64; + let carryf = (t0_sum >> 64) as u64; + + let t1_sum = r1 as u128 + cr2_hi as u128 + cr3_lo as u128 + carryf as u128; + let t1 = t1_sum as u64; + + let t2_sum = cr3_hi as u128 + (t1_sum >> 64); + let t2 = t2_sum as u64; + debug_assert_eq!(t2_sum >> 64, 0); + + Self::fold2_canonicalize(t0, t1, t2) + } +} + +impl Default for PackedFp128Neon

{ + #[inline] + fn default() -> Self { + Self::broadcast(Fp128::zero()) + } +} + +impl fmt::Debug for PackedFp128Neon

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp128Neon") + .field(&[self.extract(0), self.extract(1)]) + .finish() + } +} + +impl PartialEq for PackedFp128Neon

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.extract(0) == other.extract(0) && self.extract(1) == other.extract(1) + } +} + +impl Eq for PackedFp128Neon

{} + +impl PackedValue for PackedFp128Neon

{ + type Value = Fp128

; + const WIDTH: usize = WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + let x0 = f(0); + let x1 = f(1); + Self { + lo: [x0.0[0], x1.0[0]], + hi: [x0.0[1], x1.0[1]], + } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < WIDTH); + Fp128([self.lo[lane], self.hi[lane]]) + } +} + +impl Add for PackedFp128Neon

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + let lo_a = to_vec(self.lo); + let hi_a = to_vec(self.hi); + let lo_b = to_vec(rhs.lo); + let hi_b = to_vec(rhs.hi); + + let (out_lo, out_hi) = unsafe { + let p_lo = vdupq_n_u64(modulus_lo::

()); + let p_hi = vdupq_n_u64(modulus_hi::

()); + + let sum_lo = vaddq_u64(lo_a, lo_b); + let carry_lo = mask_to_bit(vcltq_u64(sum_lo, lo_a)); + + let hi_tmp = vaddq_u64(hi_a, hi_b); + let carry_hi1 = vcltq_u64(hi_tmp, hi_a); + let sum_hi = vaddq_u64(hi_tmp, carry_lo); + let carry_hi2 = vcltq_u64(sum_hi, hi_tmp); + let carry_128 = vorrq_u64(carry_hi1, carry_hi2); + + let red_lo = vsubq_u64(sum_lo, p_lo); + let borrow_lo = mask_to_bit(vcgtq_u64(p_lo, sum_lo)); + + let red_hi_tmp = vsubq_u64(sum_hi, p_hi); + let borrow_hi1 = vcgtq_u64(p_hi, sum_hi); + let red_hi = vsubq_u64(red_hi_tmp, borrow_lo); + let borrow_hi2 = vcltq_u64(red_hi_tmp, borrow_lo); + let borrow = vorrq_u64(borrow_hi1, borrow_hi2); + + let not_borrow = veorq_u64(borrow, vdupq_n_u64(u64::MAX)); + let use_reduced = vorrq_u64(carry_128, not_borrow); + let out_lo = vbslq_u64(use_reduced, red_lo, sum_lo); + let out_hi = vbslq_u64(use_reduced, red_hi, sum_hi); + (out_lo, out_hi) + }; + + Self { + lo: from_vec(out_lo), + hi: from_vec(out_hi), + } + } +} + +impl Sub for PackedFp128Neon

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + let lo_a = to_vec(self.lo); + let hi_a = to_vec(self.hi); + let lo_b = to_vec(rhs.lo); + let hi_b = to_vec(rhs.hi); + + let (out_lo, out_hi) = unsafe { + let p_lo = vdupq_n_u64(modulus_lo::

()); + let p_hi = vdupq_n_u64(modulus_hi::

()); + + let diff_lo = vsubq_u64(lo_a, lo_b); + let borrow_lo = mask_to_bit(vcltq_u64(lo_a, lo_b)); + + let diff_hi_tmp = vsubq_u64(hi_a, hi_b); + let borrow_hi1 = vcltq_u64(hi_a, hi_b); + let diff_hi = vsubq_u64(diff_hi_tmp, borrow_lo); + let borrow_hi2 = vcltq_u64(diff_hi_tmp, borrow_lo); + let borrow_128 = vorrq_u64(borrow_hi1, borrow_hi2); + + let corr_lo = vaddq_u64(diff_lo, p_lo); + let carry_lo = mask_to_bit(vcltq_u64(corr_lo, diff_lo)); + + let corr_hi_tmp = vaddq_u64(diff_hi, p_hi); + let corr_hi = vaddq_u64(corr_hi_tmp, carry_lo); + + let out_lo = vbslq_u64(borrow_128, corr_lo, diff_lo); + let out_hi = vbslq_u64(borrow_128, corr_hi, diff_hi); + (out_lo, out_hi) + }; + + Self { + lo: from_vec(out_lo), + hi: from_vec(out_hi), + } + } +} + +impl Mul for PackedFp128Neon

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let (o0_lo, o0_hi) = Self::mul_raw_lane(self.lo[0], self.hi[0], rhs.lo[0], rhs.hi[0]); + let (o1_lo, o1_hi) = Self::mul_raw_lane(self.lo[1], self.hi[1], rhs.lo[1], rhs.hi[1]); + + Self { + lo: [o0_lo, o1_lo], + hi: [o0_hi, o1_hi], + } + } +} + +impl AddAssign for PackedFp128Neon

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp128Neon

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp128Neon

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp128Neon

{ + type Scalar = Fp128

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self::from_fn(|_| value) + } +} + +/// Number of packed `Fp32` lanes. +pub const FP32_WIDTH: usize = 4; + +/// NEON packed `Fp32` backend: 4 lanes in `uint32x4_t`. +#[derive(Clone, Copy)] +pub struct PackedFp32Neon { + vals: [u32; 4], +} + +#[inline(always)] +fn to_vec32(x: [u32; 4]) -> uint32x4_t { + unsafe { transmute::<[u32; 4], uint32x4_t>(x) } +} + +#[inline(always)] +fn from_vec32(v: uint32x4_t) -> [u32; 4] { + unsafe { transmute::(v) } +} + +impl PackedFp32Neon

{ + const BITS: u32 = 32 - P.leading_zeros(); + + const C: u32 = { + let c = if Self::BITS == 32 { + 0u32.wrapping_sub(P) + } else { + (1u32 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + assert!( + (c as u64) * (c as u64 + 1) < P as u64, + "C(C+1) < P required for fused canonicalize" + ); + c + }; + + const MASK_U64: u64 = if Self::BITS == 32 { + u32::MAX as u64 + } else { + (1u64 << Self::BITS) - 1 + }; + + #[inline(always)] + fn mul_c_u64(hi: uint64x2_t, c: uint32x2_t) -> uint64x2_t { + unsafe { + let hi_narrow = vmovn_u64(hi); + vmull_u32(hi_narrow, c) + } + } + + #[inline(always)] + fn solinas_reduce(prod_lo: uint64x2_t, prod_hi: uint64x2_t) -> uint32x4_t { + unsafe { + let mask = vdupq_n_u64(Self::MASK_U64); + let neg_bits = vdupq_n_s64(-(Self::BITS as i64)); + let c = vdup_n_u32(Self::C); + + let f1_lo = vaddq_u64( + vandq_u64(prod_lo, mask), + Self::mul_c_u64(vshlq_u64(prod_lo, neg_bits), c), + ); + let f1_hi = vaddq_u64( + vandq_u64(prod_hi, mask), + Self::mul_c_u64(vshlq_u64(prod_hi, neg_bits), c), + ); + + let f2_lo = vaddq_u64( + vandq_u64(f1_lo, mask), + Self::mul_c_u64(vshlq_u64(f1_lo, neg_bits), c), + ); + let f2_hi = vaddq_u64( + vandq_u64(f1_hi, mask), + Self::mul_c_u64(vshlq_u64(f1_hi, neg_bits), c), + ); + + if Self::BITS < 32 { + let result = vcombine_u32(vmovn_u64(f2_lo), vmovn_u64(f2_hi)); + let p = vdupq_n_u32(P); + vminq_u32(result, vsubq_u32(result, p)) + } else { + let p_u64 = vdupq_n_u64(P as u64); + + let red_lo = vsubq_u64(f2_lo, p_u64); + let keep_lo = vcltq_u64(f2_lo, p_u64); + let out_lo = vbslq_u64(keep_lo, f2_lo, red_lo); + + let red_hi = vsubq_u64(f2_hi, p_u64); + let keep_hi = vcltq_u64(f2_hi, p_u64); + let out_hi = vbslq_u64(keep_hi, f2_hi, red_hi); + + vcombine_u32(vmovn_u64(out_lo), vmovn_u64(out_hi)) + } + } + } +} + +impl Default for PackedFp32Neon

{ + #[inline] + fn default() -> Self { + Self { vals: [0; 4] } + } +} + +impl fmt::Debug for PackedFp32Neon

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp32Neon").field(&self.vals).finish() + } +} + +impl PartialEq for PackedFp32Neon

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.vals == other.vals + } +} + +impl Eq for PackedFp32Neon

{} + +impl Add for PackedFp32Neon

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + let a = to_vec32(self.vals); + let b = to_vec32(rhs.vals); + let result = unsafe { + let p = vdupq_n_u32(P); + if Self::BITS <= 31 { + let t = vaddq_u32(a, b); + vminq_u32(t, vsubq_u32(t, p)) + } else { + let c = vdupq_n_u32(Self::C); + let t = vaddq_u32(a, b); + let overflow = vcltq_u32(t, a); + let t_plus_c = vaddq_u32(t, c); + let no_of = vminq_u32(t, vsubq_u32(t, p)); + vbslq_u32(overflow, t_plus_c, no_of) + } + }; + Self { + vals: from_vec32(result), + } + } +} + +impl Sub for PackedFp32Neon

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + let a = to_vec32(self.vals); + let b = to_vec32(rhs.vals); + let result = unsafe { + let p = vdupq_n_u32(P); + if Self::BITS <= 31 { + let t = vsubq_u32(a, b); + vminq_u32(t, vaddq_u32(t, p)) + } else { + let t = vsubq_u32(a, b); + let underflow = vcltq_u32(a, b); + vbslq_u32(underflow, vaddq_u32(t, p), t) + } + }; + Self { + vals: from_vec32(result), + } + } +} + +impl Mul for PackedFp32Neon

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let a = to_vec32(self.vals); + let b = to_vec32(rhs.vals); + let result = unsafe { + let prod_lo = vmull_u32(vget_low_u32(a), vget_low_u32(b)); + let prod_hi = vmull_high_u32(a, b); + Self::solinas_reduce(prod_lo, prod_hi) + }; + Self { + vals: from_vec32(result), + } + } +} + +impl PackedValue for PackedFp32Neon

{ + type Value = Fp32

; + const WIDTH: usize = FP32_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self { + vals: [f(0).0, f(1).0, f(2).0, f(3).0], + } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP32_WIDTH); + Fp32(self.vals[lane]) + } +} + +impl AddAssign for PackedFp32Neon

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp32Neon

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp32Neon

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp32Neon

{ + type Scalar = Fp32

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self { vals: [value.0; 4] } + } +} + +/// Number of packed `Fp64` lanes. +pub const FP64_WIDTH: usize = 2; + +/// NEON packed `Fp64` backend: 2 lanes in `uint64x2_t`. +#[derive(Clone, Copy)] +pub struct PackedFp64Neon { + vals: [u64; 2], +} + +impl PackedFp64Neon

{ + const BITS: u32 = 64 - P.leading_zeros(); + + const C_LO: u64 = { + let c = if Self::BITS == 64 { + 0u64.wrapping_sub(P) + } else { + (1u64 << Self::BITS) - P + }; + assert!(P != 0, "modulus must be nonzero"); + assert!(P & 1 == 1, "modulus must be odd"); + c + }; + + const MASK64: u64 = if Self::BITS < 64 { + (1u64 << Self::BITS) - 1 + } else { + u64::MAX + }; + + const MASK_U128: u128 = if Self::BITS == 64 { + u64::MAX as u128 + } else { + (1u128 << Self::BITS) - 1 + }; + + const FOLD_IN_U64: bool = + Self::BITS < 64 && (Self::C_LO as u128) < (1u128 << (64 - Self::BITS)); + + #[inline(always)] + fn mul_c_narrow(x: u64) -> u64 { + Self::C_LO.wrapping_mul(x) + } + + #[inline(always)] + fn reduce_product(x: u128) -> u64 { + if Self::FOLD_IN_U64 { + let lo = x as u64; + let hi = (x >> 64) as u64; + let high = (lo >> Self::BITS) | (hi << (64 - Self::BITS)); + let f1 = (lo & Self::MASK64).wrapping_add(Self::mul_c_narrow(high)); + let f2 = (f1 & Self::MASK64).wrapping_add(Self::mul_c_narrow(f1 >> Self::BITS)); + let reduced = f2.wrapping_sub(P); + let borrow = reduced >> 63; + reduced.wrapping_add(borrow.wrapping_neg() & P) + } else { + let f1 = + (x & Self::MASK_U128) + (Self::C_LO as u128) * ((x >> Self::BITS) as u64 as u128); + let f2 = + (f1 & Self::MASK_U128) + (Self::C_LO as u128) * ((f1 >> Self::BITS) as u64 as u128); + let reduced = f2.wrapping_sub(P as u128); + let borrow = reduced >> 127; + reduced.wrapping_add(borrow.wrapping_neg() & (P as u128)) as u64 + } + } +} + +impl Default for PackedFp64Neon

{ + #[inline] + fn default() -> Self { + Self { vals: [0; 2] } + } +} + +impl fmt::Debug for PackedFp64Neon

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PackedFp64Neon").field(&self.vals).finish() + } +} + +impl PartialEq for PackedFp64Neon

{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.vals == other.vals + } +} + +impl Eq for PackedFp64Neon

{} + +impl Add for PackedFp64Neon

{ + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + let a = to_vec(self.vals); + let b = to_vec(rhs.vals); + let result = unsafe { + let p = vdupq_n_u64(P); + if Self::BITS <= 62 { + let s = vaddq_u64(a, b); + let r = vsubq_u64(s, p); + let borrow = vcltq_u64(s, p); + vbslq_u64(borrow, s, r) + } else { + let s = vaddq_u64(a, b); + let overflow = vcltq_u64(s, a); + let c = vdupq_n_u64(Self::C_LO); + let s_plus_c = vaddq_u64(s, c); + let s_minus_p = vsubq_u64(s, p); + let borrow = vcltq_u64(s, p); + let no_of = vbslq_u64(borrow, s, s_minus_p); + vbslq_u64(overflow, s_plus_c, no_of) + } + }; + Self { + vals: from_vec(result), + } + } +} + +impl Sub for PackedFp64Neon

{ + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + let a = to_vec(self.vals); + let b = to_vec(rhs.vals); + let result = unsafe { + let p = vdupq_n_u64(P); + let d = vsubq_u64(a, b); + let underflow = vcltq_u64(a, b); + vbslq_u64(underflow, vaddq_u64(d, p), d) + }; + Self { + vals: from_vec(result), + } + } +} + +impl Mul for PackedFp64Neon

{ + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let x0 = (self.vals[0] as u128) * (rhs.vals[0] as u128); + let x1 = (self.vals[1] as u128) * (rhs.vals[1] as u128); + let r0 = Self::reduce_product(x0); + let r1 = Self::reduce_product(x1); + Self { vals: [r0, r1] } + } +} + +impl PackedValue for PackedFp64Neon

{ + type Value = Fp64

; + const WIDTH: usize = FP64_WIDTH; + + #[inline] + fn from_fn(mut f: F) -> Self + where + F: FnMut(usize) -> Self::Value, + { + Self { + vals: [f(0).0, f(1).0], + } + } + + #[inline] + fn extract(&self, lane: usize) -> Self::Value { + debug_assert!(lane < FP64_WIDTH); + Fp64(self.vals[lane]) + } +} + +impl AddAssign for PackedFp64Neon

{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl SubAssign for PackedFp64Neon

{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl MulAssign for PackedFp64Neon

{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl PackedField for PackedFp64Neon

{ + type Scalar = Fp64

; + + #[inline] + fn broadcast(value: Self::Scalar) -> Self { + Self { vals: [value.0; 2] } + } +} diff --git a/src/algebra/fields/pseudo_mersenne.rs b/src/algebra/fields/pseudo_mersenne.rs new file mode 100644 index 00000000..17920899 --- /dev/null +++ b/src/algebra/fields/pseudo_mersenne.rs @@ -0,0 +1,177 @@ +//! `2^k - offset` pseudo-Mersenne registry and field aliases. +//! +//! This module models the specific flavor where each modulus is the smallest +//! prime below `2^k` with `q % 8 == 5`, written as `q = 2^k - offset`. + +use super::{Fp128, Fp32, Fp64}; + +/// Offset table (`q = 2^k - offset[k]`) imported from `labrador/data.py`. +pub const POW2_OFFSET_TABLE: [i16; 256] = [ + -1, -1, -1, 3, 3, 3, 3, 19, 27, 3, 3, 19, 3, 75, 3, 19, 99, 91, 11, 19, 3, 19, 3, 27, 3, 91, + 27, 115, 299, 3, 35, 19, 99, 355, 131, 451, 243, 123, 107, 19, 195, 75, 11, 67, 539, 139, 635, + 115, 59, 123, 27, 139, 395, 315, 131, 67, 27, 195, 27, 99, 107, 259, 171, 259, 59, 115, 203, + 19, 83, 19, 35, 411, 107, 475, 35, 427, 123, 43, 11, 67, 1307, 51, 315, 139, 35, 19, 35, 67, + 299, 99, 75, 315, 83, 51, 3, 211, 147, 595, 51, 115, 99, 99, 483, 339, 395, 139, 1187, 171, 59, + 91, 195, 835, 75, 211, 11, 67, 3, 451, 563, 867, 395, 531, 3, 67, 59, 579, 203, 507, 275, 315, + 27, 315, 347, 99, 603, 795, 243, 339, 203, 187, 27, 171, 1491, 355, 83, 355, 1371, 387, 347, + 99, 3, 195, 539, 171, 243, 499, 195, 19, 155, 91, 75, 1011, 627, 867, 155, 115, 1811, 771, + 1467, 643, 195, 19, 155, 531, 3, 267, 563, 339, 563, 507, 107, 283, 267, 147, 59, 339, 371, + 1411, 363, 819, 11, 19, 915, 123, 75, 915, 459, 75, 627, 459, 75, 1035, 195, 187, 1515, 1219, + 1443, 91, 299, 451, 171, 1099, 99, 3, 395, 1147, 683, 675, 243, 355, 395, 3, 875, 235, 363, + 1131, 155, 835, 723, 91, 27, 235, 875, 3, 83, 259, 875, 1515, 731, 531, 467, 819, 267, 475, + 1923, 163, 107, 411, 387, 75, 2331, 355, 1515, 1723, 1427, 19, +]; + +/// Maximum supported offset in this `2^k - offset` specialization. +pub const POW2_OFFSET_MAX: u128 = 1u128 << 16; + +/// Current active bit-size bound for concrete field aliases in this phase. +pub const POW2_OFFSET_IMPLEMENTED_MAX_BITS: u32 = 128; + +/// Metadata describing a `2^k - offset` pseudo-Mersenne modulus. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Pow2OffsetPrimeSpec { + /// `k` in `2^k - offset`. + pub bits: u32, + /// `offset` in `2^k - offset`. + pub offset: u16, + /// Modulus value. + pub modulus: u128, +} + +/// Return table offset for `q = 2^k - offset` when available and positive. +pub const fn pow2_offset(bits: u32) -> Option { + if bits as usize >= POW2_OFFSET_TABLE.len() { + return None; + } + let offset = POW2_OFFSET_TABLE[bits as usize]; + if offset <= 0 { + None + } else { + Some(offset as u16) + } +} + +/// Compute `2^k - offset` for `k <= 128`. +pub const fn pseudo_mersenne_modulus(bits: u32, offset: u128) -> Option { + if bits == 0 || bits > 128 || offset == 0 { + return None; + } + if bits == 128 { + Some(u128::MAX - (offset - 1)) + } else { + Some((1u128 << bits) - offset) + } +} + +/// Check whether `(k, offset)` is accepted by the `2^k - offset` policy. +pub const fn is_pow2_offset(bits: u32, offset: u128) -> bool { + if bits > POW2_OFFSET_IMPLEMENTED_MAX_BITS || offset > POW2_OFFSET_MAX { + return false; + } + match pow2_offset(bits) { + Some(qoff) => (qoff as u128) == offset, + None => false, + } +} + +/// `offset` for `k = 24`. +pub const POW2_OFFSET_24: u16 = 3; +/// `offset` for `k = 30`. +pub const POW2_OFFSET_30: u16 = 35; +/// `offset` for `k = 31`. +pub const POW2_OFFSET_31: u16 = 19; +/// `offset` for `k = 32`. +pub const POW2_OFFSET_32: u16 = 99; +/// `offset` for `k = 40`. +pub const POW2_OFFSET_40: u16 = 195; +/// `offset` for `k = 48`. +pub const POW2_OFFSET_48: u16 = 59; +/// `offset` for `k = 56`. +pub const POW2_OFFSET_56: u16 = 27; +/// `offset` for `k = 64`. +pub const POW2_OFFSET_64: u16 = 59; +/// `offset` for `k = 128`. +pub const POW2_OFFSET_128: u16 = 275; + +/// `2^24 - 3`. +pub const POW2_OFFSET_MODULUS_24: u32 = ((1u128 << 24) - (POW2_OFFSET_24 as u128)) as u32; +/// `2^30 - 35`. +pub const POW2_OFFSET_MODULUS_30: u32 = ((1u128 << 30) - (POW2_OFFSET_30 as u128)) as u32; +/// `2^31 - 19`. +pub const POW2_OFFSET_MODULUS_31: u32 = ((1u128 << 31) - (POW2_OFFSET_31 as u128)) as u32; +/// `2^32 - 99`. +pub const POW2_OFFSET_MODULUS_32: u32 = ((1u128 << 32) - (POW2_OFFSET_32 as u128)) as u32; +/// `2^40 - 195`. +pub const POW2_OFFSET_MODULUS_40: u64 = ((1u128 << 40) - (POW2_OFFSET_40 as u128)) as u64; +/// `2^48 - 59`. +pub const POW2_OFFSET_MODULUS_48: u64 = ((1u128 << 48) - (POW2_OFFSET_48 as u128)) as u64; +/// `2^56 - 27`. +pub const POW2_OFFSET_MODULUS_56: u64 = ((1u128 << 56) - (POW2_OFFSET_56 as u128)) as u64; +/// `2^64 - 59`. +pub const POW2_OFFSET_MODULUS_64: u64 = u64::MAX - ((POW2_OFFSET_64 as u64) - 1); +/// `2^128 - 275`. +pub const POW2_OFFSET_MODULUS_128: u128 = u128::MAX - (POW2_OFFSET_128 as u128 - 1); + +/// Alias for `2^24 - offset`. +pub type Pow2Offset24Field = Fp32; +/// Alias for `2^30 - offset`. +pub type Pow2Offset30Field = Fp32; +/// Alias for `2^31 - offset`. +pub type Pow2Offset31Field = Fp32; +/// Alias for `2^32 - offset`. +pub type Pow2Offset32Field = Fp32; +/// Alias for `2^40 - offset`. +pub type Pow2Offset40Field = Fp64; +/// Alias for `2^48 - offset`. +pub type Pow2Offset48Field = Fp64; +/// Alias for `2^56 - offset`. +pub type Pow2Offset56Field = Fp64; +/// Alias for `2^64 - offset`. +pub type Pow2Offset64Field = Fp64; +/// Alias for `2^128 - offset`. +pub type Pow2Offset128Field = Fp128; + +/// `2^k - offset` profiles currently enabled in-code. +/// +/// Each listed modulus satisfies `q % 8 == 5`. +pub const POW2_OFFSET_PRIMES: [Pow2OffsetPrimeSpec; 7] = [ + Pow2OffsetPrimeSpec { + bits: 24, + offset: POW2_OFFSET_24, + modulus: POW2_OFFSET_MODULUS_24 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 32, + offset: POW2_OFFSET_32, + modulus: POW2_OFFSET_MODULUS_32 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 40, + offset: POW2_OFFSET_40, + modulus: POW2_OFFSET_MODULUS_40 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 48, + offset: POW2_OFFSET_48, + modulus: POW2_OFFSET_MODULUS_48 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 56, + offset: POW2_OFFSET_56, + modulus: POW2_OFFSET_MODULUS_56 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 64, + offset: POW2_OFFSET_64, + modulus: POW2_OFFSET_MODULUS_64 as u128, + }, + Pow2OffsetPrimeSpec { + bits: 128, + offset: POW2_OFFSET_128, + modulus: POW2_OFFSET_MODULUS_128, + }, +]; + +// All PseudoMersenneField impls for Fp32/Fp64/Fp128 are blanket impls in +// their respective modules (fp32.rs, fp64.rs, fp128.rs). diff --git a/src/algebra/fields/util.rs b/src/algebra/fields/util.rs new file mode 100644 index 00000000..9496f1ce --- /dev/null +++ b/src/algebra/fields/util.rs @@ -0,0 +1,38 @@ +//! Shared helpers for field arithmetic backends. + +#[inline(always)] +pub(crate) const fn is_pow2_u64(x: u64) -> bool { + x != 0 && (x & (x - 1)) == 0 +} + +#[inline(always)] +pub(crate) const fn log2_pow2_u64(mut x: u64) -> u32 { + let mut k = 0u32; + while x > 1 { + x >>= 1; + k += 1; + } + k +} + +/// `a * b` widening to 128 bits; returns `(lo64, hi64)`. +#[inline(always)] +pub(crate) fn mul64_wide(a: u64, b: u64) -> (u64, u64) { + #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] + { + unsafe { mul64_wide_bmi2(a, b) } + } + #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))] + { + let prod = (a as u128) * (b as u128); + (prod as u64, (prod >> 64) as u64) + } +} + +#[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] +#[inline(always)] +unsafe fn mul64_wide_bmi2(a: u64, b: u64) -> (u64, u64) { + let mut hi = 0; + let lo = unsafe { std::arch::x86_64::_mulx_u64(a, b, &mut hi) }; + (lo, hi) +} diff --git a/src/algebra/fields/wide.rs b/src/algebra/fields/wide.rs new file mode 100644 index 00000000..e160f3d1 --- /dev/null +++ b/src/algebra/fields/wide.rs @@ -0,0 +1,1238 @@ +//! Wide unreduced field accumulators for carry-free signed addition. +//! +//! Each type splits a canonical field element into 16-bit limbs stored in +//! `i32` slots. Addition and negation are element-wise i32 ops — no carry +//! propagation, no modular reduction. Reduction back to canonical form +//! happens once after accumulation via [`reduce`](Fp128x8i32::reduce). +//! +//! The i32 overflow budget is `i32::MAX / u16::MAX ≈ 32,769` signed +//! additions before any limb can overflow. + +use std::ops::{Add, AddAssign, Neg, Sub, SubAssign}; + +use crate::{AdditiveGroup, CanonicalField, FieldCore}; + +use super::fp128::Fp128; +use super::fp32::Fp32; +use super::fp64::Fp64; + +/// Wide unreduced accumulator for `Fp32`: 2 × i32 limbs (16-bit data each). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(C)] +pub struct Fp32x2i32(pub [i32; 2]); + +impl Fp32x2i32 { + /// Returns the zero accumulator. + #[inline] + pub fn zero() -> Self { + ::ZERO + } +} + +impl From> for Fp32x2i32 { + #[inline] + fn from(x: Fp32

) -> Self { + let v = x.0; + Self([(v & 0xFFFF) as i32, (v >> 16) as i32]) + } +} + +impl Fp32x2i32 { + /// Multiply every limb by a small signed scalar. + /// + /// Safe when `|small| * max_limb_magnitude` fits in i32. After `From`, + /// limbs are in `[0, 0xFFFF]`, so `|small| ≤ 32_767` is safe for a single + /// product. For accumulation of `k` scaled values, require + /// `k * |small| * 0xFFFF < i32::MAX`, i.e. roughly `k * |small| < 32_768`. + #[inline] + pub fn scale_i32(self, small: i32) -> Self { + Self([self.0[0] * small, self.0[1] * small]) + } + + /// Reduce back to canonical `Fp32

`. + /// + /// Carry-propagates the i32 limbs into a signed value, normalizes to + /// `[0, p)`, and returns the canonical field element. + #[inline] + pub fn reduce(self) -> Fp32

{ + let [l0, l1] = self.0; + // Carry-propagate: value = l0 + l1 * 2^16 + let wide = l0 as i64 + (l1 as i64) * (1i64 << 16); + // Normalize to [0, p) + let p = P as i64; + let normalized = ((wide % p) + p) % p; + Fp32::from_canonical_u32(normalized as u32) + } +} + +impl Add for Fp32x2i32 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([self.0[0] + rhs.0[0], self.0[1] + rhs.0[1]]) + } +} + +impl AddAssign for Fp32x2i32 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0[0] += rhs.0[0]; + self.0[1] += rhs.0[1]; + } +} + +impl Sub for Fp32x2i32 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([self.0[0] - rhs.0[0], self.0[1] - rhs.0[1]]) + } +} + +impl SubAssign for Fp32x2i32 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0[0] -= rhs.0[0]; + self.0[1] -= rhs.0[1]; + } +} + +impl Neg for Fp32x2i32 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self([-self.0[0], -self.0[1]]) + } +} + +/// Wide unreduced accumulator for `Fp64`: 4 × i32 limbs (16-bit data each). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(C)] +pub struct Fp64x4i32(pub [i32; 4]); + +impl Fp64x4i32 { + /// Returns the zero accumulator. + #[inline] + pub fn zero() -> Self { + ::ZERO + } +} + +impl From> for Fp64x4i32 { + #[inline] + fn from(x: Fp64

) -> Self { + let v = x.0; + Self([ + (v & 0xFFFF) as i32, + ((v >> 16) & 0xFFFF) as i32, + ((v >> 32) & 0xFFFF) as i32, + ((v >> 48) & 0xFFFF) as i32, + ]) + } +} + +impl Fp64x4i32 { + /// Multiply every limb by a small signed scalar. See [`Fp32x2i32::scale_i32`]. + #[inline] + pub fn scale_i32(self, small: i32) -> Self { + Self([ + self.0[0] * small, + self.0[1] * small, + self.0[2] * small, + self.0[3] * small, + ]) + } + + /// Reduce back to canonical `Fp64

`. + #[inline] + pub fn reduce(self) -> Fp64

{ + let [l0, l1, l2, l3] = self.0; + // Carry-propagate: value = l0 + l1*2^16 + l2*2^32 + l3*2^48 + let wide = l0 as i128 + + (l1 as i128) * (1i128 << 16) + + (l2 as i128) * (1i128 << 32) + + (l3 as i128) * (1i128 << 48); + let p = P as i128; + let normalized = ((wide % p) + p) % p; + Fp64::

::from_canonical_u64(normalized as u64) + } +} + +#[cfg(target_arch = "aarch64")] +impl Add for Fp64x4i32 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + use std::arch::aarch64::*; + let a = vld1q_s32(self.0.as_ptr()); + let b = vld1q_s32(rhs.0.as_ptr()); + let mut out = [0i32; 4]; + vst1q_s32(out.as_mut_ptr(), vaddq_s32(a, b)); + Self(out) + } + } +} + +#[cfg(target_arch = "aarch64")] +impl AddAssign for Fp64x4i32 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +#[cfg(target_arch = "aarch64")] +impl Sub for Fp64x4i32 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + use std::arch::aarch64::*; + let a = vld1q_s32(self.0.as_ptr()); + let b = vld1q_s32(rhs.0.as_ptr()); + let mut out = [0i32; 4]; + vst1q_s32(out.as_mut_ptr(), vsubq_s32(a, b)); + Self(out) + } + } +} + +#[cfg(target_arch = "aarch64")] +impl SubAssign for Fp64x4i32 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +#[cfg(target_arch = "aarch64")] +impl Neg for Fp64x4i32 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + unsafe { + use std::arch::aarch64::*; + let a = vld1q_s32(self.0.as_ptr()); + let mut out = [0i32; 4]; + vst1q_s32(out.as_mut_ptr(), vnegq_s32(a)); + Self(out) + } + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl Add for Fp64x4i32 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([ + self.0[0] + rhs.0[0], + self.0[1] + rhs.0[1], + self.0[2] + rhs.0[2], + self.0[3] + rhs.0[3], + ]) + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl AddAssign for Fp64x4i32 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0[0] += rhs.0[0]; + self.0[1] += rhs.0[1]; + self.0[2] += rhs.0[2]; + self.0[3] += rhs.0[3]; + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl Sub for Fp64x4i32 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([ + self.0[0] - rhs.0[0], + self.0[1] - rhs.0[1], + self.0[2] - rhs.0[2], + self.0[3] - rhs.0[3], + ]) + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl SubAssign for Fp64x4i32 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0[0] -= rhs.0[0]; + self.0[1] -= rhs.0[1]; + self.0[2] -= rhs.0[2]; + self.0[3] -= rhs.0[3]; + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl Neg for Fp64x4i32 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self([-self.0[0], -self.0[1], -self.0[2], -self.0[3]]) + } +} + +/// Wide unreduced accumulator for `Fp128`: 8 × i32 limbs (16-bit data each). +/// +/// On AVX2, one element fits a single 256-bit YMM register. On NEON, it +/// spans two 128-bit Q registers. All arithmetic is carry-free element-wise +/// i32 operations. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(C)] +pub struct Fp128x8i32(pub [i32; 8]); + +impl Fp128x8i32 { + /// Returns the zero accumulator. + #[inline] + pub fn zero() -> Self { + ::ZERO + } +} + +impl From> for Fp128x8i32 { + #[inline] + fn from(x: Fp128

) -> Self { + let lo = x.0[0]; + let hi = x.0[1]; + Self([ + (lo & 0xFFFF) as i32, + ((lo >> 16) & 0xFFFF) as i32, + ((lo >> 32) & 0xFFFF) as i32, + ((lo >> 48) & 0xFFFF) as i32, + (hi & 0xFFFF) as i32, + ((hi >> 16) & 0xFFFF) as i32, + ((hi >> 32) & 0xFFFF) as i32, + ((hi >> 48) & 0xFFFF) as i32, + ]) + } +} + +impl Fp128x8i32 { + /// Multiply every limb by a small signed scalar. See [`Fp32x2i32::scale_i32`]. + #[inline] + pub fn scale_i32(self, small: i32) -> Self { + Self([ + self.0[0] * small, + self.0[1] * small, + self.0[2] * small, + self.0[3] * small, + self.0[4] * small, + self.0[5] * small, + self.0[6] * small, + self.0[7] * small, + ]) + } + + /// Reduce back to canonical `Fp128

`. + /// + /// Carry-propagates the 8 × i32 limbs into unsigned u64 limbs, then + /// applies Solinas reduction. + #[inline] + pub fn reduce(self) -> Fp128

{ + let limbs = self.0; + + // Carry-propagate from low to high, accumulating into i64 slots. + // Each i32 limb can be in [-32769*65535, 32769*65535] ≈ ±2^31. + // After propagation, each 16-bit "digit" is in [0, 65535] and we + // may have a signed residual in the top that overflows 128 bits. + let mut carry: i64 = 0; + let mut digits = [0u16; 8]; + for i in 0..8 { + let v = limbs[i] as i64 + carry; + // Arithmetic right-shift to propagate sign correctly + digits[i] = (v & 0xFFFF) as u16; + carry = v >> 16; + } + + // Reassemble into u64 limbs + let lo = digits[0] as u64 + | (digits[1] as u64) << 16 + | (digits[2] as u64) << 32 + | (digits[3] as u64) << 48; + let hi = digits[4] as u64 + | (digits[5] as u64) << 16 + | (digits[6] as u64) << 32 + | (digits[7] as u64) << 48; + + // p = 2^128 - c, so 2^128 ≡ c (mod p). + // value = lo + hi*2^64 + carry*2^128 ≡ lo + hi*2^64 + carry*c (mod p). + let c = Fp128::

::C_LO; + if carry == 0 { + Fp128::

::from_canonical_u128_reduced(lo as u128 | (hi as u128) << 64) + } else if carry > 0 { + Fp128::

::solinas_reduce(&[lo, hi, carry as u64]) + } else { + // carry < 0: value = base - |carry|*c. + let neg_carry = (-carry) as u64; + let sub = neg_carry as u128 * c as u128; + let base = lo as u128 | (hi as u128) << 64; + if base >= sub { + Fp128::

::from_canonical_u128_reduced(base - sub) + } else { + let diff = sub - base; + Fp128::

::from_canonical_u128_reduced(P - diff) + } + } + } +} + +#[cfg(target_arch = "aarch64")] +impl Add for Fp128x8i32 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + unsafe { + use std::arch::aarch64::*; + let a0 = vld1q_s32(self.0.as_ptr()); + let a1 = vld1q_s32(self.0.as_ptr().add(4)); + let b0 = vld1q_s32(rhs.0.as_ptr()); + let b1 = vld1q_s32(rhs.0.as_ptr().add(4)); + let mut out = [0i32; 8]; + vst1q_s32(out.as_mut_ptr(), vaddq_s32(a0, b0)); + vst1q_s32(out.as_mut_ptr().add(4), vaddq_s32(a1, b1)); + Self(out) + } + } +} + +#[cfg(target_arch = "aarch64")] +impl AddAssign for Fp128x8i32 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +#[cfg(target_arch = "aarch64")] +impl Sub for Fp128x8i32 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + unsafe { + use std::arch::aarch64::*; + let a0 = vld1q_s32(self.0.as_ptr()); + let a1 = vld1q_s32(self.0.as_ptr().add(4)); + let b0 = vld1q_s32(rhs.0.as_ptr()); + let b1 = vld1q_s32(rhs.0.as_ptr().add(4)); + let mut out = [0i32; 8]; + vst1q_s32(out.as_mut_ptr(), vsubq_s32(a0, b0)); + vst1q_s32(out.as_mut_ptr().add(4), vsubq_s32(a1, b1)); + Self(out) + } + } +} + +#[cfg(target_arch = "aarch64")] +impl SubAssign for Fp128x8i32 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +#[cfg(target_arch = "aarch64")] +impl Neg for Fp128x8i32 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + unsafe { + use std::arch::aarch64::*; + let a0 = vld1q_s32(self.0.as_ptr()); + let a1 = vld1q_s32(self.0.as_ptr().add(4)); + let mut out = [0i32; 8]; + vst1q_s32(out.as_mut_ptr(), vnegq_s32(a0)); + vst1q_s32(out.as_mut_ptr().add(4), vnegq_s32(a1)); + Self(out) + } + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl Add for Fp128x8i32 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([ + self.0[0] + rhs.0[0], + self.0[1] + rhs.0[1], + self.0[2] + rhs.0[2], + self.0[3] + rhs.0[3], + self.0[4] + rhs.0[4], + self.0[5] + rhs.0[5], + self.0[6] + rhs.0[6], + self.0[7] + rhs.0[7], + ]) + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl AddAssign for Fp128x8i32 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0[0] += rhs.0[0]; + self.0[1] += rhs.0[1]; + self.0[2] += rhs.0[2]; + self.0[3] += rhs.0[3]; + self.0[4] += rhs.0[4]; + self.0[5] += rhs.0[5]; + self.0[6] += rhs.0[6]; + self.0[7] += rhs.0[7]; + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl Sub for Fp128x8i32 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([ + self.0[0] - rhs.0[0], + self.0[1] - rhs.0[1], + self.0[2] - rhs.0[2], + self.0[3] - rhs.0[3], + self.0[4] - rhs.0[4], + self.0[5] - rhs.0[5], + self.0[6] - rhs.0[6], + self.0[7] - rhs.0[7], + ]) + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl SubAssign for Fp128x8i32 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0[0] -= rhs.0[0]; + self.0[1] -= rhs.0[1]; + self.0[2] -= rhs.0[2]; + self.0[3] -= rhs.0[3]; + self.0[4] -= rhs.0[4]; + self.0[5] -= rhs.0[5]; + self.0[6] -= rhs.0[6]; + self.0[7] -= rhs.0[7]; + } +} + +#[cfg(not(target_arch = "aarch64"))] +impl Neg for Fp128x8i32 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self([ + -self.0[0], -self.0[1], -self.0[2], -self.0[3], -self.0[4], -self.0[5], -self.0[6], + -self.0[7], + ]) + } +} + +impl AdditiveGroup for Fp32x2i32 { + const ZERO: Self = Self([0; 2]); +} + +impl AdditiveGroup for Fp64x4i32 { + const ZERO: Self = Self([0; 4]); +} + +impl AdditiveGroup for Fp128x8i32 { + const ZERO: Self = Self([0; 8]); +} + +/// Accumulator for `Fp64 × u64` products (also used for `Fp64 × Fp64`). +/// +/// Each product is ≤ 128 bits, split into two u64 halves stored as u128 slots. +/// Headroom: 2^64 additions per slot before overflow. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Fp64ProductAccum(pub [u128; 2]); + +impl Fp64ProductAccum { + /// Reduce accumulated products to a canonical `Fp64

`. + #[inline] + pub fn reduce(self) -> Fp64

{ + let [s0, s1] = self.0; + // s0 = Σ lo_i, s1 = Σ hi_i; value = s0 + s1 * 2^64 + let a = Fp64::

::solinas_reduce(s0); + let b = Fp64::

::solinas_reduce(s1); + let shift = Fp64::

::solinas_reduce(1u128 << 64); + let b_shifted = Fp64::

::solinas_reduce(b.mul_wide_u64(shift.to_limbs())); + a + b_shifted + } +} + +impl From> for Fp64ProductAccum { + #[inline] + fn from(x: Fp64

) -> Self { + Self([x.to_limbs() as u128, 0]) + } +} + +impl AdditiveGroup for Fp64ProductAccum { + const ZERO: Self = Self([0; 2]); +} + +impl Add for Fp64ProductAccum { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([self.0[0] + rhs.0[0], self.0[1] + rhs.0[1]]) + } +} +impl AddAssign for Fp64ProductAccum { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0[0] += rhs.0[0]; + self.0[1] += rhs.0[1]; + } +} +impl Sub for Fp64ProductAccum { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([ + self.0[0].wrapping_sub(rhs.0[0]), + self.0[1].wrapping_sub(rhs.0[1]), + ]) + } +} +impl SubAssign for Fp64ProductAccum { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0[0] = self.0[0].wrapping_sub(rhs.0[0]); + self.0[1] = self.0[1].wrapping_sub(rhs.0[1]); + } +} +impl Neg for Fp64ProductAccum { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self([self.0[0].wrapping_neg(), self.0[1].wrapping_neg()]) + } +} + +/// Accumulator for `Fp128 × u64` products. +/// +/// Each `mul_wide_u64` produces 3 u64 limbs; stored as `[u128; 3]`. +/// Headroom: 2^64 additions per slot. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Fp128MulU64Accum(pub [u128; 3]); + +impl Fp128MulU64Accum { + /// Reduce to canonical `Fp128

`. + #[inline] + pub fn reduce(self) -> Fp128

{ + let [s0, s1, s2] = self.0; + let c0 = s0 >> 64; + let r0 = s0 as u64; + let t1 = s1 + c0; + let r1 = t1 as u64; + let c1 = t1 >> 64; + let t2 = s2 + c1; + let r2 = t2 as u64; + let r3 = (t2 >> 64) as u64; + Fp128::

::solinas_reduce(&[r0, r1, r2, r3]) + } +} + +impl From> for Fp128MulU64Accum { + #[inline] + fn from(x: Fp128

) -> Self { + let [lo, hi] = x.to_limbs(); + Self([lo as u128, hi as u128, 0]) + } +} + +impl AdditiveGroup for Fp128MulU64Accum { + const ZERO: Self = Self([0; 3]); +} + +impl Add for Fp128MulU64Accum { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([ + self.0[0] + rhs.0[0], + self.0[1] + rhs.0[1], + self.0[2] + rhs.0[2], + ]) + } +} +impl AddAssign for Fp128MulU64Accum { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0[0] += rhs.0[0]; + self.0[1] += rhs.0[1]; + self.0[2] += rhs.0[2]; + } +} +impl Sub for Fp128MulU64Accum { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([ + self.0[0].wrapping_sub(rhs.0[0]), + self.0[1].wrapping_sub(rhs.0[1]), + self.0[2].wrapping_sub(rhs.0[2]), + ]) + } +} +impl SubAssign for Fp128MulU64Accum { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0[0] = self.0[0].wrapping_sub(rhs.0[0]); + self.0[1] = self.0[1].wrapping_sub(rhs.0[1]); + self.0[2] = self.0[2].wrapping_sub(rhs.0[2]); + } +} +impl Neg for Fp128MulU64Accum { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self([ + self.0[0].wrapping_neg(), + self.0[1].wrapping_neg(), + self.0[2].wrapping_neg(), + ]) + } +} + +/// Accumulator for `Fp128 × Fp128` products. +/// +/// Each `mul_wide` produces 4 u64 limbs; stored as `[u128; 4]`. +/// Headroom: 2^64 additions per slot. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Fp128ProductAccum(pub [u128; 4]); + +impl Fp128ProductAccum { + /// Reduce to canonical `Fp128

`. + #[inline] + pub fn reduce(self) -> Fp128

{ + let [s0, s1, s2, s3] = self.0; + let c0 = s0 >> 64; + let r0 = s0 as u64; + let t1 = s1 + c0; + let r1 = t1 as u64; + let c1 = t1 >> 64; + let t2 = s2 + c1; + let r2 = t2 as u64; + let c2 = t2 >> 64; + let t3 = s3 + c2; + let r3 = t3 as u64; + let r4 = (t3 >> 64) as u64; + Fp128::

::solinas_reduce(&[r0, r1, r2, r3, r4]) + } +} + +impl From> for Fp128ProductAccum { + #[inline] + fn from(x: Fp128

) -> Self { + let [lo, hi] = x.to_limbs(); + Self([lo as u128, hi as u128, 0, 0]) + } +} + +impl AdditiveGroup for Fp128ProductAccum { + const ZERO: Self = Self([0; 4]); +} + +impl Add for Fp128ProductAccum { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([ + self.0[0] + rhs.0[0], + self.0[1] + rhs.0[1], + self.0[2] + rhs.0[2], + self.0[3] + rhs.0[3], + ]) + } +} +impl AddAssign for Fp128ProductAccum { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0[0] += rhs.0[0]; + self.0[1] += rhs.0[1]; + self.0[2] += rhs.0[2]; + self.0[3] += rhs.0[3]; + } +} +impl Sub for Fp128ProductAccum { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([ + self.0[0].wrapping_sub(rhs.0[0]), + self.0[1].wrapping_sub(rhs.0[1]), + self.0[2].wrapping_sub(rhs.0[2]), + self.0[3].wrapping_sub(rhs.0[3]), + ]) + } +} +impl SubAssign for Fp128ProductAccum { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0[0] = self.0[0].wrapping_sub(rhs.0[0]); + self.0[1] = self.0[1].wrapping_sub(rhs.0[1]); + self.0[2] = self.0[2].wrapping_sub(rhs.0[2]); + self.0[3] = self.0[3].wrapping_sub(rhs.0[3]); + } +} +impl Neg for Fp128ProductAccum { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self([ + self.0[0].wrapping_neg(), + self.0[1].wrapping_neg(), + self.0[2].wrapping_neg(), + self.0[3].wrapping_neg(), + ]) + } +} + +/// Pair accumulator for extension fields. +/// +/// Wraps two base-field accumulators `(c0, c1)` component-wise. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AccumPair(pub A, pub A); + +impl AdditiveGroup for AccumPair { + const ZERO: Self = Self(A::ZERO, A::ZERO); +} + +impl Add for AccumPair { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self(self.0 + rhs.0, self.1 + rhs.1) + } +} +impl AddAssign for AccumPair { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0 += rhs.0; + self.1 += rhs.1; + } +} +impl Sub for AccumPair { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self(self.0 - rhs.0, self.1 - rhs.1) + } +} +impl SubAssign for AccumPair { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.0 -= rhs.0; + self.1 -= rhs.1; + } +} +impl Neg for AccumPair { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self(-self.0, -self.1) + } +} + +/// Reduce a wide unreduced accumulator back to a canonical field element. +pub trait ReduceTo { + /// Carry-propagate and reduce to a canonical field element. + fn reduce(self) -> F; +} + +impl ReduceTo> for Fp32x2i32 { + #[inline] + fn reduce(self) -> Fp32

{ + Fp32x2i32::reduce::

(self) + } +} + +impl ReduceTo> for Fp64x4i32 { + #[inline] + fn reduce(self) -> Fp64

{ + Fp64x4i32::reduce::

(self) + } +} + +impl ReduceTo> for Fp128x8i32 { + #[inline] + fn reduce(self) -> Fp128

{ + Fp128x8i32::reduce::

(self) + } +} + +/// Multi-level unreduced multiplication hierarchy. +/// +/// Provides `field × u64` and `field × field` widening multiplies that return +/// accumulator types supporting carry-free addition. Reduction back to a +/// canonical field element happens once after accumulation. +pub trait HasUnreducedOps: FieldCore { + /// Accumulator for `self × u64` products (narrower than full product). + type MulU64Accum: AdditiveGroup; + /// Accumulator for `self × self` products. + type ProductAccum: AdditiveGroup; + + /// Widening `self × small` with no reduction. + fn mul_u64_unreduced(self, small: u64) -> Self::MulU64Accum; + /// Widening `self × other` with no reduction. + fn mul_to_product_accum(self, other: Self) -> Self::ProductAccum; + + /// Reduce a narrow-mul accumulator to a canonical field element. + fn reduce_mul_u64_accum(accum: Self::MulU64Accum) -> Self; + /// Reduce a full-product accumulator to a canonical field element. + fn reduce_product_accum(accum: Self::ProductAccum) -> Self; +} + +impl HasUnreducedOps for Fp64

{ + type MulU64Accum = Fp64ProductAccum; + type ProductAccum = Fp64ProductAccum; + + #[inline] + fn mul_u64_unreduced(self, small: u64) -> Fp64ProductAccum { + let wide = self.mul_wide_u64(small); + Fp64ProductAccum([wide & u64::MAX as u128, wide >> 64]) + } + + #[inline] + fn mul_to_product_accum(self, other: Self) -> Fp64ProductAccum { + let wide = self.mul_wide(other); + Fp64ProductAccum([wide & u64::MAX as u128, wide >> 64]) + } + + #[inline] + fn reduce_mul_u64_accum(accum: Fp64ProductAccum) -> Self { + accum.reduce::

() + } + + #[inline] + fn reduce_product_accum(accum: Fp64ProductAccum) -> Self { + accum.reduce::

() + } +} + +impl HasUnreducedOps for Fp128

{ + type MulU64Accum = Fp128MulU64Accum; + type ProductAccum = Fp128ProductAccum; + + #[inline] + fn mul_u64_unreduced(self, small: u64) -> Fp128MulU64Accum { + let [lo, mid, hi] = self.mul_wide_u64(small); + Fp128MulU64Accum([lo as u128, mid as u128, hi as u128]) + } + + #[inline] + fn mul_to_product_accum(self, other: Self) -> Fp128ProductAccum { + let [r0, r1, r2, r3] = self.mul_wide(other); + Fp128ProductAccum([r0 as u128, r1 as u128, r2 as u128, r3 as u128]) + } + + #[inline] + fn reduce_mul_u64_accum(accum: Fp128MulU64Accum) -> Self { + accum.reduce::

() + } + + #[inline] + fn reduce_product_accum(accum: Fp128ProductAccum) -> Self { + accum.reduce::

() + } +} + +/// Element-wise scaling of a wide accumulator by a small signed integer. +pub trait ScaleI32 { + /// Scale each element by `small`. + fn scale_i32(self, small: i32) -> Self; +} + +impl ScaleI32 for Fp32x2i32 { + #[inline] + fn scale_i32(self, small: i32) -> Self { + self.scale_i32(small) + } +} + +impl ScaleI32 for Fp64x4i32 { + #[inline] + fn scale_i32(self, small: i32) -> Self { + self.scale_i32(small) + } +} + +impl ScaleI32 for Fp128x8i32 { + #[inline] + fn scale_i32(self, small: i32) -> Self { + self.scale_i32(small) + } +} + +/// Associates a field type with its wide unreduced accumulator. +pub trait HasWide: FieldCore { + /// The wide accumulator type. + type Wide: AdditiveGroup + From + ReduceTo + ScaleI32; + + /// Convert `self` to wide form and scale every limb by `small`. + /// + /// Equivalent to `Self::Wide::from(self).scale_i32(small)` but avoids + /// the trait-method ambiguity at call sites. + #[inline] + fn mul_small_to_wide(self, small: i32) -> Self::Wide { + Self::Wide::from(self).scale_i32(small) + } +} + +impl HasWide for Fp32

{ + type Wide = Fp32x2i32; +} + +impl HasWide for Fp64

{ + type Wide = Fp64x4i32; +} + +impl HasWide for Fp128

{ + type Wide = Fp128x8i32; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::{Pow2Offset24Field, Pow2Offset40Field, Prime128M8M4M1M0}; + use crate::{FieldCore, FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + use rand_core::RngCore; + + type F128 = Prime128M8M4M1M0; + type F32 = Pow2Offset24Field; + type F64 = Pow2Offset40Field; + + const P128: u128 = 0xfffffffffffffffffffffffffffffeed; + const P32: u32 = (1 << 24) - 3; + const P64: u64 = (1 << 40) - 195; + + #[test] + fn fp128_roundtrip() { + let mut rng = StdRng::seed_from_u64(0xdead_1234); + for _ in 0..1000 { + let a: F128 = FieldSampling::sample(&mut rng); + let wide = Fp128x8i32::from(a); + let back = wide.reduce::(); + assert_eq!(a, back, "roundtrip failed for {a:?}"); + } + } + + #[test] + fn fp128_accumulate_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0xbeef_cafe_4321); + let n = 1000; + let vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let scalar_sum = vals.iter().fold(F128::zero(), |acc, &x| acc + x); + + let wide_sum = vals + .iter() + .fold(Fp128x8i32::zero(), |acc, &x| acc + Fp128x8i32::from(x)); + let reduced = wide_sum.reduce::(); + + assert_eq!(scalar_sum, reduced); + } + + #[test] + fn fp128_add_sub_neg_match_scalar() { + let mut rng = StdRng::seed_from_u64(0x1122_3344_5566); + for _ in 0..500 { + let a: F128 = FieldSampling::sample(&mut rng); + let b: F128 = FieldSampling::sample(&mut rng); + + let wa = Fp128x8i32::from(a); + let wb = Fp128x8i32::from(b); + + assert_eq!((wa + wb).reduce::(), a + b); + assert_eq!((wa - wb).reduce::(), a - b); + assert_eq!((-wa).reduce::(), -a); + } + } + + #[test] + fn fp128_mixed_add_sub_stress() { + let mut rng = StdRng::seed_from_u64(0xaaaa_bbbb_cccc); + let n = 500; + let vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let mut scalar = F128::zero(); + let mut wide = Fp128x8i32::zero(); + for (i, &v) in vals.iter().enumerate() { + let wv = Fp128x8i32::from(v); + if i % 3 == 0 { + scalar -= v; + wide -= wv; + } else { + scalar += v; + wide += wv; + } + } + assert_eq!(wide.reduce::(), scalar); + } + + #[test] + fn fp32_roundtrip() { + let mut rng = StdRng::seed_from_u64(0x3232_3232); + for _ in 0..1000 { + let a: F32 = FieldSampling::sample(&mut rng); + let wide = Fp32x2i32::from(a); + let back = wide.reduce::(); + assert_eq!(a, back); + } + } + + #[test] + fn fp32_accumulate_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0x3232_abcd); + let n = 1000; + let vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let scalar_sum = vals.iter().fold(F32::zero(), |acc, &x| acc + x); + let wide_sum = vals + .iter() + .fold(Fp32x2i32::zero(), |acc, &x| acc + Fp32x2i32::from(x)); + assert_eq!(wide_sum.reduce::(), scalar_sum); + } + + #[test] + fn fp64_roundtrip() { + let mut rng = StdRng::seed_from_u64(0x6464_6464); + for _ in 0..1000 { + let a: F64 = FieldSampling::sample(&mut rng); + let wide = Fp64x4i32::from(a); + let back = wide.reduce::(); + assert_eq!(a, back); + } + } + + #[test] + fn fp64_accumulate_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0x6464_beef); + let n = 1000; + let vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let scalar_sum = vals.iter().fold(F64::zero(), |acc, &x| acc + x); + let wide_sum = vals + .iter() + .fold(Fp64x4i32::zero(), |acc, &x| acc + Fp64x4i32::from(x)); + assert_eq!(wide_sum.reduce::(), scalar_sum); + } + + #[test] + fn fp64_product_accum_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0x6464_4444); + let n = 500; + let a_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let scalar_sum: F64 = a_vals + .iter() + .zip(b_vals.iter()) + .fold(F64::zero(), |acc, (&a, &b)| acc + a * b); + + let accum_sum = a_vals + .iter() + .zip(b_vals.iter()) + .fold(Fp64ProductAccum::ZERO, |acc, (&a, &b)| { + acc + a.mul_to_product_accum(b) + }); + assert_eq!(F64::reduce_product_accum(accum_sum), scalar_sum); + } + + #[test] + fn fp64_mul_u64_accum_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0x6464_5555); + let n = 500; + let a_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b_vals: Vec = (0..n).map(|_| rng.next_u64() >> 32).collect(); + + let scalar_sum: F64 = a_vals + .iter() + .zip(b_vals.iter()) + .fold(F64::zero(), |acc, (&a, &b)| acc + a * F64::from_u64(b)); + + let accum_sum = a_vals + .iter() + .zip(b_vals.iter()) + .fold(Fp64ProductAccum::ZERO, |acc, (&a, &b)| { + acc + a.mul_u64_unreduced(b) + }); + assert_eq!(F64::reduce_mul_u64_accum(accum_sum), scalar_sum); + } + + #[test] + fn fp128_product_accum_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0x0128_6666); + let n = 500; + let a_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let scalar_sum: F128 = a_vals + .iter() + .zip(b_vals.iter()) + .fold(F128::zero(), |acc, (&a, &b)| acc + a * b); + + let accum_sum = a_vals + .iter() + .zip(b_vals.iter()) + .fold(Fp128ProductAccum::ZERO, |acc, (&a, &b)| { + acc + a.mul_to_product_accum(b) + }); + assert_eq!(F128::reduce_product_accum(accum_sum), scalar_sum); + } + + #[test] + fn fp128_mul_u64_accum_matches_scalar() { + let mut rng = StdRng::seed_from_u64(0x0128_7777); + let n = 500; + let a_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b_vals: Vec = (0..n).map(|_| rng.next_u64()).collect(); + + let scalar_sum: F128 = a_vals + .iter() + .zip(b_vals.iter()) + .fold(F128::zero(), |acc, (&a, &b)| acc + a * F128::from_u64(b)); + + let accum_sum = a_vals + .iter() + .zip(b_vals.iter()) + .fold(Fp128MulU64Accum::ZERO, |acc, (&a, &b)| { + acc + a.mul_u64_unreduced(b) + }); + assert_eq!(F128::reduce_mul_u64_accum(accum_sum), scalar_sum); + } + + #[test] + fn fp128_product_accum_sub_neg() { + let mut rng = StdRng::seed_from_u64(0x0128_8888); + let n = 500; + let a_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + let b_vals: Vec = (0..n).map(|_| FieldSampling::sample(&mut rng)).collect(); + + let mut scalar_sum = F128::zero(); + let mut accum_pos = Fp128ProductAccum::ZERO; + let mut accum_neg = Fp128ProductAccum::ZERO; + for (i, (&a, &b)) in a_vals.iter().zip(b_vals.iter()).enumerate() { + let prod = a.mul_to_product_accum(b); + if i % 2 == 0 { + scalar_sum += a * b; + accum_pos += prod; + } else { + scalar_sum -= a * b; + accum_neg += prod; + } + } + let result = F128::reduce_product_accum(accum_pos) - F128::reduce_product_accum(accum_neg); + assert_eq!(result, scalar_sum); + } +} diff --git a/src/algebra/mod.rs b/src/algebra/mod.rs new file mode 100644 index 00000000..1ee911f4 --- /dev/null +++ b/src/algebra/mod.rs @@ -0,0 +1,33 @@ +//! Concrete algebra backends and arithmetic building blocks. +//! +//! This module includes: +//! - Generic prime fields and extensions (`fields`) +//! - Module and polynomial containers (`module`, `poly`) +//! - Low-level NTT and CRT+NTT arithmetic scaffolding (`ntt`) + +pub mod backend; +pub mod fields; +pub mod module; +pub mod ntt; +pub mod poly; +pub mod ring; + +// Flat re-exports for convenience. +pub use backend::{CrtReconstruct, NttPrimeOps, NttTransform, RingBackend, ScalarBackend}; +pub use fields::{ + is_pow2_offset, pow2_offset, pseudo_mersenne_modulus, ExtField, Fp128, Fp128Packing, Fp2, + Fp2Config, Fp32, Fp32Packing, Fp4, Fp4Config, Fp64, Fp64Packing, HasPacking, LiftBase, + NoPacking, PackedField, PackedValue, Pow2Offset128Field, Pow2Offset24Field, Pow2Offset30Field, + Pow2Offset31Field, Pow2Offset32Field, Pow2Offset40Field, Pow2Offset48Field, Pow2Offset56Field, + Pow2Offset64Field, Pow2OffsetPrimeSpec, Prime128M13M4P0, Prime128M37P3P0, Prime128M52M3P0, + Prime128M54P4P0, Prime128M8M4M1M0, Prime128Offset5823, POW2_OFFSET_IMPLEMENTED_MAX_BITS, + POW2_OFFSET_MAX, POW2_OFFSET_PRIMES, POW2_OFFSET_TABLE, +}; +pub use module::VectorModule; +pub use ntt::tables; +pub use ntt::{GarnerData, LimbQ, MontCoeff, NttPrime, PrimeWidth, RADIX_BITS}; +pub use ring::{ + CenteredMontLut, CrtNttConvertibleField, CrtNttParamSet, CyclotomicCrtNtt, CyclotomicRing, + DigitMontLut, PackedPartialSplitEval32, PackedPartialSplitNtt32, PartialSplitEval32, + PartialSplitNtt32, SparseChallenge, SparseChallengeConfig, +}; diff --git a/src/algebra/module.rs b/src/algebra/module.rs new file mode 100644 index 00000000..bf41c567 --- /dev/null +++ b/src/algebra/module.rs @@ -0,0 +1,194 @@ +//! Simple module implementations. + +use super::fields::{Fp128, Fp32, Fp64}; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{CanonicalField, FieldCore, FieldSampling, Module}; +use rand_core::RngCore; +use std::io::{Read, Write}; +use std::ops::{Add, Mul, Neg, Sub}; + +/// Fixed-length vector module over a scalar type `F`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct VectorModule(pub [F; N]); + +impl VectorModule { + /// Construct the zero vector. + #[inline] + pub fn zero_vec() -> Self { + Self([F::zero(); N]) + } +} + +impl Add for VectorModule { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + let mut out = self.0; + for (dst, src) in out.iter_mut().zip(rhs.0.iter()) { + *dst += *src; + } + Self(out) + } +} + +impl Sub for VectorModule { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + let mut out = self.0; + for (dst, src) in out.iter_mut().zip(rhs.0.iter()) { + *dst -= *src; + } + Self(out) + } +} + +impl Neg for VectorModule { + type Output = Self; + fn neg(self) -> Self::Output { + let mut out = self.0; + for coeff in &mut out { + *coeff = -*coeff; + } + Self(out) + } +} + +impl<'a, F: FieldCore, const N: usize> Add<&'a Self> for VectorModule { + type Output = Self; + fn add(self, rhs: &'a Self) -> Self::Output { + self + *rhs + } +} + +impl<'a, F: FieldCore, const N: usize> Sub<&'a Self> for VectorModule { + type Output = Self; + fn sub(self, rhs: &'a Self) -> Self::Output { + self - *rhs + } +} + +impl Valid for VectorModule { + fn check(&self) -> Result<(), SerializationError> { + for x in self.0.iter() { + x.check()?; + } + Ok(()) + } +} + +impl HachiSerialize for VectorModule { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + for x in self.0.iter() { + x.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.0.iter().map(|x| x.serialized_size(compress)).sum() + } +} + +impl HachiDeserialize for VectorModule { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let mut arr = [F::zero(); N]; + for coeff in &mut arr { + *coeff = F::deserialize_with_mode(&mut reader, compress, validate)?; + } + let out = Self(arr); + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl Module for VectorModule +where + F: FieldCore + + CanonicalField + + FieldSampling + + Valid + + Mul, Output = VectorModule> + + for<'a> Mul<&'a VectorModule, Output = VectorModule>, +{ + type Scalar = F; + + fn zero() -> Self { + Self::zero_vec() + } + + fn scale(&self, k: &Self::Scalar) -> Self { + // Delegate to Scalar * VectorModule to satisfy the Module trait’s scalar bounds. + *k * *self + } + + fn random(rng: &mut R) -> Self { + Self(std::array::from_fn(|_| F::sample(rng))) + } +} + +// Scalar * VectorModule impls for our local prime field types. + +impl Mul, N>> for Fp32

{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, const P: u32, const N: usize> Mul<&'a VectorModule, N>> for Fp32

{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} + +impl Mul, N>> for Fp64

{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, const P: u64, const N: usize> Mul<&'a VectorModule, N>> for Fp64

{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} + +impl Mul, N>> for Fp128

{ + type Output = VectorModule, N>; + fn mul(self, rhs: VectorModule, N>) -> Self::Output { + let mut out = rhs.0; + for coeff in &mut out { + *coeff = self * *coeff; + } + VectorModule(out) + } +} + +impl<'a, const P: u128, const N: usize> Mul<&'a VectorModule, N>> for Fp128

{ + type Output = VectorModule, N>; + fn mul(self, rhs: &'a VectorModule, N>) -> Self::Output { + self * *rhs + } +} diff --git a/src/algebra/ntt/butterfly.rs b/src/algebra/ntt/butterfly.rs new file mode 100644 index 00000000..fb942750 --- /dev/null +++ b/src/algebra/ntt/butterfly.rs @@ -0,0 +1,382 @@ +//! NTT butterfly transforms for negacyclic rings `Z_p[X]/(X^D + 1)`. +//! +//! Implements a negacyclic NTT via the standard **twist + cyclic NTT** method. +//! +//! Let `psi` be a primitive `2D`-th root of unity (`psi^D = -1 mod p`) and +//! `omega = psi^2`, a primitive `D`-th root of unity. For polynomials modulo +//! `X^D + 1`, we: +//! - pre-twist coefficients by `psi^i` +//! - run a cyclic size-`D` NTT using `omega` +//! - inverse-cyclic NTT using `omega^{-1}` +//! - post-untwist by `psi^{-i}` + +use super::prime::{MontCoeff, NttPrime, PrimeWidth}; + +/// Precomputed twiddle factors for a specific prime and degree `D`. +/// +/// `D` must be a power of two. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NttTwiddles { + /// Stage roots for iterative forward cyclic NTT in Montgomery form. + pub(crate) fwd_wlen: [MontCoeff; D], + /// Stage roots for iterative inverse cyclic NTT in Montgomery form. + pub(crate) inv_wlen: [MontCoeff; D], + /// Number of active stages in the twiddle arrays (`log2(D)`). + pub(crate) num_stages: usize, + /// Twist factors `psi^i` for negacyclic embedding, in Montgomery form. + pub(crate) psi_pows: [MontCoeff; D], + /// Untwist factors `psi^{-i}`, in Montgomery form. + pub(crate) psi_inv_pows: [MontCoeff; D], + /// `D^{-1} mod p` in Montgomery form, used for inverse NTT final scaling. + pub(crate) d_inv: MontCoeff, + /// Fused `D^{-1} * psi^{-i}` for each index, in Montgomery form. + pub(crate) d_inv_psi_inv: [MontCoeff; D], + /// Per-position forward twiddles, packed across stages. + /// Stage s (with butterfly half-length 2^s) occupies `[2^s - 1 .. 2^(s+1) - 2]`. + /// Breaks the serial `w = mul(w, wlen)` dependency chain in butterfly loops. + pub(crate) fwd_twiddles: [MontCoeff; D], + /// Per-position inverse twiddles, same layout as `fwd_twiddles`. + pub(crate) inv_twiddles: [MontCoeff; D], +} + +impl NttTwiddles { + /// Compute twiddle factors for the given prime. + /// + /// Finds a primitive `2D`-th root `psi` and derives `omega = psi^2`. + /// Fills cyclic forward/inverse twiddles for `omega` and twist/untwist + /// tables for `psi`. All values are stored in Montgomery form. + /// + /// # Panics + /// + /// Panics if `D` is not a power of two, or if `2D` does not divide `p - 1`. + pub fn compute(prime: NttPrime) -> Self { + assert!(D.is_power_of_two(), "D must be a power of two"); + let p = prime.p.to_i64(); + assert!( + (p - 1) % (2 * D as i64) == 0, + "2D must divide p - 1 for NTT roots to exist" + ); + + let psi = find_primitive_root_2d(p, D); + let omega = (psi * psi) % p; + let omega_inv = pow_mod(omega, p - 2, p); + + let psi_inv = pow_mod(psi, p - 2, p); + let mut psi_pows = [MontCoeff::from_raw(W::default()); D]; + let mut psi_inv_pows = [MontCoeff::from_raw(W::default()); D]; + let mut cur = 1i64; + let mut cur_inv = 1i64; + for i in 0..D { + psi_pows[i] = prime.from_canonical(W::from_i64(cur)); + psi_inv_pows[i] = prime.from_canonical(W::from_i64(cur_inv)); + cur = (cur * psi) % p; + cur_inv = (cur_inv * psi_inv) % p; + } + + let mut fwd_wlen = [MontCoeff::from_raw(W::default()); D]; + let mut inv_wlen = [MontCoeff::from_raw(W::default()); D]; + let mut len = 1usize; + let mut stage = 0usize; + while len < D { + let exp = (D / (2 * len)) as i64; + fwd_wlen[stage] = prime.from_canonical(W::from_i64(pow_mod(omega, exp, p))); + inv_wlen[stage] = prime.from_canonical(W::from_i64(pow_mod(omega_inv, exp, p))); + len *= 2; + stage += 1; + } + + let d_inv_canonical = pow_mod(D as i64, p - 2, p); + let d_inv = prime.from_canonical(W::from_i64(d_inv_canonical)); + + let mut d_inv_psi_inv = [MontCoeff::from_raw(W::default()); D]; + for i in 0..D { + d_inv_psi_inv[i] = prime.mul(d_inv, psi_inv_pows[i]); + } + + let num_stages = stage; + let mut fwd_twiddles = [MontCoeff::from_raw(W::default()); D]; + let mut inv_twiddles = [MontCoeff::from_raw(W::default()); D]; + let one = prime.from_canonical(W::from_i64(1)); + for s in 0..num_stages { + let len = 1usize << s; + let base = len - 1; + let mut w_fwd = one; + let mut w_inv = one; + for j in 0..len { + fwd_twiddles[base + j] = w_fwd; + inv_twiddles[base + j] = w_inv; + w_fwd = prime.mul(w_fwd, fwd_wlen[s]); + w_inv = prime.mul(w_inv, inv_wlen[s]); + } + } + + Self { + fwd_wlen, + inv_wlen, + num_stages, + psi_pows, + psi_inv_pows, + d_inv, + d_inv_psi_inv, + fwd_twiddles, + inv_twiddles, + } + } +} + +/// Forward negacyclic NTT (twist + cyclic Gentleman-Sande DIF). +/// +/// Transforms `D` coefficients in-place from coefficient form to NTT +/// evaluation form. Both outputs of each butterfly are range-reduced +/// to prevent overflow. +pub fn forward_ntt( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + #[cfg(target_arch = "aarch64")] + if super::neon::use_neon_ntt() { + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::forward_ntt_i32( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::forward_ntt_i16( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + } + + for (ai, psi) in a.iter_mut().zip(tw.psi_pows.iter()) { + *ai = prime.mul(*ai, *psi); + } + + let mut len = D / 2; + while len > 0 { + let twiddle_base = len - 1; + let mut start = 0usize; + while start < D { + for j in 0..len { + let w = tw.fwd_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = a[start + j + len]; + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.mul(MontCoeff::from_raw(diff), w); + } + start += 2 * len; + } + len /= 2; + } + + prime.reduce_range_in_place(a); +} + +/// Inverse negacyclic NTT (cyclic Cooley-Tukey DIT + untwist). +/// +/// Transforms `D` evaluations in-place back to coefficient form. +/// Includes the final `D^{-1}` scaling. +pub fn inverse_ntt( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + #[cfg(target_arch = "aarch64")] + if super::neon::use_neon_ntt() { + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::inverse_ntt_i32( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::inverse_ntt_i16( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + } + + let mut len = 1usize; + while len < D { + let twiddle_base = len - 1; + let mut start = 0usize; + while start < D { + for j in 0..len { + let w = tw.inv_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = prime.mul(a[start + j + len], w); + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.reduce_range(MontCoeff::from_raw(diff)); + } + start += 2 * len; + } + len *= 2; + } + + for (ai, fused) in a.iter_mut().zip(tw.d_inv_psi_inv.iter()) { + *ai = prime.mul(*ai, *fused); + } +} + +/// Forward cyclic NTT (Gentleman-Sande DIF, **no** negacyclic twist). +/// +/// Evaluates a polynomial at the D-th roots of *unity* (roots of X^D - 1) +/// rather than X^D + 1. Used with `inverse_ntt_cyclic` to compute unreduced +/// polynomial products via CRT over (X^D - 1)(X^D + 1). +pub fn forward_ntt_cyclic( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + #[cfg(target_arch = "aarch64")] + if super::neon::use_neon_ntt() { + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::forward_ntt_cyclic_i32( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::forward_ntt_cyclic_i16( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + } + + let mut len = D / 2; + while len > 0 { + let twiddle_base = len - 1; + let mut start = 0usize; + while start < D { + for j in 0..len { + let w = tw.fwd_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = a[start + j + len]; + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.mul(MontCoeff::from_raw(diff), w); + } + start += 2 * len; + } + len /= 2; + } + prime.reduce_range_in_place(a); +} + +/// Inverse cyclic NTT (Cooley-Tukey DIT, **no** negacyclic untwist). +/// +/// Recovers coefficients of a polynomial from evaluations at D-th roots of unity. +/// Includes the `D^{-1}` scaling factor. +pub fn inverse_ntt_cyclic( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + #[cfg(target_arch = "aarch64")] + if super::neon::use_neon_ntt() { + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::inverse_ntt_cyclic_i32( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + if std::mem::size_of::() == std::mem::size_of::() { + unsafe { + super::neon::inverse_ntt_cyclic_i16( + &mut *(a as *mut _ as *mut [MontCoeff; D]), + *(&prime as *const _ as *const NttPrime), + &*(tw as *const _ as *const NttTwiddles), + ); + } + return; + } + } + + let mut len = 1usize; + while len < D { + let twiddle_base = len - 1; + let mut start = 0usize; + while start < D { + for j in 0..len { + let w = tw.inv_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = prime.mul(a[start + j + len], w); + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.reduce_range(MontCoeff::from_raw(diff)); + } + start += 2 * len; + } + len *= 2; + } + + for c in a.iter_mut() { + *c = prime.mul(*c, tw.d_inv); + } +} + +/// Find a primitive `2D`-th root of unity mod `p`. +fn find_primitive_root_2d(p: i64, d: usize) -> i64 { + let half = (p - 1) / 2; + let exp = (p - 1) / (2 * d as i64); + for a in 2..p { + if pow_mod(a, half, p) == p - 1 { + let psi = pow_mod(a, exp, p); + debug_assert_eq!(pow_mod(psi, d as i64, p), p - 1, "psi^D != -1"); + return psi; + } + } + panic!("no primitive root found for p={p}"); +} + +/// Modular exponentiation: `base^exp mod modulus`. +fn pow_mod(mut base: i64, mut exp: i64, modulus: i64) -> i64 { + let mut result = 1i64; + base %= modulus; + while exp > 0 { + if exp & 1 == 1 { + result = result * base % modulus; + } + base = base * base % modulus; + exp >>= 1; + } + result +} diff --git a/src/algebra/ntt/crt.rs b/src/algebra/ntt/crt.rs new file mode 100644 index 00000000..e0897937 --- /dev/null +++ b/src/algebra/ntt/crt.rs @@ -0,0 +1,202 @@ +//! CRT helpers: Garner reconstruction and limb-based modular arithmetic. + +use std::cmp::Ordering; +use std::fmt; +use std::ops::{Add, Sub}; + +use super::prime::{NttPrime, PrimeWidth}; + +/// Limb radix bit-width (`2^14`). +pub const RADIX_BITS: u32 = 14; +const RADIX: i32 = 1 << RADIX_BITS; +const RADIX_MASK: i32 = RADIX - 1; + +/// Precomputed Garner inverse table for CRT reconstruction. +/// +/// `gamma[i][j]` = `p_j^{-1} mod p_i` for `j < i`. Upper triangle and +/// diagonal entries are zero (unused). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GarnerData { + /// `gamma[i][j]` = `p_j^{-1} mod p_i` for `j < i`. + pub gamma: [[W; K]; K], +} + +impl GarnerData { + /// Compute Garner constants from a set of NTT primes. + pub fn compute(primes: &[NttPrime; K]) -> Self { + let mut gamma = [[W::default(); K]; K]; + for i in 1..K { + let pi = primes[i].p.to_i64(); + #[allow(clippy::needless_range_loop)] + for j in 0..i { + let pj = primes[j].p.to_i64(); + let inv = mod_inverse_i64(pj, pi); + gamma[i][j] = W::from_i64(inv); + } + } + Self { gamma } + } +} + +/// Modular inverse via extended GCD, operating in `i64`. +fn mod_inverse_i64(a: i64, modulus: i64) -> i64 { + let (mut t, mut new_t) = (0i64, 1i64); + let (mut r, mut new_r) = (modulus, ((a % modulus) + modulus) % modulus); + while new_r != 0 { + let q = r / new_r; + (t, new_t) = (new_t, t - q * new_t); + (r, new_r) = (new_r, r - q * new_r); + } + assert_eq!(r, 1, "modular inverse does not exist"); + ((t % modulus) + modulus) % modulus +} + +/// Fixed-width radix-`2^14` integer. +/// +/// Limbs are little-endian: `limbs[0]` is least significant. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct LimbQ { + /// Little-endian limbs. + pub limbs: [u16; L], +} + +impl Default for LimbQ { + #[inline] + fn default() -> Self { + Self::zero() + } +} + +impl LimbQ { + /// Zero value. + #[inline] + pub const fn zero() -> Self { + Self { limbs: [0; L] } + } + + /// Construct directly from limbs. + #[inline] + pub const fn from_limbs(limbs: [u16; L]) -> Self { + Self { limbs } + } + + /// Conditional subtraction: if `self >= modulus`, return `self - modulus` (branchless). + #[inline] + pub fn csub_mod(self, modulus: Self) -> Self { + let mut diff = [0u16; L]; + let mut borrow = 0i32; + for (i, df) in diff.iter_mut().enumerate() { + let d = self.limbs[i] as i32 - modulus.limbs[i] as i32 + borrow; + borrow = d >> 31; + if i + 1 < L { + *df = (d - borrow * RADIX) as u16; + } else { + *df = d as u16; + } + } + let mask = borrow as u16; + let mut result = [0u16; L]; + for (i, r) in result.iter_mut().enumerate() { + *r = (self.limbs[i] & mask) | (diff[i] & !mask); + } + Self { limbs: result } + } +} + +impl From for LimbQ { + fn from(mut x: u128) -> Self { + let mut out = [0u16; L]; + for (i, limb) in out.iter_mut().enumerate() { + if i + 1 < L { + *limb = (x & (RADIX_MASK as u128)) as u16; + x >>= RADIX_BITS; + } else { + *limb = x as u16; + } + } + Self { limbs: out } + } +} + +impl TryFrom> for u128 { + type Error = &'static str; + + fn try_from(limb: LimbQ) -> Result { + if (L as u32) * RADIX_BITS > 128 { + return Err("LimbQ too wide for u128"); + } + let mut acc = 0u128; + for i in (0..L).rev() { + acc <<= RADIX_BITS; + acc |= limb.limbs[i] as u128; + } + Ok(acc) + } +} + +impl PartialOrd for LimbQ { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for LimbQ { + fn cmp(&self, other: &Self) -> Ordering { + for i in (0..L).rev() { + match self.limbs[i].cmp(&other.limbs[i]) { + Ordering::Equal => continue, + ord => return ord, + } + } + Ordering::Equal + } +} + +impl Add for LimbQ { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + let mut out = [0u16; L]; + let mut carry = 0i32; + for (i, out_limb) in out.iter_mut().enumerate() { + let s = self.limbs[i] as i32 + rhs.limbs[i] as i32 + carry; + if i + 1 < L { + carry = s >> RADIX_BITS; + *out_limb = (s & RADIX_MASK) as u16; + } else { + *out_limb = s as u16; + } + } + Self { limbs: out } + } +} + +impl Sub for LimbQ { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + let mut out = [0u16; L]; + let mut borrow = 0i32; + for (i, out_limb) in out.iter_mut().enumerate() { + let d = self.limbs[i] as i32 - rhs.limbs[i] as i32 + borrow; + if i + 1 < L { + borrow = d >> 31; + *out_limb = (d - borrow * RADIX) as u16; + } else { + *out_limb = d as u16; + } + } + Self { limbs: out } + } +} + +impl fmt::Display for LimbQ { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Ok(val) = u128::try_from(*self) { + write!(f, "{val}") + } else { + write!(f, "LimbQ{:?}", self.limbs) + } + } +} diff --git a/src/algebra/ntt/mod.rs b/src/algebra/ntt/mod.rs new file mode 100644 index 00000000..b059e9b6 --- /dev/null +++ b/src/algebra/ntt/mod.rs @@ -0,0 +1,12 @@ +//! NTT-friendly small-prime arithmetic and CRT helpers. + +pub mod butterfly; +pub mod crt; +#[cfg(target_arch = "aarch64")] +pub(crate) mod neon; +pub mod prime; +pub mod tables; + +pub use butterfly::NttTwiddles; +pub use crt::{GarnerData, LimbQ, RADIX_BITS}; +pub use prime::{MontCoeff, NttPrime, PrimeWidth}; diff --git a/src/algebra/ntt/neon.rs b/src/algebra/ntt/neon.rs new file mode 100644 index 00000000..10ed1817 --- /dev/null +++ b/src/algebra/ntt/neon.rs @@ -0,0 +1,1001 @@ +//! AArch64 NEON SIMD kernels for NTT butterfly, Montgomery multiply, +//! and pointwise operations. +//! +//! Provides vectorized i32 (for Q64/Q128) and i16 (for Q32) paths. +//! Dispatch is controlled by [`use_neon_ntt`]: set `HACHI_SCALAR_NTT=1` +//! to force the scalar fallback for A/B performance comparison. + +use std::arch::aarch64::*; +use std::sync::OnceLock; + +use super::butterfly::NttTwiddles; +use super::prime::{MontCoeff, NttPrime}; + +/// Whether the NEON NTT path is active. Cached on first call. +/// Set `HACHI_SCALAR_NTT=1` to force scalar fallback. +pub(crate) fn use_neon_ntt() -> bool { + static ENABLED: OnceLock = OnceLock::new(); + *ENABLED.get_or_init(|| std::env::var("HACHI_SCALAR_NTT").map_or(true, |v| v != "1")) +} + +/// 4-wide Montgomery multiply for i32 primes. +/// +/// Uses two 2-wide `vmull_s32` chains (since i32×i32→i64 only fills 2 lanes +/// of a 128-bit register) and combines the results. +#[inline(always)] +unsafe fn mont_mul_4x_i32(a: int32x4_t, b: int32x4_t, p: int32x2_t, pinv: int32x2_t) -> int32x4_t { + let a_lo = vget_low_s32(a); + let a_hi = vget_high_s32(a); + let b_lo = vget_low_s32(b); + let b_hi = vget_high_s32(b); + + // Low pair + let c_lo = vmull_s32(a_lo, b_lo); + let t_lo = vmul_s32(vmovn_s64(c_lo), pinv); + let tp_lo = vmull_s32(t_lo, p); + let r_lo = vmovn_s64(vshrq_n_s64::<32>(vsubq_s64(c_lo, tp_lo))); + + // High pair + let c_hi = vmull_s32(a_hi, b_hi); + let t_hi = vmul_s32(vmovn_s64(c_hi), pinv); + let tp_hi = vmull_s32(t_hi, p); + let r_hi = vmovn_s64(vshrq_n_s64::<32>(vsubq_s64(c_hi, tp_hi))); + + vcombine_s32(r_lo, r_hi) +} + +/// 4-wide range reduction for i32: maps `(-2p, 2p)` → `(-p, p)`. +/// +/// Uses comparison-first approach to avoid the i64 widening that the +/// scalar `csubp`/`caddp` path requires (since `a - p` can overflow i32). +#[inline(always)] +unsafe fn reduce_range_4x_i32(a: int32x4_t, p: int32x4_t) -> int32x4_t { + let zero = vdupq_n_s32(0); + + // csubp: subtract p where a >= p + let ge_mask = vcgeq_s32(a, p); + let after_sub = vsubq_s32( + a, + vreinterpretq_s32_u32(vandq_u32(vreinterpretq_u32_s32(p), ge_mask)), + ); + + // caddp: add p where result < 0 + let lt_mask = vcltq_s32(after_sub, zero); + vaddq_s32( + after_sub, + vreinterpretq_s32_u32(vandq_u32(vreinterpretq_u32_s32(p), lt_mask)), + ) +} + +/// NEON-accelerated forward negacyclic NTT for i32 primes. +/// +/// Processes 4 butterfly pairs per iteration when `len >= 4`; +/// falls back to scalar for the final 2 stages (`len = 2, 1`). +pub(crate) unsafe fn forward_ntt_i32( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s32(prime.p); + let pinv_d = vdup_n_s32(prime.pinv); + let p_q = vdupq_n_s32(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i32; + + // Pre-twist by psi^i + { + let psi_ptr = tw.psi_pows.as_ptr() as *const i32; + let mut i = 0; + while i + 4 <= D { + let ai = vld1q_s32(a_ptr.add(i)); + let psi = vld1q_s32(psi_ptr.add(i)); + vst1q_s32(a_ptr.add(i), mont_mul_4x_i32(ai, psi, p_d, pinv_d)); + i += 4; + } + } + + // DIF butterfly stages + let mut len = D / 2; + while len > 0 { + let twiddle_base = len - 1; + let tw_ptr = tw.fwd_twiddles.as_ptr() as *const i32; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let u = vld1q_s32(a_ptr.add(start + j)); + let v = vld1q_s32(a_ptr.add(start + j + len)); + let w = vld1q_s32(tw_ptr.add(twiddle_base + j)); + + let sum = vaddq_s32(u, v); + let diff = vsubq_s32(u, v); + + vst1q_s32(a_ptr.add(start + j), reduce_range_4x_i32(sum, p_q)); + vst1q_s32( + a_ptr.add(start + j + len), + mont_mul_4x_i32(diff, w, p_d, pinv_d), + ); + j += 4; + } + } else { + for j in 0..len { + let w = tw.fwd_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = a[start + j + len]; + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.mul(MontCoeff::from_raw(diff), w); + } + } + start += 2 * len; + } + len /= 2; + } + + // Final reduce_range pass + reduce_range_in_place_i32(a, p_q); +} + +/// NEON-accelerated inverse negacyclic NTT for i32 primes. +pub(crate) unsafe fn inverse_ntt_i32( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s32(prime.p); + let pinv_d = vdup_n_s32(prime.pinv); + let p_q = vdupq_n_s32(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i32; + + // DIT butterfly stages + let mut len = 1usize; + while len < D { + let twiddle_base = len - 1; + let tw_ptr = tw.inv_twiddles.as_ptr() as *const i32; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let w = vld1q_s32(tw_ptr.add(twiddle_base + j)); + let u = vld1q_s32(a_ptr.add(start + j)); + let v_raw = vld1q_s32(a_ptr.add(start + j + len)); + let v = mont_mul_4x_i32(v_raw, w, p_d, pinv_d); + + let sum = vaddq_s32(u, v); + let diff = vsubq_s32(u, v); + + vst1q_s32(a_ptr.add(start + j), reduce_range_4x_i32(sum, p_q)); + vst1q_s32(a_ptr.add(start + j + len), reduce_range_4x_i32(diff, p_q)); + j += 4; + } + } else { + for j in 0..len { + let w = tw.inv_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = prime.mul(a[start + j + len], w); + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.reduce_range(MontCoeff::from_raw(diff)); + } + } + start += 2 * len; + } + len *= 2; + } + + // Fused D^{-1} * psi^{-i} untwist + { + let fused_ptr = tw.d_inv_psi_inv.as_ptr() as *const i32; + let mut i = 0; + while i + 4 <= D { + let ai = vld1q_s32(a_ptr.add(i)); + let f = vld1q_s32(fused_ptr.add(i)); + vst1q_s32(a_ptr.add(i), mont_mul_4x_i32(ai, f, p_d, pinv_d)); + i += 4; + } + } +} + +/// NEON-accelerated forward cyclic NTT for i32 (no negacyclic twist). +pub(crate) unsafe fn forward_ntt_cyclic_i32( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s32(prime.p); + let pinv_d = vdup_n_s32(prime.pinv); + let p_q = vdupq_n_s32(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i32; + + let mut len = D / 2; + while len > 0 { + let twiddle_base = len - 1; + let tw_ptr = tw.fwd_twiddles.as_ptr() as *const i32; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let u = vld1q_s32(a_ptr.add(start + j)); + let v = vld1q_s32(a_ptr.add(start + j + len)); + let w = vld1q_s32(tw_ptr.add(twiddle_base + j)); + let sum = vaddq_s32(u, v); + let diff = vsubq_s32(u, v); + vst1q_s32(a_ptr.add(start + j), reduce_range_4x_i32(sum, p_q)); + vst1q_s32( + a_ptr.add(start + j + len), + mont_mul_4x_i32(diff, w, p_d, pinv_d), + ); + j += 4; + } + } else { + for j in 0..len { + let w = tw.fwd_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = a[start + j + len]; + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.mul(MontCoeff::from_raw(diff), w); + } + } + start += 2 * len; + } + len /= 2; + } + reduce_range_in_place_i32(a, p_q); +} + +/// NEON-accelerated inverse cyclic NTT for i32 (no negacyclic untwist). +pub(crate) unsafe fn inverse_ntt_cyclic_i32( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s32(prime.p); + let pinv_d = vdup_n_s32(prime.pinv); + let p_q = vdupq_n_s32(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i32; + + let mut len = 1usize; + while len < D { + let twiddle_base = len - 1; + let tw_ptr = tw.inv_twiddles.as_ptr() as *const i32; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let w = vld1q_s32(tw_ptr.add(twiddle_base + j)); + let u = vld1q_s32(a_ptr.add(start + j)); + let v_raw = vld1q_s32(a_ptr.add(start + j + len)); + let v = mont_mul_4x_i32(v_raw, w, p_d, pinv_d); + let sum = vaddq_s32(u, v); + let diff = vsubq_s32(u, v); + vst1q_s32(a_ptr.add(start + j), reduce_range_4x_i32(sum, p_q)); + vst1q_s32(a_ptr.add(start + j + len), reduce_range_4x_i32(diff, p_q)); + j += 4; + } + } else { + for j in 0..len { + let w = tw.inv_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = prime.mul(a[start + j + len], w); + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.reduce_range(MontCoeff::from_raw(diff)); + } + } + start += 2 * len; + } + len *= 2; + } + + // D^{-1} scaling + { + let d_inv = tw.d_inv; + let d_inv_q = vdupq_n_s32(d_inv.raw()); + let mut i = 0; + while i + 4 <= D { + let ai = vld1q_s32(a_ptr.add(i)); + vst1q_s32(a_ptr.add(i), mont_mul_4x_i32(ai, d_inv_q, p_d, pinv_d)); + i += 4; + } + } +} + +/// 4-wide pointwise multiply-accumulate for a single CRT limb (i32). +/// +/// `acc[i] = reduce_range(acc[i] + mont_mul(lhs[i], rhs[i]))` for `i in 0..d`. +pub(crate) unsafe fn pointwise_mul_acc_i32( + acc: *mut i32, + lhs: *const i32, + rhs: *const i32, + d: usize, + p: i32, + pinv: i32, +) { + let p_d = vdup_n_s32(p); + let pinv_d = vdup_n_s32(pinv); + let p_q = vdupq_n_s32(p); + let mut i = 0; + while i + 4 <= d { + let a = vld1q_s32(acc.add(i)); + let l = vld1q_s32(lhs.add(i)); + let r = vld1q_s32(rhs.add(i)); + let prod = mont_mul_4x_i32(l, r, p_d, pinv_d); + let sum = vaddq_s32(a, prod); + vst1q_s32(acc.add(i), reduce_range_4x_i32(sum, p_q)); + i += 4; + } +} + +/// 4-wide add-and-reduce for a single CRT limb (i32). +/// +/// `acc[i] = reduce_range(acc[i] + other[i])` for `i in 0..d`. +pub(crate) unsafe fn add_reduce_i32(acc: *mut i32, other: *const i32, d: usize, p: i32) { + let p_q = vdupq_n_s32(p); + let mut i = 0; + while i + 4 <= d { + let a = vld1q_s32(acc.add(i)); + let b = vld1q_s32(other.add(i)); + vst1q_s32(acc.add(i), reduce_range_4x_i32(vaddq_s32(a, b), p_q)); + i += 4; + } +} + +/// In-place reduce_range over a full array. +unsafe fn reduce_range_in_place_i32(a: &mut [MontCoeff; D], p_q: int32x4_t) { + let ptr = a.as_mut_ptr() as *mut i32; + let mut i = 0; + while i + 4 <= D { + let val = vld1q_s32(ptr.add(i)); + vst1q_s32(ptr.add(i), reduce_range_4x_i32(val, p_q)); + i += 4; + } +} + +/// 4-wide Montgomery multiply for i16 primes. +/// +/// Natural 4-wide: `vmull_s16` produces `int32x4_t`. +#[inline(always)] +unsafe fn mont_mul_4x_i16(a: int16x4_t, b: int16x4_t, p: int16x4_t, pinv: int16x4_t) -> int16x4_t { + let c = vmull_s16(a, b); + let t = vmul_s16(vmovn_s32(c), pinv); + let tp = vmull_s16(t, p); + vmovn_s32(vshrq_n_s32::<16>(vsubq_s32(c, tp))) +} + +/// 8-wide Montgomery multiply for i16 primes (two 4-wide chains). +#[inline(always)] +unsafe fn mont_mul_8x_i16(a: int16x8_t, b: int16x8_t, p: int16x4_t, pinv: int16x4_t) -> int16x8_t { + let r_lo = mont_mul_4x_i16(vget_low_s16(a), vget_low_s16(b), p, pinv); + let r_hi = mont_mul_4x_i16(vget_high_s16(a), vget_high_s16(b), p, pinv); + vcombine_s16(r_lo, r_hi) +} + +/// 8-wide range reduction for i16: `(-2p, 2p)` → `(-p, p)`. +/// +/// Same comparison-first approach as i32 but on `int16x8_t`. +#[inline(always)] +unsafe fn reduce_range_8x_i16(a: int16x8_t, p: int16x8_t) -> int16x8_t { + let zero = vdupq_n_s16(0); + let ge_mask = vcgeq_s16(a, p); + let after_sub = vsubq_s16( + a, + vreinterpretq_s16_u16(vandq_u16(vreinterpretq_u16_s16(p), ge_mask)), + ); + let lt_mask = vcltq_s16(after_sub, zero); + vaddq_s16( + after_sub, + vreinterpretq_s16_u16(vandq_u16(vreinterpretq_u16_s16(p), lt_mask)), + ) +} + +/// NEON-accelerated forward negacyclic NTT for i16 primes. +/// +/// Processes 4 butterflies per iteration when `len >= 4`; +/// scalar fallback for `len < 4`. +pub(crate) unsafe fn forward_ntt_i16( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s16(prime.p); + let pinv_d = vdup_n_s16(prime.pinv); + let p_q = vdupq_n_s16(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i16; + + // Pre-twist by psi^i + { + let psi_ptr = tw.psi_pows.as_ptr() as *const i16; + let mut i = 0; + while i + 4 <= D { + let ai = vld1_s16(a_ptr.add(i)); + let psi = vld1_s16(psi_ptr.add(i)); + vst1_s16(a_ptr.add(i), mont_mul_4x_i16(ai, psi, p_d, pinv_d)); + i += 4; + } + } + + // DIF butterfly stages + let mut len = D / 2; + while len > 0 { + let twiddle_base = len - 1; + let tw_ptr = tw.fwd_twiddles.as_ptr() as *const i16; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let u = vld1_s16(a_ptr.add(start + j)); + let v = vld1_s16(a_ptr.add(start + j + len)); + let w = vld1_s16(tw_ptr.add(twiddle_base + j)); + let sum = vadd_s16(u, v); + let diff = vsub_s16(u, v); + + // reduce_range on 4-wide i16 (use 8-wide by padding) + let sum_q = vcombine_s16(sum, vdup_n_s16(0)); + let sum_reduced = vget_low_s16(reduce_range_8x_i16(sum_q, p_q)); + + let diff_mul_w = mont_mul_4x_i16(diff, w, p_d, pinv_d); + vst1_s16(a_ptr.add(start + j), sum_reduced); + vst1_s16(a_ptr.add(start + j + len), diff_mul_w); + j += 4; + } + } else { + for j in 0..len { + let w = tw.fwd_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = a[start + j + len]; + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.mul(MontCoeff::from_raw(diff), w); + } + } + start += 2 * len; + } + len /= 2; + } + + // Final reduce_range pass + reduce_range_in_place_i16(a, p_q); +} + +/// NEON-accelerated inverse negacyclic NTT for i16 primes. +pub(crate) unsafe fn inverse_ntt_i16( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s16(prime.p); + let pinv_d = vdup_n_s16(prime.pinv); + let p_q = vdupq_n_s16(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i16; + + let mut len = 1usize; + while len < D { + let twiddle_base = len - 1; + let tw_ptr = tw.inv_twiddles.as_ptr() as *const i16; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let w = vld1_s16(tw_ptr.add(twiddle_base + j)); + let u = vld1_s16(a_ptr.add(start + j)); + let v_raw = vld1_s16(a_ptr.add(start + j + len)); + let v = mont_mul_4x_i16(v_raw, w, p_d, pinv_d); + let sum = vadd_s16(u, v); + let diff = vsub_s16(u, v); + let reduced = reduce_range_8x_i16(vcombine_s16(sum, diff), p_q); + vst1_s16(a_ptr.add(start + j), vget_low_s16(reduced)); + vst1_s16(a_ptr.add(start + j + len), vget_high_s16(reduced)); + j += 4; + } + } else { + for j in 0..len { + let w = tw.inv_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = prime.mul(a[start + j + len], w); + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.reduce_range(MontCoeff::from_raw(diff)); + } + } + start += 2 * len; + } + len *= 2; + } + + // Fused D^{-1} * psi^{-i} untwist + { + let fused_ptr = tw.d_inv_psi_inv.as_ptr() as *const i16; + let mut i = 0; + while i + 4 <= D { + let ai = vld1_s16(a_ptr.add(i)); + let f = vld1_s16(fused_ptr.add(i)); + vst1_s16(a_ptr.add(i), mont_mul_4x_i16(ai, f, p_d, pinv_d)); + i += 4; + } + } +} + +/// NEON-accelerated forward cyclic NTT for i16. +pub(crate) unsafe fn forward_ntt_cyclic_i16( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s16(prime.p); + let pinv_d = vdup_n_s16(prime.pinv); + let p_q = vdupq_n_s16(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i16; + + let mut len = D / 2; + while len > 0 { + let twiddle_base = len - 1; + let tw_ptr = tw.fwd_twiddles.as_ptr() as *const i16; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let u = vld1_s16(a_ptr.add(start + j)); + let v = vld1_s16(a_ptr.add(start + j + len)); + let w = vld1_s16(tw_ptr.add(twiddle_base + j)); + let sum = vadd_s16(u, v); + let diff = vsub_s16(u, v); + let sum_q = vcombine_s16(sum, vdup_n_s16(0)); + vst1_s16( + a_ptr.add(start + j), + vget_low_s16(reduce_range_8x_i16(sum_q, p_q)), + ); + vst1_s16( + a_ptr.add(start + j + len), + mont_mul_4x_i16(diff, w, p_d, pinv_d), + ); + j += 4; + } + } else { + for j in 0..len { + let w = tw.fwd_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = a[start + j + len]; + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.mul(MontCoeff::from_raw(diff), w); + } + } + start += 2 * len; + } + len /= 2; + } + reduce_range_in_place_i16(a, p_q); +} + +/// NEON-accelerated inverse cyclic NTT for i16. +pub(crate) unsafe fn inverse_ntt_cyclic_i16( + a: &mut [MontCoeff; D], + prime: NttPrime, + tw: &NttTwiddles, +) { + let p_d = vdup_n_s16(prime.p); + let pinv_d = vdup_n_s16(prime.pinv); + let p_q = vdupq_n_s16(prime.p); + let a_ptr = a.as_mut_ptr() as *mut i16; + + let mut len = 1usize; + while len < D { + let twiddle_base = len - 1; + let tw_ptr = tw.inv_twiddles.as_ptr() as *const i16; + let mut start = 0usize; + while start < D { + if len >= 4 { + let mut j = 0; + while j < len { + let w = vld1_s16(tw_ptr.add(twiddle_base + j)); + let u = vld1_s16(a_ptr.add(start + j)); + let v_raw = vld1_s16(a_ptr.add(start + j + len)); + let v = mont_mul_4x_i16(v_raw, w, p_d, pinv_d); + let sum = vadd_s16(u, v); + let diff = vsub_s16(u, v); + let reduced = reduce_range_8x_i16(vcombine_s16(sum, diff), p_q); + vst1_s16(a_ptr.add(start + j), vget_low_s16(reduced)); + vst1_s16(a_ptr.add(start + j + len), vget_high_s16(reduced)); + j += 4; + } + } else { + for j in 0..len { + let w = tw.inv_twiddles[twiddle_base + j]; + let u = a[start + j]; + let v = prime.mul(a[start + j + len], w); + let sum = u.raw().wrapping_add(v.raw()); + let diff = u.raw().wrapping_sub(v.raw()); + a[start + j] = prime.reduce_range(MontCoeff::from_raw(sum)); + a[start + j + len] = prime.reduce_range(MontCoeff::from_raw(diff)); + } + } + start += 2 * len; + } + len *= 2; + } + + // D^{-1} scaling + { + let d_inv = tw.d_inv; + let d_inv_d = vdup_n_s16(d_inv.raw()); + let mut i = 0; + while i + 4 <= D { + let ai = vld1_s16(a_ptr.add(i)); + vst1_s16(a_ptr.add(i), mont_mul_4x_i16(ai, d_inv_d, p_d, pinv_d)); + i += 4; + } + } +} + +/// 8-wide pointwise multiply-accumulate for a single CRT limb (i16). +pub(crate) unsafe fn pointwise_mul_acc_i16( + acc: *mut i16, + lhs: *const i16, + rhs: *const i16, + d: usize, + p: i16, + pinv: i16, +) { + let p_d = vdup_n_s16(p); + let pinv_d = vdup_n_s16(pinv); + let p_q = vdupq_n_s16(p); + let mut i = 0; + while i + 8 <= d { + let a = vld1q_s16(acc.add(i)); + let l = vld1q_s16(lhs.add(i)); + let r = vld1q_s16(rhs.add(i)); + let prod = mont_mul_8x_i16(l, r, p_d, pinv_d); + let sum = vaddq_s16(a, prod); + vst1q_s16(acc.add(i), reduce_range_8x_i16(sum, p_q)); + i += 8; + } + while i + 4 <= d { + let a = vld1_s16(acc.add(i)); + let l = vld1_s16(lhs.add(i)); + let r = vld1_s16(rhs.add(i)); + let prod = mont_mul_4x_i16(l, r, p_d, pinv_d); + let sum = vadd_s16(a, prod); + let sum_q = vcombine_s16(sum, vdup_n_s16(0)); + vst1_s16(acc.add(i), vget_low_s16(reduce_range_8x_i16(sum_q, p_q))); + i += 4; + } +} + +/// 8-wide add-and-reduce for a single CRT limb (i16). +pub(crate) unsafe fn add_reduce_i16(acc: *mut i16, other: *const i16, d: usize, p: i16) { + let p_q = vdupq_n_s16(p); + let mut i = 0; + while i + 8 <= d { + let a = vld1q_s16(acc.add(i)); + let b = vld1q_s16(other.add(i)); + vst1q_s16(acc.add(i), reduce_range_8x_i16(vaddq_s16(a, b), p_q)); + i += 8; + } + while i + 4 <= d { + let a = vld1_s16(acc.add(i)); + let b = vld1_s16(other.add(i)); + let sum_q = vcombine_s16(vadd_s16(a, b), vdup_n_s16(0)); + vst1_s16(acc.add(i), vget_low_s16(reduce_range_8x_i16(sum_q, p_q))); + i += 4; + } +} + +/// In-place reduce_range over a full i16 array. +unsafe fn reduce_range_in_place_i16(a: &mut [MontCoeff; D], p_q: int16x8_t) { + let ptr = a.as_mut_ptr() as *mut i16; + let mut i = 0; + while i + 8 <= D { + let val = vld1q_s16(ptr.add(i)); + vst1q_s16(ptr.add(i), reduce_range_8x_i16(val, p_q)); + i += 8; + } + while i + 4 <= D { + let val = vld1_s16(ptr.add(i)); + let padded = vcombine_s16(val, vdup_n_s16(0)); + vst1_s16(ptr.add(i), vget_low_s16(reduce_range_8x_i16(padded, p_q))); + i += 4; + } +} + +#[cfg(test)] +mod tests { + use super::super::butterfly::{ + forward_ntt as scalar_forward_ntt, forward_ntt_cyclic as scalar_forward_ntt_cyclic, + inverse_ntt as scalar_inverse_ntt, inverse_ntt_cyclic as scalar_inverse_ntt_cyclic, + NttTwiddles, + }; + use super::super::prime::{MontCoeff, NttPrime}; + use super::*; + + fn random_mont_array_i32( + prime: NttPrime, + seed: u64, + ) -> [MontCoeff; D] { + let mut state = seed; + std::array::from_fn(|_| { + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + let val = ((state >> 33) as i64 % prime.p as i64) as i32; + prime.from_canonical(val) + }) + } + + fn random_mont_array_i16( + prime: NttPrime, + seed: u64, + ) -> [MontCoeff; D] { + let mut state = seed; + std::array::from_fn(|_| { + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + let val = ((state >> 33) as i64 % prime.p as i64) as i16; + prime.from_canonical(val) + }) + } + + const TEST_PRIME_I32: i32 = 1073707009; + const TEST_PRIME_I16: i16 = 13697; + + #[test] + fn neon_forward_ntt_i32_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I32); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i32::<512>(prime, 0xCAFE); + + let mut neon_result = input; + unsafe { forward_ntt_i32(&mut neon_result, prime, &tw) }; + + let mut scalar_result = input; + scalar_forward_ntt(&mut scalar_result, prime, &tw); + + for i in 0..512 { + let n = prime.to_canonical(neon_result[i]); + let s = prime.to_canonical(scalar_result[i]); + assert_eq!(n, s, "mismatch at index {i}: neon={n}, scalar={s}"); + } + } + + #[test] + fn neon_inverse_ntt_i32_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I32); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i32::<512>(prime, 0xBEEF); + + let mut neon_result = input; + unsafe { inverse_ntt_i32(&mut neon_result, prime, &tw) }; + + let mut scalar_result = input; + scalar_inverse_ntt(&mut scalar_result, prime, &tw); + + for i in 0..512 { + let n = prime.to_canonical(neon_result[i]); + let s = prime.to_canonical(scalar_result[i]); + assert_eq!(n, s, "mismatch at index {i}: neon={n}, scalar={s}"); + } + } + + #[test] + fn neon_forward_inverse_roundtrip_i32() { + let prime = NttPrime::compute(TEST_PRIME_I32); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i32::<512>(prime, 0xDEAD); + let canonical_input: Vec = input.iter().map(|c| prime.to_canonical(*c)).collect(); + + let mut a = input; + unsafe { + forward_ntt_i32(&mut a, prime, &tw); + inverse_ntt_i32(&mut a, prime, &tw); + } + + for i in 0..512 { + let result = prime.to_canonical(a[i]); + assert_eq!( + result, canonical_input[i], + "roundtrip mismatch at index {i}" + ); + } + } + + #[test] + fn neon_cyclic_ntt_i32_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I32); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i32::<512>(prime, 0xFACE); + + let mut neon_fwd = input; + unsafe { forward_ntt_cyclic_i32(&mut neon_fwd, prime, &tw) }; + + let mut scalar_fwd = input; + scalar_forward_ntt_cyclic(&mut scalar_fwd, prime, &tw); + + for i in 0..512 { + let n = prime.to_canonical(neon_fwd[i]); + let s = prime.to_canonical(scalar_fwd[i]); + assert_eq!(n, s, "forward cyclic mismatch at {i}: neon={n}, scalar={s}"); + } + + let mut neon_inv = neon_fwd; + unsafe { inverse_ntt_cyclic_i32(&mut neon_inv, prime, &tw) }; + + let mut scalar_inv = scalar_fwd; + scalar_inverse_ntt_cyclic(&mut scalar_inv, prime, &tw); + + for i in 0..512 { + let n = prime.to_canonical(neon_inv[i]); + let s = prime.to_canonical(scalar_inv[i]); + assert_eq!(n, s, "inverse cyclic mismatch at {i}: neon={n}, scalar={s}"); + } + } + + #[test] + fn neon_pointwise_mul_acc_i32_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I32); + const D: usize = 512; + let acc_init = random_mont_array_i32::(prime, 0x1111); + let lhs = random_mont_array_i32::(prime, 0x2222); + let rhs = random_mont_array_i32::(prime, 0x3333); + + let mut neon_acc = acc_init; + unsafe { + pointwise_mul_acc_i32( + neon_acc.as_mut_ptr() as *mut i32, + lhs.as_ptr() as *const i32, + rhs.as_ptr() as *const i32, + D, + prime.p, + prime.pinv, + ); + } + + let mut scalar_acc = acc_init; + for i in 0..D { + let prod = prime.mul(lhs[i], rhs[i]); + let sum = MontCoeff::from_raw(scalar_acc[i].raw().wrapping_add(prod.raw())); + scalar_acc[i] = prime.reduce_range(sum); + } + + for i in 0..D { + let n = prime.to_canonical(neon_acc[i]); + let s = prime.to_canonical(scalar_acc[i]); + assert_eq!(n, s, "pointwise mul acc mismatch at {i}"); + } + } + + #[test] + fn neon_forward_ntt_i16_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I16); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i16::<64>(prime, 0xABCD); + + let mut neon_result = input; + unsafe { forward_ntt_i16(&mut neon_result, prime, &tw) }; + + let mut scalar_result = input; + scalar_forward_ntt(&mut scalar_result, prime, &tw); + + for i in 0..64 { + let n = prime.to_canonical(neon_result[i]); + let s = prime.to_canonical(scalar_result[i]); + assert_eq!(n, s, "i16 forward mismatch at {i}: neon={n}, scalar={s}"); + } + } + + #[test] + fn neon_inverse_ntt_i16_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I16); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i16::<64>(prime, 0xFEED); + + let mut neon_result = input; + unsafe { inverse_ntt_i16(&mut neon_result, prime, &tw) }; + + let mut scalar_result = input; + scalar_inverse_ntt(&mut scalar_result, prime, &tw); + + for i in 0..64 { + let n = prime.to_canonical(neon_result[i]); + let s = prime.to_canonical(scalar_result[i]); + assert_eq!(n, s, "i16 inverse mismatch at {i}: neon={n}, scalar={s}"); + } + } + + #[test] + fn neon_forward_inverse_roundtrip_i16() { + let prime = NttPrime::compute(TEST_PRIME_I16); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i16::<64>(prime, 0x7777); + let canonical_input: Vec = input.iter().map(|c| prime.to_canonical(*c)).collect(); + + let mut a = input; + unsafe { + forward_ntt_i16(&mut a, prime, &tw); + inverse_ntt_i16(&mut a, prime, &tw); + } + + for i in 0..64 { + let result = prime.to_canonical(a[i]); + assert_eq!(result, canonical_input[i], "i16 roundtrip mismatch at {i}"); + } + } + + #[test] + fn neon_cyclic_i16_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I16); + let tw = NttTwiddles::::compute(prime); + let input = random_mont_array_i16::<64>(prime, 0x9999); + + let mut neon_fwd = input; + unsafe { forward_ntt_cyclic_i16(&mut neon_fwd, prime, &tw) }; + + let mut scalar_fwd = input; + scalar_forward_ntt_cyclic(&mut scalar_fwd, prime, &tw); + + for i in 0..64 { + let n = prime.to_canonical(neon_fwd[i]); + let s = prime.to_canonical(scalar_fwd[i]); + assert_eq!(n, s, "i16 fwd cyclic mismatch at {i}"); + } + + let mut neon_inv = neon_fwd; + unsafe { inverse_ntt_cyclic_i16(&mut neon_inv, prime, &tw) }; + + let mut scalar_inv = scalar_fwd; + scalar_inverse_ntt_cyclic(&mut scalar_inv, prime, &tw); + + for i in 0..64 { + let n = prime.to_canonical(neon_inv[i]); + let s = prime.to_canonical(scalar_inv[i]); + assert_eq!(n, s, "i16 inv cyclic mismatch at {i}"); + } + } + + #[test] + fn neon_pointwise_mul_acc_i16_matches_scalar() { + let prime = NttPrime::compute(TEST_PRIME_I16); + const D: usize = 64; + let acc_init = random_mont_array_i16::(prime, 0xAAAA); + let lhs = random_mont_array_i16::(prime, 0xBBBB); + let rhs = random_mont_array_i16::(prime, 0xCCCC); + + let mut neon_acc = acc_init; + unsafe { + pointwise_mul_acc_i16( + neon_acc.as_mut_ptr() as *mut i16, + lhs.as_ptr() as *const i16, + rhs.as_ptr() as *const i16, + D, + prime.p, + prime.pinv, + ); + } + + let mut scalar_acc = acc_init; + for i in 0..D { + let prod = prime.mul(lhs[i], rhs[i]); + let sum = MontCoeff::from_raw(scalar_acc[i].raw().wrapping_add(prod.raw())); + scalar_acc[i] = prime.reduce_range(sum); + } + + for i in 0..D { + let n = prime.to_canonical(neon_acc[i]); + let s = prime.to_canonical(scalar_acc[i]); + assert_eq!(n, s, "i16 pointwise mul acc mismatch at {i}"); + } + } +} diff --git a/src/algebra/ntt/prime.rs b/src/algebra/ntt/prime.rs new file mode 100644 index 00000000..a5646234 --- /dev/null +++ b/src/algebra/ntt/prime.rs @@ -0,0 +1,382 @@ +//! NTT prime arithmetic kernels generic over coefficient width. +//! +//! Per-prime scalar operations: +//! - Montgomery multiplication ([`NttPrime::mul`]) +//! - Branchless conditional add/sub and range reduction +//! +//! Coefficients in Montgomery domain are wrapped in [`MontCoeff`] to prevent +//! accidental mixing with canonical values. +//! +//! The [`PrimeWidth`] trait abstracts over `i16` (R = 2^16, for primes < 2^14) +//! and `i32` (R = 2^32, for primes < 2^30). All NTT types are generic over +//! `W: PrimeWidth`; monomorphization produces optimal code for each width. + +use std::fmt; + +mod sealed { + pub trait Sealed {} + impl Sealed for i16 {} + impl Sealed for i32 {} +} + +/// Integer width abstraction for NTT prime arithmetic. +/// +/// Sealed with exactly two implementations: `i16` and `i32`. +pub trait PrimeWidth: + sealed::Sealed + Copy + Clone + Eq + Default + fmt::Debug + Send + Sync + 'static +{ + /// Double-width type for intermediate Montgomery products. + type Wide: Copy + Clone; + + /// log2(R) for Montgomery reduction: 16 for `i16`, 32 for `i32`. + const R_LOG: u32; + + /// Widening multiply: `a * b` as `Wide`. + fn wide_mul(a: Self, b: Self) -> Self::Wide; + + /// Truncate wide value to narrow (low half, i.e. mod R). + fn truncate(w: Self::Wide) -> Self; + + /// Arithmetic right shift of wide value by `R_LOG` bits. + fn wide_shift(w: Self::Wide) -> Self; + + /// Wide subtraction (wrapping). + fn wide_sub(a: Self::Wide, b: Self::Wide) -> Self::Wide; + + /// Wrapping addition. + fn wrapping_add(self, rhs: Self) -> Self; + /// Wrapping subtraction. + fn wrapping_sub(self, rhs: Self) -> Self; + /// Wrapping multiplication. + fn wrapping_mul(self, rhs: Self) -> Self; + /// Wrapping negation. + fn wrapping_neg(self) -> Self; + + /// Arithmetic right shift by `BITS - 1`: all-1s if negative, all-0s otherwise. + fn sign_mask(self) -> Self; + + /// Bitwise AND. + fn bitand(self, rhs: Self) -> Self; + + /// Convert from `i64` (truncating). + fn from_i64(v: i64) -> Self; + /// Convert to `i64` (sign-extending). + fn to_i64(self) -> i64; +} + +impl PrimeWidth for i16 { + type Wide = i32; + const R_LOG: u32 = 16; + + #[inline] + fn wide_mul(a: Self, b: Self) -> i32 { + (a as i32) * (b as i32) + } + #[inline] + fn truncate(w: i32) -> Self { + w as i16 + } + #[inline] + fn wide_shift(w: i32) -> Self { + (w >> 16) as i16 + } + #[inline] + fn wide_sub(a: i32, b: i32) -> i32 { + a.wrapping_sub(b) + } + #[inline] + fn wrapping_add(self, rhs: Self) -> Self { + i16::wrapping_add(self, rhs) + } + #[inline] + fn wrapping_sub(self, rhs: Self) -> Self { + i16::wrapping_sub(self, rhs) + } + #[inline] + fn wrapping_mul(self, rhs: Self) -> Self { + i16::wrapping_mul(self, rhs) + } + #[inline] + fn wrapping_neg(self) -> Self { + i16::wrapping_neg(self) + } + #[inline] + fn sign_mask(self) -> Self { + self >> 15 + } + #[inline] + fn bitand(self, rhs: Self) -> Self { + self & rhs + } + #[inline] + fn from_i64(v: i64) -> Self { + v as i16 + } + #[inline] + fn to_i64(self) -> i64 { + self as i64 + } +} + +impl PrimeWidth for i32 { + type Wide = i64; + const R_LOG: u32 = 32; + + #[inline] + fn wide_mul(a: Self, b: Self) -> i64 { + (a as i64) * (b as i64) + } + #[inline] + fn truncate(w: i64) -> Self { + w as i32 + } + #[inline] + fn wide_shift(w: i64) -> Self { + (w >> 32) as i32 + } + #[inline] + fn wide_sub(a: i64, b: i64) -> i64 { + a.wrapping_sub(b) + } + #[inline] + fn wrapping_add(self, rhs: Self) -> Self { + i32::wrapping_add(self, rhs) + } + #[inline] + fn wrapping_sub(self, rhs: Self) -> Self { + i32::wrapping_sub(self, rhs) + } + #[inline] + fn wrapping_mul(self, rhs: Self) -> Self { + i32::wrapping_mul(self, rhs) + } + #[inline] + fn wrapping_neg(self) -> Self { + i32::wrapping_neg(self) + } + #[inline] + fn sign_mask(self) -> Self { + self >> 31 + } + #[inline] + fn bitand(self, rhs: Self) -> Self { + self & rhs + } + #[inline] + fn from_i64(v: i64) -> Self { + v as i32 + } + #[inline] + fn to_i64(self) -> i64 { + self as i64 + } +} + +/// A coefficient in Montgomery domain for an NTT prime. +/// +/// Wraps a `W` representing `a * R mod p` (where `R = 2^{W::R_LOG}`). +/// Use [`NttPrime::from_canonical`] to enter and [`NttPrime::to_canonical`] +/// to leave Montgomery domain. +#[derive(Clone, Copy, PartialEq, Eq, Default)] +#[repr(transparent)] +pub struct MontCoeff(W); + +impl MontCoeff { + /// Wrap a raw Montgomery-domain value. + #[inline] + pub fn from_raw(val: W) -> Self { + Self(val) + } + + /// Extract the raw value (still in Montgomery domain). + #[inline] + pub fn raw(self) -> W { + self.0 + } +} + +impl fmt::Debug for MontCoeff { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Mont({:?})", self.0) + } +} + +/// Per-prime constants for NTT arithmetic. +/// +/// Generic over `W: PrimeWidth` — use `i16` for primes below 2^14 (R = 2^16), +/// or `i32` for primes below 2^30 (R = 2^32). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NttPrime { + /// Prime modulus. + pub p: W, + /// `p^{-1} mod R` (centered signed). Used in Montgomery reduction. + pub pinv: W, + /// `R mod p` (centered signed). Montgomery form of 1. + pub mont: W, + /// `R^2 mod p` (centered signed). Used for canonical → Montgomery conversion. + pub montsq: W, +} + +impl NttPrime { + /// Derive all Montgomery constants from a raw prime value. + pub fn compute(p: W) -> Self { + let p_i64 = p.to_i64(); + debug_assert!(p_i64 > 1 && p_i64 % 2 == 1, "NTT prime must be odd and > 1"); + + // pinv via Newton's method: x_{n+1} = x_n * (2 - p * x_n). + // 5 iterations gives correctness mod 2^32 (sufficient for both i16 and i32). + let mut pinv: i64 = 1; + for _ in 0..5 { + pinv = pinv.wrapping_mul(2i64.wrapping_sub(p_i64.wrapping_mul(pinv))); + } + let pinv = W::from_i64(pinv); + + let half = p_i64 / 2; + let center = |x: i64| -> W { W::from_i64(if x > half { x - p_i64 } else { x }) }; + + let r_mod_p = ((1i128 << W::R_LOG) % (p_i64 as i128)) as i64; + let mont = center(r_mod_p); + + let rsq_mod_p = ((1i128 << (2 * W::R_LOG)) % (p_i64 as i128)) as i64; + let montsq = center(rsq_mod_p); + + Self { + p, + pinv, + mont, + montsq, + } + } + + /// Montgomery product: `a * b * R^{-1} mod p`. + #[inline] + pub fn mul(self, a: MontCoeff, b: MontCoeff) -> MontCoeff { + MontCoeff(self.mont_mul_raw(a.0, b.0)) + } + + /// Raw Montgomery multiply on bare `W` values. + #[inline] + pub(crate) fn mont_mul_raw(self, a: W, b: W) -> W { + let c = W::wide_mul(a, b); + let t = W::truncate(c).wrapping_mul(self.pinv); + let tp = W::wide_mul(t, self.p); + W::wide_shift(W::wide_sub(c, tp)) + } + + /// Conditionally subtract `p` if `a >= p` (branchless). + /// + /// Input may be in `(-2p, 2p)` during butterflies. Both i16 and i32 paths + /// widen to avoid signed overflow: i16→i32, i32→i64. + #[inline] + pub fn csubp(self, a: MontCoeff) -> MontCoeff { + if W::R_LOG == 16 { + let ai = a.0.to_i64() as i32; + let pi = self.p.to_i64() as i32; + let diff = ai - pi; + let mask = diff >> 31; + MontCoeff(W::from_i64((diff + (mask & pi)) as i64)) + } else { + let ai = a.0.to_i64(); + let pi = self.p.to_i64(); + let diff = ai - pi; + let mask = diff >> 63; + MontCoeff(W::from_i64(diff + (mask & pi))) + } + } + + /// Conditionally add `p` if `a < 0` (branchless). + /// + /// Widened arithmetic mirrors `csubp` to avoid i16 overflow edge cases. + #[inline] + pub fn caddp(self, a: MontCoeff) -> MontCoeff { + if W::R_LOG == 16 { + let ai = a.0.to_i64() as i32; + let pi = self.p.to_i64() as i32; + let mask = ai >> 31; + MontCoeff(W::from_i64((ai + (mask & pi)) as i64)) + } else { + let ai = a.0.to_i64(); + let pi = self.p.to_i64(); + let mask = ai >> 63; + MontCoeff(W::from_i64(ai + (mask & pi))) + } + } + + /// Range-reduce from `(-2p, 2p)` to `(-p, p)`. + #[inline] + pub fn reduce_range(self, a: MontCoeff) -> MontCoeff { + self.caddp(self.csubp(a)) + } + + /// Fully normalize a Montgomery coefficient to `[0, p)`. + #[inline] + pub fn normalize(self, a: MontCoeff) -> MontCoeff { + self.csubp(self.caddp(a)) + } + + /// Convert a canonical value into Montgomery domain: `a ↦ a * R mod p`. + #[inline] + pub fn from_canonical(self, a: W) -> MontCoeff { + MontCoeff(self.mont_mul_raw(a, self.montsq)) + } + + /// Convert from Montgomery domain to canonical `[0, p)`. + #[inline] + pub fn to_canonical(self, a: MontCoeff) -> W { + let raw = MontCoeff(self.mont_mul_raw(a.0, W::from_i64(1))); + self.normalize(raw).0 + } + + /// Center a canonical value from approximately `(-p, p)` into `[-p/2, p/2)`. + #[inline] + pub fn center(self, a: W) -> W { + let mask_neg = a.sign_mask(); + let canonical = a.wrapping_add(mask_neg.bitand(self.p)); + let half = W::from_i64(self.p.to_i64() / 2); + let needs_sub = half.wrapping_sub(canonical).sign_mask(); + canonical.wrapping_add(needs_sub.bitand(self.p.wrapping_neg())) + } + + /// Pointwise Montgomery multiplication of two coefficient slices. + /// + /// # Panics + /// + /// Panics if slices have different lengths. + #[inline] + pub fn pointwise_mul( + self, + out: &mut [MontCoeff], + lhs: &[MontCoeff], + rhs: &[MontCoeff], + ) { + assert_eq!(out.len(), lhs.len()); + assert_eq!(lhs.len(), rhs.len()); + for ((o, a), b) in out.iter_mut().zip(lhs.iter()).zip(rhs.iter()) { + *o = self.mul(*a, *b); + } + } + + /// In-place Montgomery scaling by a constant. + #[inline] + pub fn scale_in_place(self, coeffs: &mut [MontCoeff], scalar: MontCoeff) { + for c in coeffs { + *c = self.mul(*c, scalar); + } + } + + /// In-place range reduction on a coefficient slice. + #[inline] + pub fn reduce_range_in_place(self, coeffs: &mut [MontCoeff]) { + for c in coeffs { + *c = self.reduce_range(*c); + } + } + + /// In-place centering of canonical values to `[-p/2, p/2)`. + #[inline] + pub fn center_slice(self, coeffs: &mut [W]) { + for c in coeffs { + *c = self.center(*c); + } + } +} diff --git a/src/algebra/ntt/tables.rs b/src/algebra/ntt/tables.rs new file mode 100644 index 00000000..70e35bd1 --- /dev/null +++ b/src/algebra/ntt/tables.rs @@ -0,0 +1,215 @@ +//! Deterministic parameter presets for small-prime CRT arithmetic. +//! +//! Q32: `logq = 32` with six `i16` NTT-friendly primes (D ≤ 64). +//! Q64: `logq = 64` with `i32` NTT-friendly primes (D ≤ 1024). +//! Q128: `logq = 128` with five `i32` NTT-friendly primes (D ≤ 1024). + +use super::crt::GarnerData; +use super::prime::NttPrime; + +/// Polynomial degree for the base ring `Z_q[X]/(X^d + 1)`. +pub const RING_DEGREE: usize = 64; +/// Maximum ring degree covered by the i32 CRT parameter sets. +pub const MAX_CRT_RING_DEGREE: usize = 1024; + +/// Number of CRT primes for the `logq = 32` parameter set. +pub const Q32_NUM_PRIMES: usize = 6; + +/// The modulus `q = 2^32 - 99`. +pub const Q32_MODULUS: u64 = (1u64 << 32) - 99; + +/// CRT primes and per-prime Montgomery constants for `logq = 32`. +/// +/// All constants are for `R = 2^16` (i16 width). +pub const Q32_PRIMES: [NttPrime; Q32_NUM_PRIMES] = [ + NttPrime { + p: 13697, + pinv: 2689, + mont: -2949, + montsq: -994, + }, + NttPrime { + p: 13441, + pinv: 2945, + mont: -1669, + montsq: 3274, + }, + NttPrime { + p: 13313, + pinv: -13311, + mont: -1029, + montsq: -6199, + }, + NttPrime { + p: 12289, + pinv: -12287, + mont: 4091, + montsq: -1337, + }, + NttPrime { + p: 12161, + pinv: 4225, + mont: 4731, + montsq: -6040, + }, + NttPrime { + p: 11777, + pinv: -11775, + mont: -5126, + montsq: 1389, + }, +]; + +/// Garner CRT reconstruction constants for Q32. +pub fn q32_garner() -> GarnerData { + GarnerData::compute(&Q32_PRIMES) +} + +/// Number of CRT primes for the `logq = 64` fast profile (`P > q`). +pub const Q64_NUM_PRIMES_FAST: usize = 3; +/// Number of CRT primes for the `logq = 64` conservative profile (`P > 128*q^2`). +pub const Q64_NUM_PRIMES: usize = 5; + +/// The modulus `q = 2^64 - 59`. +pub const Q64_MODULUS: u64 = u64::MAX - 58; + +/// Number of CRT primes for the `logq = 128` parameter set. +pub const Q128_NUM_PRIMES: usize = 5; + +/// The modulus `q = 2^128 - 275`. +pub const Q128_MODULUS: u128 = u128::MAX - 274; + +/// Raw 30-bit primes for D≤1024, each satisfying `2048 | (p - 1)`. +/// +/// They are ordered descending by value. +pub const D1024_RAW_PRIMES: [i32; Q128_NUM_PRIMES] = + [1073707009, 1073698817, 1073692673, 1073682433, 1073668097]; + +/// Raw 30-bit primes for Q64 fast profile (`K=3`, `P > q`). +pub const Q64_RAW_PRIMES_FAST: [i32; Q64_NUM_PRIMES_FAST] = [ + D1024_RAW_PRIMES[0], + D1024_RAW_PRIMES[1], + D1024_RAW_PRIMES[2], +]; + +/// Raw 30-bit primes for Q64 conservative profile (`K=5`, `P > 128*q^2`). +pub const Q64_RAW_PRIMES: [i32; Q64_NUM_PRIMES] = D1024_RAW_PRIMES; + +/// Raw 30-bit primes for Q128, each satisfying `2048 | (p - 1)`. +pub const Q128_RAW_PRIMES: [i32; Q128_NUM_PRIMES] = D1024_RAW_PRIMES; + +/// CRT primes and per-prime Montgomery constants for `logq = 64` fast profile. +pub fn q64_primes_fast() -> [NttPrime; Q64_NUM_PRIMES_FAST] { + std::array::from_fn(|k| NttPrime::compute(Q64_RAW_PRIMES_FAST[k])) +} + +/// Garner CRT reconstruction constants for Q64 fast profile. +pub fn q64_garner_fast() -> GarnerData { + let primes = q64_primes_fast(); + GarnerData::compute(&primes) +} + +/// CRT primes and per-prime Montgomery constants for `logq = 64` conservative profile. +pub fn q64_primes() -> [NttPrime; Q64_NUM_PRIMES] { + std::array::from_fn(|k| NttPrime::compute(Q64_RAW_PRIMES[k])) +} + +/// Garner CRT reconstruction constants for Q64 conservative profile. +pub fn q64_garner() -> GarnerData { + let primes = q64_primes(); + GarnerData::compute(&primes) +} + +/// CRT primes and per-prime Montgomery constants for `logq = 128`. +pub fn q128_primes() -> [NttPrime; Q128_NUM_PRIMES] { + std::array::from_fn(|k| NttPrime::compute(Q128_RAW_PRIMES[k])) +} + +/// Garner CRT reconstruction constants for Q128. +pub fn q128_garner() -> GarnerData { + let primes = q128_primes(); + GarnerData::compute(&primes) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn verify_q32_prime_derived_constants() { + for prime in &Q32_PRIMES { + let recomputed = NttPrime::compute(prime.p); + assert_eq!( + prime.pinv, recomputed.pinv, + "pinv mismatch for p={}", + prime.p + ); + assert_eq!( + prime.mont, recomputed.mont, + "mont mismatch for p={}", + prime.p + ); + assert_eq!( + prime.montsq, recomputed.montsq, + "montsq mismatch for p={}", + prime.p + ); + } + } + + #[test] + fn verify_q128_primes_are_valid() { + let primes = q128_primes(); + for np in &primes { + let p = np.p as i64; + assert!(p > 1 && p % 2 == 1, "prime must be odd and > 1"); + assert_eq!( + (p - 1) % 2048, + 0, + "2048 must divide p-1 for D=1024 NTT (p={p})" + ); + // Verify pinv: p * pinv ≡ 1 (mod 2^32) + assert_eq!( + np.p.wrapping_mul(np.pinv), + 1, + "pinv verification failed for p={p}" + ); + } + } + + #[test] + fn verify_q64_primes_are_valid() { + let primes = q64_primes(); + for np in &primes { + let p = np.p as i64; + assert!(p > 1 && p % 2 == 1, "prime must be odd and > 1"); + assert_eq!( + (p - 1) % 2048, + 0, + "2048 must divide p-1 for D=1024 NTT (p={p})" + ); + assert_eq!( + np.p.wrapping_mul(np.pinv), + 1, + "pinv verification failed for p={p}" + ); + } + } + + #[test] + fn garner_data_is_consistent() { + let garner = q32_garner(); + for (i, &prime_i) in Q32_PRIMES.iter().enumerate().skip(1) { + let pi = prime_i.p as i64; + for (j, &prime_j) in Q32_PRIMES[..i].iter().enumerate() { + let pj = prime_j.p as i64; + let g = garner.gamma[i][j] as i64; + assert_eq!( + (pj * g) % pi, + 1, + "garner gamma[{i}][{j}] not inverse of p_{j} mod p_{i}" + ); + } + } + } +} diff --git a/src/algebra/poly.rs b/src/algebra/poly.rs new file mode 100644 index 00000000..ee0c120b --- /dev/null +++ b/src/algebra/poly.rs @@ -0,0 +1,240 @@ +//! Polynomial containers and evaluation utilities. + +use crate::algebra::fields::wide::{HasWide, ReduceTo}; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::protocol::sumcheck::eq_poly::EqPolynomial; +use crate::{cfg_fold_reduce, AdditiveGroup, FieldCore, FromSmallInt}; +use std::io::{Read, Write}; +use std::ops::{Add, Neg, Sub}; + +/// A degree-(pub [F; D]); + +impl Poly { + /// Construct the zero polynomial. + pub fn zero() -> Self { + Self([F::zero(); D]) + } +} + +impl Add for Poly { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + let mut out = self.0; + for (dst, src) in out.iter_mut().zip(rhs.0.iter()) { + *dst += *src; + } + Self(out) + } +} + +impl Sub for Poly { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + let mut out = self.0; + for (dst, src) in out.iter_mut().zip(rhs.0.iter()) { + *dst -= *src; + } + Self(out) + } +} + +impl Neg for Poly { + type Output = Self; + fn neg(self) -> Self::Output { + let mut out = self.0; + for coeff in &mut out { + *coeff = -*coeff; + } + Self(out) + } +} + +impl Valid for Poly { + fn check(&self) -> Result<(), SerializationError> { + for x in self.0.iter() { + x.check()?; + } + Ok(()) + } +} + +impl HachiSerialize for Poly { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + for x in self.0.iter() { + x.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.0.iter().map(|x| x.serialized_size(compress)).sum() + } +} + +impl HachiDeserialize for Poly { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let mut arr = [F::zero(); D]; + for coeff in &mut arr { + *coeff = F::deserialize_with_mode(&mut reader, compress, validate)?; + } + let out = Self(arr); + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Evaluate the range-check polynomial `Π_{k=−b/2}^{b/2−1} (w − k)`. +/// +/// This polynomial vanishes exactly on the balanced-digit set `{−b/2, …, b/2−1}`, +/// matching the output of `balanced_decompose_pow2`. +/// Total degree in `w` is `b`. +pub fn range_check_eval(w: E, b: usize) -> E { + let half = (b / 2) as i64; + let mut acc = E::one(); + for k in -half..half { + acc = acc * (w - E::from_i64(k)); + } + acc +} + +/// Evaluate a multilinear polynomial (given by boolean-hypercube evaluations in +/// little-endian bit order) at an arbitrary point via iterated folding. +/// +/// # Errors +/// +/// Returns an error if the evaluation table length is not a power of two or +/// does not match `2^point.len()`. +pub fn multilinear_eval(evals: &[E], point: &[E]) -> Result { + if !evals.len().is_power_of_two() { + return Err(HachiError::InvalidSize { + expected: 1 << point.len(), + actual: evals.len(), + }); + } + if evals.len() != 1 << point.len() { + return Err(HachiError::InvalidSize { + expected: 1 << point.len(), + actual: evals.len(), + }); + } + Ok(multilinear_eval_ref(evals, point)) +} + +#[inline] +fn multilinear_eval_ref(evals: &[E], point: &[E]) -> E { + match point.split_last() { + None => { + debug_assert_eq!(evals.len(), 1); + evals[0] + } + Some((&r, rest)) => { + let half = evals.len() / 2; + let lo = multilinear_eval_ref(&evals[..half], rest); + let hi = multilinear_eval_ref(&evals[half..], rest); + lo + r * (hi - lo) + } + } +} + +/// Fold an evaluation table in place by binding its first variable to `r`, +/// halving the table size. +/// +/// # Panics +/// +/// Panics if the evaluation table length is not a power of two or has fewer +/// than 2 elements. This is a prover-only helper where the caller guarantees +/// well-formed input. +#[tracing::instrument(skip_all, name = "fold_evals_in_place")] +pub fn fold_evals_in_place(evals: &mut Vec, r: E) { + assert!( + evals.len().is_power_of_two(), + "evals length must be a power of two" + ); + assert!(evals.len() >= 2, "evals must have at least 2 elements"); + let half = evals.len() / 2; + for i in 0..half { + evals[i] = evals[2 * i] + r * (evals[2 * i + 1] - evals[2 * i]); + } + evals.truncate(half); +} + +/// Evaluate a multilinear polynomial with small integer evaluations at a +/// field point, using the split-eq structure with unreduced accumulation. +/// +/// Uses `HasWide::mul_small_to_wide` in the inner loop: each eq table entry +/// is widened, scaled by the small witness value, and accumulated without +/// reduction. The inner sum is reduced once per outer iteration, then +/// multiplied by the outer eq factor and accumulated again in wide form. +/// +/// Overflow budget: each inner accumulation adds at most `0xFFFF * |small|` +/// to each i32 limb. For `|small| ≤ 128` (b ≤ 256), we can safely +/// accumulate 256 products before an i32 limb overflows. +/// +/// # Errors +/// +/// Returns an error if the table length does not match `2^point.len()`. +#[tracing::instrument(skip_all, name = "multilinear_eval_small")] +pub fn multilinear_eval_small( + evals_small: &[i8], + point: &[E], +) -> Result { + let n = point.len(); + if evals_small.len() != 1 << n { + return Err(HachiError::InvalidSize { + expected: 1 << n, + actual: evals_small.len(), + }); + } + if n == 0 { + return Ok(E::from_i64(evals_small[0] as i64)); + } + + let m = n / 2; + let (r_first, r_second) = point.split_at(m); + let eq_first = EqPolynomial::evals(r_first); + let eq_second = EqPolynomial::evals(r_second); + let in_len = eq_first.len(); + + // Max safe accumulations per chunk before i32 overflow. + // Limbs are 16-bit (0..0xFFFF), scaled by |small| ≤ 128 → 23-bit products. + // i32::MAX / (0xFFFF * 128) ≈ 256. + const CHUNK: usize = 256; + + let outer_accum = cfg_fold_reduce!( + 0..eq_second.len(), + || E::Wide::ZERO, + |acc, x_out| { + let base = x_out * in_len; + let mut inner_field = E::zero(); + for chunk_start in (0..in_len).step_by(CHUNK) { + let chunk_end = (chunk_start + CHUNK).min(in_len); + let mut chunk_acc = E::Wide::ZERO; + for x_in in chunk_start..chunk_end { + chunk_acc += eq_first[x_in].mul_small_to_wide(evals_small[base + x_in] as i32); + } + inner_field += >::reduce(chunk_acc); + } + + acc + E::Wide::from(eq_second[x_out] * inner_field) + }, + |a, b| a + b + ); + Ok(>::reduce(outer_accum)) +} diff --git a/src/algebra/ring/crt_ntt_repr.rs b/src/algebra/ring/crt_ntt_repr.rs new file mode 100644 index 00000000..86fe0e86 --- /dev/null +++ b/src/algebra/ring/crt_ntt_repr.rs @@ -0,0 +1,743 @@ +//! CRT+NTT-domain representation of cyclotomic ring elements. + +use std::array::from_fn; +#[cfg(target_arch = "aarch64")] +use std::mem::size_of; + +use crate::algebra::backend::{CrtReconstruct, NttPrimeOps, NttTransform, ScalarBackend}; +use crate::algebra::ntt::butterfly::{ + forward_ntt, forward_ntt_cyclic, inverse_ntt_cyclic, NttTwiddles, +}; +use crate::algebra::ntt::crt::GarnerData; +#[cfg(target_arch = "aarch64")] +use crate::algebra::ntt::neon; +use crate::algebra::ntt::prime::{MontCoeff, NttPrime, PrimeWidth}; +use crate::{CanonicalField, FieldCore}; + +use super::cyclotomic::CyclotomicRing; + +/// CRT+NTT-domain representation of a cyclotomic ring element. +/// +/// Stores `K` arrays of `D` [`MontCoeff`] values, one per CRT prime. +/// Multiplication is pointwise per prime — O(K*D) vs O(D^2) for coefficient form. +/// +/// Generic over: +/// - `W: PrimeWidth` — integer width (`i16` or `i32`) +/// - `K` — number of CRT primes +/// - `D` — polynomial degree +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CyclotomicCrtNtt { + pub(crate) limbs: [[MontCoeff; D]; K], +} + +/// Field types that can convert to/from the CRT+NTT representation. +/// +/// Blanket-implemented for all `FieldCore + CanonicalField` types. +pub trait CrtNttConvertibleField: FieldCore + CanonicalField {} + +impl CrtNttConvertibleField for F {} + +/// Bundled CRT+NTT parameters for a fixed width/prime-count/degree tuple. +/// +/// Keeps primes/twiddles/Garner constants consistent and avoids passing them +/// independently at every call site. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CrtNttParamSet { + /// CRT primes with Montgomery constants. + pub primes: [NttPrime; K], + /// Per-prime twiddle tables for forward/inverse NTT. + pub twiddles: [NttTwiddles; K], + /// Garner reconstruction constants for CRT lift-back. + pub garner: GarnerData, +} + +/// Precomputed Montgomery forms for small balanced digit values. +/// +/// Covers the full `{-8, ..., 7}` range (16 entries per CRT prime), +/// which is sufficient for any `log_basis <= 4`. Storing the Montgomery +/// representation eliminates one `from_canonical` (a Montgomery multiply) +/// per coefficient in the `from_i8` hot path. +#[derive(Debug, Clone)] +pub struct DigitMontLut { + vals: [[MontCoeff; 16]; K], +} + +/// Precomputed Montgomery forms for centered integer coefficients in +/// `[-max_abs, max_abs]`. +#[derive(Debug, Clone)] +pub struct CenteredMontLut { + vals: [Vec>; K], + offset: i32, +} + +const DIGIT_LUT_HALF_B: i16 = 8; + +impl DigitMontLut { + /// Build the lookup table from CRT primes. + /// + /// Covers digit values in `{-8, ..., 7}` (balanced representation for + /// `log_basis <= 4`). + pub fn new(params: &CrtNttParamSet) -> Self { + let mut vals = [[MontCoeff::from_raw(W::default()); 16]; K]; + for (k, prime) in params.primes.iter().enumerate() { + for v_idx in 0..16u8 { + let v = v_idx as i64 - DIGIT_LUT_HALF_B as i64; + vals[k][v_idx as usize] = prime.from_canonical(W::from_i64(v)); + } + } + Self { vals } + } + + /// Look up the Montgomery form of a balanced digit for CRT prime `k`. + #[inline(always)] + pub fn get(&self, k: usize, digit: i8) -> MontCoeff { + unsafe { + *self + .vals + .get_unchecked(k) + .get_unchecked((digit as i16 + DIGIT_LUT_HALF_B) as usize) + } + } +} + +impl CenteredMontLut { + /// Build a lookup table for all centered coefficients in `[-max_abs, max_abs]`. + pub fn new(params: &CrtNttParamSet, max_abs: i32) -> Self { + let vals = from_fn(|k| { + let prime = params.primes[k]; + (-max_abs..=max_abs) + .map(|v| prime.from_canonical(W::from_i64(v as i64))) + .collect() + }); + Self { + vals, + offset: max_abs, + } + } + + /// Look up the Montgomery form of a centered coefficient for CRT prime `k`. + #[inline(always)] + pub fn get(&self, k: usize, coeff: i32) -> MontCoeff { + unsafe { + *self + .vals + .get_unchecked(k) + .get_unchecked((coeff + self.offset) as usize) + } + } +} + +impl CrtNttParamSet { + /// Build a full parameter set from CRT primes. + /// + /// Computes per-prime twiddles and Garner reconstruction constants. + pub fn new(primes: [NttPrime; K]) -> Self { + let twiddles = from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = GarnerData::compute(&primes); + Self { + primes, + twiddles, + garner, + } + } +} + +impl CyclotomicCrtNtt { + /// The additive identity (all zeros in every CRT limb). + pub fn zero() -> Self { + Self { + limbs: [[MontCoeff::from_raw(W::default()); D]; K], + } + } + + /// Convert a coefficient-form ring element into CRT+NTT domain + /// using the default scalar backend. + pub fn from_ring( + ring: &CyclotomicRing, + primes: &[NttPrime; K], + twiddles: &[NttTwiddles; K], + ) -> Self { + Self::from_ring_with_backend::(ring, primes, twiddles) + } + + /// Convert a coefficient-form ring element into CRT+NTT domain + /// using a bundled parameter set and the scalar backend. + pub fn from_ring_with_params( + ring: &CyclotomicRing, + params: &CrtNttParamSet, + ) -> Self { + Self::from_ring(ring, ¶ms.primes, ¶ms.twiddles) + } + + /// Convert a coefficient-form ring element into CRT+NTT domain + /// through an explicit backend implementation. + pub fn from_ring_with_backend< + F: CrtNttConvertibleField, + B: NttPrimeOps + NttTransform, + >( + ring: &CyclotomicRing, + primes: &[NttPrime; K], + twiddles: &[NttTwiddles; K], + ) -> Self { + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + let centered_coeffs: [i128; D] = from_fn(|i| { + let canonical = ring.coeffs[i].to_canonical_u128(); + if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + } + }); + + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for ((limb, prime), tw) in limbs.iter_mut().zip(primes.iter()).zip(twiddles.iter()) { + // Interpret coefficients in centered form (-q/2, q/2] before reducing + // into the CRT primes. This makes the reduction map consistent with + // negacyclic subtraction (which naturally produces negative values). + let p = prime.p.to_i64(); + let p_u64 = p as u64; + let r64 = ((1u128 << 64) % p_u64 as u128) as i64; + let half_p = p / 2; + for (dst, centered) in limb.iter_mut().zip(centered_coeffs.iter()) { + let c = *centered; + let lo = (c as u64 % p_u64) as i64; + let hi = ((c >> 64) as i64).rem_euclid(p); + let mut r = (lo + hi * r64) % p; + if r >= half_p { + r -= p; + } + *dst = B::from_canonical(*prime, W::from_i64(r)); + } + B::forward_ntt(limb, *prime, tw); + } + Self { limbs } + } + + /// Convert small integer coefficients (e.g. gadget digits) into + /// negacyclic CRT+NTT domain, bypassing Fp128 centering entirely. + pub fn from_i8_with_params(digits: &[i8; D], params: &CrtNttParamSet) -> Self { + Self::from_i8_negacyclic_backend::(digits, params) + } + + /// Convert centered i32 coefficients into negacyclic CRT+NTT domain. + pub fn from_centered_i32_with_params( + coeffs: &[i32; D], + params: &CrtNttParamSet, + ) -> Self { + Self::from_centered_i32_negacyclic_backend::(coeffs, params) + } + + /// Like [`Self::from_i8_with_params`] but uses a precomputed + /// [`DigitMontLut`] to replace per-coefficient `from_canonical` + /// (Montgomery multiply) with a table lookup. + #[inline] + pub fn from_i8_with_lut( + digits: &[i8; D], + params: &CrtNttParamSet, + lut: &DigitMontLut, + ) -> Self { + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for (k, (limb, tw)) in limbs.iter_mut().zip(params.twiddles.iter()).enumerate() { + for (dst, &d) in limb.iter_mut().zip(digits.iter()) { + *dst = lut.get(k, d); + } + forward_ntt(limb, params.primes[k], tw); + } + Self { limbs } + } + + /// Like [`Self::from_i8_cyclic`] but uses a precomputed [`DigitMontLut`]. + #[inline] + pub fn from_i8_cyclic_with_lut( + digits: &[i8; D], + params: &CrtNttParamSet, + lut: &DigitMontLut, + ) -> Self { + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for (k, (limb, tw)) in limbs.iter_mut().zip(params.twiddles.iter()).enumerate() { + for (dst, &d) in limb.iter_mut().zip(digits.iter()) { + *dst = lut.get(k, d); + } + forward_ntt_cyclic(limb, params.primes[k], tw); + } + Self { limbs } + } + + /// Convert centered i32 coefficients into cyclic CRT+NTT domain. + pub fn from_centered_i32_cyclic_with_params( + coeffs: &[i32; D], + params: &CrtNttParamSet, + ) -> Self { + Self::from_centered_i32_cyclic_backend::(coeffs, params) + } + + /// Convert centered i32 coefficients into both negacyclic and cyclic + /// CRT+NTT domains while sharing the coefficient preparation step. + pub fn from_centered_i32_pair_with_params( + coeffs: &[i32; D], + params: &CrtNttParamSet, + ) -> (Self, Self) { + Self::from_centered_i32_pair_backend::(coeffs, params, None) + } + + /// Like [`Self::from_centered_i32_pair_with_params`] but uses a precomputed + /// [`CenteredMontLut`] for the coefficient-to-Montgomery conversion. + pub fn from_centered_i32_pair_with_lut( + coeffs: &[i32; D], + params: &CrtNttParamSet, + lut: &CenteredMontLut, + ) -> (Self, Self) { + Self::from_centered_i32_pair_backend::(coeffs, params, Some(lut)) + } + + fn from_i8_negacyclic_backend + NttTransform>( + digits: &[i8; D], + params: &CrtNttParamSet, + ) -> Self { + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for ((limb, prime), tw) in limbs + .iter_mut() + .zip(params.primes.iter()) + .zip(params.twiddles.iter()) + { + for (dst, &d) in limb.iter_mut().zip(digits.iter()) { + *dst = B::from_canonical(*prime, W::from_i64(d as i64)); + } + B::forward_ntt(limb, *prime, tw); + } + Self { limbs } + } + + fn from_centered_i32_negacyclic_backend + NttTransform>( + coeffs: &[i32; D], + params: &CrtNttParamSet, + ) -> Self { + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for ((limb, prime), tw) in limbs + .iter_mut() + .zip(params.primes.iter()) + .zip(params.twiddles.iter()) + { + let p = prime.p.to_i64(); + let half_p = p / 2; + for (dst, &coeff) in limb.iter_mut().zip(coeffs.iter()) { + let mut r = (coeff as i64).rem_euclid(p); + if r >= half_p { + r -= p; + } + *dst = B::from_canonical(*prime, W::from_i64(r)); + } + B::forward_ntt(limb, *prime, tw); + } + Self { limbs } + } + + /// Convert small integer coefficients into cyclic CRT+NTT domain, + /// bypassing Fp128 centering entirely. + pub fn from_i8_cyclic(digits: &[i8; D], params: &CrtNttParamSet) -> Self { + Self::from_i8_cyclic_backend::(digits, params) + } + + fn from_i8_cyclic_backend>( + digits: &[i8; D], + params: &CrtNttParamSet, + ) -> Self { + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for ((limb, prime), tw) in limbs + .iter_mut() + .zip(params.primes.iter()) + .zip(params.twiddles.iter()) + { + for (dst, &d) in limb.iter_mut().zip(digits.iter()) { + *dst = B::from_canonical(*prime, W::from_i64(d as i64)); + } + forward_ntt_cyclic(limb, *prime, tw); + } + Self { limbs } + } + + fn from_centered_i32_cyclic_backend>( + coeffs: &[i32; D], + params: &CrtNttParamSet, + ) -> Self { + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for ((limb, prime), tw) in limbs + .iter_mut() + .zip(params.primes.iter()) + .zip(params.twiddles.iter()) + { + let p = prime.p.to_i64(); + let half_p = p / 2; + for (dst, &coeff) in limb.iter_mut().zip(coeffs.iter()) { + let mut r = (coeff as i64).rem_euclid(p); + if r >= half_p { + r -= p; + } + *dst = B::from_canonical(*prime, W::from_i64(r)); + } + forward_ntt_cyclic(limb, *prime, tw); + } + Self { limbs } + } + + fn from_centered_i32_pair_backend>( + coeffs: &[i32; D], + params: &CrtNttParamSet, + lut: Option<&CenteredMontLut>, + ) -> (Self, Self) { + let mut neg_limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + let mut cyc_limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for (k, (((neg_limb, cyc_limb), prime), tw)) in neg_limbs + .iter_mut() + .zip(cyc_limbs.iter_mut()) + .zip(params.primes.iter()) + .zip(params.twiddles.iter()) + .enumerate() + { + if let Some(lut) = lut { + for (dst, &coeff) in neg_limb.iter_mut().zip(coeffs.iter()) { + *dst = lut.get(k, coeff); + } + } else { + let p = prime.p.to_i64(); + let half_p = p / 2; + for (dst, &coeff) in neg_limb.iter_mut().zip(coeffs.iter()) { + let mut r = (coeff as i64).rem_euclid(p); + if r >= half_p { + r -= p; + } + *dst = B::from_canonical(*prime, W::from_i64(r)); + } + } + *cyc_limb = *neg_limb; + forward_ntt(neg_limb, *prime, tw); + forward_ntt_cyclic(cyc_limb, *prime, tw); + } + (Self { limbs: neg_limbs }, Self { limbs: cyc_limbs }) + } + + /// Convert from CRT+NTT domain back to coefficient form + /// using the default scalar backend. + pub fn to_ring( + &self, + primes: &[NttPrime; K], + twiddles: &[NttTwiddles; K], + garner: &GarnerData, + ) -> CyclotomicRing { + self.to_ring_with_backend::(primes, twiddles, garner) + } + + /// Convert from CRT+NTT domain back to coefficient form + /// using a bundled parameter set and the scalar backend. + pub fn to_ring_with_params( + &self, + params: &CrtNttParamSet, + ) -> CyclotomicRing { + self.to_ring(¶ms.primes, ¶ms.twiddles, ¶ms.garner) + } + + /// Convert from CRT+NTT domain back to coefficient form + /// through an explicit backend implementation. + pub fn to_ring_with_backend< + F: CrtNttConvertibleField, + B: NttPrimeOps + NttTransform + CrtReconstruct, + >( + &self, + primes: &[NttPrime; K], + twiddles: &[NttTwiddles; K], + garner: &GarnerData, + ) -> CyclotomicRing { + let mut canonical = [[W::default(); D]; K]; + for (k, ((can, prime), tw)) in canonical + .iter_mut() + .zip(primes.iter()) + .zip(twiddles.iter()) + .enumerate() + { + let mut limb = self.limbs[k]; + B::inverse_ntt(&mut limb, *prime, tw); + for (dst, src) in can.iter_mut().zip(limb.iter()) { + let canon = B::to_canonical(*prime, *src); + *dst = prime.center(canon); + } + } + + let coeffs = B::reconstruct::(primes, &canonical, garner); + + CyclotomicRing::from_coefficients(coeffs) + } + + /// Add another CRT+NTT element and reduce each coefficient with the matching + /// prime to maintain valid Montgomery ranges using the scalar backend. + pub fn add_reduced(&self, rhs: &Self, primes: &[NttPrime; K]) -> Self { + self.add_reduced_with_backend::(rhs, primes) + } + + /// Add another CRT+NTT element and reduce using a bundled parameter set. + pub fn add_reduced_with_params(&self, rhs: &Self, params: &CrtNttParamSet) -> Self { + self.add_reduced(rhs, ¶ms.primes) + } + + /// Add another CRT+NTT element and reduce each coefficient with the matching + /// prime through an explicit backend implementation. + pub fn add_reduced_with_backend>( + &self, + rhs: &Self, + primes: &[NttPrime; K], + ) -> Self { + let mut out = self.clone(); + for (k, (limb, rhs_limb)) in out.limbs.iter_mut().zip(rhs.limbs.iter()).enumerate() { + let prime = primes[k]; + for (a, b) in limb.iter_mut().zip(rhs_limb.iter()) { + let sum = MontCoeff::from_raw(a.raw().wrapping_add(b.raw())); + *a = B::reduce_range(prime, sum); + } + } + out + } + + /// Subtract another CRT+NTT element and reduce using the scalar backend. + pub fn sub_reduced(&self, rhs: &Self, primes: &[NttPrime; K]) -> Self { + self.sub_reduced_with_backend::(rhs, primes) + } + + /// Subtract another CRT+NTT element and reduce using a bundled parameter set. + pub fn sub_reduced_with_params(&self, rhs: &Self, params: &CrtNttParamSet) -> Self { + self.sub_reduced(rhs, ¶ms.primes) + } + + /// Subtract another CRT+NTT element and reduce through an explicit backend. + pub fn sub_reduced_with_backend>( + &self, + rhs: &Self, + primes: &[NttPrime; K], + ) -> Self { + let mut out = self.clone(); + for (k, (limb, rhs_limb)) in out.limbs.iter_mut().zip(rhs.limbs.iter()).enumerate() { + let prime = primes[k]; + for (a, b) in limb.iter_mut().zip(rhs_limb.iter()) { + let diff = MontCoeff::from_raw(a.raw().wrapping_sub(b.raw())); + *a = B::reduce_range(prime, diff); + } + } + out + } + + /// Negate each CRT+NTT coefficient and reduce using the scalar backend. + pub fn neg_reduced(&self, primes: &[NttPrime; K]) -> Self { + self.neg_reduced_with_backend::(primes) + } + + /// Negate each CRT+NTT coefficient and reduce using a bundled parameter set. + pub fn neg_reduced_with_params(&self, params: &CrtNttParamSet) -> Self { + self.neg_reduced(¶ms.primes) + } + + /// Negate each CRT+NTT coefficient and reduce through an explicit backend. + pub fn neg_reduced_with_backend>( + &self, + primes: &[NttPrime; K], + ) -> Self { + let mut out = self.clone(); + for (k, limb) in out.limbs.iter_mut().enumerate() { + let prime = primes[k]; + for a in limb.iter_mut() { + let neg = MontCoeff::from_raw(a.raw().wrapping_neg()); + *a = B::reduce_range(prime, neg); + } + } + out + } + + /// Convert a coefficient-form ring element into CRT+**cyclic** NTT domain. + /// + /// Evaluates at D-th roots of unity (X^D - 1) instead of X^D + 1. + /// Used together with `to_ring_cyclic` to compute unreduced polynomial products. + pub fn from_ring_cyclic( + ring: &CyclotomicRing, + params: &CrtNttParamSet, + ) -> Self { + Self::from_ring_cyclic_with_backend::(ring, params) + } + + /// Convert a coefficient-form ring element into CRT+**cyclic** NTT domain + /// through an explicit backend. + pub fn from_ring_cyclic_with_backend>( + ring: &CyclotomicRing, + params: &CrtNttParamSet, + ) -> Self { + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + let centered_coeffs: [i128; D] = from_fn(|i| { + let canonical = ring.coeffs[i].to_canonical_u128(); + if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + } + }); + + let mut limbs = [[MontCoeff::from_raw(W::default()); D]; K]; + for ((limb, prime), tw) in limbs + .iter_mut() + .zip(params.primes.iter()) + .zip(params.twiddles.iter()) + { + let p = prime.p.to_i64(); + let p_u64 = p as u64; + let r64 = ((1u128 << 64) % p_u64 as u128) as i64; + let half_p = p / 2; + for (dst, centered) in limb.iter_mut().zip(centered_coeffs.iter()) { + let c = *centered; + let lo = (c as u64 % p_u64) as i64; + let hi = ((c >> 64) as i64).rem_euclid(p); + let mut r = (lo + hi * r64) % p; + if r >= half_p { + r -= p; + } + *dst = B::from_canonical(*prime, W::from_i64(r)); + } + forward_ntt_cyclic(limb, *prime, tw); + } + Self { limbs } + } + + /// Convert from CRT+**cyclic** NTT domain back to coefficient form. + /// + /// Inverse of `from_ring_cyclic`: applies inverse cyclic NTT then CRT reconstruction. + pub fn to_ring_cyclic( + &self, + params: &CrtNttParamSet, + ) -> CyclotomicRing { + self.to_ring_cyclic_with_backend::(params) + } + + /// Convert from CRT+**cyclic** NTT domain back to coefficient form + /// through an explicit backend. + pub fn to_ring_cyclic_with_backend< + F: CrtNttConvertibleField, + B: NttPrimeOps + CrtReconstruct, + >( + &self, + params: &CrtNttParamSet, + ) -> CyclotomicRing { + let mut canonical = [[W::default(); D]; K]; + for (k, ((can, prime), tw)) in canonical + .iter_mut() + .zip(params.primes.iter()) + .zip(params.twiddles.iter()) + .enumerate() + { + let mut limb = self.limbs[k]; + inverse_ntt_cyclic(&mut limb, *prime, tw); + for (dst, src) in can.iter_mut().zip(limb.iter()) { + let canon = B::to_canonical(*prime, *src); + *dst = prime.center(canon); + } + } + let coeffs = B::reconstruct::(¶ms.primes, &canonical, ¶ms.garner); + CyclotomicRing::from_coefficients(coeffs) + } + + /// Pointwise multiplication in CRT+NTT domain using the scalar backend. + pub fn pointwise_mul(&self, rhs: &Self, primes: &[NttPrime; K]) -> Self { + self.pointwise_mul_with_backend::(rhs, primes) + } + + /// Pointwise multiplication in CRT+NTT domain using a bundled parameter set. + #[inline(always)] + pub fn pointwise_mul_with_params(&self, rhs: &Self, params: &CrtNttParamSet) -> Self { + self.pointwise_mul(rhs, ¶ms.primes) + } + + /// Accumulate `lhs * rhs` into `self` in CRT+NTT domain. + /// + /// On AArch64, this uses the fused NEON pointwise-multiply-accumulate kernel + /// when available; otherwise it falls back to the scalar loop. + #[inline(always)] + pub fn add_assign_pointwise_mul_with_params( + &mut self, + lhs: &Self, + rhs: &Self, + params: &CrtNttParamSet, + ) { + #[cfg(target_arch = "aarch64")] + if neon::use_neon_ntt() { + for k in 0..K { + let prime = params.primes[k]; + unsafe { + if size_of::() == size_of::() { + neon::pointwise_mul_acc_i32( + self.limbs[k].as_mut_ptr() as *mut i32, + lhs.limbs[k].as_ptr() as *const i32, + rhs.limbs[k].as_ptr() as *const i32, + D, + prime.p.to_i64() as i32, + prime.pinv.to_i64() as i32, + ); + } else { + neon::pointwise_mul_acc_i16( + self.limbs[k].as_mut_ptr() as *mut i16, + lhs.limbs[k].as_ptr() as *const i16, + rhs.limbs[k].as_ptr() as *const i16, + D, + prime.p.to_i64() as i16, + prime.pinv.to_i64() as i16, + ); + } + } + } + return; + } + + for k in 0..K { + let prime = params.primes[k]; + let acc_limb = &mut self.limbs[k]; + let lhs_limb = &lhs.limbs[k]; + let rhs_limb = &rhs.limbs[k]; + for ((acc_coeff, lhs_coeff), rhs_coeff) in acc_limb + .iter_mut() + .zip(lhs_limb.iter()) + .zip(rhs_limb.iter()) + { + let prod = prime.mul(*lhs_coeff, *rhs_coeff); + let sum = MontCoeff::from_raw(acc_coeff.raw().wrapping_add(prod.raw())); + *acc_coeff = prime.reduce_range(sum); + } + } + } + + /// Pointwise multiplication in CRT+NTT domain through an explicit backend. + pub fn pointwise_mul_with_backend>( + &self, + rhs: &Self, + primes: &[NttPrime; K], + ) -> Self { + let mut out = [[MontCoeff::from_raw(W::default()); D]; K]; + for (k, ((o, a), b)) in out + .iter_mut() + .zip(self.limbs.iter()) + .zip(rhs.limbs.iter()) + .enumerate() + { + let prime = primes[k]; + B::pointwise_mul(prime, o, a, b); + // Keep coefficients in a bounded range for subsequent inverse NTT. + for c in o.iter_mut() { + *c = B::reduce_range(prime, *c); + } + } + Self { limbs: out } + } + + /// Apply `sigma_{-1}` directly in NTT domain (`slot[j] -> slot[D-1-j]`). + /// + /// This is a pure index permutation per CRT limb and does not negate values. + pub fn conjugation_automorphism_ntt(&self) -> Self { + let limbs = std::array::from_fn(|k| { + std::array::from_fn(|j| self.limbs[k][D.saturating_sub(1) - j]) + }); + Self { limbs } + } +} diff --git a/src/algebra/ring/cyclotomic.rs b/src/algebra/ring/cyclotomic.rs new file mode 100644 index 00000000..c33b4ed8 --- /dev/null +++ b/src/algebra/ring/cyclotomic.rs @@ -0,0 +1,1071 @@ +//! Cyclotomic ring `Z_q[X]/(X^D + 1)` in coefficient form. + +use super::sparse_challenge::SparseChallenge; +use crate::algebra::fields::wide::ReduceTo; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::{AdditiveGroup, CanonicalField, FieldCore, FieldSampling}; +use rand_core::RngCore; +use std::array::from_fn; +use std::io::{Read, Write}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +/// Element of the cyclotomic ring `Z_q[X]/(X^D + 1)`. +/// +/// Stored as `D` coefficients in the base field `F`, representing +/// `a_0 + a_1*X + ... + a_{D-1}*X^{D-1}`. +/// +/// Multiplication is negacyclic convolution: `X^D = -1`, so a product +/// term at index `i + j >= D` wraps to index `(i + j) - D` with a sign flip. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(transparent)] +pub struct CyclotomicRing { + pub(crate) coeffs: [F; D], +} + +impl CyclotomicRing { + /// Construct from a coefficient array. + #[inline] + pub fn from_coefficients(coeffs: [F; D]) -> Self { + Self { coeffs } + } + + /// Construct from a slice, zero-padding if shorter than `D`. + /// + /// Avoids creating a `[F; D]` stack temporary when `D` is large. + #[inline] + pub fn from_slice(slice: &[F]) -> Self { + let mut coeffs = [F::zero(); D]; + let len = slice.len().min(D); + coeffs[..len].copy_from_slice(&slice[..len]); + Self { coeffs } + } + + /// Borrow the coefficient array. + #[inline] + pub fn coefficients(&self) -> &[F; D] { + &self.coeffs + } + + /// Mutably borrow the coefficient array. + #[inline] + pub fn coefficients_mut(&mut self) -> &mut [F; D] { + &mut self.coeffs + } + + /// The additive identity (all-zero polynomial). + #[inline] + pub fn zero() -> Self { + Self { + coeffs: [F::zero(); D], + } + } + + /// The multiplicative identity (`1 + 0*X + ... + 0*X^{D-1}`). + #[inline] + pub fn one() -> Self { + let mut coeffs = [F::zero(); D]; + coeffs[0] = F::one(); + Self { coeffs } + } + + /// The monomial `X` (i.e., `[0, 1, 0, ..., 0]`). + /// + /// # Panics + /// + /// Panics if `D < 2`. + #[inline] + pub fn x() -> Self { + assert!(D >= 2, "ring degree must be at least 2"); + let mut coeffs = [F::zero(); D]; + coeffs[1] = F::one(); + Self { coeffs } + } + + /// Scalar multiplication: multiply every coefficient by `k`. + #[inline] + pub fn scale(&self, k: &F) -> Self { + let mut out = self.coeffs; + for c in &mut out { + *c = *c * *k; + } + Self { coeffs: out } + } + + /// Apply the cyclotomic automorphism `sigma_k: X -> X^k` for odd `k`. + /// + /// In `Z_q[X]/(X^D + 1)`, this permutes/sign-flips coefficients using + /// exponent reduction modulo `2D`. + /// + /// # Panics + /// + /// Panics if `D == 0` or `k` is not odd modulo `2D`. + pub fn sigma(&self, k: usize) -> Self { + assert!(D > 0, "ring degree must be non-zero"); + let two_d = 2 * D; + let k_mod = k % two_d; + assert!(k_mod % 2 == 1, "sigma_k requires odd k in Z_q[X]/(X^D + 1)"); + + let mut out = [F::zero(); D]; + for (j, coeff) in self.coeffs.iter().copied().enumerate() { + let idx = (j * k_mod) % two_d; + if idx < D { + out[idx] += coeff; + } else { + out[idx - D] -= coeff; + } + } + Self { coeffs: out } + } + + /// Apply `sigma_{-1}` (`X -> X^{-1} = X^{2D-1}` in this ring). + /// + /// # Panics + /// + /// Panics if `D == 0`. + pub fn sigma_m1(&self) -> Self { + assert!(D > 0, "ring degree must be non-zero"); + self.sigma(2 * D - 1) + } + + /// Multiply by `X^k` in `Z_q[X]/(X^D + 1)` via O(D) coefficient rotation. + /// + /// Since `X^D = -1`, coefficients that wrap past index `D` get negated. + #[inline] + pub fn negacyclic_shift(&self, k: usize) -> Self { + let k = k % D; + if k == 0 { + return *self; + } + let mut out = [F::zero(); D]; + for i in 0..D { + let target = i + k; + if target < D { + out[target] = self.coeffs[i]; + } else { + out[target - D] = -self.coeffs[i]; + } + } + Self { coeffs: out } + } + + /// Multiply `self` by a sum of monomials `X^{k_1} + X^{k_2} + ...` + /// + /// Each term is a negacyclic shift, so the total cost is + /// `O(positions.len() * D)` field additions with zero multiplications. + pub fn mul_by_monomial_sum(&self, nonzero_positions: &[usize]) -> Self { + let mut result = Self::zero(); + for &k in nonzero_positions { + self.shift_accumulate_into(&mut result, k); + } + result + } + + /// Fused negacyclic shift + accumulate: `dst += self * X^k`. + /// + /// Equivalent to `*dst += self.negacyclic_shift(k)` but avoids + /// allocating a temporary ring element. + #[inline] + pub fn shift_accumulate_into(&self, dst: &mut Self, k: usize) { + let k = k % D; + if k == 0 { + for i in 0..D { + dst.coeffs[i] += self.coeffs[i]; + } + return; + } + for i in 0..D { + let target = i + k; + if target < D { + dst.coeffs[target] += self.coeffs[i]; + } else { + dst.coeffs[target - D] -= self.coeffs[i]; + } + } + } + + /// Fused negacyclic shift + subtract: `dst -= self * X^k`. + /// + /// Equivalent to `*dst -= self.negacyclic_shift(k)` but avoids + /// allocating a temporary ring element. + #[inline] + pub fn shift_sub_into(&self, dst: &mut Self, k: usize) { + let k = k % D; + if k == 0 { + for i in 0..D { + dst.coeffs[i] -= self.coeffs[i]; + } + return; + } + for i in 0..D { + let target = i + k; + if target < D { + dst.coeffs[target] -= self.coeffs[i]; + } else { + dst.coeffs[target - D] += self.coeffs[i]; + } + } + } + + /// Fused negacyclic shift + scaled accumulate: `dst += scale * self * X^k`. + #[inline] + pub fn shift_scale_accumulate_into(&self, dst: &mut Self, k: usize, scale: F) { + if scale.is_zero() { + return; + } + let k = k % D; + if k == 0 { + for i in 0..D { + dst.coeffs[i] += self.coeffs[i] * scale; + } + return; + } + for i in 0..D { + let target = i + k; + let product = self.coeffs[i] * scale; + if target < D { + dst.coeffs[target] += product; + } else { + dst.coeffs[target - D] -= product; + } + } + } + + /// Fused multiply-by-monomial-sum + accumulate: + /// `dst += self * (X^{k_1} + X^{k_2} + ...)`. + /// + /// Equivalent to `*dst += self.mul_by_monomial_sum(positions)` but avoids + /// all intermediate temporaries. + pub fn mul_by_monomial_sum_into(&self, dst: &mut Self, nonzero_positions: &[usize]) { + for &k in nonzero_positions { + self.shift_accumulate_into(dst, k); + } + } + + /// Multiply `self` by a sparse challenge element. + /// + /// Cost: `O(omega * D)` field additions instead of `O(D^2)` multiplications. + /// For `omega=23, D=256` this is 5,888 adds vs 65,536 muls. + pub fn mul_by_sparse(&self, challenge: &SparseChallenge) -> Self + where + F: CanonicalField, + { + let mut result = Self::zero(); + self.mul_by_sparse_into(challenge, &mut result); + result + } + + /// Fused `dst += self * challenge` for a sparse challenge element. + pub fn mul_by_sparse_into(&self, challenge: &SparseChallenge, dst: &mut Self) + where + F: CanonicalField, + { + for (&pos, &coeff) in challenge.positions.iter().zip(challenge.coeffs.iter()) { + match coeff { + 1 => self.shift_accumulate_into(dst, pos as usize), + -1 => self.shift_sub_into(dst, pos as usize), + c => self.shift_scale_accumulate_into(dst, pos as usize, F::from_i64(c as i64)), + } + } + } + + /// Fused `dst += self * rhs` when `rhs` is coefficient-sparse. + /// + /// This is exact for any field coefficients in `rhs`, but runs in + /// `O(hw(rhs) * D)` instead of `O(D^2)`. + pub fn mul_accumulate_sparse_rhs_into(&self, rhs: &Self, dst: &mut Self) + where + F: CanonicalField, + { + for (pos, coeff) in rhs.coeffs.iter().copied().enumerate() { + if coeff.is_zero() { + continue; + } + if coeff == F::one() { + self.shift_accumulate_into(dst, pos); + } else if coeff == -F::one() { + self.shift_sub_into(dst, pos); + } else { + self.shift_scale_accumulate_into(dst, pos, coeff); + } + } + } + + /// Fused `dst += self * rhs` via schoolbook negacyclic convolution. + /// + /// Accumulates the product directly into `dst` without allocating a + /// temporary ring element, saving D additions and one memory pass. + #[inline] + pub fn mul_accumulate_into(&self, rhs: &Self, dst: &mut Self) { + for i in 0..D { + let ai = self.coeffs[i]; + if ai.is_zero() { + continue; + } + + let (dst_wrap, dst_direct) = dst.coeffs.split_at_mut(i); + let (rhs_direct, rhs_wrap) = rhs.coeffs.split_at(D - i); + + for (dst_coeff, rhs_coeff) in dst_direct.iter_mut().zip(rhs_direct.iter()) { + *dst_coeff += ai * *rhs_coeff; + } + for (dst_coeff, rhs_coeff) in dst_wrap.iter_mut().zip(rhs_wrap.iter()) { + *dst_coeff -= ai * *rhs_coeff; + } + } + } + + /// Check whether all coefficients are zero. + #[inline] + pub fn is_zero(&self) -> bool { + self.coeffs.iter().all(|c| c.is_zero()) + } + + /// Count non-zero coefficients. + #[inline] + pub fn hamming_weight(&self) -> usize { + self.coeffs.iter().filter(|c| !c.is_zero()).count() + } + + /// Sample a sparse challenge with exactly `omega` non-zeros in `{+1, -1}`. + /// + /// # Panics + /// + /// Panics if `omega > D` or `D == 0` with non-zero `omega`. + pub fn sample_sparse_pm1(rng: &mut R, omega: usize) -> Self { + assert!(omega <= D, "omega must be <= ring degree"); + assert!(D > 0 || omega == 0, "ring degree must be non-zero"); + + let mut coeffs = [F::zero(); D]; + let mut placed = 0usize; + while placed < omega { + let idx = (rng.next_u64() % (D as u64)) as usize; + if coeffs[idx].is_zero() { + coeffs[idx] = if (rng.next_u32() & 1) == 0 { + F::one() + } else { + -F::one() + }; + placed += 1; + } + } + Self { coeffs } + } +} + +impl CyclotomicRing { + /// Balanced decomposition writing directly into a pre-allocated output slice. + /// + /// `out` must have length exactly `levels`. Each element receives one digit plane. + /// + /// # Panics + /// + /// Panics if `log_basis == 0`, `log_basis >= 128`, or `out.len() * log_basis > 128 + log_basis`. + pub fn balanced_decompose_pow2_into(&self, out: &mut [Self], log_basis: u32) { + let levels = out.len(); + assert!(log_basis > 0 && log_basis < 128, "invalid log_basis"); + assert!( + (levels as u32).saturating_mul(log_basis) <= 128 + log_basis, + "levels * log_basis must be <= 128 + log_basis" + ); + + let half_b = 1i128 << (log_basis - 1); + let b = half_b << 1; + let mask = b - 1; + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + + for plane in out.iter_mut() { + *plane = Self::zero(); + } + + for i in 0..D { + let canonical = self.coeffs[i].to_canonical_u128(); + let mut c: i128 = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + + for plane in out.iter_mut() { + let d = c & mask; + let balanced = if d >= half_b { d - b } else { d }; + c = (c - balanced) >> log_basis; + + plane.coeffs[i] = if balanced >= 0 { + F::from_canonical_u128_reduced(balanced as u128) + } else { + F::from_canonical_u128_reduced(q - ((-balanced) as u128)) + }; + } + } + } + + /// Squared Euclidean norm of centered integer coefficients. + /// + /// Coefficients are centered into `(-q/2, q/2]` and accumulated as + /// `sum_i c_i^2`, using saturating arithmetic. + #[inline] + pub fn coeff_norm_sq(&self) -> u128 + where + F: CanonicalField, + { + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + self.coeffs.iter().fold(0u128, |acc, &coeff| { + let canonical = coeff.to_canonical_u128(); + let centered: i128 = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + let abs = centered.unsigned_abs(); + acc.saturating_add(abs.saturating_mul(abs)) + }) + } + + /// Functional gadget recomposition (`G * digits`) for base `2^log_basis`. + /// + /// Coefficients from each part are interpreted as one digit plane and + /// recombined back into canonical integers (then reduced into the field). + /// + /// # Panics + /// + /// Panics if `log_basis == 0`, `log_basis >= 128`, or `parts.len() * log_basis > 128`. + pub fn gadget_recompose_pow2(parts: &[Self], log_basis: u32) -> Self { + if parts.is_empty() { + return Self::zero(); + } + + assert!( + log_basis > 0 && log_basis <= 128, + "invalid log_basis: {log_basis}" + ); + + if parts.len() == 1 { + return parts[0]; + } + + let b = F::from_canonical_u128_reduced(1u128 << log_basis); + let coeffs = from_fn(|i| { + let mut acc = F::zero(); + let mut power = F::one(); + for part in parts.iter() { + acc += part.coeffs[i] * power; + power = power * b; + } + acc + }); + Self { coeffs } + } + + /// Recompose from i8 digit planes (output of `balanced_decompose_pow2_i8`). + /// + /// # Panics + /// + /// Panics if `log_basis` is zero or >= 128. + pub fn gadget_recompose_pow2_i8(digits: &[[i8; D]], log_basis: u32) -> Self + where + F: CanonicalField, + { + if digits.is_empty() { + return Self::zero(); + } + assert!( + log_basis > 0 && log_basis <= 128, + "invalid log_basis: {log_basis}" + ); + + if digits.len() == 1 { + let coeffs = from_fn(|i| F::from_i64(digits[0][i] as i64)); + return Self { coeffs }; + } + + let b = F::from_canonical_u128_reduced(1u128 << log_basis); + let coeffs = from_fn(|i| { + let mut acc = F::zero(); + let mut power = F::one(); + for plane in digits { + acc += F::from_i64(plane[i] as i64) * power; + power = power * b; + } + acc + }); + Self { coeffs } + } + + /// Balanced (centered) base-`2^log_basis` gadget decomposition: `G^{-1}`. + /// + /// Each coefficient `c` (centered into `(-q/2, q/2]`) is decomposed into + /// `levels` balanced digits `d_k ∈ [-b/2, b/2)` satisfying + /// `c ≡ Σ_k d_k · b^k (mod q)`. + /// + /// Negative digits are stored as their field representation (`q + d`). + /// + /// # Panics + /// + /// Panics if `log_basis == 0`, `log_basis >= 128`, or `levels * log_basis > 128`. + pub fn balanced_decompose_pow2(&self, levels: usize, log_basis: u32) -> Vec { + assert!(log_basis > 0 && log_basis < 128, "invalid log_basis"); + assert!( + (levels as u32).saturating_mul(log_basis) <= 128 + log_basis, + "levels * log_basis must be <= 128 + log_basis" + ); + + let half_b = 1i128 << (log_basis - 1); + let b = half_b << 1; + let mask = b - 1; + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + + let mut digit_planes: Vec<[F; D]> = (0..levels).map(|_| [F::zero(); D]).collect(); + + for i in 0..D { + let canonical = self.coeffs[i].to_canonical_u128(); + let mut c: i128 = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + + for plane in digit_planes.iter_mut() { + let d = c & mask; + let balanced = if d >= half_b { d - b } else { d }; + c = (c - balanced) >> log_basis; + + plane[i] = if balanced >= 0 { + F::from_canonical_u128_reduced(balanced as u128) + } else { + F::from_canonical_u128_reduced(q - ((-balanced) as u128)) + }; + } + } + + digit_planes + .into_iter() + .map(Self::from_coefficients) + .collect() + } + + /// Balanced gadget decomposition into native `i8` digits. + /// + /// Same semantics as [`balanced_decompose_pow2`](Self::balanced_decompose_pow2) + /// but stores each digit as `i8` instead of a field element, avoiding + /// the cost of `F::from_canonical_u128_reduced`. + /// + /// Requires `log_basis <= 7` so digits fit in `[-64, 63]` (i8 range). + /// + /// # Panics + /// + /// Panics if `log_basis` is 0 or > 7, or if `levels * log_basis > 128 + log_basis`. + pub fn balanced_decompose_pow2_i8_into(&self, out: &mut [[i8; D]], log_basis: u32) + where + F: CanonicalField, + { + let levels = out.len(); + assert!( + log_basis > 0 && log_basis <= 7, + "log_basis must be in 1..=7 for i8 output" + ); + assert!( + (levels as u32).saturating_mul(log_basis) <= 128 + log_basis, + "levels * log_basis must be <= 128 + log_basis" + ); + + let half_b = 1i128 << (log_basis - 1); + let b = half_b << 1; + let mask = b - 1; + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + + for i in 0..D { + let canonical = self.coeffs[i].to_canonical_u128(); + let mut c: i128 = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + + for plane in out.iter_mut() { + let d = c & mask; + let balanced = if d >= half_b { d - b } else { d }; + c = (c - balanced) >> log_basis; + plane[i] = balanced as i8; + } + } + } + + /// Allocating variant of [`balanced_decompose_pow2_i8_into`](Self::balanced_decompose_pow2_i8_into). + pub fn balanced_decompose_pow2_i8(&self, levels: usize, log_basis: u32) -> Vec<[i8; D]> + where + F: CanonicalField, + { + let mut digit_planes: Vec<[i8; D]> = vec![[0i8; D]; levels]; + self.balanced_decompose_pow2_i8_into(&mut digit_planes, log_basis); + digit_planes + } + + /// Balanced decomposition where the last digit carries the remainder. + /// + /// The first `levels-1` digits are balanced in `[-b/2, b/2)`, while the + /// final digit is the remaining (possibly larger) centered value. + /// + /// # Panics + /// + /// Panics if `levels` is zero, `log_basis` is zero or >= 128, or + /// `(levels - 1) * log_basis >= 128`. + pub fn balanced_decompose_pow2_with_carry_into(&self, out: &mut [Self], log_basis: u32) + where + F: CanonicalField, + { + let levels = out.len(); + assert!(levels > 0, "levels must be positive"); + assert!( + log_basis > 0 && log_basis <= 128, + "invalid log_basis: {log_basis}" + ); + assert!( + ((levels - 1) as u32).saturating_mul(log_basis) < 128, + "(levels-1) * log_basis must be < 128" + ); + + // When levels==1 every coefficient takes the carry path and b/half_b + // are unused, so skip the shift that would overflow at log_basis==128. + let (b, half_b) = if levels == 1 { + (0i128, 0i128) + } else { + let b = 1i128 << log_basis; + (b, b / 2) + }; + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + + for i in 0..D { + let canonical = self.coeffs[i].to_canonical_u128(); + let mut c: i128 = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + + for (plane_idx, plane) in out.iter_mut().enumerate() { + let balanced = if plane_idx + 1 == levels { + c + } else { + let d = c.rem_euclid(b); + let digit = if d >= half_b { d - b } else { d }; + c = (c - digit) / b; + digit + }; + + plane.coeffs[i] = if balanced >= 0 { + F::from_canonical_u128_reduced(balanced as u128) + } else { + F::from_canonical_u128_reduced(q - ((-balanced) as u128)) + }; + } + } + } + + /// Allocating variant of + /// [`balanced_decompose_pow2_with_carry_into`](Self::balanced_decompose_pow2_with_carry_into). + pub fn balanced_decompose_pow2_with_carry(&self, levels: usize, log_basis: u32) -> Vec + where + F: CanonicalField, + { + let mut out = vec![Self::zero(); levels]; + self.balanced_decompose_pow2_with_carry_into(&mut out, log_basis); + out + } +} + +impl CyclotomicRing { + /// Generate a random ring element. + pub fn random(rng: &mut R) -> Self { + Self { + coeffs: from_fn(|_| F::sample(rng)), + } + } +} + +impl AddAssign for CyclotomicRing { + fn add_assign(&mut self, rhs: Self) { + for (dst, src) in self.coeffs.iter_mut().zip(rhs.coeffs.iter()) { + *dst = *dst + *src; + } + } +} + +impl SubAssign for CyclotomicRing { + fn sub_assign(&mut self, rhs: Self) { + for (dst, src) in self.coeffs.iter_mut().zip(rhs.coeffs.iter()) { + *dst = *dst - *src; + } + } +} + +impl Add for CyclotomicRing { + type Output = Self; + fn add(mut self, rhs: Self) -> Self { + self += rhs; + self + } +} + +impl Sub for CyclotomicRing { + type Output = Self; + fn sub(mut self, rhs: Self) -> Self { + self -= rhs; + self + } +} + +impl Neg for CyclotomicRing { + type Output = Self; + fn neg(self) -> Self { + let mut out = self.coeffs; + for c in &mut out { + *c = -*c; + } + Self { coeffs: out } + } +} + +impl MulAssign for CyclotomicRing { + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl<'a, F: FieldCore, const D: usize> Add<&'a Self> for CyclotomicRing { + type Output = Self; + fn add(self, rhs: &'a Self) -> Self { + self + *rhs + } +} + +impl<'a, F: FieldCore, const D: usize> Sub<&'a Self> for CyclotomicRing { + type Output = Self; + fn sub(self, rhs: &'a Self) -> Self { + self - *rhs + } +} + +impl<'a, F: FieldCore, const D: usize> Mul<&'a Self> for CyclotomicRing { + type Output = Self; + fn mul(self, rhs: &'a Self) -> Self { + self * *rhs + } +} + +/// Schoolbook negacyclic convolution: O(D^2). +/// +/// For each pair `(i, j)`: +/// - If `i + j < D`: accumulate `a_i * b_j` at index `i + j`. +/// - If `i + j >= D`: accumulate `-(a_i * b_j)` at index `(i + j) - D`. +impl Mul for CyclotomicRing { + type Output = Self; + fn mul(self, rhs: Self) -> Self { + let mut out = [F::zero(); D]; + for i in 0..D { + for j in 0..D { + let product = self.coeffs[i] * rhs.coeffs[j]; + let idx = i + j; + if idx < D { + out[idx] += product; + } else { + out[idx - D] -= product; + } + } + } + Self { coeffs: out } + } +} + +impl Valid for CyclotomicRing { + fn check(&self) -> Result<(), SerializationError> { + for x in self.coeffs.iter() { + x.check()?; + } + Ok(()) + } +} + +impl HachiSerialize for CyclotomicRing { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + for x in self.coeffs.iter() { + x.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.coeffs + .iter() + .map(|x| x.serialized_size(compress)) + .sum() + } +} + +impl HachiDeserialize for CyclotomicRing { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let mut coeffs = [F::zero(); D]; + for c in &mut coeffs { + *c = F::deserialize_with_mode(&mut reader, compress, validate)?; + } + let out = Self { coeffs }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl Default for CyclotomicRing { + fn default() -> Self { + Self::zero() + } +} + +/// Wide (unreduced) cyclotomic ring element for carry-free accumulation. +/// +/// Coefficients are wide accumulators (`W: AdditiveGroup`) that support +/// addition/subtraction without modular reduction. After accumulation, +/// call [`reduce`](Self::reduce) to convert back to `CyclotomicRing`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(transparent)] +pub struct WideCyclotomicRing { + pub(crate) coeffs: [W; D], +} + +impl WideCyclotomicRing { + /// The additive identity (all-zero coefficients). + pub const ZERO: Self = Self { + coeffs: [W::ZERO; D], + }; + + /// Returns the zero ring element. + #[inline] + pub fn zero() -> Self { + Self::ZERO + } + + /// Convert a reduced `CyclotomicRing` into wide form. + #[inline] + pub fn from_ring(ring: &CyclotomicRing) -> Self + where + W: From, + { + Self { + coeffs: from_fn(|i| W::from(ring.coeffs[i])), + } + } + + /// Reduce all coefficients back to canonical field form. + #[inline] + pub fn reduce(&self) -> CyclotomicRing + where + W: ReduceTo, + { + CyclotomicRing { + coeffs: from_fn(|i| self.coeffs[i].reduce()), + } + } + + /// Fused negacyclic shift + accumulate: `dst += self * X^k`. + #[inline] + pub fn shift_accumulate_into(&self, dst: &mut Self, k: usize) { + let k = k % D; + if k == 0 { + for i in 0..D { + dst.coeffs[i] += self.coeffs[i]; + } + return; + } + for i in 0..D { + let target = i + k; + if target < D { + dst.coeffs[target] += self.coeffs[i]; + } else { + dst.coeffs[target - D] -= self.coeffs[i]; + } + } + } + + /// Fused negacyclic shift + subtract: `dst -= self * X^k`. + #[inline] + pub fn shift_sub_into(&self, dst: &mut Self, k: usize) { + let k = k % D; + if k == 0 { + for i in 0..D { + dst.coeffs[i] -= self.coeffs[i]; + } + return; + } + for i in 0..D { + let target = i + k; + if target < D { + dst.coeffs[target] -= self.coeffs[i]; + } else { + dst.coeffs[target - D] += self.coeffs[i]; + } + } + } + + /// Fused multiply-by-monomial-sum + accumulate: + /// `dst += self * (X^{k_1} + X^{k_2} + ...)`. + pub fn mul_by_monomial_sum_into(&self, dst: &mut Self, nonzero_positions: &[usize]) { + for &k in nonzero_positions { + self.shift_accumulate_into(dst, k); + } + } +} + +impl Add for WideCyclotomicRing { + type Output = Self; + fn add(mut self, rhs: Self) -> Self { + for i in 0..D { + self.coeffs[i] += rhs.coeffs[i]; + } + self + } +} + +impl AddAssign for WideCyclotomicRing { + fn add_assign(&mut self, rhs: Self) { + for i in 0..D { + self.coeffs[i] += rhs.coeffs[i]; + } + } +} + +impl Sub for WideCyclotomicRing { + type Output = Self; + fn sub(mut self, rhs: Self) -> Self { + for i in 0..D { + self.coeffs[i] -= rhs.coeffs[i]; + } + self + } +} + +impl SubAssign for WideCyclotomicRing { + fn sub_assign(&mut self, rhs: Self) { + for i in 0..D { + self.coeffs[i] -= rhs.coeffs[i]; + } + } +} + +impl Neg for WideCyclotomicRing { + type Output = Self; + fn neg(self) -> Self { + Self { + coeffs: from_fn(|i| -self.coeffs[i]), + } + } +} + +impl Default for WideCyclotomicRing { + fn default() -> Self { + Self::zero() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::{Fp128x8i32, Fp64, Fp64x4i32, Prime128M8M4M1M0}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F64 = Fp64<4294967197>; + type F128 = Prime128M8M4M1M0; + const D: usize = 64; + + #[test] + fn wide_shift_accumulate_matches_narrow_fp64() { + let mut rng = StdRng::seed_from_u64(0x1234); + let src = CyclotomicRing::::random(&mut rng); + let initial = CyclotomicRing::::random(&mut rng); + + for k in [0, 1, 7, 31, 63] { + let mut narrow = initial; + src.shift_accumulate_into(&mut narrow, k); + + let wide_src = WideCyclotomicRing::::from_ring(&src); + let mut wide_dst = WideCyclotomicRing::::from_ring(&initial); + wide_src.shift_accumulate_into(&mut wide_dst, k); + let wide_reduced: CyclotomicRing = wide_dst.reduce(); + + assert_eq!(narrow, wide_reduced, "shift_accumulate k={k}"); + } + } + + #[test] + fn wide_shift_sub_matches_narrow_fp64() { + let mut rng = StdRng::seed_from_u64(0x5678); + let src = CyclotomicRing::::random(&mut rng); + let initial = CyclotomicRing::::random(&mut rng); + + for k in [0, 1, 15, 32, 63] { + let mut narrow = initial; + src.shift_sub_into(&mut narrow, k); + + let wide_src = WideCyclotomicRing::::from_ring(&src); + let mut wide_dst = WideCyclotomicRing::::from_ring(&initial); + wide_src.shift_sub_into(&mut wide_dst, k); + let wide_reduced: CyclotomicRing = wide_dst.reduce(); + + assert_eq!(narrow, wide_reduced, "shift_sub k={k}"); + } + } + + #[test] + fn wide_mul_by_monomial_sum_matches_narrow_fp64() { + let mut rng = StdRng::seed_from_u64(0xabcd); + let src = CyclotomicRing::::random(&mut rng); + let positions = vec![0, 5, 17, 42, 63]; + + let mut narrow = CyclotomicRing::::zero(); + src.mul_by_monomial_sum_into(&mut narrow, &positions); + + let wide_src = WideCyclotomicRing::::from_ring(&src); + let mut wide_dst = WideCyclotomicRing::::zero(); + wide_src.mul_by_monomial_sum_into(&mut wide_dst, &positions); + let wide_reduced: CyclotomicRing = wide_dst.reduce(); + + assert_eq!(narrow, wide_reduced); + } + + #[test] + fn wide_many_accumulations_fp128() { + let mut rng = StdRng::seed_from_u64(0xbeef); + let src = CyclotomicRing::::random(&mut rng); + + let mut narrow = CyclotomicRing::::zero(); + let wide_src = WideCyclotomicRing::::from_ring(&src); + let mut wide_dst = WideCyclotomicRing::::zero(); + + for k in 0..50 { + src.shift_accumulate_into(&mut narrow, k % D); + wide_src.shift_accumulate_into(&mut wide_dst, k % D); + } + for k in 0..30 { + src.shift_sub_into(&mut narrow, k % D); + wide_src.shift_sub_into(&mut wide_dst, k % D); + } + + let wide_reduced: CyclotomicRing = wide_dst.reduce(); + assert_eq!(narrow, wide_reduced); + } +} diff --git a/src/algebra/ring/mod.rs b/src/algebra/ring/mod.rs new file mode 100644 index 00000000..59ba5c64 --- /dev/null +++ b/src/algebra/ring/mod.rs @@ -0,0 +1,17 @@ +//! Cyclotomic ring types and NTT representations. + +pub mod crt_ntt_repr; +pub mod cyclotomic; +pub mod partial_split_ntt; +pub mod sparse_challenge; + +pub use crt_ntt_repr::{ + CenteredMontLut, CrtNttConvertibleField, CrtNttParamSet, CyclotomicCrtNtt, DigitMontLut, +}; +pub use cyclotomic::{CyclotomicRing, WideCyclotomicRing}; +pub use partial_split_ntt::{ + PackedPartialSplitEval32, PackedPartialSplitNtt32, PartialSplitEval32, PartialSplitNtt32, +}; +pub use sparse_challenge::{ + sample_quaternary, sample_ternary, SparseChallenge, SparseChallengeConfig, +}; diff --git a/src/algebra/ring/partial_split_ntt.rs b/src/algebra/ring/partial_split_ntt.rs new file mode 100644 index 00000000..56245621 --- /dev/null +++ b/src/algebra/ring/partial_split_ntt.rs @@ -0,0 +1,800 @@ +//! Prototype `k=32` partial-split NTT for `F_p[X]/(X^64 + 1)`. +//! +//! For primes `p` with `64 | (p - 1)` but `128 ∤ (p - 1)`, the usual size-64 +//! negacyclic NTT does not exist over `F_p`. We can still exploit the +//! `k=32` split: +//! +//! `X^64 + 1 = \prod_{j=0}^{31} (X^2 - r_j)`, +//! +//! where `r_j` ranges over the roots of `Y^32 + 1` with `Y = X^2`. +//! +//! Writing a ring element as `a(X) = a_0(Y) + X a_1(Y)`, multiplication in the +//! quotient ring reduces to: +//! +//! 1. two size-32 negacyclic transforms over `F_p` (for `a_0` and `a_1`), +//! 2. pointwise multiplication in the quadratic factors `F_p[X]/(X^2 - r_j)`, +//! 3. two inverse size-32 negacyclic transforms. +//! +//! This module is intentionally a prototype helper for benchmarking the +//! split-native path against the existing multi-CRT NTT implementation. + +use super::CyclotomicRing; +use crate::algebra::PackedField; +use crate::{CanonicalField, FieldCore}; +use core::ops::{Add, Mul, Sub}; +use std::array::from_fn; + +const CLASS_D: usize = 32; +const RING_D: usize = 64; +const CLASS_LOG_D: usize = CLASS_D.trailing_zeros() as usize; +const CENTERED_I8_LUT_OFFSET: i16 = 128; + +/// Cached `k=32` split-domain representation of a `D=64` ring element. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PartialSplitEval32 { + even: [F; CLASS_D], + odd: [F; CLASS_D], +} + +impl PartialSplitEval32 { + /// The additive identity in split-evaluation form. + #[inline(always)] + pub fn zero() -> Self { + Self { + even: [F::zero(); CLASS_D], + odd: [F::zero(); CLASS_D], + } + } + + /// Convert a `D=64` ring element into the cached split-domain form. + #[inline(always)] + pub fn from_ring(split: &PartialSplitNtt32, ring: &CyclotomicRing) -> Self { + let (mut even, mut odd) = split_even_odd(ring.coefficients()); + split.forward_class_pair(&mut even, &mut odd); + Self { even, odd } + } + + /// Convert centered `i8` coefficients into the cached split-domain form. + #[inline(always)] + pub fn from_i8(split: &PartialSplitNtt32, coeffs: &[i8; RING_D]) -> Self { + let (even_i8, odd_i8) = split_even_odd_i8(coeffs); + let (even, odd) = split.forward_class_i8_pair(&even_i8, &odd_i8); + Self { even, odd } + } + + /// Convert the cached split-domain form back to coefficient form. + #[inline(always)] + pub fn to_ring(&self, split: &PartialSplitNtt32) -> CyclotomicRing { + let mut even = self.even; + let mut odd = self.odd; + split.inverse_class_pair(&mut even, &mut odd); + merge_even_odd(even, odd) + } + + /// Pointwise multiplication in split-evaluation form. + #[inline(always)] + pub fn pointwise_mul(&self, rhs: &Self, split: &PartialSplitNtt32) -> Self { + let mut even = [F::zero(); CLASS_D]; + let mut odd = [F::zero(); CLASS_D]; + mul_quadratic_pairs( + &mut even, + &mut odd, + &self.even, + &self.odd, + &rhs.even, + &rhs.odd, + &split.eval_roots, + ); + Self { even, odd } + } + + /// Accumulate a pointwise product directly into `self`. + #[inline(always)] + pub fn add_mul_assign(&mut self, lhs: &Self, rhs: &Self, split: &PartialSplitNtt32) { + add_mul_quadratic_pairs( + &mut self.even, + &mut self.odd, + &lhs.even, + &lhs.odd, + &rhs.even, + &rhs.odd, + &split.eval_roots, + ); + } +} + +/// Packed cached `k=32` split-domain representation of one or more `D=64` +/// ring elements, grouped across SIMD lanes. +#[derive(Clone, Copy)] +pub struct PackedPartialSplitEval32 { + even: [PF; CLASS_D], + odd: [PF; CLASS_D], +} + +impl PackedPartialSplitEval32 { + /// Number of scalar lanes grouped into this packed value. + pub const WIDTH: usize = PF::WIDTH; + + /// The additive identity in packed split-evaluation form. + #[inline(always)] + pub fn zero() -> Self { + let zero = PF::broadcast(PF::Scalar::zero()); + Self { + even: [zero; CLASS_D], + odd: [zero; CLASS_D], + } + } + + /// Pack one scalar split-domain value per lane. + #[inline(always)] + pub fn from_fn(mut f: FN) -> Self + where + PF::Scalar: CanonicalField, + FN: FnMut(usize) -> PartialSplitEval32, + { + let lanes: Vec> = (0..PF::WIDTH).map(&mut f).collect(); + Self::from_chunk(&lanes) + } + + /// Broadcast one scalar split-domain value to every SIMD lane. + #[inline(always)] + pub fn broadcast(value: &PartialSplitEval32) -> Self + where + PF::Scalar: CanonicalField, + { + Self { + even: from_fn(|i| PF::broadcast(value.even[i])), + odd: from_fn(|i| PF::broadcast(value.odd[i])), + } + } + + #[inline(always)] + fn from_chunk(chunk: &[PartialSplitEval32]) -> Self + where + PF::Scalar: CanonicalField, + { + assert_eq!( + chunk.len(), + PF::WIDTH, + "chunk length {} must equal packed width {}", + chunk.len(), + PF::WIDTH + ); + Self { + even: from_fn(|i| PF::from_fn(|lane| chunk[lane].even[i])), + odd: from_fn(|i| PF::from_fn(|lane| chunk[lane].odd[i])), + } + } +} + +/// Pre-broadcast split-NTT tables for packed SIMD execution. +#[derive(Clone, Copy)] +pub struct PackedPartialSplitNtt32 { + eval_roots: [PF; CLASS_D], + inv_stage_roots: [PF; CLASS_D], + d_inv_psi_inv: [PF; CLASS_D], +} + +impl PackedPartialSplitNtt32 { + /// Multiply two packed split-domain values lane-wise. + #[inline(always)] + pub fn pointwise_mul( + &self, + lhs: &PackedPartialSplitEval32, + rhs: &PackedPartialSplitEval32, + ) -> PackedPartialSplitEval32 + where + PF::Scalar: CanonicalField, + { + let zero = PF::broadcast(PF::Scalar::zero()); + let mut even = [zero; CLASS_D]; + let mut odd = [zero; CLASS_D]; + mul_quadratic_pairs( + &mut even, + &mut odd, + &lhs.even, + &lhs.odd, + &rhs.even, + &rhs.odd, + &self.eval_roots, + ); + PackedPartialSplitEval32 { even, odd } + } + + /// Accumulate a packed pointwise product directly into `acc`. + #[inline(always)] + pub fn add_mul_assign( + &self, + acc: &mut PackedPartialSplitEval32, + lhs: &PackedPartialSplitEval32, + rhs: &PackedPartialSplitEval32, + ) where + PF::Scalar: CanonicalField, + { + add_mul_quadratic_pairs( + &mut acc.even, + &mut acc.odd, + &lhs.even, + &lhs.odd, + &rhs.even, + &rhs.odd, + &self.eval_roots, + ); + } + + /// Append the coefficient-form outputs for every packed lane. + #[inline(always)] + pub fn append_rings( + &self, + eval: &PackedPartialSplitEval32, + out: &mut Vec>, + ) where + PF::Scalar: CanonicalField, + { + let mut even = eval.even; + let mut odd = eval.odd; + inverse_cyclic_dit_pair_prebroadcast(&mut even, &mut odd, &self.inv_stage_roots); + scale_pair_in_place(&mut even, &mut odd, &self.d_inv_psi_inv); + out.reserve(PF::WIDTH); + for lane in 0..PF::WIDTH { + let even_lane = from_fn(|i| even[i].extract(lane)); + let odd_lane = from_fn(|i| odd[i].extract(lane)); + out.push(merge_even_odd(even_lane, odd_lane)); + } + } +} + +/// Precomputed `k=32` split-NTT data for `F_p[X]/(X^64 + 1)`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PartialSplitNtt32 { + fwd_stage_roots: [F; CLASS_D], + inv_stage_roots: [F; CLASS_D], + psi_pows: [F; CLASS_D], + d_inv_psi_inv: [F; CLASS_D], + eval_roots: [F; CLASS_D], + cyclic_fwd_stage_roots: [F; RING_D], + cyclic_inv_stage_roots: [F; RING_D], + cyclic_d_inv: F, + centered_i8_lut: [F; 256], +} + +impl PartialSplitNtt32 { + /// Build the size-32 negacyclic transform tables from the field modulus. + /// + /// # Panics + /// + /// Panics if `64 ∤ (p - 1)`, or if no primitive `64`-th root is found by + /// the small-generator search. + pub fn compute() -> Self { + let q = modulus::(); + assert!( + (q - 1) % (2 * CLASS_D as u128) == 0, + "64 must divide q - 1 for the k=32 split transform" + ); + + let psi = find_primitive_root_2d::(q, CLASS_D); + let omega = psi.square(); + let omega_inv = omega + .inv() + .expect("primitive root must be invertible in the base field"); + let psi_inv = psi + .inv() + .expect("primitive root must be invertible in the base field"); + let d_inv = F::from_u64(CLASS_D as u64) + .inv() + .expect("transform size must be invertible in the base field"); + let cyclic_d_inv = F::from_u64(RING_D as u64) + .inv() + .expect("transform size must be invertible in the base field"); + + let psi_pows = powers(psi); + let psi_inv_pows = powers(psi_inv); + let d_inv_psi_inv = from_fn(|i| d_inv * psi_inv_pows[i]); + let eval_roots = { + let mut natural = [F::zero(); CLASS_D]; + let mut cur = psi; + for root in &mut natural { + *root = cur; + cur = cur * omega; + } + from_fn(|i| natural[bit_reverse_index(i, CLASS_LOG_D)]) + }; + + let mut fwd_stage_roots = [F::zero(); CLASS_D]; + let mut inv_stage_roots = [F::zero(); CLASS_D]; + let mut len = 1usize; + while len < CLASS_D { + let exp = (CLASS_D / (2 * len)) as u128; + fwd_stage_roots[len] = pow_field(omega, exp); + inv_stage_roots[len] = pow_field(omega_inv, exp); + len *= 2; + } + + let mut cyclic_fwd_stage_roots = [F::zero(); RING_D]; + let mut cyclic_inv_stage_roots = [F::zero(); RING_D]; + let mut len = 1usize; + while len < RING_D { + let exp = (RING_D / (2 * len)) as u128; + cyclic_fwd_stage_roots[len] = pow_field(psi, exp); + cyclic_inv_stage_roots[len] = pow_field(psi_inv, exp); + len *= 2; + } + + let centered_i8_lut = from_fn(|idx| F::from_i64(idx as i64 - 128)); + + Self { + fwd_stage_roots, + inv_stage_roots, + psi_pows, + d_inv_psi_inv, + eval_roots, + cyclic_fwd_stage_roots, + cyclic_inv_stage_roots, + cyclic_d_inv, + centered_i8_lut, + } + } + + /// Roots `r_j` of `Y^32 + 1` in the same slot order used by the transform. + #[inline(always)] + pub fn eval_roots(&self) -> &[F; CLASS_D] { + &self.eval_roots + } + + /// Broadcast the split-domain tables into a packed SIMD representation. + #[inline(always)] + pub fn packed>(&self) -> PackedPartialSplitNtt32 { + PackedPartialSplitNtt32 { + eval_roots: from_fn(|i| PF::broadcast(self.eval_roots[i])), + inv_stage_roots: from_fn(|i| PF::broadcast(self.inv_stage_roots[i])), + d_inv_psi_inv: from_fn(|i| PF::broadcast(self.d_inv_psi_inv[i])), + } + } + + /// Forward size-32 negacyclic transform on the class ring `F_p[Y]/(Y^32+1)`. + #[inline(always)] + pub fn forward_class(&self, coeffs: &mut [F; CLASS_D]) { + for (coeff, psi) in coeffs.iter_mut().zip(self.psi_pows.iter()) { + *coeff = *coeff * *psi; + } + forward_cyclic_dif(coeffs, &self.fwd_stage_roots); + } + + /// Forward size-32 negacyclic transform on two class polynomials at once. + #[inline(always)] + pub fn forward_class_pair(&self, lhs: &mut [F; CLASS_D], rhs: &mut [F; CLASS_D]) { + scale_pair_in_place(lhs, rhs, &self.psi_pows); + forward_cyclic_dif_pair(lhs, rhs, &self.fwd_stage_roots); + } + + /// Forward size-32 negacyclic transform for two centered-`i8` classes. + #[inline(always)] + pub fn forward_class_i8_pair( + &self, + lhs: &[i8; CLASS_D], + rhs: &[i8; CLASS_D], + ) -> ([F; CLASS_D], [F; CLASS_D]) { + let mut out_lhs = from_fn(|i| self.i8_to_field(lhs[i])); + let mut out_rhs = from_fn(|i| self.i8_to_field(rhs[i])); + self.forward_class_pair(&mut out_lhs, &mut out_rhs); + (out_lhs, out_rhs) + } + + /// Inverse size-32 negacyclic transform on two slot vectors at once. + #[inline(always)] + pub fn inverse_class_pair(&self, lhs: &mut [F; CLASS_D], rhs: &mut [F; CLASS_D]) { + inverse_cyclic_dit_pair(lhs, rhs, &self.inv_stage_roots); + scale_pair_in_place(lhs, rhs, &self.d_inv_psi_inv); + } + + /// Forward size-64 cyclic transform over `F_p[X]/(X^64 - 1)`. + #[inline(always)] + pub fn forward_cyclic_ring(&self, coeffs: &mut [F; RING_D]) { + forward_cyclic_dif64(coeffs, &self.cyclic_fwd_stage_roots); + } + + /// Inverse size-64 cyclic transform over `F_p[X]/(X^64 - 1)`. + #[inline(always)] + pub fn inverse_cyclic_ring(&self, evals: &mut [F; RING_D]) { + inverse_cyclic_dit64(evals, &self.cyclic_inv_stage_roots); + for value in evals.iter_mut() { + *value = *value * self.cyclic_d_inv; + } + } + + /// Multiply two `D=64` ring elements using the `k=32` partial split. + #[inline(always)] + pub fn multiply_d64( + &self, + lhs: &CyclotomicRing, + rhs: &CyclotomicRing, + ) -> CyclotomicRing { + let (mut lhs_even, mut lhs_odd) = split_even_odd(lhs.coefficients()); + let (mut rhs_even, mut rhs_odd) = split_even_odd(rhs.coefficients()); + + self.forward_class_pair(&mut lhs_even, &mut lhs_odd); + self.forward_class_pair(&mut rhs_even, &mut rhs_odd); + + let mut out_even = [F::zero(); CLASS_D]; + let mut out_odd = [F::zero(); CLASS_D]; + mul_quadratic_pairs( + &mut out_even, + &mut out_odd, + &lhs_even, + &lhs_odd, + &rhs_even, + &rhs_odd, + &self.eval_roots, + ); + + self.inverse_class_pair(&mut out_even, &mut out_odd); + merge_even_odd(out_even, out_odd) + } + + /// Multiply a full field-valued `D=64` ring element by centered `i8` + /// coefficients using the `k=32` partial split. + #[inline(always)] + pub fn multiply_d64_rhs_i8( + &self, + lhs: &CyclotomicRing, + rhs: &[i8; RING_D], + ) -> CyclotomicRing { + let (mut lhs_even, mut lhs_odd) = split_even_odd(lhs.coefficients()); + let (rhs_even, rhs_odd) = split_even_odd_i8(rhs); + + self.forward_class_pair(&mut lhs_even, &mut lhs_odd); + let (rhs_even_eval, rhs_odd_eval) = self.forward_class_i8_pair(&rhs_even, &rhs_odd); + + let mut out_even = [F::zero(); CLASS_D]; + let mut out_odd = [F::zero(); CLASS_D]; + mul_quadratic_pairs( + &mut out_even, + &mut out_odd, + &lhs_even, + &lhs_odd, + &rhs_even_eval, + &rhs_odd_eval, + &self.eval_roots, + ); + + self.inverse_class_pair(&mut out_even, &mut out_odd); + merge_even_odd(out_even, out_odd) + } + + /// Multiply two `D=64` coefficient arrays modulo `X^64 - 1`. + #[inline(always)] + pub fn multiply_cyclic_d64( + &self, + lhs: &CyclotomicRing, + rhs: &CyclotomicRing, + ) -> [F; RING_D] { + let mut lhs_eval = *lhs.coefficients(); + let mut rhs_eval = *rhs.coefficients(); + self.forward_cyclic_ring(&mut lhs_eval); + self.forward_cyclic_ring(&mut rhs_eval); + let mut out = [F::zero(); RING_D]; + for i in 0..RING_D { + out[i] = lhs_eval[i] * rhs_eval[i]; + } + self.inverse_cyclic_ring(&mut out); + out + } + + /// Compute the high-half quotient `H` such that + /// `lhs * rhs = H * X^64 + reduced`, with `deg(H) < 64`. + #[inline(always)] + pub fn unreduced_quotient_d64( + &self, + lhs: &CyclotomicRing, + rhs: &CyclotomicRing, + ) -> CyclotomicRing { + let neg = self.multiply_d64(lhs, rhs); + let cyc = self.multiply_cyclic_d64(lhs, rhs); + let neg_coeffs = neg.coefficients(); + let quotient = from_fn(|i| (cyc[i] - neg_coeffs[i]) * F::TWO_INV); + CyclotomicRing::from_coefficients(quotient) + } + + #[inline(always)] + fn i8_to_field(&self, coeff: i8) -> F { + self.centered_i8_lut[(coeff as i16 + CENTERED_I8_LUT_OFFSET) as usize] + } +} + +fn modulus() -> u128 { + (-F::one()).to_canonical_u128() + 1 +} + +fn powers(base: F) -> [F; CLASS_D] { + let mut out = [F::zero(); CLASS_D]; + let mut cur = F::one(); + for value in &mut out { + *value = cur; + cur = cur * base; + } + out +} + +fn pow_field(mut base: F, mut exp: u128) -> F { + let mut acc = F::one(); + while exp > 0 { + if exp & 1 == 1 { + acc = acc * base; + } + exp >>= 1; + if exp > 0 { + base = base.square(); + } + } + acc +} + +fn find_primitive_root_2d(q: u128, d: usize) -> F { + let half = (q - 1) / 2; + let exp = (q - 1) / (2 * d as u128); + let minus_one = -F::one(); + + for candidate in 2u64..(1 << 20) { + let probe = F::from_u64(candidate); + if pow_field(probe, half) == minus_one { + let psi = pow_field(probe, exp); + debug_assert!(pow_field(psi, d as u128) == minus_one, "psi^D != -1"); + debug_assert!(pow_field(psi, 2 * d as u128) == F::one(), "psi^(2D) != 1"); + return psi; + } + } + + panic!("no primitive {}-th root found", 2 * d); +} + +#[inline(always)] +fn split_even_odd(coeffs: &[F; RING_D]) -> ([F; CLASS_D], [F; CLASS_D]) { + let even = from_fn(|i| coeffs[2 * i]); + let odd = from_fn(|i| coeffs[2 * i + 1]); + (even, odd) +} + +#[inline(always)] +fn split_even_odd_i8(coeffs: &[i8; RING_D]) -> ([i8; CLASS_D], [i8; CLASS_D]) { + let even = from_fn(|i| coeffs[2 * i]); + let odd = from_fn(|i| coeffs[2 * i + 1]); + (even, odd) +} + +#[inline(always)] +fn mul_quadratic_karatsuba(a0: T, a1: T, b0: T, b1: T, r: T) -> (T, T) +where + T: Copy + Add + Sub + Mul, +{ + let p0 = a0 * b0; + let p1 = a1 * b1; + let p2 = (a0 + a1) * (b0 + b1); + let c0 = p0 + r * p1; + let c1 = p2 - p0 - p1; + (c0, c1) +} + +#[inline(always)] +fn mul_quadratic_pairs( + out_even: &mut [T; CLASS_D], + out_odd: &mut [T; CLASS_D], + lhs_even: &[T; CLASS_D], + lhs_odd: &[T; CLASS_D], + rhs_even: &[T; CLASS_D], + rhs_odd: &[T; CLASS_D], + roots: &[T; CLASS_D], +) where + T: Copy + Add + Sub + Mul, +{ + for i in 0..CLASS_D { + (out_even[i], out_odd[i]) = + mul_quadratic_karatsuba(lhs_even[i], lhs_odd[i], rhs_even[i], rhs_odd[i], roots[i]); + } +} + +#[inline(always)] +fn add_mul_quadratic_pairs( + acc_even: &mut [T; CLASS_D], + acc_odd: &mut [T; CLASS_D], + lhs_even: &[T; CLASS_D], + lhs_odd: &[T; CLASS_D], + rhs_even: &[T; CLASS_D], + rhs_odd: &[T; CLASS_D], + roots: &[T; CLASS_D], +) where + T: Copy + Add + Sub + Mul, +{ + for i in 0..CLASS_D { + let (even, odd) = + mul_quadratic_karatsuba(lhs_even[i], lhs_odd[i], rhs_even[i], rhs_odd[i], roots[i]); + acc_even[i] = acc_even[i] + even; + acc_odd[i] = acc_odd[i] + odd; + } +} + +#[inline(always)] +fn scale_pair_in_place(lhs: &mut [T; CLASS_D], rhs: &mut [T; CLASS_D], scales: &[T; CLASS_D]) +where + T: Copy + Mul, +{ + for ((lhs_i, rhs_i), scale) in lhs.iter_mut().zip(rhs.iter_mut()).zip(scales.iter()) { + *lhs_i = *lhs_i * *scale; + *rhs_i = *rhs_i * *scale; + } +} + +#[inline(always)] +fn merge_even_odd( + even: [F; CLASS_D], + odd: [F; CLASS_D], +) -> CyclotomicRing { + let mut coeffs = [F::zero(); RING_D]; + for i in 0..CLASS_D { + coeffs[2 * i] = even[i]; + coeffs[2 * i + 1] = odd[i]; + } + CyclotomicRing::from_coefficients(coeffs) +} + +#[inline(always)] +fn forward_cyclic_dif(a: &mut [F; CLASS_D], stage_roots: &[F; CLASS_D]) { + let mut len = CLASS_D / 2; + while len > 0 { + let step = stage_roots[len]; + let mut start = 0usize; + while start < CLASS_D { + let mut w = F::one(); + for j in 0..len { + let u = a[start + j]; + let v = a[start + j + len]; + a[start + j] = u + v; + a[start + j + len] = (u - v) * w; + w = w * step; + } + start += 2 * len; + } + len /= 2; + } +} + +#[inline(always)] +fn forward_cyclic_dif_pair( + lhs: &mut [F; CLASS_D], + rhs: &mut [F; CLASS_D], + stage_roots: &[F; CLASS_D], +) { + let mut len = CLASS_D / 2; + while len > 0 { + let step = stage_roots[len]; + let mut start = 0usize; + while start < CLASS_D { + let mut w = F::one(); + for j in 0..len { + let lhs_u = lhs[start + j]; + let lhs_v = lhs[start + j + len]; + lhs[start + j] = lhs_u + lhs_v; + lhs[start + j + len] = (lhs_u - lhs_v) * w; + + let rhs_u = rhs[start + j]; + let rhs_v = rhs[start + j + len]; + rhs[start + j] = rhs_u + rhs_v; + rhs[start + j + len] = (rhs_u - rhs_v) * w; + + w = w * step; + } + start += 2 * len; + } + len /= 2; + } +} + +#[inline(always)] +fn inverse_cyclic_dit_pair( + lhs: &mut [F; CLASS_D], + rhs: &mut [F; CLASS_D], + stage_roots: &[F; CLASS_D], +) { + let mut len = 1usize; + while len < CLASS_D { + let step = stage_roots[len]; + let mut start = 0usize; + while start < CLASS_D { + let mut w = F::one(); + for j in 0..len { + let lhs_u = lhs[start + j]; + let lhs_v = lhs[start + j + len] * w; + lhs[start + j] = lhs_u + lhs_v; + lhs[start + j + len] = lhs_u - lhs_v; + + let rhs_u = rhs[start + j]; + let rhs_v = rhs[start + j + len] * w; + rhs[start + j] = rhs_u + rhs_v; + rhs[start + j + len] = rhs_u - rhs_v; + + w = w * step; + } + start += 2 * len; + } + len *= 2; + } +} + +#[inline(always)] +fn inverse_cyclic_dit_pair_prebroadcast( + lhs: &mut [PF; CLASS_D], + rhs: &mut [PF; CLASS_D], + stage_roots: &[PF; CLASS_D], +) where + PF::Scalar: CanonicalField, +{ + let mut len = 1usize; + while len < CLASS_D { + let step = stage_roots[len]; + let mut start = 0usize; + while start < CLASS_D { + let mut w = PF::broadcast(PF::Scalar::one()); + for j in 0..len { + let lhs_u = lhs[start + j]; + let lhs_v = lhs[start + j + len] * w; + lhs[start + j] = lhs_u + lhs_v; + lhs[start + j + len] = lhs_u - lhs_v; + + let rhs_u = rhs[start + j]; + let rhs_v = rhs[start + j + len] * w; + rhs[start + j] = rhs_u + rhs_v; + rhs[start + j + len] = rhs_u - rhs_v; + + w = w * step; + } + start += 2 * len; + } + len *= 2; + } +} + +#[inline(always)] +fn forward_cyclic_dif64(a: &mut [F; RING_D], stage_roots: &[F; RING_D]) { + let mut len = RING_D / 2; + while len > 0 { + let step = stage_roots[len]; + let mut start = 0usize; + while start < RING_D { + let mut w = F::one(); + for j in 0..len { + let u = a[start + j]; + let v = a[start + j + len]; + a[start + j] = u + v; + a[start + j + len] = (u - v) * w; + w = w * step; + } + start += 2 * len; + } + len /= 2; + } +} + +#[inline(always)] +fn inverse_cyclic_dit64(a: &mut [F; RING_D], stage_roots: &[F; RING_D]) { + let mut len = 1usize; + while len < RING_D { + let step = stage_roots[len]; + let mut start = 0usize; + while start < RING_D { + let mut w = F::one(); + for j in 0..len { + let u = a[start + j]; + let v = a[start + j + len] * w; + a[start + j] = u + v; + a[start + j + len] = u - v; + w = w * step; + } + start += 2 * len; + } + len *= 2; + } +} + +#[inline(always)] +fn bit_reverse_index(idx: usize, log_n: usize) -> usize { + idx.reverse_bits() >> (usize::BITS as usize - log_n) +} diff --git a/src/algebra/ring/sparse_challenge.rs b/src/algebra/ring/sparse_challenge.rs new file mode 100644 index 00000000..5cf14b9a --- /dev/null +++ b/src/algebra/ring/sparse_challenge.rs @@ -0,0 +1,212 @@ +//! Sparse ring challenges for cyclotomic protocols. +//! +//! Many lattice protocols sample "short/sparse" ring challenges whose coefficients +//! are mostly zero and whose non-zero coefficients come from a tiny integer alphabet +//! (e.g. `{±1}` or `{±1,±2}`), with a fixed Hamming weight `ω`. +//! +//! This module provides a minimal representation that is: +//! - independent of any specific protocol (Hachi/Greyhound/SuperNeo, etc.), +//! - easy to sample deterministically from Fiat–Shamir at the protocol layer, +//! - and efficient to evaluate at a point `α` using precomputed powers. + +use super::CyclotomicRing; +use crate::algebra::fields::LiftBase; +use crate::{CanonicalField, FieldCore}; +use rand_core::RngCore; + +/// Configuration for sampling a sparse challenge. +/// +/// This intentionally avoids redundant knobs: the distribution is determined by: +/// - exact `weight` (Hamming weight), +/// - and a list of allowed **non-zero** integer coefficients. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SparseChallengeConfig { + /// Exact Hamming weight ω. + pub weight: usize, + /// Allowed non-zero coefficients (small signed integers). + /// + /// Examples: + /// - `{±1}`: `vec![-1, 1]` + /// - `{±1,±2}`: `vec![-2, -1, 1, 2]` + pub nonzero_coeffs: Vec, +} + +impl SparseChallengeConfig { + /// Validate basic invariants for a given ring degree `D`. + /// + /// # Errors + /// + /// Returns an error if `weight > D`, if `nonzero_coeffs` is empty, or if it + /// contains `0`. + pub fn validate(&self) -> Result<(), &'static str> { + if self.weight > D { + return Err("weight must be <= ring degree D"); + } + if self.nonzero_coeffs.is_empty() { + return Err("nonzero_coeffs must be non-empty"); + } + if self.nonzero_coeffs.contains(&0) { + return Err("nonzero_coeffs must not contain 0"); + } + Ok(()) + } +} + +/// Sparse polynomial in `F[X]/(X^D+1)` represented by its non-zero terms. +/// +/// Invariants: +/// - `positions.len() == coeffs.len()` +/// - all positions are `< D` +/// - positions are unique +/// - all coeffs are non-zero +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SparseChallenge { + /// Coefficient indices (powers of `X`) where the polynomial is non-zero. + pub positions: Vec, + /// Small integer coefficients at the corresponding positions. + pub coeffs: Vec, +} + +impl SparseChallenge { + /// Construct an empty (all-zero) challenge. + #[inline] + pub fn zero() -> Self { + Self { + positions: Vec::new(), + coeffs: Vec::new(), + } + } + + /// Number of non-zero coefficients (Hamming weight). + #[inline] + pub fn hamming_weight(&self) -> usize { + debug_assert_eq!(self.positions.len(), self.coeffs.len()); + self.positions.len() + } + + /// ℓ₁ norm over integers: `Σ |coeff_i|`. + #[inline] + pub fn l1_norm(&self) -> u64 { + self.coeffs + .iter() + .map(|&c| (c as i32).unsigned_abs() as u64) + .sum() + } + + /// Validate structural invariants for a ring degree `D`. + /// + /// # Errors + /// + /// Returns an error if lengths mismatch, if any coefficient is zero, if any + /// position is out of range, or if positions contain duplicates. + pub fn validate(&self) -> Result<(), &'static str> { + if self.positions.len() != self.coeffs.len() { + return Err("positions and coeffs must have same length"); + } + // Check coeffs are non-zero and positions are in range + unique. + let mut seen = vec![false; D]; + for (&pos, &c) in self.positions.iter().zip(self.coeffs.iter()) { + if c == 0 { + return Err("coeffs must not contain 0"); + } + let p = pos as usize; + if p >= D { + return Err("position out of range"); + } + if seen[p] { + return Err("positions must be unique"); + } + seen[p] = true; + } + Ok(()) + } + + /// Convert to a dense ring element by placing coefficients in the canonical + /// coefficient basis. + /// + /// # Errors + /// + /// Returns an error if the sparse representation violates structural invariants. + pub fn to_dense( + &self, + ) -> Result, &'static str> { + self.validate::()?; + let mut out = [F::zero(); D]; + for (&pos, &c) in self.positions.iter().zip(self.coeffs.iter()) { + out[pos as usize] += F::from_i64(c as i64); + } + Ok(CyclotomicRing::from_coefficients(out)) + } + + /// Evaluate this sparse polynomial at `α` in `E`, given precomputed powers + /// `[α^0, α^1, ..., α^{D-1}]`. + /// + /// This is `O(weight)` and is intended to be used for verifier-side oracles + /// where `D` may be large but `weight` is small. + /// + /// # Errors + /// + /// Returns an error if structural invariants fail or if `alpha_pows.len() != D`. + pub fn eval_at_alpha(&self, alpha_pows: &[E]) -> Result + where + F: FieldCore + CanonicalField, + E: FieldCore + LiftBase, + { + self.validate::()?; + if alpha_pows.len() != D { + return Err("alpha_pows length mismatch"); + } + let mut acc = E::zero(); + for (&pos, &c) in self.positions.iter().zip(self.coeffs.iter()) { + let coeff_f = F::from_i64(c as i64); + acc += E::lift_base(coeff_f) * alpha_pows[pos as usize]; + } + Ok(acc) + } +} + +/// Sample a dense ternary ring element with coefficients in `{-1, 0, 1}`. +/// +/// Distribution matches Labrador C's ternary nibble LUT (`0xA815`), yielding +/// probabilities `5/16, 6/16, 5/16` for `-1, 0, 1` respectively. +pub fn sample_ternary( + rng: &mut R, +) -> CyclotomicRing { + const LUT: u16 = 0xA815; + let mut coeffs = [F::zero(); D]; + let mut i = 0usize; + while i < D { + let byte = (rng.next_u32() & 0xFF) as u8; + let lo = (((LUT >> (byte & 0x0F)) & 0x3) as i16) - 1; + coeffs[i] = F::from_i64(lo as i64); + i += 1; + if i < D { + let hi = (((LUT >> (byte >> 4)) & 0x3) as i16) - 1; + coeffs[i] = F::from_i64(hi as i64); + i += 1; + } + } + CyclotomicRing::from_coefficients(coeffs) +} + +/// Sample a dense quaternary ring element with coefficients in `{-2, -1, 0, 1}`. +/// +/// Coefficients are sampled uniformly from two-bit chunks and shifted by `-2`. +pub fn sample_quaternary( + rng: &mut R, +) -> CyclotomicRing { + let mut coeffs = [F::zero(); D]; + let mut i = 0usize; + while i < D { + let bits = rng.next_u32(); + for lane in 0..16 { + if i >= D { + break; + } + let val = (((bits >> (2 * lane)) & 0x3) as i16) - 2; + coeffs[i] = F::from_i64(val as i64); + i += 1; + } + } + CyclotomicRing::from_coefficients(coeffs) +} diff --git a/src/lib.rs b/src/lib.rs index e3859c54..ccbfbc1b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,7 @@ +#![cfg_attr( + all(target_arch = "x86_64", target_feature = "avx512f"), + feature(stdarch_x86_avx512) +)] //! # hachi //! //! A high performance and modular implementation of the Hachi polynomial commitment scheme. @@ -17,7 +21,7 @@ //! ### Core Modules //! - [`primitives`] - Core traits and abstractions //! - [`primitives::arithmetic`] - Field and module traits for lattice arithmetic -//! - [`primitives::poly`] - Multilinear polynomial traits and operations +//! - [`primitives::poly`] - Multilinear polynomial utility functions //! - [`primitives::transcript`] - Fiat-Shamir transcript trait //! - [`primitives::serialization`] - Serialization abstractions //! - [`error`] - Error types @@ -35,7 +39,27 @@ pub mod error; /// Primitive traits and operations pub mod primitives; +/// Concrete algebra backends (prime fields, extensions, rings) +pub mod algebra; + +/// Conditional parallelism utilities (`cfg_iter!`, `cfg_into_iter!`, etc.) +#[macro_use] +#[doc(hidden)] +pub mod parallel; + +/// Protocol-layer transcript and commitment abstractions +pub mod protocol; + +/// Shared test configuration and helpers. +#[doc(hidden)] +pub mod test_utils; + pub use error::HachiError; -pub use primitives::arithmetic::{Field, HachiRoutines, Module}; -pub use primitives::poly::{MultilinearLagrange, Polynomial}; +pub use primitives::arithmetic::{ + AdditiveGroup, CanonicalField, FieldCore, FieldSampling, FromSmallInt, Invertible, Module, + PseudoMersenneField, +}; pub use primitives::serialization::{HachiDeserialize, HachiSerialize}; +pub use protocol::{ + BasisMode, CommitmentScheme, DensePoly, HachiPolyOps, OneHotIndex, OneHotPoly, Transcript, +}; diff --git a/src/parallel.rs b/src/parallel.rs new file mode 100644 index 00000000..1be93577 --- /dev/null +++ b/src/parallel.rs @@ -0,0 +1,74 @@ +//! Conditional parallelism utilities. +//! +//! When the `parallel` feature is enabled, the `cfg_iter!` family of macros +//! expand to rayon's parallel iterators. Otherwise they fall back to standard +//! sequential iterators. + +#[cfg(feature = "parallel")] +pub use rayon::prelude::*; + +/// Returns `.par_iter()` when `parallel` is enabled, `.iter()` otherwise. +#[macro_export] +macro_rules! cfg_iter { + ($e:expr) => {{ + #[cfg(feature = "parallel")] + let it = $e.par_iter(); + #[cfg(not(feature = "parallel"))] + let it = $e.iter(); + it + }}; +} + +/// Returns `.par_iter_mut()` when `parallel` is enabled, `.iter_mut()` otherwise. +#[macro_export] +macro_rules! cfg_iter_mut { + ($e:expr) => {{ + #[cfg(feature = "parallel")] + let it = $e.par_iter_mut(); + #[cfg(not(feature = "parallel"))] + let it = $e.iter_mut(); + it + }}; +} + +/// Returns `.into_par_iter()` when `parallel` is enabled, `.into_iter()` otherwise. +#[macro_export] +macro_rules! cfg_into_iter { + ($e:expr) => {{ + #[cfg(feature = "parallel")] + let it = $e.into_par_iter(); + #[cfg(not(feature = "parallel"))] + let it = $e.into_iter(); + it + }}; +} + +/// Returns `.par_chunks(n)` when `parallel` is enabled, `.chunks(n)` otherwise. +#[macro_export] +macro_rules! cfg_chunks { + ($e:expr, $n:expr) => {{ + #[cfg(feature = "parallel")] + let it = $e.par_chunks($n); + #[cfg(not(feature = "parallel"))] + let it = $e.chunks($n); + it + }}; +} + +/// Parallel fold-reduce over a range. +/// +/// With `parallel`: `range.into_par_iter().fold(identity, fold_op).reduce(identity, reduce_op)`. +/// Without: `range.into_iter().fold(identity(), fold_op)`. +#[macro_export] +macro_rules! cfg_fold_reduce { + ($range:expr, $identity:expr, $fold_op:expr, $reduce_op:expr) => {{ + #[cfg(feature = "parallel")] + let result = $range + .into_par_iter() + .fold($identity, $fold_op) + .reduce($identity, $reduce_op); + #[cfg(not(feature = "parallel"))] + let result = $range.into_iter().fold(($identity)(), $fold_op); + result + }}; +} diff --git a/src/primitives/arithmetic.rs b/src/primitives/arithmetic.rs index ccaaa4ae..655e68ae 100644 --- a/src/primitives/arithmetic.rs +++ b/src/primitives/arithmetic.rs @@ -3,26 +3,42 @@ use super::{HachiDeserialize, HachiSerialize}; use rand_core::RngCore; -/// Field trait for lattice-based arithmetic -pub trait Field: +/// Minimal additive group: add, sub, neg, zero. +/// +/// Satisfied by both reduced field elements (`FieldCore`) and wide unreduced +/// accumulators (`Fp128x8i32`, etc.), enabling generic shift-accumulate +/// operations on `WideCyclotomicRing`. +pub trait AdditiveGroup: Sized + Clone + Copy - + PartialEq + Send + Sync - + HachiSerialize - + HachiDeserialize + std::ops::Add + std::ops::Sub - + std::ops::Mul + std::ops::Neg + + std::ops::AddAssign + + std::ops::SubAssign +{ + /// Additive identity. + const ZERO: Self; +} + +/// Core field operations required across algebra backends. +pub trait FieldCore: + AdditiveGroup + + PartialEq + + HachiSerialize + + HachiDeserialize + + std::ops::Mul + for<'a> std::ops::Add<&'a Self, Output = Self> + for<'a> std::ops::Sub<&'a Self, Output = Self> + for<'a> std::ops::Mul<&'a Self, Output = Self> { - /// Additive identity - fn zero() -> Self; + /// Additive identity. + fn zero() -> Self { + Self::ZERO + } /// Multiplicative identity fn one() -> Self; @@ -30,26 +46,149 @@ pub trait Field: /// Check if element is zero fn is_zero(&self) -> bool; - /// Field addition - fn add(&self, rhs: &Self) -> Self; - - /// Field subtraction - fn sub(&self, rhs: &Self) -> Self; - - /// Field multiplication - fn mul(&self, rhs: &Self) -> Self; - - /// Field inversion + /// Field squaring. + /// + /// Default is `self * self`; extension fields override with specialized + /// formulas that use fewer base-field multiplications. + fn square(&self) -> Self { + *self * *self + } + + /// Field inversion. + /// + /// This API may branch on zero-check and is intended for public/non-secret + /// values. For secret-bearing paths, use [`Invertible::inv_or_zero`]. fn inv(self) -> Option; - /// Generate random field element - fn random(rng: &mut R) -> Self; + /// Multiplicative inverse of 2: `(p + 1) / 2` for odd-characteristic fields. + const TWO_INV: Self; +} - /// Convert from u64 +/// Constant-time inversion helper for secret-bearing code paths. +/// +/// Implementations return `0` when the input is `0`, and `x^{-1}` otherwise, +/// without branching on the input value. +pub trait Invertible: FieldCore { + /// Constant-time inversion with zero-mapping behavior. + fn inv_or_zero(self) -> Self; +} + +/// Embed small integers into a field. +/// +/// Every field contains a copy of its prime subfield, and small integers embed +/// into it canonically via reduction modulo the characteristic. This trait is +/// implementable for ALL fields — base and extension alike. +/// +/// Only `from_u64` and `from_i64` need concrete implementations; the narrower +/// widths have default impls via lossless widening. +pub trait FromSmallInt: FieldCore { + /// Embed a `u8` into the field. + fn from_u8(val: u8) -> Self { + Self::from_u64(val as u64) + } + + /// Embed an `i8` into the field. + fn from_i8(val: i8) -> Self { + Self::from_i64(val as i64) + } + + /// Embed a `u16` into the field. + fn from_u16(val: u16) -> Self { + Self::from_u64(val as u64) + } + + /// Embed an `i16` into the field. + fn from_i16(val: i16) -> Self { + Self::from_i64(val as i64) + } + + /// Embed a `u32` into the field. + fn from_u32(val: u32) -> Self { + Self::from_u64(val as u64) + } + + /// Embed an `i32` into the field. + fn from_i32(val: i32) -> Self { + Self::from_i64(val as i64) + } + + /// Embed a `u64` into the field (reduce mod characteristic). fn from_u64(val: u64) -> Self; - /// Convert from i64 + /// Embed an `i64` into the field (reduce mod characteristic). fn from_i64(val: i64) -> Self; + + /// Embed an `i128` into the field. + /// + /// Default implementation splits into u64 limbs with field multiplication + /// by `2^64`. Override for base fields that have a direct path. + fn from_i128(val: i128) -> Self { + if val >= 0 { + let lo = val as u64; + let hi = (val >> 64) as u64; + if hi == 0 { + Self::from_u64(lo) + } else { + let two_64 = Self::from_u64(1u64 << 32) * Self::from_u64(1u64 << 32); + Self::from_u64(lo) + Self::from_u64(hi) * two_64 + } + } else { + -Self::from_i128(-val) + } + } + + /// Lookup table mapping balanced digit index → field element. + /// + /// For `log_basis` in `1..=4`, returns a 16-entry table where + /// `table[i]` = `from_i64(i - b/2)` for `i < b = 2^log_basis`, + /// and zero for `i >= b`. + /// + /// Index a digit `d ∈ [-b/2, b/2)` as `table[(d + b/2) as usize]`. + fn digit_lut(log_basis: u32) -> [Self; 16] { + debug_assert!(log_basis > 0 && log_basis <= 4); + let b = 1usize << log_basis; + let half_b = (b >> 1) as i64; + std::array::from_fn(|i| { + if i < b { + Self::from_i64(i as i64 - half_b) + } else { + Self::zero() + } + }) + } +} + +/// Canonical integer representation for prime (base) field elements. +/// +/// Provides a bijection between field elements and integers in `[0, p)`. +/// Only meaningful for base prime fields where elements ARE residues mod p. +/// Extension fields should NOT implement this trait. +pub trait CanonicalField: FromSmallInt { + /// Return canonical integer representation as `u128`. + fn to_canonical_u128(self) -> u128; + + /// Construct from canonical value if it is in range. + fn from_canonical_u128_checked(val: u128) -> Option; + + /// Construct from canonical value reduced modulo the field modulus. + fn from_canonical_u128_reduced(val: u128) -> Self; +} + +/// Optional sampling support for field elements. +/// +/// This is intentionally separate from core field algebra and may evolve. +pub trait FieldSampling: FieldCore { + /// Generate a sampled field element. + fn sample(rng: &mut R) -> Self; +} + +/// Metadata for pseudo-Mersenne style moduli (`2^k - c`). +pub trait PseudoMersenneField: CanonicalField { + /// Exponent `k` in `2^k - c`. + const MODULUS_BITS: u32; + + /// Offset `c` in `2^k - c`. + const MODULUS_OFFSET: u128; } /// Module trait for lattice-based algebraic structures @@ -73,24 +212,18 @@ pub trait Module: + for<'a> std::ops::Sub<&'a Self, Output = Self> { /// Scalar type (field/ring elements) - type Scalar: Field + type Scalar: FieldCore + + CanonicalField + + FieldSampling + std::ops::Mul + for<'a> std::ops::Mul<&'a Self, Output = Self>; /// Zero element fn zero() -> Self; - /// Addition - fn add(&self, rhs: &Self) -> Self; - - /// Negation - fn neg(&self) -> Self; - /// Scalar multiplication fn scale(&self, k: &Self::Scalar) -> Self; /// Generate random module element fn random(rng: &mut R) -> Self; } - -pub trait HachiRoutines {} diff --git a/src/primitives/poly.rs b/src/primitives/poly.rs index 1e9a8b63..6a4185b9 100644 --- a/src/primitives/poly.rs +++ b/src/primitives/poly.rs @@ -1,58 +1,6 @@ -//! Polynomial trait for multilinear polynomials +//! Multilinear polynomial utility functions. -use super::arithmetic::Field; - -/// Trait for multilinear Lagrange polynomial operations -pub trait MultilinearLagrange: Polynomial { - /// Compute multilinear Lagrange basis evaluations at a point - /// - /// For variables (r₀, r₁, ..., r_{n-1}), computes all 2^n basis polynomial evaluations. - /// The i-th basis polynomial evaluates to 1 at the i-th hypercube vertex and 0 elsewhere. - fn lagrange_basis(&self, output: &mut [F], point: &[F]) { - multilinear_lagrange_basis(output, point) - } - - /// Compute vector-matrix product: v = L^T * M - /// - /// Treats coefficients as a 2^nu × 2^sigma matrix. - /// For each column j: v\[j\] = Σ_i left_vec\[i\] * coefficients\[i\]\[j\] - fn vector_matrix_product(&self, left_vec: &[F], nu: usize, sigma: usize) -> Vec; - - /// Compute left and right vectors from evaluation point - /// - /// Given a point arranged for matrix evaluation, computes L and R such that: - /// polynomial_evaluation(point) = L^T × M × R - fn compute_evaluation_vectors(&self, point: &[F], nu: usize, sigma: usize) -> (Vec, Vec) { - compute_left_right_vectors(point, nu, sigma) - } -} - -/// Trait for multilinear polynomials -/// -/// Represents a polynomial in evaluation form (coefficients at hypercube points). -pub trait Polynomial { - /// Number of variables - fn num_vars(&self) -> usize; - - /// Total number of coefficients (2^num_vars) - fn len(&self) -> usize { - 1 << self.num_vars() - } - - /// Check if polynomial is empty - fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Evaluate polynomial at a point - /// - /// # Parameters - /// - `point`: Evaluation point (length must equal num_vars) - /// - /// # Returns - /// Polynomial evaluation result - fn evaluate(&self, point: &[F]) -> F; -} +use super::arithmetic::FieldCore; /// Compute multilinear Lagrange basis evaluations at a point /// @@ -62,7 +10,7 @@ pub trait Polynomial { /// Uses an iterative doubling approach: /// - Start with [1-r₀, r₀] /// - For each variable rᵢ, split each value v into [v*(1-rᵢ), v*rᵢ] -pub(crate) fn multilinear_lagrange_basis(output: &mut [F], point: &[F]) { +pub(crate) fn multilinear_lagrange_basis(output: &mut [F], point: &[F]) { assert!( output.len() <= (1 << point.len()), "Output length must be at most 2^point.len()" @@ -115,7 +63,7 @@ pub(crate) fn multilinear_lagrange_basis(output: &mut [F], point: &[F] /// polynomial_evaluation(point) = L^T × M × R /// /// Splits variables between rows and columns based on sigma and nu. -pub fn compute_left_right_vectors( +pub fn compute_left_right_vectors( point: &[F], nu: usize, sigma: usize, diff --git a/src/primitives/serialization.rs b/src/primitives/serialization.rs index f4a0bf1e..ea14bdc3 100644 --- a/src/primitives/serialization.rs +++ b/src/primitives/serialization.rs @@ -182,11 +182,43 @@ mod primitive_impls { impl_primitive_serialization!(u16, 2); impl_primitive_serialization!(u32, 4); impl_primitive_serialization!(u64, 8); - impl_primitive_serialization!(usize, std::mem::size_of::()); + impl_primitive_serialization!(u128, 16); impl_primitive_serialization!(i8, 1); impl_primitive_serialization!(i16, 2); impl_primitive_serialization!(i32, 4); impl_primitive_serialization!(i64, 8); + impl_primitive_serialization!(i128, 16); + + impl Valid for usize { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } + } + + impl HachiSerialize for usize { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (*self as u64).serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 8 + } + } + + impl HachiDeserialize for usize { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let val = u64::deserialize_with_mode(reader, compress, validate)?; + Ok(val as usize) + } + } impl Valid for bool { fn check(&self) -> Result<(), SerializationError> { diff --git a/src/primitives/transcript.rs b/src/primitives/transcript.rs index 89ebc29a..ce2e8f53 100644 --- a/src/primitives/transcript.rs +++ b/src/primitives/transcript.rs @@ -2,13 +2,13 @@ #![allow(missing_docs)] -use crate::primitives::arithmetic::{Field, Module}; +use crate::primitives::arithmetic::{CanonicalField, FieldCore, Module}; use crate::primitives::HachiSerialize; /// Transcript for Fiat-Shamir transformations pub trait Transcript { /// Field type for challenges - type Field: Field; + type Field: FieldCore + CanonicalField; /// Append raw bytes to the transcript fn append_bytes(&mut self, label: &[u8], bytes: &[u8]); diff --git a/src/protocol/challenges/mod.rs b/src/protocol/challenges/mod.rs new file mode 100644 index 00000000..524941ed --- /dev/null +++ b/src/protocol/challenges/mod.rs @@ -0,0 +1,6 @@ +//! Protocol-level Fiat–Shamir challenge samplers. +//! +//! These utilities derive structured challenges (e.g. sparse ring elements) from +//! the transcript while keeping the low-level representations in the algebra layer. + +pub mod sparse; diff --git a/src/protocol/challenges/sparse.rs b/src/protocol/challenges/sparse.rs new file mode 100644 index 00000000..aca54aaf --- /dev/null +++ b/src/protocol/challenges/sparse.rs @@ -0,0 +1,121 @@ +//! Sparse challenge sampling via Fiat–Shamir. + +use crate::algebra::ring::{CyclotomicRing, SparseChallenge, SparseChallengeConfig}; +use crate::error::HachiError; +use crate::protocol::transcript::labels::{ABSORB_SPARSE_CHALLENGE, CHALLENGE_SPARSE_CHALLENGE}; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; + +/// Sample a sparse ring challenge (exact weight ω) from a transcript. +/// +/// This is intentionally deterministic and label-aware: +/// - first we absorb the sampling context under `ABSORB_SPARSE_CHALLENGE`, +/// - then we derive as many `CHALLENGE_SPARSE_CHALLENGE` scalars as needed. +/// +/// Notes: +/// - Indices are sampled with a simple `mod D` reduction. For the intended +/// regimes (small `D`, cryptographic transcript), any bias is negligible. +/// - Duplicate indices are rejected to enforce exact Hamming weight. +/// +/// # Errors +/// +/// Returns an error if the provided config is invalid for degree `D`. +pub fn sparse_challenge_from_transcript( + transcript: &mut T, + label: &[u8], + instance_idx: u64, + cfg: &SparseChallengeConfig, +) -> Result +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + cfg.validate::() + .map_err(|e| HachiError::InvalidInput(format!("invalid sparse challenge config: {e}")))?; + + // Absorb domain-separating context so different call sites can't collide. + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, label); + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, &instance_idx.to_le_bytes()); + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, &(D as u64).to_le_bytes()); + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, &(cfg.weight as u64).to_le_bytes()); + // Include the coefficient alphabet (as little-endian i16 stream). + let mut coeff_bytes = Vec::with_capacity(cfg.nonzero_coeffs.len() * 2); + for &c in cfg.nonzero_coeffs.iter() { + coeff_bytes.extend_from_slice(&c.to_le_bytes()); + } + transcript.append_bytes(ABSORB_SPARSE_CHALLENGE, &coeff_bytes); + + let mut seen = vec![false; D]; + let mut positions = Vec::with_capacity(cfg.weight); + let mut coeffs = Vec::with_capacity(cfg.weight); + + while positions.len() < cfg.weight { + let r = transcript + .challenge_scalar(CHALLENGE_SPARSE_CHALLENGE) + .to_canonical_u128(); + let lo = r as u64; + let hi = (r >> 64) as u64; + + let pos = (lo % (D as u64)) as usize; + if seen[pos] { + continue; + } + seen[pos] = true; + positions.push(pos as u32); + + let coeff_idx = (hi % (cfg.nonzero_coeffs.len() as u64)) as usize; + let c = cfg.nonzero_coeffs[coeff_idx]; + debug_assert_ne!(c, 0); + coeffs.push(c); + } + + Ok(SparseChallenge { positions, coeffs }) +} + +/// Sample `n` sparse challenges from a transcript, returning the sparse +/// representation directly. +/// +/// # Errors +/// +/// Returns an error if challenge sampling fails. +pub fn sample_sparse_challenges( + transcript: &mut T, + label: &[u8], + n: usize, + cfg: &SparseChallengeConfig, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + (0..n) + .map(|i| sparse_challenge_from_transcript::(transcript, label, i as u64, cfg)) + .collect() +} + +/// Sample `n` sparse challenges from a transcript and convert them to dense +/// `CyclotomicRing` elements. +/// +/// # Errors +/// +/// Returns an error if challenge sampling or dense conversion fails. +pub fn sample_dense_challenges( + transcript: &mut T, + label: &[u8], + n: usize, + cfg: &SparseChallengeConfig, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + (0..n) + .map(|i| { + let sparse = + sparse_challenge_from_transcript::(transcript, label, i as u64, cfg)?; + sparse + .to_dense::() + .map_err(|e| HachiError::InvalidInput(e.to_string())) + }) + .collect() +} diff --git a/src/protocol/commitment/commit.rs b/src/protocol/commitment/commit.rs new file mode 100644 index 00000000..679bdfa1 --- /dev/null +++ b/src/protocol/commitment/commit.rs @@ -0,0 +1,1192 @@ +//! Ring-native §4.1 commitment core implementation. + +use super::config::{ + ensure_block_layout, ensure_matrix_shape_ge, ensure_supported_num_vars, + validate_and_derive_layout, HachiCommitmentLayout, +}; +use super::onehot::{inner_ajtai_onehot_wide, map_onehot_to_sparse_blocks}; +use super::schedule::HachiScheduleInputs; +use super::scheme::{CommitWitness, RingCommitmentScheme}; +use super::types::RingCommitment; +#[cfg(feature = "disk-persistence")] +use super::utils::crt_ntt::build_ntt_slots; +use super::utils::crt_ntt::{build_ntt_slot, NttSlotCache}; +use super::utils::flat_matrix::FlatMatrix; +use super::utils::linear::{ + decompose_rows_i8, flatten_i8_blocks, mat_vec_mul_ntt_i8, mat_vec_mul_ntt_single_i8, +}; +use super::utils::matrix::{derive_public_matrix, sample_public_matrix_seed, PublicMatrixSeed}; +use super::CommitmentConfig; +use crate::algebra::fields::wide::HasWide; +use crate::algebra::CyclotomicRing; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::protocol::hachi_poly_ops::OneHotIndex; +use crate::protocol::ring_switch::w_commitment_layout; +use crate::{cfg_into_iter, cfg_iter, CanonicalField, FieldCore, FieldSampling}; +#[cfg(feature = "disk-persistence")] +use std::fs; +use std::io::{Read, Write}; +#[cfg(feature = "disk-persistence")] +use std::path::PathBuf; + +/// Seed-only stage for deterministic setup expansion. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiSetupSeed { + /// Maximum supported variable count. + pub max_num_vars: usize, + /// Runtime commitment layout. + pub layout: HachiCommitmentLayout, + /// Public seed used to derive commitment matrices. + pub public_matrix_seed: PublicMatrixSeed, +} + +/// Expanded setup stage containing coefficient-form matrices stored as +/// D-agnostic flat field-element arrays. +/// +/// The same `HachiExpandedSetup` can be viewed at different ring dimensions by +/// calling [`FlatMatrix::view`] or [`FlatMatrix::row`] with the desired +/// const-generic `D`. +#[allow(non_snake_case)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiExpandedSetup { + /// Setup seed and runtime layout metadata. + pub seed: HachiSetupSeed, + /// Inner matrix `A`. + pub A: FlatMatrix, + /// Outer matrix `B`. + pub B: FlatMatrix, + /// Prover matrix `D ∈ R_q^{n_D × δ·2^R}` (§4.2). + pub D_mat: FlatMatrix, +} + +/// Prover setup artifact (expanded setup + per-matrix NTT caches). +/// +/// The NTT caches are tied to a specific ring dimension D. +#[allow(non_snake_case)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiProverSetup { + /// Expanded matrix stage used by both prover and verifier. + pub expanded: HachiExpandedSetup, + /// NTT cache for the A matrix. + pub ntt_A: NttSlotCache, + /// NTT cache for the B matrix. + pub ntt_B: NttSlotCache, + /// NTT cache for the D matrix. + pub ntt_D: NttSlotCache, +} + +/// Verifier setup artifact derived from prover setup. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiVerifierSetup { + /// Expanded matrix stage used for verification. + pub expanded: HachiExpandedSetup, +} + +impl HachiExpandedSetup { + /// Runtime layout carried by this setup (the max-dimension layout). + pub fn layout(&self) -> HachiCommitmentLayout { + self.seed.layout + } + + /// Derive the Labrador commitment-key seed from this setup's public + /// matrix seed via domain-separated BLAKE2b. + pub fn labrador_comkey_seed(&self) -> crate::protocol::labrador::comkey::LabradorComKeySeed { + crate::protocol::labrador::comkey::derive_labrador_comkey_seed( + &self.seed.public_matrix_seed, + ) + } +} + +impl HachiProverSetup { + /// Runtime layout carried by this setup (the max-dimension layout). + pub fn layout(&self) -> HachiCommitmentLayout { + self.expanded.layout() + } + + /// Panic if `layout`'s matrix dimensions exceed this setup's maximums. + /// + /// # Panics + /// + /// Panics if any of `layout`'s matrix widths (inner, outer, D) exceed + /// those of this setup. + pub fn assert_layout_fits(&self, layout: &HachiCommitmentLayout) { + let max = &self.expanded.seed.layout; + assert!( + layout.inner_width <= max.inner_width, + "A matrix too narrow: need {} but setup has {}", + layout.inner_width, + max.inner_width + ); + assert!( + layout.outer_width <= max.outer_width, + "B matrix too narrow: need {} but setup has {}", + layout.outer_width, + max.outer_width + ); + assert!( + layout.d_matrix_width <= max.d_matrix_width, + "D matrix too narrow: need {} but setup has {}", + layout.d_matrix_width, + max.d_matrix_width + ); + } +} + +impl Valid for HachiSetupSeed { + fn check(&self) -> Result<(), SerializationError> { + self.layout.check() + } +} + +impl HachiSerialize for HachiSetupSeed { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.max_num_vars + .serialize_with_mode(&mut writer, compress)?; + self.layout.serialize_with_mode(&mut writer, compress)?; + writer.write_all(&self.public_matrix_seed)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.max_num_vars.serialized_size(compress) + self.layout.serialized_size(compress) + 32 + } +} + +impl HachiDeserialize for HachiSetupSeed { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let max_num_vars = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let layout = HachiCommitmentLayout::deserialize_with_mode(&mut reader, compress, validate)?; + let mut public_matrix_seed = [0u8; 32]; + reader.read_exact(&mut public_matrix_seed)?; + let out = Self { + max_num_vars, + layout, + public_matrix_seed, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl Valid for HachiExpandedSetup { + fn check(&self) -> Result<(), SerializationError> { + self.seed.check()?; + self.A.check()?; + self.B.check()?; + self.D_mat.check()?; + Ok(()) + } +} + +impl HachiSerialize for HachiExpandedSetup { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.seed.serialize_with_mode(&mut writer, compress)?; + self.A.serialize_with_mode(&mut writer, compress)?; + self.B.serialize_with_mode(&mut writer, compress)?; + self.D_mat.serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.seed.serialized_size(compress) + + self.A.serialized_size(compress) + + self.B.serialized_size(compress) + + self.D_mat.serialized_size(compress) + } +} + +impl HachiDeserialize for HachiExpandedSetup { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let out = Self { + seed: HachiSetupSeed::deserialize_with_mode(&mut reader, compress, validate)?, + A: FlatMatrix::deserialize_with_mode(&mut reader, compress, validate)?, + B: FlatMatrix::deserialize_with_mode(&mut reader, compress, validate)?, + D_mat: FlatMatrix::deserialize_with_mode(&mut reader, compress, validate)?, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl Valid for HachiProverSetup { + fn check(&self) -> Result<(), SerializationError> { + self.expanded.check() + } +} + +impl HachiSerialize for HachiProverSetup { + fn serialize_with_mode( + &self, + _writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + Err(SerializationError::InvalidData( + "HachiProverSetup contains runtime NTT caches and is not serializable".into(), + )) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 0 + } +} + +impl Valid for HachiVerifierSetup { + fn check(&self) -> Result<(), SerializationError> { + self.expanded.check() + } +} + +impl HachiSerialize for HachiVerifierSetup { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.expanded.serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.expanded.serialized_size(compress) + } +} + +impl HachiDeserialize for HachiVerifierSetup { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + Ok(Self { + expanded: HachiExpandedSetup::deserialize_with_mode(reader, compress, validate)?, + }) + } +} + +#[cfg(feature = "disk-persistence")] +fn cache_file_name(max_num_vars: usize) -> String { + let family = Cfg::family_key() + .chars() + .map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '_' }) + .collect::(); + format!( + "hachi_{family}_d{}_na{}_nb{}_nd{}_b{}_wb{}_nv{max_num_vars}.setup", + Cfg::D, + Cfg::N_A, + Cfg::N_B, + Cfg::N_D, + Cfg::decomposition().log_basis, + Cfg::w_log_basis(), + ) +} + +#[cfg(feature = "disk-persistence")] +fn get_storage_path(max_num_vars: usize) -> Option { + let cache_directory = if let Ok(local_app_data) = std::env::var("LOCALAPPDATA") { + Some(PathBuf::from(local_app_data)) + } else if let Ok(home) = std::env::var("HOME") { + let mut path = PathBuf::from(&home); + let macos_cache = { + let mut test_path = PathBuf::from(&home); + test_path.push("Library"); + test_path.push("Caches"); + test_path.exists() + }; + if macos_cache { + path.push("Library"); + path.push("Caches"); + } else { + path.push(".cache"); + } + Some(path) + } else { + None + }; + + cache_directory.map(|mut path| { + path.push("hachi"); + path.push(cache_file_name::(max_num_vars)); + path + }) +} + +#[cfg(feature = "disk-persistence")] +fn save_expanded_setup( + setup: &HachiExpandedSetup, + max_num_vars: usize, +) { + let Some(storage_path) = get_storage_path::(max_num_vars) else { + tracing::warn!("Could not determine storage directory; skipping setup save"); + return; + }; + + if let Some(parent) = storage_path.parent() { + fs::create_dir_all(parent) + .unwrap_or_else(|e| panic!("Failed to create storage directory: {e}")); + } + + tracing::info!("Saving setup to {}", storage_path.display()); + + let file = fs::File::create(&storage_path) + .unwrap_or_else(|e| panic!("Failed to create setup file: {e}")); + let mut writer = std::io::BufWriter::new(file); + + setup + .serialize_compressed(&mut writer) + .unwrap_or_else(|e| panic!("Failed to serialize setup: {e}")); + + tracing::info!("Successfully saved setup to disk"); +} + +#[cfg(feature = "disk-persistence")] +fn load_expanded_setup( + max_num_vars: usize, +) -> Result, HachiError> { + let storage_path = get_storage_path::(max_num_vars).ok_or_else(|| { + HachiError::InvalidSetup("Failed to determine storage directory".to_string()) + })?; + + if !storage_path.exists() { + return Err(HachiError::InvalidSetup(format!( + "Setup file not found at {}", + storage_path.display() + ))); + } + + tracing::info!("Loading setup from {}", storage_path.display()); + + let file = fs::File::open(&storage_path) + .map_err(|e| HachiError::InvalidSetup(format!("Failed to open setup file: {e}")))?; + let mut reader = std::io::BufReader::new(file); + + let setup = HachiExpandedSetup::deserialize_compressed(&mut reader) + .map_err(|e| HachiError::InvalidSetup(format!("Failed to deserialize setup: {e}")))?; + + tracing::info!("Loaded setup for max_num_vars={max_num_vars}"); + Ok(setup) +} + +/// Build prover and verifier setup from a pre-existing expanded setup by +/// reconstructing the NTT caches. +#[cfg(feature = "disk-persistence")] +pub(crate) fn setup_from_expanded( + expanded: HachiExpandedSetup, +) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> { + let (ntt_a, ntt_b, ntt_d) = build_ntt_slots( + expanded.A.view::(), + expanded.B.view::(), + expanded.D_mat.view::(), + )?; + let prover_setup = HachiProverSetup { + expanded: expanded.clone(), + ntt_A: ntt_a, + ntt_B: ntt_b, + ntt_D: ntt_d, + }; + let verifier_setup = HachiVerifierSetup { expanded }; + Ok((prover_setup, verifier_setup)) +} + +/// Concrete §4.1 commitment core. +#[derive(Clone, Copy, Default)] +pub struct HachiCommitmentCore; + +impl RingCommitmentScheme for HachiCommitmentCore +where + F: FieldCore + CanonicalField + FieldSampling + HasWide + Valid, + Cfg: CommitmentConfig, +{ + type ProverSetup = HachiProverSetup; + type VerifierSetup = HachiVerifierSetup; + type Commitment = RingCommitment; + + #[tracing::instrument(skip_all, name = "RingCommitmentScheme::setup")] + fn setup(max_num_vars: usize) -> Result<(Self::ProverSetup, Self::VerifierSetup), HachiError> { + let layout = validate_and_derive_layout::(max_num_vars)?; + let root_params = Cfg::level_params(HachiScheduleInputs { + max_num_vars, + level: 0, + current_w_len: 1usize << max_num_vars, + }); + ensure_supported_num_vars(max_num_vars, layout.required_num_vars::()?)?; + + #[cfg(feature = "disk-persistence")] + { + match load_expanded_setup::(max_num_vars) { + Ok(expanded) => { + tracing::info!("Loaded setup from disk, rebuilding NTT caches"); + return setup_from_expanded(expanded); + } + Err(e) => { + if let Some(storage_path) = get_storage_path::(max_num_vars) { + let _ = fs::remove_file(&storage_path); + tracing::warn!( + "Failed to load cached setup from {}: {e}; regenerating", + storage_path.display() + ); + } else { + tracing::warn!("Failed to load cached setup: {e}; regenerating"); + } + } + } + } + + let w_layout = w_commitment_layout::(root_params, layout)?; + let a_cols = layout.inner_width.max(w_layout.inner_width); + let b_cols = layout.outer_width.max(w_layout.outer_width); + let d_cols = layout.d_matrix_width.max(w_layout.d_matrix_width); + + let public_matrix_seed = sample_public_matrix_seed(); + let a_matrix = derive_public_matrix::(Cfg::N_A, a_cols, &public_matrix_seed, b"A"); + let b_matrix = derive_public_matrix::(Cfg::N_B, b_cols, &public_matrix_seed, b"B"); + let d_matrix = derive_public_matrix::(Cfg::N_D, d_cols, &public_matrix_seed, b"D"); + + let a_flat = FlatMatrix::from_ring_matrix(&a_matrix); + let b_flat = FlatMatrix::from_ring_matrix(&b_matrix); + let d_flat = FlatMatrix::from_ring_matrix(&d_matrix); + + let ntt_a = build_ntt_slot(a_flat.view::())?; + let ntt_b = build_ntt_slot(b_flat.view::())?; + let ntt_d = build_ntt_slot(d_flat.view::())?; + let expanded = HachiExpandedSetup { + seed: HachiSetupSeed { + max_num_vars, + layout, + public_matrix_seed, + }, + A: a_flat, + B: b_flat, + D_mat: d_flat, + }; + + #[cfg(feature = "disk-persistence")] + save_expanded_setup::(&expanded, max_num_vars); + + let prover_setup = HachiProverSetup { + expanded: expanded.clone(), + ntt_A: ntt_a, + ntt_B: ntt_b, + ntt_D: ntt_d, + }; + let verifier_setup = HachiVerifierSetup { expanded }; + ensure_matrix_shape_ge::( + &prover_setup.expanded.A, + Cfg::N_A, + layout.inner_width, + "A", + )?; + ensure_matrix_shape_ge::( + &prover_setup.expanded.B, + Cfg::N_B, + layout.outer_width, + "B", + )?; + ensure_matrix_shape_ge::( + &prover_setup.expanded.D_mat, + Cfg::N_D, + layout.d_matrix_width, + "D", + )?; + Ok((prover_setup, verifier_setup)) + } + + fn layout(setup: &Self::ProverSetup) -> Result { + Ok(setup.layout()) + } + + #[tracing::instrument(skip_all, name = "RingCommitmentScheme::commit_ring_blocks")] + fn commit_ring_blocks( + f_blocks: &[Vec>], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let layout = setup.layout(); + ensure_supported_num_vars( + setup.expanded.seed.max_num_vars, + layout.required_num_vars::()?, + )?; + ensure_block_layout(f_blocks, layout)?; + ensure_matrix_shape_ge::(&setup.expanded.A, Cfg::N_A, layout.inner_width, "A")?; + ensure_matrix_shape_ge::(&setup.expanded.B, Cfg::N_B, layout.outer_width, "B")?; + + let depth_commit = layout.num_digits_commit; + let depth_open = layout.num_digits_open; + let log_basis = layout.log_basis; + let block_slices: Vec<&[CyclotomicRing]> = + f_blocks.iter().map(|b| b.as_slice()).collect(); + let t_all = mat_vec_mul_ntt_i8(&setup.ntt_A, &block_slices, depth_commit, log_basis); + let t_hat_all: Vec> = cfg_into_iter!(t_all) + .map(|t_i| decompose_rows_i8(&t_i, depth_open, log_basis)) + .collect(); + + let t_hat_flat = flatten_i8_blocks(&t_hat_all); + + let mut u: Vec> = mat_vec_mul_ntt_single_i8(&setup.ntt_B, &t_hat_flat); + let root_params = Cfg::level_params(HachiScheduleInputs { + max_num_vars: setup.expanded.seed.max_num_vars, + level: 0, + current_w_len: 1usize << setup.expanded.seed.max_num_vars, + }); + u.truncate(root_params.n_b); + Ok(CommitWitness::new(RingCommitment { u }, t_hat_all)) + } + + #[tracing::instrument(skip_all, name = "RingCommitmentScheme::commit_coeffs")] + fn commit_coeffs( + f_coeffs: &[CyclotomicRing], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let layout = setup.layout(); + let num_blocks = layout.num_blocks; + let block_len = layout.block_len; + let max_len = num_blocks + .checked_mul(block_len) + .ok_or_else(|| HachiError::InvalidSetup("coefficient length overflow".to_string()))?; + if f_coeffs.len() > max_len { + return Err(HachiError::InvalidSize { + expected: max_len, + actual: f_coeffs.len(), + }); + } + + let depth_commit = layout.num_digits_commit; + let depth_open = layout.num_digits_open; + let log_basis = layout.log_basis; + let coeff_len = f_coeffs.len(); + + let block_slices: Vec<&[CyclotomicRing]> = (0..num_blocks) + .map(|i| { + let start = i * block_len; + if start >= coeff_len { + &[] as &[CyclotomicRing] + } else { + &f_coeffs[start..(start + block_len).min(coeff_len)] + } + }) + .collect(); + + let t_all = mat_vec_mul_ntt_i8(&setup.ntt_A, &block_slices, depth_commit, log_basis); + let t_hat_all: Vec> = cfg_into_iter!(t_all) + .map(|t_i| decompose_rows_i8(&t_i, depth_open, log_basis)) + .collect(); + + let t_hat_flat = flatten_i8_blocks(&t_hat_all); + + let mut u: Vec> = mat_vec_mul_ntt_single_i8(&setup.ntt_B, &t_hat_flat); + let root_params = Cfg::level_params(HachiScheduleInputs { + max_num_vars: setup.expanded.seed.max_num_vars, + level: 0, + current_w_len: 1usize << setup.expanded.seed.max_num_vars, + }); + u.truncate(root_params.n_b); + Ok(CommitWitness::new(RingCommitment { u }, t_hat_all)) + } + + #[tracing::instrument(skip_all, name = "RingCommitmentScheme::commit_onehot")] + fn commit_onehot( + onehot_k: usize, + indices: &[Option], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let layout = setup.layout(); + ensure_supported_num_vars( + setup.expanded.seed.max_num_vars, + layout.required_num_vars::()?, + )?; + ensure_matrix_shape_ge::(&setup.expanded.A, Cfg::N_A, layout.inner_width, "A")?; + ensure_matrix_shape_ge::(&setup.expanded.B, Cfg::N_B, layout.outer_width, "B")?; + + let sparse_blocks = + map_onehot_to_sparse_blocks(onehot_k, indices, layout.r_vars, layout.m_vars, D)?; + + let depth_commit = layout.num_digits_commit; + let depth_open = layout.num_digits_open; + let log_basis = layout.log_basis; + let zero_block_len = Cfg::N_A.checked_mul(depth_open).unwrap(); + let a_view = setup.expanded.A.view::(); + let block_len = layout.block_len; + + let t_hat_all: Vec> = cfg_iter!(sparse_blocks) + .map(|block_entries| { + if block_entries.is_empty() { + vec![[0i8; D]; zero_block_len] + } else { + let t_i = + inner_ajtai_onehot_wide(&a_view, block_entries, block_len, depth_commit); + decompose_rows_i8(&t_i, depth_open, log_basis) + } + }) + .collect(); + + let t_hat_flat = flatten_i8_blocks(&t_hat_all); + + let mut u: Vec> = mat_vec_mul_ntt_single_i8(&setup.ntt_B, &t_hat_flat); + let root_params = Cfg::level_params(HachiScheduleInputs { + max_num_vars: setup.expanded.seed.max_num_vars, + level: 0, + current_w_len: 1usize << setup.expanded.seed.max_num_vars, + }); + u.truncate(root_params.n_b); + Ok(CommitWitness::new(RingCommitment { u }, t_hat_all)) + } +} + +impl HachiCommitmentCore { + #[allow(clippy::too_many_arguments)] + fn layout_envelope( + max_num_vars: usize, + inner_width: usize, + outer_width: usize, + d_matrix_width: usize, + preferred_r_vars: usize, + num_digits_open: usize, + num_digits_fold: usize, + log_basis: u32, + ) -> Result { + let alpha = D.trailing_zeros() as usize; + let outer_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + let r_vars = preferred_r_vars.min(outer_vars); + let m_vars = outer_vars - r_vars; + let num_blocks = 1usize + .checked_shl(r_vars as u32) + .ok_or_else(|| HachiError::InvalidSetup("num_blocks overflow".to_string()))?; + let block_len = 1usize + .checked_shl(m_vars as u32) + .ok_or_else(|| HachiError::InvalidSetup("block_len overflow".to_string()))?; + + Ok(HachiCommitmentLayout { + m_vars, + r_vars, + num_blocks, + block_len, + inner_width, + outer_width, + d_matrix_width, + // Setup metadata only tracks width envelopes; runtime commits/proofs + // carry their own exact decomposition parameters. + num_digits_commit: 1, + num_digits_open, + num_digits_fold, + log_basis, + }) + } + + /// Create a setup with a caller-specified layout, bypassing + /// `CommitmentConfig::commitment_layout`. + /// + /// Use this when the desired `(m_vars, r_vars)` split differs from what + /// the config's heuristic would choose (e.g. mega-polynomial commitments + /// where each sub-polynomial occupies one block). + /// + /// # Errors + /// + /// Returns `HachiError` on invalid layout or matrix generation failures. + pub fn setup_with_layout( + layout: HachiCommitmentLayout, + ) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + let max_num_vars = layout.required_num_vars::()?; + let public_matrix_seed = sample_public_matrix_seed(); + Self::setup_with_layout_and_seed::(layout, max_num_vars, public_matrix_seed) + } + + /// Create a setup that supports any of the provided runtime layouts. + /// + /// This sizes the public matrices from the exact per-layout maxima + /// (including recursive `w` commitments) instead of inflating through a + /// synthetic max layout. + /// + /// # Errors + /// + /// Returns `HachiError` if `layouts` is empty, uses inconsistent + /// decomposition parameters, or matrix generation fails. + pub fn setup_with_layouts( + layouts: &[HachiCommitmentLayout], + ) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + let Some((&first_layout, _)) = layouts.split_first() else { + return Err(HachiError::InvalidSetup( + "setup_with_layouts requires at least one layout".to_string(), + )); + }; + + let mut max_num_vars = 0usize; + let mut max_inner_width = 0usize; + let mut max_outer_width = 0usize; + let mut max_d_matrix_width = 0usize; + let mut max_r_vars = 0usize; + let mut max_num_digits_open = 0usize; + let mut max_num_digits_fold = 0usize; + + for &layout in layouts { + if layout.log_basis != first_layout.log_basis { + return Err(HachiError::InvalidSetup(format!( + "setup_with_layouts requires a shared log_basis (expected {}, got {})", + first_layout.log_basis, layout.log_basis + ))); + } + + max_num_vars = max_num_vars.max(layout.required_num_vars::()?); + max_inner_width = max_inner_width.max(layout.inner_width); + max_outer_width = max_outer_width.max(layout.outer_width); + max_d_matrix_width = max_d_matrix_width.max(layout.d_matrix_width); + max_r_vars = max_r_vars.max(layout.r_vars); + max_num_digits_open = max_num_digits_open.max(layout.num_digits_open); + max_num_digits_fold = max_num_digits_fold.max(layout.num_digits_fold); + + let layout_num_vars = layout.required_num_vars::()?; + let level_params = Cfg::level_params(HachiScheduleInputs { + max_num_vars: layout_num_vars, + level: 0, + current_w_len: 1usize << layout_num_vars, + }); + let w_layout = w_commitment_layout::(level_params, layout)?; + tracing::debug!(?layout, ?w_layout, "setup layout"); + max_inner_width = max_inner_width.max(w_layout.inner_width); + max_outer_width = max_outer_width.max(w_layout.outer_width); + max_d_matrix_width = max_d_matrix_width.max(w_layout.d_matrix_width); + max_r_vars = max_r_vars.max(w_layout.r_vars); + max_num_digits_open = max_num_digits_open.max(w_layout.num_digits_open); + max_num_digits_fold = max_num_digits_fold.max(w_layout.num_digits_fold); + } + + let envelope_layout = Self::layout_envelope::( + max_num_vars, + max_inner_width, + max_outer_width, + max_d_matrix_width, + max_r_vars, + max_num_digits_open, + max_num_digits_fold, + first_layout.log_basis, + )?; + tracing::debug!(?envelope_layout, max_num_vars, "setup envelope"); + let public_matrix_seed = sample_public_matrix_seed(); + Self::setup_with_matrix_widths_and_seed::( + envelope_layout, + max_num_vars, + public_matrix_seed, + max_inner_width, + max_outer_width, + max_d_matrix_width, + ) + } + + /// Like `setup_with_layout` but reuses an existing setup's random seed and + /// A matrix (which depends only on `m_vars`). Only regenerates B and D + /// matrices for the new `r_vars`. + /// + /// Use this when creating a mega-polynomial setup that shares `m_vars` with + /// an individual polynomial setup — avoids re-deriving and NTT-transforming + /// the A matrix. + /// + /// # Errors + /// + /// Returns `HachiError` if the new layout is incompatible with the existing + /// setup or matrix shapes are inconsistent. + pub fn setup_from_existing( + existing: &HachiExpandedSetup, + new_r_vars: usize, + ) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + let old_layout = existing.seed.layout; + let new_layout = HachiCommitmentLayout::new::( + old_layout.m_vars, + new_r_vars, + &Cfg::decomposition(), + )?; + + if new_layout.inner_width != old_layout.inner_width { + return Err(HachiError::InvalidSetup( + "setup_from_existing requires matching m_vars/inner_width".to_string(), + )); + } + + let layout_num_vars = new_layout.required_num_vars::()?; + let level_params = Cfg::level_params(HachiScheduleInputs { + max_num_vars: layout_num_vars, + level: 0, + current_w_len: 1usize << layout_num_vars, + }); + let w_layout = w_commitment_layout::(level_params, new_layout)?; + let a_width = existing.A.first_row_len::(); + if a_width < w_layout.inner_width { + return Err(HachiError::InvalidSetup(format!( + "existing A width {a_width} < w inner_width {}", + w_layout.inner_width + ))); + } + let b_cols = new_layout.outer_width.max(w_layout.outer_width); + + let max_num_vars = new_layout.required_num_vars::()?; + let seed = existing.seed.public_matrix_seed; + + let d_cols = new_layout.d_matrix_width.max(w_layout.d_matrix_width); + let b_matrix = derive_public_matrix::(Cfg::N_B, b_cols, &seed, b"B"); + let d_matrix = derive_public_matrix::(Cfg::N_D, d_cols, &seed, b"D"); + + let b_flat = FlatMatrix::from_ring_matrix(&b_matrix); + let d_flat = FlatMatrix::from_ring_matrix(&d_matrix); + + let ntt_a = build_ntt_slot(existing.A.view::())?; + let ntt_b = build_ntt_slot(b_flat.view::())?; + let ntt_d = build_ntt_slot(d_flat.view::())?; + let expanded = HachiExpandedSetup { + seed: HachiSetupSeed { + max_num_vars, + layout: new_layout, + public_matrix_seed: seed, + }, + A: existing.A.clone(), + B: b_flat, + D_mat: d_flat, + }; + let prover_setup = HachiProverSetup { + expanded: expanded.clone(), + ntt_A: ntt_a, + ntt_B: ntt_b, + ntt_D: ntt_d, + }; + let verifier_setup = HachiVerifierSetup { expanded }; + Ok((prover_setup, verifier_setup)) + } + + fn setup_with_layout_and_seed( + layout: HachiCommitmentLayout, + max_num_vars: usize, + public_matrix_seed: PublicMatrixSeed, + ) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + let layout_num_vars = layout.required_num_vars::()?; + let level_params = Cfg::level_params(HachiScheduleInputs { + max_num_vars: layout_num_vars, + level: 0, + current_w_len: 1usize << layout_num_vars, + }); + let w_layout = w_commitment_layout::(level_params, layout)?; + let a_cols = layout.inner_width.max(w_layout.inner_width); + let b_cols = layout.outer_width.max(w_layout.outer_width); + let d_cols = layout.d_matrix_width.max(w_layout.d_matrix_width); + + Self::setup_with_matrix_widths_and_seed::( + layout, + max_num_vars, + public_matrix_seed, + a_cols, + b_cols, + d_cols, + ) + } + + fn setup_with_matrix_widths_and_seed( + layout: HachiCommitmentLayout, + max_num_vars: usize, + public_matrix_seed: PublicMatrixSeed, + a_cols: usize, + b_cols: usize, + d_cols: usize, + ) -> Result<(HachiProverSetup, HachiVerifierSetup), HachiError> + where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, + { + { + let ring_bytes = std::mem::size_of::>(); + let a_raw_mb = (Cfg::N_A * a_cols * ring_bytes) as f64 / (1024.0_f64 * 1024.0_f64); + let b_raw_mb = (Cfg::N_B * b_cols * ring_bytes) as f64 / (1024.0_f64 * 1024.0_f64); + let d_raw_mb = (Cfg::N_D * d_cols * ring_bytes) as f64 / (1024.0_f64 * 1024.0_f64); + tracing::debug!( + a_cols, + b_cols, + d_cols, + ring_bytes, + a_mb = a_raw_mb, + b_mb = b_raw_mb, + d_mb = d_raw_mb, + total_mb = a_raw_mb + b_raw_mb + d_raw_mb, + "setup matrix sizes" + ); + } + let a_matrix = derive_public_matrix::(Cfg::N_A, a_cols, &public_matrix_seed, b"A"); + let b_matrix = derive_public_matrix::(Cfg::N_B, b_cols, &public_matrix_seed, b"B"); + let d_matrix = derive_public_matrix::(Cfg::N_D, d_cols, &public_matrix_seed, b"D"); + + let a_flat = FlatMatrix::from_ring_matrix(&a_matrix); + let b_flat = FlatMatrix::from_ring_matrix(&b_matrix); + let d_flat = FlatMatrix::from_ring_matrix(&d_matrix); + + let ntt_a = build_ntt_slot(a_flat.view::())?; + let ntt_b = build_ntt_slot(b_flat.view::())?; + let ntt_d = build_ntt_slot(d_flat.view::())?; + let expanded = HachiExpandedSetup { + seed: HachiSetupSeed { + max_num_vars, + layout, + public_matrix_seed, + }, + A: a_flat, + B: b_flat, + D_mat: d_flat, + }; + let prover_setup = HachiProverSetup { + expanded: expanded.clone(), + ntt_A: ntt_a, + ntt_B: ntt_b, + ntt_D: ntt_d, + }; + let verifier_setup = HachiVerifierSetup { expanded }; + ensure_matrix_shape_ge::( + &prover_setup.expanded.A, + Cfg::N_A, + layout.inner_width, + "A", + )?; + ensure_matrix_shape_ge::( + &prover_setup.expanded.B, + Cfg::N_B, + layout.outer_width, + "B", + )?; + ensure_matrix_shape_ge::( + &prover_setup.expanded.D_mat, + Cfg::N_D, + layout.d_matrix_width, + "D", + )?; + Ok((prover_setup, verifier_setup)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::primitives::{HachiDeserialize, HachiSerialize}; + use crate::test_utils::{TinyConfig, F as TestF}; + + #[test] + fn expanded_setup_roundtrips_and_derives_same_verifier() { + const TEST_D: usize = 64; + let (prover_setup, verifier_setup) = + >::setup(16) + .unwrap(); + + let mut bytes = Vec::new(); + prover_setup + .expanded + .serialize_compressed(&mut bytes) + .unwrap(); + let decoded = HachiExpandedSetup::::deserialize_compressed(&bytes[..]).unwrap(); + + assert_eq!(decoded, prover_setup.expanded); + + let derived_verifier = HachiVerifierSetup { + expanded: decoded.clone(), + }; + assert_eq!(derived_verifier, verifier_setup); + } + + #[test] + fn setup_with_layouts_uses_exact_width_envelope() { + const TEST_D: usize = 64; + + let layout_a = + HachiCommitmentLayout::new::(4, 2, &TinyConfig::decomposition()).unwrap(); + let layout_b = + HachiCommitmentLayout::new::(1, 6, &TinyConfig::decomposition()).unwrap(); + let params_a = TinyConfig::level_params(HachiScheduleInputs { + max_num_vars: layout_a.required_num_vars::().unwrap(), + level: 0, + current_w_len: 1usize << layout_a.required_num_vars::().unwrap(), + }); + let params_b = TinyConfig::level_params(HachiScheduleInputs { + max_num_vars: layout_b.required_num_vars::().unwrap(), + level: 0, + current_w_len: 1usize << layout_b.required_num_vars::().unwrap(), + }); + let w_layout_a = + w_commitment_layout::(params_a, layout_a).unwrap(); + let w_layout_b = + w_commitment_layout::(params_b, layout_b).unwrap(); + + let expected_inner = [ + layout_a.inner_width, + layout_b.inner_width, + w_layout_a.inner_width, + w_layout_b.inner_width, + ] + .into_iter() + .max() + .unwrap(); + let expected_outer = [ + layout_a.outer_width, + layout_b.outer_width, + w_layout_a.outer_width, + w_layout_b.outer_width, + ] + .into_iter() + .max() + .unwrap(); + let expected_d = [ + layout_a.d_matrix_width, + layout_b.d_matrix_width, + w_layout_a.d_matrix_width, + w_layout_b.d_matrix_width, + ] + .into_iter() + .max() + .unwrap(); + let expected_max_num_vars = [ + layout_a.required_num_vars::().unwrap(), + layout_b.required_num_vars::().unwrap(), + ] + .into_iter() + .max() + .unwrap(); + + let (setup, _) = HachiCommitmentCore::setup_with_layouts::(&[ + layout_a, layout_b, + ]) + .unwrap(); + let envelope = setup.layout(); + + assert_eq!(setup.expanded.seed.max_num_vars, expected_max_num_vars); + assert_eq!(envelope.inner_width, expected_inner); + assert_eq!(envelope.outer_width, expected_outer); + assert_eq!(envelope.d_matrix_width, expected_d); + assert_eq!(setup.expanded.A.first_row_len::(), expected_inner); + assert_eq!(setup.expanded.B.first_row_len::(), expected_outer); + assert_eq!(setup.expanded.D_mat.first_row_len::(), expected_d); + } + + #[cfg(feature = "disk-persistence")] + mod disk_persistence { + use super::*; + use std::fs; + + fn cleanup_setup_file(max_num_vars: usize) { + if let Some(path) = get_storage_path(max_num_vars) { + let _ = fs::remove_file(path); + } + } + + #[test] + fn save_and_load_roundtrips() { + const TEST_D: usize = 64; + const MAX_VARS: usize = 100; + + cleanup_setup_file(MAX_VARS); + + let (prover_setup, _) = + >::setup( + MAX_VARS, + ) + .unwrap(); + + let loaded = load_expanded_setup::(MAX_VARS).unwrap(); + assert_eq!(loaded, prover_setup.expanded); + + cleanup_setup_file(MAX_VARS); + } + + #[test] + fn setup_uses_cache_on_second_call() { + const TEST_D: usize = 64; + const MAX_VARS: usize = 101; + + cleanup_setup_file(MAX_VARS); + + let (first, _) = + >::setup( + MAX_VARS, + ) + .unwrap(); + + let (second, _) = + >::setup( + MAX_VARS, + ) + .unwrap(); + + assert_eq!(first.expanded, second.expanded); + + cleanup_setup_file(MAX_VARS); + } + + #[test] + fn ntt_caches_rebuilt_correctly_from_disk() { + use crate::algebra::CyclotomicRing; + + const TEST_D: usize = 64; + const MAX_VARS: usize = 102; + + cleanup_setup_file(MAX_VARS); + + let (fresh_setup, _) = + >::setup( + MAX_VARS, + ) + .unwrap(); + + let loaded_expanded = load_expanded_setup::(MAX_VARS).unwrap(); + let (disk_setup, _) = setup_from_expanded::(loaded_expanded).unwrap(); + + let layout = fresh_setup.layout(); + let num_coeffs = layout.num_blocks * layout.block_len; + let coeffs = vec![CyclotomicRing::::zero(); num_coeffs]; + + let fresh_commit = >::commit_coeffs(&coeffs, &fresh_setup) + .unwrap(); + let disk_commit = >::commit_coeffs(&coeffs, &disk_setup) + .unwrap(); + + assert_eq!(fresh_commit.commitment, disk_commit.commitment); + + cleanup_setup_file(MAX_VARS); + } + } +} diff --git a/src/protocol/commitment/config.rs b/src/protocol/commitment/config.rs new file mode 100644 index 00000000..86a06dd1 --- /dev/null +++ b/src/protocol/commitment/config.rs @@ -0,0 +1,1034 @@ +//! Configuration presets for ring-native commitment construction. + +use super::schedule::{default_stage1_challenge_config, HachiLevelParams, HachiScheduleInputs}; +use super::utils::flat_matrix::FlatMatrix; +use super::utils::math::checked_pow2; +use crate::algebra::ring::CyclotomicRing; +use crate::algebra::SparseChallengeConfig; +use crate::error::HachiError; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::FieldCore; +use std::io::{Read, Write}; + +/// Parameters controlling the gadget decomposition depth (called δ in the paper). +/// +/// The gadget base is `b = 2^log_basis`. Each ring coefficient with centered +/// magnitude fitting in `log_commit_bound` bits is decomposed into +/// `ceil(log_commit_bound / log_basis)` balanced digits in `[-b/2, b/2)`. +/// +/// Smaller `log_commit_bound` (when polynomial coefficients are known to be +/// small) yields fewer decomposition levels, which proportionally shrinks the +/// witness vector, the commitment matrices, and the proving cost. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct DecompositionParams { + /// Base-2 logarithm of the gadget base (e.g., 3 for base-8 digits in [-4, 3]). + pub log_basis: u32, + + /// Bit-width of the largest coefficient that the *commitment* decomposition + /// must represent. Controls the commitment-side decomposition depth (δ in + /// the paper): `num_digits = ceil(log_commit_bound / log_basis)`. + /// + /// The centered representation maps each coefficient `c ∈ [0, q)` to the + /// signed value in `(-q/2, q/2]`. A value of `k` means the signed magnitude + /// fits in `k` bits, i.e., lies in `[-2^(k-1), 2^(k-1) - 1]`. + /// + /// Examples: + /// - Binary (0/1) polynomials: 1 + /// - Already range-checked digits in `[-b/2, b/2)`: `log_basis` (one digit) + /// - Arbitrary Fp128 elements: 128 + pub log_commit_bound: u32, + + /// Bit-width of the largest coefficient that the *opening* decomposition + /// must represent (ŵ = G⁻¹(w_folded)). + /// + /// During opening, `fold_blocks` computes inner products with arbitrary + /// field-element weights, so the result always has full-field-size + /// coefficients regardless of the original `log_commit_bound`. When `None`, + /// defaults to `log_commit_bound` (correct when `log_commit_bound` already + /// covers the full field, e.g. 128). Set to the field modulus bit-width + /// when `log_commit_bound` is smaller (e.g. for recursive w commitments + /// where entries are small but fold products are not). + pub log_open_bound: Option, +} + +/// Compute the gadget decomposition depth (δ in the paper) from a +/// coefficient bit-width bound. +/// +/// Returns `ceil(log_bound / log_basis)`, with an extra level when the +/// balanced-digit range would not cover the full bound. +/// +/// # Panics +/// +/// Panics if `log_basis` is 0 or >= 128. +pub fn compute_num_digits(log_bound: u32, log_basis: u32) -> usize { + assert!(log_basis > 0 && log_basis < 128, "invalid log_basis"); + if log_bound == 0 { + return 1; + } + let mut levels = (log_bound as usize).div_ceil(log_basis as usize); + + // When levels * log_basis > log_bound (i.e., not exactly aligned), the + // balanced digit range (b/2-1) * (b^levels - 1)/(b-1) always exceeds + // 2^(log_bound-1) for b >= 4 (log_basis >= 2). Only check when aligned. + let total_bits = (levels as u32).saturating_mul(log_basis); + if total_bits <= log_bound { + let b: u128 = 1u128 << log_basis; + let half_b_minus_1 = b / 2 - 1; + let b_minus_1 = b - 1; + let mut b_pow = 1u128; + for _ in 0..levels { + b_pow = b_pow.saturating_mul(b); + } + let max_positive = half_b_minus_1.saturating_mul(b_pow.saturating_sub(1) / b_minus_1); + let required = if log_bound > 128 { + u128::MAX / 2 + } else if log_bound == 0 { + 0 + } else { + (1u128 << (log_bound - 1)).saturating_sub(1) + }; + if max_positive < required { + levels += 1; + } + } + levels.max(1) +} + +/// Compute the decomposition depth for the folded witness `z_pre` +/// (τ in the paper). +/// +/// The folded witness satisfies `||z_pre||_inf <= β` where +/// `β = 2^r_vars * challenge_weight * 2^(log_basis - 1)`. +/// Returns enough gadget levels to represent values up to `β`. +pub fn compute_num_digits_fold(r_vars: usize, challenge_weight: usize, log_basis: u32) -> usize { + let shift = r_vars + (log_basis as usize) - 1; + if shift >= 127 || challenge_weight == 0 { + return compute_num_digits(128, log_basis); + } + let beta = (challenge_weight as u128).saturating_mul(1u128 << shift); + if beta == 0 { + return 1; + } + let log_beta = 128 - beta.leading_zeros(); + compute_num_digits(log_beta, log_basis) +} + +/// Find the `(m_vars, r_vars)` split that minimizes the level-0 +/// witness-to-polynomial ratio for a given config. +/// +/// The witness ring element count is dominated by: +/// ```text +/// w ≈ 2^r · (δ_open + N_A · δ_commit) + 2^m · δ_commit · δ_fold(r) +/// ``` +/// Multiplying the ratio by `2^(m+r)` (constant for fixed `reduced_vars`) +/// gives an equivalent integer cost: +/// ```text +/// C1 · 2^r + δ_commit · δ_fold(r) · 2^m +/// ``` +/// where `C1 = δ_open + N_A · δ_commit`. This function searches all valid +/// `(m, r)` pairs for the minimum using pure integer arithmetic (no +/// floating-point), so it is safe to run inside a zkVM guest. +/// +/// For the full-field config (`δ_commit = 43`), z_pre dominates and the +/// result is near-balanced (`m ≈ r`). For narrow configs (`δ_commit = 1`), +/// the w_hat/t_hat term matters more and the result skews to `m ≈ r + 4`. +pub(super) fn optimal_m_r_split_with_params( + params: HachiLevelParams, + decomp: DecompositionParams, + reduced_vars: usize, +) -> (usize, usize) { + // Guard: for S >= 53, shifts could overflow u64. Fall back to balanced + // split (this threshold is far beyond any practical polynomial size). + if reduced_vars <= 2 || reduced_vars >= 53 { + let r = reduced_vars / 2; + return (reduced_vars - r, r); + } + + let open_bound = decomp.log_open_bound.unwrap_or(decomp.log_commit_bound); + let delta_open = compute_num_digits(open_bound, decomp.log_basis) as u64; + let delta_commit = compute_num_digits(decomp.log_commit_bound, decomp.log_basis) as u64; + let c1 = delta_open + params.n_a as u64 * delta_commit; + + let mut best_r = reduced_vars / 2; + let mut best_cost = u64::MAX; + + for r in 1..reduced_vars { + let m = reduced_vars - r; + let delta_fold = + compute_num_digits_fold(r, params.challenge_weight, decomp.log_basis) as u64; + let cost = c1 * (1u64 << r) + delta_commit * delta_fold * (1u64 << m); + if cost < best_cost { + best_cost = cost; + best_r = r; + } + } + + (reduced_vars - best_r, best_r) +} + +/// Find the `(m_vars, r_vars)` split using the level-0 params of `Cfg`. +pub fn optimal_m_r_split(reduced_vars: usize) -> (usize, usize) { + let params = Cfg::level_params(HachiScheduleInputs { + max_num_vars: reduced_vars + Cfg::D.trailing_zeros() as usize, + level: 0, + current_w_len: 1usize << (reduced_vars + Cfg::D.trailing_zeros() as usize), + }); + optimal_m_r_split_with_params(params, Cfg::decomposition(), reduced_vars) +} + +const D64_CHALLENGE_MASS: usize = 54; +const D64_STAGE1_SUPPORT: usize = 21; + +fn d64_stage1_challenge_config(_level_params: HachiLevelParams) -> SparseChallengeConfig { + SparseChallengeConfig { + weight: D64_STAGE1_SUPPORT, + nonzero_coeffs: vec![-2, -1, 1, 2], + } +} + +/// Runtime commitment layout authority for ring-native commitments. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct HachiCommitmentLayout { + /// Number of variables inside each committed block (`2^m_vars` entries). + pub m_vars: usize, + /// Number of block-select variables (`2^r_vars` blocks). + pub r_vars: usize, + /// Number of committed blocks (`2^r_vars`). + pub num_blocks: usize, + /// Number of ring elements per block (`2^m_vars`). + pub block_len: usize, + /// Width of inner matrix `A` (`block_len * num_digits_commit`). + pub inner_width: usize, + /// Width of outer matrix `B` (`n_a * num_digits_open * num_blocks`). + pub outer_width: usize, + /// Width of prover matrix `D` (`num_digits_open * num_blocks`). + pub d_matrix_width: usize, + /// Number of gadget decomposition levels for commitment-time coefficients + /// (δ_commit in the paper). Controls how the original polynomial + /// coefficients are decomposed into balanced base-b digits for the Ajtai + /// commitment. + pub num_digits_commit: usize, + /// Number of gadget decomposition levels for opening-time folded + /// evaluations (δ_open in the paper). Folding inner-products with + /// arbitrary field-element weights produces full-field-size coefficients, + /// so this equals `num_digits_commit` when `log_commit_bound` covers + /// the full field, and is larger otherwise (e.g. recursive w witnesses). + pub num_digits_open: usize, + /// Number of gadget decomposition levels for the folded witness `z_pre` + /// (τ in the paper). Derived from the L∞ bound on `z_pre`. + pub num_digits_fold: usize, + /// Base-2 logarithm of gadget decomposition base. + pub log_basis: u32, +} + +impl HachiCommitmentLayout { + /// Build a layout from `(m_vars, r_vars)`, config constants, and decomposition + /// parameters. + /// + /// `num_digits_fold` (τ) is auto-derived from the beta bound + /// (`r_vars`, `challenge_weight`, `log_basis`). + /// + /// # Errors + /// + /// Returns an error when powers or derived widths overflow. + pub fn new( + m_vars: usize, + r_vars: usize, + decomp: &DecompositionParams, + ) -> Result { + let depth_commit = compute_num_digits(decomp.log_commit_bound, decomp.log_basis); + let open_bound = decomp.log_open_bound.unwrap_or(decomp.log_commit_bound); + let depth_open = compute_num_digits(open_bound, decomp.log_basis); + let depth_fold = compute_num_digits_fold( + r_vars, + Cfg::challenge_weight_for_ring_dim(Cfg::D), + decomp.log_basis, + ); + Self::new_with_decomp( + m_vars, + r_vars, + Cfg::N_A, + depth_commit, + depth_open, + depth_fold, + decomp.log_basis, + ) + } + + /// Build a layout from explicit decomposition parameters (no config trait needed). + /// + /// # Errors + /// + /// Returns an error when parameters are invalid or derived widths overflow. + pub fn new_with_decomp( + m_vars: usize, + r_vars: usize, + n_a: usize, + num_digits_commit: usize, + num_digits_open: usize, + num_digits_fold: usize, + log_basis: u32, + ) -> Result { + if log_basis == 0 || log_basis >= 128 { + return Err(HachiError::InvalidSetup("invalid log_basis".to_string())); + } + let num_blocks = checked_pow2(r_vars)?; + let block_len = checked_pow2(m_vars)?; + let inner_width = block_len + .checked_mul(num_digits_commit) + .ok_or_else(|| HachiError::InvalidSetup("inner width overflow".to_string()))?; + let outer_width = n_a + .checked_mul(num_digits_open) + .and_then(|x| x.checked_mul(num_blocks)) + .ok_or_else(|| HachiError::InvalidSetup("outer width overflow".to_string()))?; + let d_matrix_width = num_digits_open + .checked_mul(num_blocks) + .ok_or_else(|| HachiError::InvalidSetup("D-matrix width overflow".to_string()))?; + Ok(Self { + m_vars, + r_vars, + num_blocks, + block_len, + inner_width, + outer_width, + d_matrix_width, + num_digits_commit, + num_digits_open, + num_digits_fold, + log_basis, + }) + } + + /// Total number of outer variables consumed by ring coefficients. + /// + /// # Errors + /// + /// Returns an error if the variable count overflows. + pub fn outer_vars(&self) -> Result { + self.m_vars + .checked_add(self.r_vars) + .ok_or_else(|| HachiError::InvalidSetup("variable count overflow".to_string())) + } + + /// Required polynomial variable count for this layout (`outer + alpha`). + /// + /// # Errors + /// + /// Returns an error if the variable count overflows. + pub fn required_num_vars(&self) -> Result { + let alpha = D.trailing_zeros() as usize; + self.outer_vars()? + .checked_add(alpha) + .ok_or_else(|| HachiError::InvalidSetup("variable count overflow".to_string())) + } +} + +impl Valid for HachiCommitmentLayout { + fn check(&self) -> Result<(), SerializationError> { + if self.num_blocks == 0 || self.block_len == 0 { + return Err(SerializationError::InvalidData( + "invalid zero block layout".to_string(), + )); + } + Ok(()) + } +} + +impl HachiSerialize for HachiCommitmentLayout { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.m_vars.serialize_with_mode(&mut writer, compress)?; + self.r_vars.serialize_with_mode(&mut writer, compress)?; + self.num_blocks.serialize_with_mode(&mut writer, compress)?; + self.block_len.serialize_with_mode(&mut writer, compress)?; + self.inner_width + .serialize_with_mode(&mut writer, compress)?; + self.outer_width + .serialize_with_mode(&mut writer, compress)?; + self.d_matrix_width + .serialize_with_mode(&mut writer, compress)?; + self.num_digits_commit + .serialize_with_mode(&mut writer, compress)?; + self.num_digits_open + .serialize_with_mode(&mut writer, compress)?; + self.num_digits_fold + .serialize_with_mode(&mut writer, compress)?; + (self.log_basis as usize).serialize_with_mode(&mut writer, compress)?; + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.m_vars.serialized_size(compress) + + self.r_vars.serialized_size(compress) + + self.num_blocks.serialized_size(compress) + + self.block_len.serialized_size(compress) + + self.inner_width.serialized_size(compress) + + self.outer_width.serialized_size(compress) + + self.d_matrix_width.serialized_size(compress) + + self.num_digits_commit.serialized_size(compress) + + self.num_digits_open.serialized_size(compress) + + self.num_digits_fold.serialized_size(compress) + + (self.log_basis as usize).serialized_size(compress) + } +} + +impl HachiDeserialize for HachiCommitmentLayout { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let out = Self { + m_vars: usize::deserialize_with_mode(&mut reader, compress, validate)?, + r_vars: usize::deserialize_with_mode(&mut reader, compress, validate)?, + num_blocks: usize::deserialize_with_mode(&mut reader, compress, validate)?, + block_len: usize::deserialize_with_mode(&mut reader, compress, validate)?, + inner_width: usize::deserialize_with_mode(&mut reader, compress, validate)?, + outer_width: usize::deserialize_with_mode(&mut reader, compress, validate)?, + d_matrix_width: usize::deserialize_with_mode(&mut reader, compress, validate)?, + num_digits_commit: usize::deserialize_with_mode(&mut reader, compress, validate)?, + num_digits_open: usize::deserialize_with_mode(&mut reader, compress, validate)?, + num_digits_fold: usize::deserialize_with_mode(&mut reader, compress, validate)?, + log_basis: usize::deserialize_with_mode(&mut reader, compress, validate)? as u32, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Parameter bundle for the ring-native commitment core (§4.1–§4.2). +/// +/// Security parameters (`N_A`, `N_B`, `N_D`, `CHALLENGE_WEIGHT`) are +/// compile-time constants fixed for a given security level. Decomposition +/// parameters (gadget depths, `log_basis`) are runtime values derived from +/// [`DecompositionParams`] and live in [`HachiCommitmentLayout`]. +pub trait CommitmentConfig: Clone + Send + Sync + 'static { + /// Ring degree used by `CyclotomicRing`. + const D: usize; + /// Inner Ajtai matrix row count. + const N_A: usize; + /// Outer commitment matrix row count. + const N_B: usize; + /// Prover commitment matrix `D` row count (§4.2). + const N_D: usize; + /// Hamming weight of sparse challenges (`ω` in the paper). + const CHALLENGE_WEIGHT: usize; + + /// Decomposition parameters (gadget base and coefficient bounds). + fn decomposition() -> DecompositionParams; + + /// Stable identifier for setup-cache versioning and fixture selection. + fn family_key() -> &'static str { + std::any::type_name::() + } + + /// Deterministically derive the active params for one level from public inputs. + fn level_params(inputs: HachiScheduleInputs) -> HachiLevelParams { + HachiLevelParams { + d: Self::d_at_level(inputs.level, inputs.current_w_len), + log_basis: if inputs.level == 0 { + Self::decomposition().log_basis + } else { + Self::w_log_basis() + }, + n_a: Self::n_a_at_level(inputs.level), + n_b: Self::n_b_at_level(inputs.level, inputs.max_num_vars, inputs.current_w_len), + n_d: Self::n_d_at_level(inputs.level, inputs.max_num_vars, inputs.current_w_len), + challenge_weight: Self::challenge_weight_for_ring_dim(Self::d_at_level( + inputs.level, + inputs.current_w_len, + )), + } + } + + /// Choose the runtime commitment layout for `max_num_vars`. + /// + /// # Errors + /// + /// Returns an error if `max_num_vars` does not admit a valid layout. + fn commitment_layout(max_num_vars: usize) -> Result; + + /// Runtime L∞ bound for `z` (`β`) used by stage-1 folding checks. + /// + /// # Errors + /// + /// Returns an error on invalid parameters or arithmetic overflow. + fn beta_bound(layout: HachiCommitmentLayout) -> Result { + beta_linf_fold_bound( + layout.r_vars, + Self::challenge_weight_for_ring_dim(Self::D), + layout.log_basis, + ) + } + + /// Ring dimension to use at a given fold level. + /// + /// `level` is 0-indexed (level 0 is the initial polynomial). + /// `_w_num_vars` is the number of variables in the witness at this level. + /// + /// The default implementation returns `Self::D` at all levels (constant D). + /// Override for decreasing-D schedules. + fn d_at_level(_level: usize, _current_w_len: usize) -> usize { + Self::D + } + + /// Module rank (inner Ajtai row count) at a given fold level. + /// + /// Must satisfy `d_at_level(level) * n_a_at_level(level) >= security_dim` + /// for the target security level. The default returns `Self::N_A` at all levels. + fn n_a_at_level(_level: usize) -> usize { + Self::N_A + } + + /// Outer commitment matrix rank at a given fold level. + fn n_b_at_level(_level: usize, _max_num_vars: usize, _current_w_len: usize) -> usize { + Self::N_B + } + + /// Prover D-matrix rank at a given fold level. + fn n_d_at_level(_level: usize, _max_num_vars: usize, _current_w_len: usize) -> usize { + Self::N_D + } + + /// Challenge weight (Hamming weight ω) appropriate for ring dimension `d`. + /// + /// The default returns `Self::CHALLENGE_WEIGHT` for any `d`, which is + /// correct for constant-D configs. Override for varying-D schedules where + /// the optimal weight depends on the ring dimension (e.g., to maintain + /// ≥128 bits of challenge entropy as D decreases). + fn challenge_weight_for_ring_dim(_d: usize) -> usize { + Self::CHALLENGE_WEIGHT + } + + /// Sparse challenge family used at this level. + fn stage1_challenge_config(level_params: HachiLevelParams) -> SparseChallengeConfig { + default_stage1_challenge_config(level_params) + } + + /// Gadget base for recursive w-opening levels (levels 1+). + /// + /// At level 0, the decomposition uses `Self::decomposition().log_basis`. + /// At recursive levels, the `WCommitmentConfig` uses this value instead. + /// A larger basis at w-levels reduces decomposition depth (`delta_open`) + /// at the cost of a higher-degree norm sumcheck — acceptable because the + /// w-level witness is much smaller than the level-0 witness. + /// + /// Default: same as the level-0 basis. + fn w_log_basis() -> u32 { + Self::decomposition().log_basis + } + + /// Witness length (in i8 digits) above which the prover hands off to + /// Labrador (D'=64) instead of sending the witness directly. + /// + /// The default returns 65 536 (64 Ki). Override to a lower value in test + /// configs to exercise the Labrador tail path with smaller polynomials. + fn labrador_handoff_threshold() -> usize { + 65_536 + } +} + +/// Deterministic upper bound for the stage-1 folded-witness infinity norm. +/// +/// This encodes the bound used in `QuadraticEquation::compute_z_hat`: +/// `||z||_inf <= 2^R * ω * (b/2)` where `b = 2^LOG_BASIS`. +/// +/// # Errors +/// +/// Returns an error when parameters are out of range or intermediate products +/// overflow `u128`. +pub fn beta_linf_fold_bound( + r: usize, + challenge_weight: usize, + log_basis: u32, +) -> Result { + if !(1..128).contains(&log_basis) { + return Err(HachiError::InvalidSetup("invalid LOG_BASIS".to_string())); + } + if r >= 128 { + return Err(HachiError::InvalidSetup("r_vars must be < 128".to_string())); + } + + let blocks = 1u128 << r; + let b = 1u128 << log_basis; + let half_b = b / 2; + + let term = blocks + .checked_mul(challenge_weight as u128) + .ok_or_else(|| HachiError::InvalidSetup("beta bound overflow".to_string()))?; + term.checked_mul(half_b) + .ok_or_else(|| HachiError::InvalidSetup("beta bound overflow".to_string())) +} + +/// Validate static config invariants and derive runtime dimensions. +/// +/// # Errors +/// +/// Returns an error when config constants are inconsistent or overflow. +pub(super) fn validate_and_derive_layout( + max_num_vars: usize, +) -> Result { + if D != Cfg::D { + return Err(HachiError::InvalidSetup(format!( + "const D={D} mismatches config D={}", + Cfg::D + ))); + } + Cfg::commitment_layout(max_num_vars) +} + +/// Ensure `max_num_vars` is sufficient for config dimensions. +/// +/// # Errors +/// +/// Returns an error when `max_num_vars < required_vars`. +pub(super) fn ensure_supported_num_vars( + max_num_vars: usize, + required_vars: usize, +) -> Result<(), HachiError> { + if max_num_vars < required_vars { + return Err(HachiError::InvalidSetup(format!( + "max_num_vars {max_num_vars} is smaller than required {required_vars}" + ))); + } + Ok(()) +} + +/// Ensure input blocks match the expected config-derived layout. +/// +/// # Errors +/// +/// Returns an error when block count or per-block size mismatch. +pub(super) fn ensure_block_layout( + f_blocks: &[Vec>], + layout: HachiCommitmentLayout, +) -> Result<(), HachiError> { + if f_blocks.len() != layout.num_blocks { + return Err(HachiError::InvalidSize { + expected: layout.num_blocks, + actual: f_blocks.len(), + }); + } + for block in f_blocks { + if block.len() != layout.block_len { + return Err(HachiError::InvalidSize { + expected: layout.block_len, + actual: block.len(), + }); + } + } + Ok(()) +} + +/// Ensure matrix has at least the expected dimensions. +/// +/// Matrix envelopes may have more rows and columns than the active level uses. +/// Callers are expected to consume only the active row prefix. +/// +/// # Errors +/// +/// Returns an error if row count is too small or any row is too narrow. +pub(super) fn ensure_matrix_shape_ge( + mat: &FlatMatrix, + expected_rows: usize, + min_cols: usize, + name: &str, +) -> Result<(), HachiError> { + if mat.num_rows() < expected_rows { + return Err(HachiError::InvalidSize { + expected: expected_rows, + actual: mat.num_rows(), + }); + } + let actual_cols = mat.num_cols_at::(); + if actual_cols < min_cols { + return Err(HachiError::InvalidSetup(format!( + "{name} has width {actual_cols}, expected >= {min_cols}", + ))); + } + Ok(()) +} + +/// Small correctness-first config for tests and local benchmarks. +/// +/// Fixed layout (m_vars=4, r_vars=2) for fast test iteration. For larger +/// polynomials, use [`DynamicSmallTestCommitmentConfig`] instead. +#[derive(Clone, Copy, Debug, Default)] +pub struct SmallTestCommitmentConfig; + +impl CommitmentConfig for SmallTestCommitmentConfig { + const D: usize = 16; + const N_A: usize = 8; + const N_B: usize = 4; + const N_D: usize = 4; + const CHALLENGE_WEIGHT: usize = 3; + + fn decomposition() -> DecompositionParams { + DecompositionParams { + log_basis: 3, + log_commit_bound: 32, + log_open_bound: None, + } + } + + fn commitment_layout(_max_num_vars: usize) -> Result { + HachiCommitmentLayout::new::(4, 2, &Self::decomposition()) + } +} + +/// D=16 config with dynamic layout that adapts to polynomial size. +/// +/// Uses the same D=16 ring dimension as [`SmallTestCommitmentConfig`] but +/// derives `m_vars`/`r_vars` from `max_num_vars`, so it can commit +/// polynomials with an arbitrary number of variables. +#[derive(Clone, Copy, Debug, Default)] +pub struct DynamicSmallTestCommitmentConfig; + +impl CommitmentConfig for DynamicSmallTestCommitmentConfig { + const D: usize = 16; + const N_A: usize = 8; + const N_B: usize = 4; + const N_D: usize = 4; + const CHALLENGE_WEIGHT: usize = 3; + + fn decomposition() -> DecompositionParams { + DecompositionParams { + log_basis: 3, + log_commit_bound: 32, + log_open_bound: None, + } + } + + fn commitment_layout(max_num_vars: usize) -> Result { + let alpha = Self::D.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + if reduced_vars == 0 { + return Err(HachiError::InvalidSetup( + "max_num_vars must leave at least one outer variable".to_string(), + )); + } + let (m_vars, r_vars) = optimal_m_r_split::(reduced_vars); + HachiCommitmentLayout::new::(m_vars, r_vars, &Self::decomposition()) + } +} + +/// Production-oriented profile for 128-bit base fields (`Fp128

`), +/// parameterized by the coefficient bound and gadget basis. +/// +/// This profile targets the `D = 256`, `n_A = n_B = n_D = 1` regime with +/// balanced decomposition over ~128-bit moduli. +/// +/// - `LOG_COMMIT_BOUND`: bit-width of the largest polynomial coefficient the +/// commitment decomposition must represent. +/// - `LOG_BASIS`: base-2 log of the gadget base at level 0. +/// - `W_LOG_BASIS`: base-2 log of the gadget base at recursive w-opening +/// levels (levels 1+). A larger w-basis reduces `delta_open` (fewer +/// decomposition digits) at the cost of a higher-degree norm sumcheck at +/// those levels — acceptable because the w-level witness is much smaller. +/// +/// # Aliases +/// +/// - [`Fp128FullCommitmentConfig`] = `<128, 3, 3>` +/// - [`Fp128OneHotCommitmentConfig`] = `<1, 3, 3>` +/// - [`Fp128LogBasisCommitmentConfig`] = `<3, 3, 3>` +/// - [`Fp128CommitmentConfig`] — alias for `Fp128FullCommitmentConfig` +/// +/// # β derivation (stage-1 folded witness `z`) +/// +/// `||z||_inf <= 2^R * ω * (b/2)` where `b = 2^LOG_BASIS`. +#[derive(Clone, Copy, Debug, Default)] +pub struct Fp128BoundedCommitmentConfig< + const LOG_COMMIT_BOUND: u32, + const LOG_BASIS: u32, + const W_LOG_BASIS: u32 = LOG_BASIS, +>; + +impl CommitmentConfig + for Fp128BoundedCommitmentConfig +{ + const D: usize = 256; + const N_A: usize = 1; + const N_B: usize = 1; + const N_D: usize = 1; + const CHALLENGE_WEIGHT: usize = 23; + + fn decomposition() -> DecompositionParams { + DecompositionParams { + log_basis: LOG_BASIS, + log_commit_bound: LOG_COMMIT_BOUND, + log_open_bound: if LOG_COMMIT_BOUND < 128 { + Some(128) + } else { + None + }, + } + } + + fn w_log_basis() -> u32 { + W_LOG_BASIS + } + + fn commitment_layout(max_num_vars: usize) -> Result { + let alpha = Self::D.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + if reduced_vars == 0 { + return Err(HachiError::InvalidSetup( + "max_num_vars must leave at least one outer variable".to_string(), + )); + } + let (m_vars, r_vars) = optimal_m_r_split::(reduced_vars); + HachiCommitmentLayout::new::(m_vars, r_vars, &Self::decomposition()) + } +} + +/// D=64, rank-1 everywhere. +#[derive(Clone, Copy, Debug, Default)] +pub struct Fp128D64BoundedCommitmentConfig< + const LOG_COMMIT_BOUND: u32, + const LOG_BASIS: u32, + const W_LOG_BASIS: u32 = LOG_BASIS, +>; + +impl CommitmentConfig + for Fp128D64BoundedCommitmentConfig +{ + const D: usize = 64; + const N_A: usize = 1; + const N_B: usize = 1; + const N_D: usize = 1; + const CHALLENGE_WEIGHT: usize = D64_CHALLENGE_MASS; + + fn decomposition() -> DecompositionParams { + Fp128BoundedCommitmentConfig::::decomposition() + } + + fn commitment_layout(max_num_vars: usize) -> Result { + let alpha = Self::D.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + if reduced_vars == 0 { + return Err(HachiError::InvalidSetup( + "max_num_vars must leave at least one outer variable".to_string(), + )); + } + let (m_vars, r_vars) = optimal_m_r_split::(reduced_vars); + HachiCommitmentLayout::new::(m_vars, r_vars, &Self::decomposition()) + } + + fn w_log_basis() -> u32 { + W_LOG_BASIS + } + + fn stage1_challenge_config(level_params: HachiLevelParams) -> SparseChallengeConfig { + d64_stage1_challenge_config(level_params) + } + + fn labrador_handoff_threshold() -> usize { + usize::MAX + } +} + +/// Full-field (128-bit) coefficient bound, base-8 decomposition. +pub type Fp128FullCommitmentConfig = Fp128BoundedCommitmentConfig<128, 3>; + +/// Binary (1-bit) D=64 onehot preset with the coarse adaptive outer-rank +/// schedule. +pub type Fp128OneHotCommitmentConfig = Fp128AdaptiveOneHotCommitmentConfig; + +/// Log-basis (3-bit) coefficient bound, base-8 decomposition. +/// +/// For recursive w-openings where entries are already balanced digits. +pub type Fp128LogBasisCommitmentConfig = Fp128BoundedCommitmentConfig<3, 3>; + +/// Alias for [`Fp128FullCommitmentConfig`]. +pub type Fp128CommitmentConfig = Fp128FullCommitmentConfig; + +/// D=64, rank-2 everywhere. +#[derive(Clone, Copy, Debug, Default)] +pub struct Fp128Rank2BoundedCommitmentConfig< + const LOG_COMMIT_BOUND: u32, + const LOG_BASIS: u32, + const W_LOG_BASIS: u32 = LOG_BASIS, +>; + +impl CommitmentConfig + for Fp128Rank2BoundedCommitmentConfig +{ + const D: usize = 64; + const N_A: usize = 1; + const N_B: usize = 2; + const N_D: usize = 2; + const CHALLENGE_WEIGHT: usize = D64_CHALLENGE_MASS; + + fn decomposition() -> DecompositionParams { + Fp128BoundedCommitmentConfig::::decomposition() + } + + fn commitment_layout(max_num_vars: usize) -> Result { + Fp128BoundedCommitmentConfig::::commitment_layout( + max_num_vars, + ) + } + + fn w_log_basis() -> u32 { + W_LOG_BASIS + } + + fn n_b_at_level(_level: usize, _max_num_vars: usize, _current_w_len: usize) -> usize { + 2 + } + + fn n_d_at_level(_level: usize, _max_num_vars: usize, _current_w_len: usize) -> usize { + 2 + } + + fn stage1_challenge_config(level_params: HachiLevelParams) -> SparseChallengeConfig { + d64_stage1_challenge_config(level_params) + } + + fn labrador_handoff_threshold() -> usize { + usize::MAX + } +} + +/// D=64 onehot preset with the coarse adaptive outer-rank schedule from the +/// current local planning note: rank-2 only in the short early window. +#[derive(Clone, Copy, Debug, Default)] +pub struct Fp128AdaptiveOneHotCommitmentConfig; + +impl CommitmentConfig for Fp128AdaptiveOneHotCommitmentConfig { + const D: usize = 64; + const N_A: usize = 1; + const N_B: usize = 2; + const N_D: usize = 2; + const CHALLENGE_WEIGHT: usize = D64_CHALLENGE_MASS; + + fn decomposition() -> DecompositionParams { + DecompositionParams { + log_basis: 3, + log_commit_bound: 1, + log_open_bound: Some(128), + } + } + + fn commitment_layout(max_num_vars: usize) -> Result { + let alpha = Self::D.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + if reduced_vars == 0 { + return Err(HachiError::InvalidSetup( + "max_num_vars must leave at least one outer variable".to_string(), + )); + } + let (m_vars, r_vars) = optimal_m_r_split::(reduced_vars); + HachiCommitmentLayout::new::(m_vars, r_vars, &Self::decomposition()) + } + + fn n_b_at_level(level: usize, max_num_vars: usize, _current_w_len: usize) -> usize { + if max_num_vars >= 44 { + if level <= 1 { + 2 + } else { + 1 + } + } else if max_num_vars >= 38 { + if level == 0 { + 2 + } else { + 1 + } + } else { + 1 + } + } + + fn n_d_at_level(level: usize, max_num_vars: usize, current_w_len: usize) -> usize { + Self::n_b_at_level(level, max_num_vars, current_w_len) + } + + fn stage1_challenge_config(level_params: HachiLevelParams) -> SparseChallengeConfig { + d64_stage1_challenge_config(level_params) + } + + fn labrador_handoff_threshold() -> usize { + usize::MAX + } +} + +/// Halving-D commitment config for Fp128 (D=256 → 128). +/// +/// Uses `d_at_level` and `n_a_at_level` to halve the ring dimension at each +/// fold level while doubling the module rank to maintain D×N_A ≥ 256. +/// Stops halving at D=128, which is the minimum ring dimension +/// for which sparse ternary challenges provide sufficient security. +/// +/// Challenge weights are scaled per ring dimension to maintain ≥128 bits +/// of challenge entropy (log₂(C(D,ω) · 2^ω) ≥ 128): +/// D=256: ω=23 (~131 bits), D=128: ω=31 (~130 bits). +#[derive(Clone, Copy, Debug, Default)] +pub struct Fp128HalvingDCommitmentConfig; + +impl CommitmentConfig for Fp128HalvingDCommitmentConfig { + const D: usize = 256; + const N_A: usize = 1; + const N_B: usize = 1; + const N_D: usize = 1; + const CHALLENGE_WEIGHT: usize = 23; + + fn decomposition() -> DecompositionParams { + DecompositionParams { + log_basis: 3, + log_commit_bound: 128, + log_open_bound: None, + } + } + + fn commitment_layout(max_num_vars: usize) -> Result { + let alpha = Self::D.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + if reduced_vars == 0 { + return Err(HachiError::InvalidSetup( + "max_num_vars must leave at least one outer variable".to_string(), + )); + } + let (m_vars, r_vars) = optimal_m_r_split::(reduced_vars); + HachiCommitmentLayout::new::(m_vars, r_vars, &Self::decomposition()) + } + + fn d_at_level(level: usize, _w_num_vars: usize) -> usize { + match level { + 0 => 256, + _ => 128, + } + } + + fn n_a_at_level(level: usize) -> usize { + match level { + 0 => 1, + _ => 2, + } + } + + fn challenge_weight_for_ring_dim(d: usize) -> usize { + match d { + 256 => 23, + 128 => 31, + _ => panic!("Fp128HalvingDCommitmentConfig: unsupported ring dim {d}"), + } + } +} diff --git a/src/protocol/commitment/mod.rs b/src/protocol/commitment/mod.rs new file mode 100644 index 00000000..b1574a73 --- /dev/null +++ b/src/protocol/commitment/mod.rs @@ -0,0 +1,32 @@ +//! Protocol commitment abstraction layer. + +mod commit; +mod config; +pub mod onehot; +mod schedule; +mod scheme; +pub(crate) mod transcript_append; +mod types; +pub mod utils; + +pub use commit::{ + HachiCommitmentCore, HachiExpandedSetup, HachiProverSetup, HachiSetupSeed, HachiVerifierSetup, +}; +pub use config::optimal_m_r_split; +pub use config::{ + beta_linf_fold_bound, compute_num_digits, compute_num_digits_fold, CommitmentConfig, + DecompositionParams, DynamicSmallTestCommitmentConfig, Fp128AdaptiveOneHotCommitmentConfig, + Fp128BoundedCommitmentConfig, Fp128CommitmentConfig, Fp128D64BoundedCommitmentConfig, + Fp128FullCommitmentConfig, Fp128HalvingDCommitmentConfig, Fp128LogBasisCommitmentConfig, + Fp128OneHotCommitmentConfig, Fp128Rank2BoundedCommitmentConfig, HachiCommitmentLayout, + SmallTestCommitmentConfig, +}; +pub use onehot::{map_onehot_to_sparse_blocks, SparseBlockEntry}; +pub use schedule::{ + hachi_level_layout, hachi_root_level_layout, HachiLevelParams, HachiScheduleInputs, +}; +pub use scheme::{CommitWitness, CommitmentScheme, RingCommitmentScheme}; +pub use transcript_append::AppendToTranscript; +pub use types::{ + DummyProof, HachiCommitment, HachiOpeningClaim, HachiOpeningPoint, RingCommitment, +}; diff --git a/src/protocol/commitment/onehot.rs b/src/protocol/commitment/onehot.rs new file mode 100644 index 00000000..470d3bd4 --- /dev/null +++ b/src/protocol/commitment/onehot.rs @@ -0,0 +1,336 @@ +//! One-hot commitment path for regular one-hot ring elements. +//! +//! Exploits the sparsity of one-hot witnesses (coefficients in {0,1}) to +//! eliminate all inner ring multiplications. The inner Ajtai `t = A * s` +//! reduces to summing selected columns of `A` with negacyclic rotations. + +use std::collections::BTreeMap; + +use crate::algebra::fields::wide::{HasWide, ReduceTo}; +use crate::algebra::ring::{CyclotomicRing, WideCyclotomicRing}; +use crate::error::HachiError; +use crate::protocol::commitment::utils::flat_matrix::RingMatrixView; +use crate::protocol::hachi_poly_ops::OneHotIndex; +use crate::{AdditiveGroup, CanonicalField, FieldCore}; + +/// Describes a nonzero ring element within one block of the commitment layout. +#[derive(Debug, Clone, PartialEq)] +pub struct SparseBlockEntry { + /// Position within the block (0..2^M). + pub pos_in_block: usize, + /// Coefficient indices that are 1 within this ring element. + pub nonzero_coeffs: Vec, +} + +/// Map a regular one-hot witness to sparse ring block entries. +/// +/// - `onehot_k`: chunk size K. The witness has T chunks of K field elements, +/// each chunk containing exactly one 1. +/// - `indices`: length-T slice where `indices[c]` is the hot position in +/// chunk `c` (must be in `[0, K)`). +/// - `r`, `m`: commitment config parameters (2^R blocks of 2^M ring elements). +/// - `D`: ring degree (const generic on caller side, passed as runtime here). +/// +/// Returns one `Vec` per block (outer len = 2^R). +/// +/// # Errors +/// +/// Returns an error if K and D are not "nicely matched" (one must divide +/// the other), if any index is out of range, or if the dimensions don't +/// fill the commitment layout. +pub fn map_onehot_to_sparse_blocks( + onehot_k: usize, + indices: &[Option], + r: usize, + m: usize, + d: usize, +) -> Result>, HachiError> { + if onehot_k == 0 || d == 0 { + return Err(HachiError::InvalidInput( + "onehot_k and D must be nonzero".into(), + )); + } + if !(onehot_k % d == 0 || d % onehot_k == 0) { + return Err(HachiError::InvalidInput(format!( + "K={onehot_k} and D={d} must be nicely matched (one divides the other)" + ))); + } + + let num_chunks = indices.len(); + let total_field_elems = num_chunks + .checked_mul(onehot_k) + .ok_or_else(|| HachiError::InvalidInput("T*K overflow".into()))?; + if total_field_elems % d != 0 { + return Err(HachiError::InvalidInput(format!( + "T*K={total_field_elems} is not divisible by D={d}" + ))); + } + let total_ring_elems = total_field_elems / d; + let num_blocks = 1usize << r; + let block_len = 1usize << m; + if total_ring_elems != num_blocks * block_len { + return Err(HachiError::InvalidSize { + expected: num_blocks * block_len, + actual: total_ring_elems, + }); + } + + let mut ring_elem_map: BTreeMap> = BTreeMap::new(); + for (c, opt) in indices.iter().enumerate() { + let Some(&idx_raw) = opt.as_ref() else { + continue; + }; + let idx = idx_raw.as_usize(); + if idx >= onehot_k { + return Err(HachiError::InvalidInput(format!( + "index {idx} out of range for chunk size K={onehot_k} at position {c}" + ))); + } + let field_pos = c * onehot_k + idx; + let ring_elem_idx = field_pos / d; + let coeff_idx = field_pos % d; + ring_elem_map + .entry(ring_elem_idx) + .or_default() + .push(coeff_idx); + } + + // Sequential block layout matching commit_coeffs: block i = ring elements + // [i*block_len, (i+1)*block_len). + let mut blocks: Vec> = vec![Vec::new(); num_blocks]; + for (ring_elem_idx, nonzero_coeffs) in ring_elem_map { + let block_idx = ring_elem_idx / block_len; + let pos_in_block = ring_elem_idx % block_len; + blocks[block_idx].push(SparseBlockEntry { + pos_in_block, + nonzero_coeffs, + }); + } + + Ok(blocks) +} + +/// Sparse inner Ajtai: compute `t = A * s` for a one-hot block. +/// +/// Instead of materializing the full decomposed vector `s` and doing a dense +/// matvec, we accumulate only the nonzero contributions using fused +/// shift-accumulate (no intermediate temporaries): +/// +/// ```text +/// t[a] += A[a][entry.pos * num_digits] * (X^{k_1} + X^{k_2} + ...) +/// ``` +#[cfg(test)] +#[allow(non_snake_case)] +pub(crate) fn inner_ajtai_onehot_t_only( + A: &[Vec>], + sparse_entries: &[SparseBlockEntry], + _block_len: usize, + num_digits: usize, +) -> Vec> { + let n_a = A.len(); + + let mut t = vec![CyclotomicRing::::zero(); n_a]; + for entry in sparse_entries { + let col = entry.pos_in_block * num_digits; + for a in 0..n_a { + A[a][col].mul_by_monomial_sum_into(&mut t[a], &entry.nonzero_coeffs); + } + } + + t +} + +/// Wide-accumulator variant of [`inner_ajtai_onehot_t_only`]. +/// +/// Accumulates into `WideCyclotomicRing` (carry-free i32 additions), +/// then reduces once at the end. This avoids per-addition modular reduction. +#[allow(non_snake_case)] +pub(crate) fn inner_ajtai_onehot_wide( + A: &RingMatrixView<'_, F, D>, + sparse_entries: &[SparseBlockEntry], + _block_len: usize, + num_digits: usize, +) -> Vec> +where + F: FieldCore + CanonicalField + HasWide, + F::Wide: AdditiveGroup + From + ReduceTo, +{ + let n_a = A.num_rows(); + let mut t_wide = vec![WideCyclotomicRing::::zero(); n_a]; + + for entry in sparse_entries { + let col = entry.pos_in_block * num_digits; + for (a_idx, t_w) in t_wide.iter_mut().enumerate() { + let a_wide = WideCyclotomicRing::from_ring(&A.row(a_idx)[col]); + a_wide.mul_by_monomial_sum_into(t_w, &entry.nonzero_coeffs); + } + } + + t_wide.into_iter().map(|w| w.reduce()).collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::{Fp64, Prime128M8M4M1M0}; + use crate::protocol::commitment::utils::flat_matrix::FlatMatrix; + use rand::rngs::StdRng; + use rand::SeedableRng; + + #[test] + fn map_onehot_k_gt_d() { + // K=16, D=4, T=2 chunks => 32 field elements => 8 ring elements + // R=1 (2 blocks), M=2 (4 per block) => 8 ring elements total + let k = 16; + let d = 4; + let indices: Vec> = vec![Some(3), Some(10)]; + let blocks = map_onehot_to_sparse_blocks(k, &indices, 1, 2, d).unwrap(); + + assert_eq!(blocks.len(), 2); + let total_entries: usize = blocks.iter().map(|b| b.len()).sum(); + assert_eq!(total_entries, 2, "T=2 nonzero ring elements"); + + for block in &blocks { + for entry in block { + assert_eq!(entry.nonzero_coeffs.len(), 1, "K>D => single monomial"); + } + } + } + + #[test] + fn map_onehot_k_eq_d() { + // K=4, D=4, T=4 chunks => 16 field elements => 4 ring elements + // R=1 (2 blocks), M=1 (2 per block) + let k = 4; + let d = 4; + let indices: Vec> = vec![Some(0), Some(2), Some(3), Some(1)]; + let blocks = map_onehot_to_sparse_blocks(k, &indices, 1, 1, d).unwrap(); + + assert_eq!(blocks.len(), 2); + let total_entries: usize = blocks.iter().map(|b| b.len()).sum(); + assert_eq!(total_entries, 4, "K=D => every ring element is nonzero"); + + for block in &blocks { + for entry in block { + assert_eq!(entry.nonzero_coeffs.len(), 1, "K=D => single monomial"); + } + } + } + + #[test] + fn map_onehot_k_lt_d() { + // K=4, D=8, T=8 chunks => 32 field elements => 4 ring elements + // R=1 (2 blocks), M=1 (2 per block) + let k = 4; + let d = 8; + let indices: Vec> = vec![ + Some(0), + Some(2), + Some(3), + Some(1), + Some(0), + Some(0), + Some(3), + Some(3), + ]; + let blocks = map_onehot_to_sparse_blocks(k, &indices, 1, 1, d).unwrap(); + + assert_eq!(blocks.len(), 2); + let total_entries: usize = blocks.iter().map(|b| b.len()).sum(); + assert_eq!(total_entries, 4, "D>K => all ring elements nonzero"); + + for block in &blocks { + for entry in block { + assert_eq!( + entry.nonzero_coeffs.len(), + 2, + "D=2K => 2 nonzero coeffs per ring element" + ); + } + } + } + + #[test] + fn map_onehot_rejects_non_divisible() { + let result = map_onehot_to_sparse_blocks(3, &[Some(0usize), Some(1)], 0, 1, 4); + assert!(result.is_err()); + } + + #[test] + fn wide_matches_reference() { + type F = Fp64<4294967197>; + const D: usize = 64; + + let mut rng = StdRng::seed_from_u64(0xdead_beef); + let n_a = 3; + let block_len = 4; + let num_digits = 5; + let a_matrix: Vec>> = (0..n_a) + .map(|_| { + (0..block_len * num_digits) + .map(|_| CyclotomicRing::random(&mut rng)) + .collect() + }) + .collect(); + + let entries = vec![ + SparseBlockEntry { + pos_in_block: 0, + nonzero_coeffs: vec![1, 7, 15], + }, + SparseBlockEntry { + pos_in_block: 2, + nonzero_coeffs: vec![0, 63], + }, + ]; + + let a_flat = FlatMatrix::from_ring_matrix(&a_matrix); + let a_view = a_flat.view::(); + let ref_result = inner_ajtai_onehot_t_only(&a_matrix, &entries, block_len, num_digits); + let wide_result = inner_ajtai_onehot_wide(&a_view, &entries, block_len, num_digits); + + assert_eq!(ref_result.len(), wide_result.len()); + for (r, w) in ref_result.iter().zip(wide_result.iter()) { + assert_eq!(r, w, "wide result must match reference"); + } + } + + #[test] + fn wide_matches_reference_fp128() { + type F = Prime128M8M4M1M0; + const D: usize = 64; + + let mut rng = StdRng::seed_from_u64(0xcafe_1234); + let n_a = 2; + let block_len = 2; + let num_digits = 3; + let a_matrix: Vec>> = (0..n_a) + .map(|_| { + (0..block_len * num_digits) + .map(|_| CyclotomicRing::random(&mut rng)) + .collect() + }) + .collect(); + + let entries = vec![ + SparseBlockEntry { + pos_in_block: 0, + nonzero_coeffs: vec![0, 5, 32, 63], + }, + SparseBlockEntry { + pos_in_block: 1, + nonzero_coeffs: vec![10], + }, + ]; + + let a_flat = FlatMatrix::from_ring_matrix(&a_matrix); + let a_view = a_flat.view::(); + let ref_result = inner_ajtai_onehot_t_only(&a_matrix, &entries, block_len, num_digits); + let wide_result = inner_ajtai_onehot_wide(&a_view, &entries, block_len, num_digits); + + assert_eq!(ref_result.len(), wide_result.len()); + for (r, w) in ref_result.iter().zip(wide_result.iter()) { + assert_eq!(r, w, "wide result must match reference (Fp128)"); + } + } +} diff --git a/src/protocol/commitment/schedule.rs b/src/protocol/commitment/schedule.rs new file mode 100644 index 00000000..52b8c8fb --- /dev/null +++ b/src/protocol/commitment/schedule.rs @@ -0,0 +1,140 @@ +use super::config::{ + compute_num_digits, compute_num_digits_fold, optimal_m_r_split_with_params, CommitmentConfig, + DecompositionParams, HachiCommitmentLayout, +}; +use crate::algebra::SparseChallengeConfig; +use crate::error::HachiError; + +/// Public inputs that deterministically select one level's active Hachi params. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct HachiScheduleInputs { + /// Root polynomial variable count. + pub max_num_vars: usize, + /// Fold level, where `0` is the original polynomial. + pub level: usize, + /// Current witness length in field elements before this level runs. + pub current_w_len: usize, +} + +/// Runtime source of truth for one Hachi level. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct HachiLevelParams { + /// Ring dimension at this level. + pub d: usize, + /// Gadget base exponent. + pub log_basis: u32, + /// Active inner Ajtai rank. + pub n_a: usize, + /// Active outer commitment rank. + pub n_b: usize, + /// Active D-matrix rank. + pub n_d: usize, + /// Conservative sparse-challenge mass used by folded-norm bounds. + pub challenge_weight: usize, +} + +impl HachiLevelParams { + /// Total number of quotient / relation rows in `M`. + pub fn m_row_count(self) -> usize { + self.n_d + self.n_b + 2 + self.n_a + } +} + +fn with_log_basis(mut decomp: DecompositionParams, log_basis: u32) -> DecompositionParams { + decomp.log_basis = log_basis; + decomp +} + +fn main_level_decomposition( + params: HachiLevelParams, +) -> DecompositionParams { + with_log_basis(Cfg::decomposition(), params.log_basis) +} + +fn recursive_level_decomposition( + params: HachiLevelParams, +) -> DecompositionParams { + let parent = Cfg::decomposition(); + let parent_open = parent.log_open_bound.unwrap_or(parent.log_commit_bound); + DecompositionParams { + log_basis: params.log_basis, + log_commit_bound: params.log_basis, + log_open_bound: Some(parent_open), + } +} + +fn layout_from_params( + m_vars: usize, + r_vars: usize, + params: HachiLevelParams, + decomp: DecompositionParams, +) -> Result { + let depth_commit = compute_num_digits(decomp.log_commit_bound, decomp.log_basis); + let open_bound = decomp.log_open_bound.unwrap_or(decomp.log_commit_bound); + let depth_open = compute_num_digits(open_bound, decomp.log_basis); + let depth_fold = compute_num_digits_fold(r_vars, params.challenge_weight, decomp.log_basis); + HachiCommitmentLayout::new_with_decomp( + m_vars, + r_vars, + params.n_a, + depth_commit, + depth_open, + depth_fold, + decomp.log_basis, + ) +} + +/// Derive the root level's active params and layout. +/// +/// # Errors +/// +/// Returns an error if the root variable split is invalid or overflows. +pub fn hachi_root_level_layout( + max_num_vars: usize, +) -> Result<(HachiLevelParams, HachiCommitmentLayout), HachiError> { + let params = Cfg::level_params(HachiScheduleInputs { + max_num_vars, + level: 0, + current_w_len: 1usize << max_num_vars, + }); + let alpha = params.d.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + HachiError::InvalidSetup("max_num_vars is smaller than alpha".to_string()) + })?; + if reduced_vars == 0 { + return Err(HachiError::InvalidSetup( + "max_num_vars must leave at least one outer variable".to_string(), + )); + } + let decomp = main_level_decomposition::(params); + let (m_vars, r_vars) = optimal_m_r_split_with_params(params, decomp, reduced_vars); + Ok((params, layout_from_params(m_vars, r_vars, params, decomp)?)) +} + +/// Derive a recursive `w`-opening level's active params and layout. +/// +/// # Errors +/// +/// Returns an error if the recursive layout derivation overflows. +pub fn hachi_level_layout( + inputs: HachiScheduleInputs, +) -> Result<(HachiLevelParams, HachiCommitmentLayout), HachiError> { + let params = Cfg::level_params(inputs); + let num_ring_elems = inputs.current_w_len / params.d; + let total = num_ring_elems.next_power_of_two().max(1); + let alpha = params.d.trailing_zeros() as usize; + let reduced_vars = total.trailing_zeros() as usize; + let max_num_vars = reduced_vars + alpha; + let decomp = recursive_level_decomposition::(params); + let (m_vars, r_vars) = optimal_m_r_split_with_params(params, decomp, reduced_vars); + let layout = layout_from_params(m_vars, r_vars, params, decomp)?; + debug_assert_eq!(layout.m_vars + layout.r_vars + alpha, max_num_vars); + Ok((params, layout)) +} + +pub(crate) fn default_stage1_challenge_config(params: HachiLevelParams) -> SparseChallengeConfig { + SparseChallengeConfig { + weight: params.challenge_weight, + nonzero_coeffs: vec![-1, 1], + } +} diff --git a/src/protocol/commitment/scheme.rs b/src/protocol/commitment/scheme.rs new file mode 100644 index 00000000..eeebb313 --- /dev/null +++ b/src/protocol/commitment/scheme.rs @@ -0,0 +1,239 @@ +//! Commitment-scheme trait surface for Hachi protocol code. + +use super::config::{CommitmentConfig, HachiCommitmentLayout}; +use super::transcript_append::AppendToTranscript; +use crate::algebra::CyclotomicRing; +use crate::error::HachiError; +use crate::protocol::hachi_poly_ops::{HachiPolyOps, OneHotIndex}; +use crate::protocol::opening_point::BasisMode; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; + +/// Witness data produced alongside a ring-native commitment. +/// +/// Contains the commitment itself plus `t_hat` (basis-decomposed inner Ajtai +/// output) from the two-layer Ajtai construction (§4.1). The decomposed input +/// vectors `s` are NOT stored; they are recomputed from the polynomial during +/// proving via `HachiPolyOps`. +pub struct CommitWitness { + /// The ring commitment (outer Ajtai output `u = B · t̂`). + pub commitment: C, + /// Per-block basis-decomposed inner Ajtai output vectors as i8 digit planes. + pub t_hat: Vec>, + _marker: std::marker::PhantomData, +} + +impl CommitWitness { + /// Construct a new commit witness. + pub fn new(commitment: C, t_hat: Vec>) -> Self { + Self { + commitment, + t_hat, + _marker: std::marker::PhantomData, + } + } +} + +/// Commitment-scheme interface used by Hachi protocol code. +/// +/// Generic over field `F` and cyclotomic ring degree `D`. +/// Polynomials are provided as `impl HachiPolyOps`. +pub trait CommitmentScheme: Clone + Send + Sync + 'static +where + F: FieldCore + CanonicalField, +{ + /// Prover setup parameters. + type ProverSetup: Clone + Send + Sync; + /// Verifier setup parameters. + type VerifierSetup: Clone + Send + Sync; + /// Commitment object. + type Commitment: Clone + PartialEq + Send + Sync + AppendToTranscript; + /// Evaluation/opening proof object. + type Proof: Clone + Send + Sync; + /// Prover-side hint produced at commitment time. + type CommitHint: Clone + Send + Sync; + + /// Build prover setup for maximum polynomial dimension. + /// + /// # Panics + /// + /// Panics if internal setup fails (programming error, not adversarial input). + fn setup_prover(max_num_vars: usize) -> Self::ProverSetup; + + /// Derive verifier setup from prover setup. + fn setup_verifier(setup: &Self::ProverSetup) -> Self::VerifierSetup; + + /// Commit to one polynomial with a caller-specified layout. + /// + /// The layout's matrix dimensions must not exceed the setup's max dimensions. + /// Callers control `num_digits_commit` via the layout to reduce decomposition + /// depth for polynomials with bounded coefficients (e.g. delta=1 for {0,1}). + /// + /// # Errors + /// + /// Returns an error when setup/parameter constraints are not satisfied. + fn commit>( + poly: &P, + setup: &Self::ProverSetup, + layout: &HachiCommitmentLayout, + ) -> Result<(Self::Commitment, Self::CommitHint), HachiError>; + + /// Produce an opening proof at `opening_point` with a caller-specified layout. + /// + /// The layout must match the one used during commitment. Recursive w-opening + /// levels derive their own layouts internally via `WCommitmentConfig`. + /// + /// `basis` selects the polynomial representation (see [`BasisMode`]). + /// + /// # Errors + /// + /// Returns an error if the opening point is invalid or proof generation fails. + #[allow(clippy::too_many_arguments)] + fn prove, P: HachiPolyOps>( + setup: &Self::ProverSetup, + poly: &P, + opening_point: &[F], + hint: Self::CommitHint, + transcript: &mut T, + commitment: &Self::Commitment, + basis: BasisMode, + layout: &HachiCommitmentLayout, + ) -> Result; + + /// Verify an opening proof with a caller-specified layout. + /// + /// The layout must be reconstructed deterministically by the verifier — + /// never deserialized from the proof. It must match the layout used by the + /// prover for commitment and proving. + /// + /// `basis` must match the mode used by the prover (see [`BasisMode`]). + /// + /// # Errors + /// + /// Returns an error when verification fails. + #[allow(clippy::too_many_arguments)] + fn verify>( + proof: &Self::Proof, + setup: &Self::VerifierSetup, + transcript: &mut T, + opening_point: &[F], + opening: &F, + commitment: &Self::Commitment, + basis: BasisMode, + layout: &HachiCommitmentLayout, + ) -> Result<(), HachiError>; + + /// Protocol identifier. + fn protocol_name() -> &'static [u8]; +} + +/// Ring-native commitment interface for §4.1 implementation work. +pub trait RingCommitmentScheme: Clone + Send + Sync + 'static +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, +{ + /// Prover setup parameters. + type ProverSetup: Clone + Send + Sync; + /// Verifier setup parameters. + type VerifierSetup: Clone + Send + Sync; + /// Ring-native commitment type. + type Commitment: Clone + PartialEq + Send + Sync; + + /// Construct commitment setup for at most `max_num_vars` variables. + /// + /// # Errors + /// + /// Returns an error if dimensions are inconsistent with `Cfg`. + fn setup(max_num_vars: usize) -> Result<(Self::ProverSetup, Self::VerifierSetup), HachiError>; + + /// Read the runtime layout carried by `setup`. + /// + /// # Errors + /// + /// Returns an error when setup metadata is inconsistent. + fn layout(setup: &Self::ProverSetup) -> Result; + + /// Commit to ring blocks arranged as `2^R` vectors of length `2^M`. + /// + /// # Errors + /// + /// Returns an error if block layout mismatches config or commitment fails. + fn commit_ring_blocks( + f_blocks: &[Vec>], + setup: &Self::ProverSetup, + ) -> Result, HachiError>; + + /// Commit to a flat coefficient table `(f_i)_{i∈{0,1}^ℓ}` in ring form. + /// + /// # Errors + /// + /// Returns an error if `f_coeffs.len()` does not match the configured block + /// layout or if the underlying commitment routine fails. + fn commit_coeffs( + f_coeffs: &[CyclotomicRing], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let layout = Self::layout(setup)?; + let num_blocks = layout.num_blocks; + let block_len = layout.block_len; + let expected_len = num_blocks + .checked_mul(block_len) + .ok_or_else(|| HachiError::InvalidSetup("coefficient length overflow".to_string()))?; + if f_coeffs.len() != expected_len { + return Err(HachiError::InvalidSize { + expected: expected_len, + actual: f_coeffs.len(), + }); + } + + let blocks: Vec>> = f_coeffs + .chunks_exact(block_len) + .map(|chunk| chunk.to_vec()) + .collect(); + + Self::commit_ring_blocks(&blocks, setup) + } + + /// Commit to a regular one-hot witness. + /// + /// # Errors + /// + /// Returns an error if dimensions are inconsistent or any index is out + /// of range. + fn commit_onehot( + onehot_k: usize, + indices: &[Option], + setup: &Self::ProverSetup, + ) -> Result, HachiError> { + let num_chunks = indices.len(); + let total_field_elems = num_chunks + .checked_mul(onehot_k) + .ok_or_else(|| HachiError::InvalidInput("T*K overflow".into()))?; + if total_field_elems % D != 0 { + return Err(HachiError::InvalidInput(format!( + "T*K={total_field_elems} is not divisible by D={D}" + ))); + } + + let total_ring_elems = total_field_elems / D; + let mut ring_coeffs = vec![CyclotomicRing::::zero(); total_ring_elems]; + for (c, opt) in indices.iter().enumerate() { + let Some(&idx_raw) = opt.as_ref() else { + continue; + }; + let idx = idx_raw.as_usize(); + if idx >= onehot_k { + return Err(HachiError::InvalidInput(format!( + "index {idx} out of range for chunk size K={onehot_k} at position {c}" + ))); + } + let field_pos = c * onehot_k + idx; + let ring_idx = field_pos / D; + let coeff_idx = field_pos % D; + ring_coeffs[ring_idx].coeffs[coeff_idx] = F::one(); + } + + Self::commit_coeffs(&ring_coeffs, setup) + } +} diff --git a/src/protocol/commitment/transcript_append.rs b/src/protocol/commitment/transcript_append.rs new file mode 100644 index 00000000..6510ccb5 --- /dev/null +++ b/src/protocol/commitment/transcript_append.rs @@ -0,0 +1,13 @@ +//! Traits for appending commitment objects to protocol transcripts. + +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; + +/// Protocol object that can be absorbed into a transcript. +pub trait AppendToTranscript +where + F: FieldCore + CanonicalField, +{ + /// Append this object to a transcript using the provided event label. + fn append_to_transcript>(&self, label: &[u8], transcript: &mut T); +} diff --git a/src/protocol/commitment/types.rs b/src/protocol/commitment/types.rs new file mode 100644 index 00000000..339a1a03 --- /dev/null +++ b/src/protocol/commitment/types.rs @@ -0,0 +1,157 @@ +//! Protocol commitment/opening wrapper types. + +use super::transcript_append::AppendToTranscript; +use crate::algebra::ring::CyclotomicRing; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; +use std::io::{Read, Write}; + +/// A Hachi opening point represented as field coordinates. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiOpeningPoint { + /// Point coordinates used for multilinear opening. + pub r: Vec, +} + +/// A Hachi opening claim `(point, value)`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiOpeningClaim { + /// Opening point. + pub point: HachiOpeningPoint, + /// Claimed value at `point`. + pub value: F, +} + +/// Minimal commitment wrapper used by protocol traits/tests. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct HachiCommitment(pub u128); + +/// Minimal proof wrapper used by protocol trait stubs and tests. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct DummyProof(pub u128); + +impl Valid for HachiCommitment { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl HachiSerialize for HachiCommitment { + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(&mut writer, Compress::No) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 16 + } +} + +impl HachiDeserialize for HachiCommitment { + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let value = u128::deserialize_with_mode(&mut reader, Compress::No, validate)?; + Ok(Self(value)) + } +} + +impl Valid for DummyProof { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl HachiSerialize for DummyProof { + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(&mut writer, Compress::No) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 16 + } +} + +impl HachiDeserialize for DummyProof { + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + validate: Validate, + ) -> Result { + let value = u128::deserialize_with_mode(&mut reader, Compress::No, validate)?; + Ok(Self(value)) + } +} + +impl AppendToTranscript for HachiCommitment +where + F: FieldCore + CanonicalField, +{ + fn append_to_transcript>(&self, label: &[u8], transcript: &mut T) { + transcript.append_serde(label, self); + } +} + +/// Ring-native commitment object `u in R_q^{n_B}` used by §4.1. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct RingCommitment { + /// Outer commitment vector. + pub u: Vec>, +} + +impl Valid for RingCommitment { + fn check(&self) -> Result<(), SerializationError> { + self.u.check() + } +} + +impl HachiSerialize for RingCommitment { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.u.serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.u.serialized_size(compress) + } +} + +impl HachiDeserialize for RingCommitment { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let u = + Vec::>::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { u }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl AppendToTranscript for RingCommitment +where + F: FieldCore + CanonicalField, +{ + fn append_to_transcript>(&self, label: &[u8], transcript: &mut T) { + transcript.append_serde(label, self); + } +} diff --git a/src/protocol/commitment/utils/crt_ntt.rs b/src/protocol/commitment/utils/crt_ntt.rs new file mode 100644 index 00000000..28f3ff14 --- /dev/null +++ b/src/protocol/commitment/utils/crt_ntt.rs @@ -0,0 +1,187 @@ +//! Protocol-facing CRT+NTT parameter dispatch and matrix caching. + +use crate::algebra::ntt::prime::PrimeWidth; +use crate::algebra::ntt::tables::{ + q128_primes, q64_primes, MAX_CRT_RING_DEGREE, Q128_MODULUS, Q128_NUM_PRIMES, Q32_MODULUS, + Q32_NUM_PRIMES, Q32_PRIMES, Q64_MODULUS, Q64_NUM_PRIMES, RING_DEGREE, +}; +use crate::algebra::ring::{CrtNttParamSet, CyclotomicCrtNtt}; +use crate::cfg_into_iter; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::{CanonicalField, FieldCore}; + +use super::flat_matrix::RingMatrixView; +use super::norm::detect_field_modulus; + +/// Supported protocol CRT+NTT parameter families. +#[derive(Clone)] +#[allow(clippy::large_enum_variant)] +pub(crate) enum ProtocolCrtNttParams { + Q32(CrtNttParamSet), + Q64(CrtNttParamSet), + Q128(CrtNttParamSet), +} + +/// Select a CRT+NTT parameter set from field modulus and ring degree. +/// +/// Dispatch policy: +/// - `q <= 2^32-99` and `D <= 64`: Q32 (`i16`) +/// - `q <= 2^64-59` and `D <= 1024`: Q64 (`i32`, conservative K=5) +/// - `q == 2^128-275` and `D <= 1024`: Q128 (`i32`, K=5) +/// - otherwise: explicit setup error +pub(crate) fn select_crt_ntt_params( +) -> Result, HachiError> { + if !D.is_power_of_two() { + return Err(HachiError::InvalidSetup(format!( + "CRT+NTT requires power-of-two ring degree, got D={D}" + ))); + } + if D > MAX_CRT_RING_DEGREE { + return Err(HachiError::InvalidSetup(format!( + "CRT+NTT supports D <= {MAX_CRT_RING_DEGREE}, got D={D}" + ))); + } + + let modulus = detect_field_modulus::(); + + if modulus == Q128_MODULUS { + return Ok(ProtocolCrtNttParams::Q128(CrtNttParamSet::new( + q128_primes(), + ))); + } + + if modulus <= Q32_MODULUS as u128 { + if D <= RING_DEGREE { + return Ok(ProtocolCrtNttParams::Q32(CrtNttParamSet::new(Q32_PRIMES))); + } + return Ok(ProtocolCrtNttParams::Q64(CrtNttParamSet::new(q64_primes()))); + } + + if modulus <= Q64_MODULUS as u128 { + return Ok(ProtocolCrtNttParams::Q64(CrtNttParamSet::new(q64_primes()))); + } + + Err(HachiError::InvalidSetup(format!( + "no CRT+NTT parameter set for modulus {modulus} and D={D}; supported ranges: <= {Q64_MODULUS} (with Q32/Q64 dispatch) or {Q128_MODULUS}" + ))) +} + +/// Pre-converted CRT+NTT cache for a single matrix, keyed by parameter family. +/// +/// Stores both negacyclic (for mat-vec) and cyclic (for quotient) representations +/// to avoid repeated coefficient-to-NTT conversion. +#[derive(Debug, Clone, PartialEq, Eq)] +#[allow(missing_docs, clippy::large_enum_variant)] +pub enum NttSlotCache { + /// 32-bit CRT primes. + Q32 { + neg: Vec>>, + cyc: Vec>>, + params: CrtNttParamSet, + }, + /// 64-bit CRT primes. + Q64 { + neg: Vec>>, + cyc: Vec>>, + params: CrtNttParamSet, + }, + /// 128-bit CRT primes. + Q128 { + neg: Vec>>, + cyc: Vec>>, + params: CrtNttParamSet, + }, +} + +fn convert_mat( + mat: RingMatrixView<'_, F, D>, + params: &CrtNttParamSet, +) -> Vec>> +where + F: FieldCore + CanonicalField, + W: PrimeWidth, +{ + cfg_into_iter!(0..mat.num_rows()) + .map(|i| { + mat.row(i) + .iter() + .map(|a| CyclotomicCrtNtt::from_ring_with_params(a, params)) + .collect() + }) + .collect() +} + +fn convert_mat_cyclic( + mat: RingMatrixView<'_, F, D>, + params: &CrtNttParamSet, +) -> Vec>> +where + F: FieldCore + CanonicalField, + W: PrimeWidth, +{ + cfg_into_iter!(0..mat.num_rows()) + .map(|i| { + mat.row(i) + .iter() + .map(|a| CyclotomicCrtNtt::from_ring_cyclic(a, params)) + .collect() + }) + .collect() +} + +/// Build an NTT slot cache for a single matrix. +/// +/// # Errors +/// +/// Returns an error if no CRT+NTT parameter set matches the field modulus and ring degree. +#[tracing::instrument(skip_all, name = "build_ntt_slot")] +pub fn build_ntt_slot( + mat: RingMatrixView<'_, F, D>, +) -> Result, HachiError> { + let params = select_crt_ntt_params::()?; + Ok(build_ntt_slot_from_params(mat, params)) +} + +fn build_ntt_slot_from_params( + mat: RingMatrixView<'_, F, D>, + params: ProtocolCrtNttParams, +) -> NttSlotCache { + match params { + ProtocolCrtNttParams::Q32(p) => NttSlotCache::Q32 { + neg: convert_mat(mat, &p), + cyc: convert_mat_cyclic(mat, &p), + params: p, + }, + ProtocolCrtNttParams::Q64(p) => NttSlotCache::Q64 { + neg: convert_mat(mat, &p), + cyc: convert_mat_cyclic(mat, &p), + params: p, + }, + ProtocolCrtNttParams::Q128(p) => NttSlotCache::Q128 { + neg: convert_mat(mat, &p), + cyc: convert_mat_cyclic(mat, &p), + params: p, + }, + } +} + +/// Build NTT slot caches for three matrices, computing CRT+NTT parameters once. +/// +/// # Errors +/// +/// Returns an error if no CRT+NTT parameter set matches the field modulus and ring degree. +#[tracing::instrument(skip_all, name = "build_ntt_slots")] +#[allow(non_snake_case)] +pub fn build_ntt_slots( + A: RingMatrixView<'_, F, D>, + B: RingMatrixView<'_, F, D>, + D_mat: RingMatrixView<'_, F, D>, +) -> Result<(NttSlotCache, NttSlotCache, NttSlotCache), HachiError> { + let params = select_crt_ntt_params::()?; + let slot_a = build_ntt_slot_from_params(A, params.clone()); + let slot_b = build_ntt_slot_from_params(B, params.clone()); + let slot_d = build_ntt_slot_from_params(D_mat, params); + Ok((slot_a, slot_b, slot_d)) +} diff --git a/src/protocol/commitment/utils/flat_matrix.rs b/src/protocol/commitment/utils/flat_matrix.rs new file mode 100644 index 00000000..47946a10 --- /dev/null +++ b/src/protocol/commitment/utils/flat_matrix.rs @@ -0,0 +1,426 @@ +//! D-agnostic flat matrix storage with typed ring-element views. +//! +//! [`FlatMatrix`] stores matrix entries as raw field elements, independent of +//! any ring dimension. A [`RingMatrixView`] borrows the flat data and +//! interprets it as `CyclotomicRing` slices, enabling the same +//! underlying matrix to be viewed at different ring dimensions. + +use crate::algebra::CyclotomicRing; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::FieldCore; +use std::io::{Read, Write}; + +/// Row-major matrix of field elements, independent of ring dimension. +/// +/// Each row contains `cols_ring * gen_ring_dim` contiguous field elements, +/// where `cols_ring` is the number of ring elements per row at the dimension +/// (`gen_ring_dim`) used when the matrix was generated. +/// +/// To view with a smaller ring dimension D' (where D' divides `gen_ring_dim`), +/// each row is re-chunked into `cols_ring * gen_ring_dim / D'` ring elements. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FlatMatrix { + data: Vec, + num_rows: usize, + /// Number of ring elements per row at the generation dimension. + cols_ring: usize, + /// Ring dimension used when generating (D_max). + gen_ring_dim: usize, +} + +impl FlatMatrix { + /// Number of rows. + #[inline] + pub fn num_rows(&self) -> usize { + self.num_rows + } + + /// Number of ring-element columns at the generation dimension. + #[inline] + pub fn cols_ring(&self) -> usize { + self.cols_ring + } + + /// Ring dimension used during generation. + #[inline] + pub fn gen_ring_dim(&self) -> usize { + self.gen_ring_dim + } + + /// Number of field elements per row. + #[inline] + pub fn row_field_len(&self) -> usize { + self.cols_ring * self.gen_ring_dim + } + + /// Build from a `Vec>>`, flattening ring elements + /// into contiguous field-element storage. + pub fn from_ring_matrix(mat: &[Vec>]) -> Self { + let num_rows = mat.len(); + let cols_ring = if num_rows > 0 { mat[0].len() } else { 0 }; + let row_len = cols_ring * D; + let mut data = Vec::with_capacity(num_rows * row_len); + for row in mat { + debug_assert_eq!(row.len(), cols_ring); + for ring_elem in row { + data.extend_from_slice(&ring_elem.coeffs); + } + } + Self { + data, + num_rows, + cols_ring, + gen_ring_dim: D, + } + } + + /// Create a typed view at ring dimension D. + /// + /// D must divide `gen_ring_dim`. The view re-chunks each row so that + /// `cols_at_d = cols_ring * gen_ring_dim / D`. + /// + /// # Panics + /// + /// Panics if `D == 0`, D does not divide `gen_ring_dim`, or the matrix is + /// empty with inconsistent metadata. + pub fn view(&self) -> RingMatrixView<'_, F, D> { + assert!(D > 0, "ring dimension must be positive"); + assert!( + self.gen_ring_dim % D == 0, + "D={D} does not divide gen_ring_dim={}", + self.gen_ring_dim + ); + let scale = self.gen_ring_dim / D; + let cols_at_d = self.cols_ring * scale; + RingMatrixView { + data: &self.data, + num_rows: self.num_rows, + num_cols: cols_at_d, + } + } + + /// Borrow the raw field-element data. + #[inline] + pub fn raw_data(&self) -> &[F] { + &self.data + } + + /// Number of ring-element columns when viewed at dimension D. + #[inline] + pub fn num_cols_at(&self) -> usize { + debug_assert!(D > 0 && self.gen_ring_dim % D == 0); + self.cols_ring * (self.gen_ring_dim / D) + } + + /// Borrow a single row as a slice of ring elements at dimension D (zero-copy). + /// + /// # Panics + /// + /// Panics if `row >= num_rows` or D does not divide `gen_ring_dim`. + #[inline] + pub fn row(&self, row: usize) -> &[CyclotomicRing] { + assert!(D > 0 && self.gen_ring_dim % D == 0); + assert!(row < self.num_rows, "row {row} out of bounds"); + let row_field_len = self.cols_ring * self.gen_ring_dim; + let start = row * row_field_len; + let field_slice = &self.data[start..start + row_field_len]; + let num_cols = row_field_len / D; + // SAFETY: CyclotomicRing is #[repr(transparent)] over [F; D]. + unsafe { + std::slice::from_raw_parts( + field_slice.as_ptr() as *const CyclotomicRing, + num_cols, + ) + } + } + + /// Whether the matrix has zero rows. + #[inline] + pub fn is_empty(&self) -> bool { + self.num_rows == 0 + } + + /// Convenience: number of ring-element columns in the first row at dimension D, + /// or 0 if empty. + #[inline] + pub fn first_row_len(&self) -> usize { + if self.is_empty() { + 0 + } else { + self.num_cols_at::() + } + } +} + +impl Valid for FlatMatrix { + fn check(&self) -> Result<(), SerializationError> { + for f in &self.data { + f.check()?; + } + Ok(()) + } +} + +impl HachiSerialize for FlatMatrix { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.num_rows.serialize_with_mode(&mut writer, compress)?; + self.cols_ring.serialize_with_mode(&mut writer, compress)?; + self.gen_ring_dim + .serialize_with_mode(&mut writer, compress)?; + for f in &self.data { + f.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + 3 * std::mem::size_of::() + + self + .data + .iter() + .map(|f| f.serialized_size(compress)) + .sum::() + } +} + +impl HachiDeserialize for FlatMatrix { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let num_rows = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let cols_ring = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let gen_ring_dim = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let total = num_rows * cols_ring * gen_ring_dim; + let mut data = Vec::with_capacity(total); + for _ in 0..total { + data.push(F::deserialize_with_mode(&mut reader, compress, validate)?); + } + let out = Self { + data, + num_rows, + cols_ring, + gen_ring_dim, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Typed read-only view of a [`FlatMatrix`] at a specific ring dimension D. +/// +/// Provides zero-copy access to rows as `&[CyclotomicRing]` by +/// transmuting the underlying `&[F]` slice (safe because `CyclotomicRing` +/// is `#[repr(transparent)]` over `[F; D]`). +#[derive(Debug, Clone, Copy)] +pub struct RingMatrixView<'a, F: FieldCore, const D: usize> { + data: &'a [F], + num_rows: usize, + num_cols: usize, +} + +impl<'a, F: FieldCore, const D: usize> RingMatrixView<'a, F, D> { + /// Number of rows in the view. + #[inline] + pub fn num_rows(&self) -> usize { + self.num_rows + } + + /// Number of ring-element columns per row. + #[inline] + pub fn num_cols(&self) -> usize { + self.num_cols + } + + /// Borrow a single row as a slice of ring elements (zero-copy). + /// + /// # Panics + /// + /// Panics if `row >= num_rows`. + #[inline] + pub fn row(&self, row: usize) -> &'a [CyclotomicRing] { + assert!(row < self.num_rows, "row {row} out of bounds"); + let row_field_len = self.num_cols * D; + let start = row * row_field_len; + let field_slice = &self.data[start..start + row_field_len]; + // SAFETY: CyclotomicRing is #[repr(transparent)] over [F; D], + // so a contiguous &[F] of length num_cols*D has the same layout as + // &[CyclotomicRing] of length num_cols. + unsafe { + std::slice::from_raw_parts( + field_slice.as_ptr() as *const CyclotomicRing, + self.num_cols, + ) + } + } + + /// Iterate over all rows. + pub fn rows(&self) -> impl Iterator]> + '_ { + (0..self.num_rows).map(move |i| self.row(i)) + } + + /// Take a sub-view: first `n_rows` rows, first `n_cols` ring-element columns. + /// + /// This cannot produce a contiguous sub-view because rows are not + /// contiguous after column truncation. Instead it returns a + /// [`SubMatrixView`] that copies on access. + /// + /// # Panics + /// + /// Panics if `n_rows > self.num_rows` or `n_cols > self.num_cols`. + pub fn submatrix(&self, n_rows: usize, n_cols: usize) -> SubMatrixView<'a, F, D> { + assert!(n_rows <= self.num_rows); + assert!(n_cols <= self.num_cols); + SubMatrixView { + parent: *self, + n_rows, + n_cols, + } + } + + /// Collect into the legacy `Vec>>` representation. + pub fn to_vec_vec(&self) -> Vec>> { + (0..self.num_rows).map(|i| self.row(i).to_vec()).collect() + } +} + +/// A non-contiguous sub-view that yields column-truncated rows. +#[derive(Debug, Clone, Copy)] +pub struct SubMatrixView<'a, F: FieldCore, const D: usize> { + parent: RingMatrixView<'a, F, D>, + n_rows: usize, + n_cols: usize, +} + +impl<'a, F: FieldCore, const D: usize> SubMatrixView<'a, F, D> { + /// Number of rows. + #[inline] + pub fn num_rows(&self) -> usize { + self.n_rows + } + + /// Number of ring-element columns. + #[inline] + pub fn num_cols(&self) -> usize { + self.n_cols + } + + /// Borrow a row, truncated to `n_cols` ring elements. + /// + /// # Panics + /// + /// Panics if `row >= n_rows`. + #[inline] + pub fn row(&self, row: usize) -> &'a [CyclotomicRing] { + assert!(row < self.n_rows, "row {row} out of bounds"); + &self.parent.row(row)[..self.n_cols] + } + + /// Iterate over rows. + pub fn rows(&self) -> impl Iterator]> + '_ { + (0..self.n_rows).map(move |i| self.row(i)) + } + + /// Collect into the legacy `Vec>>` representation. + pub fn to_vec_vec(&self) -> Vec>> { + self.rows().map(|r| r.to_vec()).collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Prime128M8M4M1M0; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Prime128M8M4M1M0; + + #[test] + fn roundtrip_from_ring_matrix_and_view() { + let mut rng = StdRng::seed_from_u64(42); + let rows = 3usize; + let cols = 5usize; + let mat: Vec>> = (0..rows) + .map(|_| { + (0..cols) + .map(|_| CyclotomicRing::random(&mut rng)) + .collect() + }) + .collect(); + + let flat = FlatMatrix::from_ring_matrix(&mat); + assert_eq!(flat.num_rows(), rows); + assert_eq!(flat.cols_ring(), cols); + assert_eq!(flat.gen_ring_dim(), 64); + + let view = flat.view::<64>(); + assert_eq!(view.num_rows(), rows); + assert_eq!(view.num_cols(), cols); + + for (i, orig_row) in mat.iter().enumerate() { + let view_row = view.row(i); + assert_eq!(view_row, orig_row.as_slice()); + } + } + + #[test] + fn view_at_smaller_d_rechunks_correctly() { + let mut rng = StdRng::seed_from_u64(99); + let rows = 2usize; + let cols = 4usize; + let mat: Vec>> = (0..rows) + .map(|_| { + (0..cols) + .map(|_| CyclotomicRing::random(&mut rng)) + .collect() + }) + .collect(); + + let flat = FlatMatrix::from_ring_matrix(&mat); + + // View at D=32: each D=64 element becomes 2 D=32 elements + let view32 = flat.view::<32>(); + assert_eq!(view32.num_rows(), rows); + assert_eq!(view32.num_cols(), cols * 2); + + // Verify field elements are the same + for r in 0..rows { + let ring32_row = view32.row(r); + let orig_row = flat.view::<64>().row(r); + for (j, orig_ring) in orig_row.iter().enumerate() { + let lo = &ring32_row[j * 2]; + let hi = &ring32_row[j * 2 + 1]; + assert_eq!(&orig_ring.coeffs[..32], lo.coefficients()); + assert_eq!(&orig_ring.coeffs[32..], hi.coefficients()); + } + } + } + + #[test] + fn submatrix_truncates_correctly() { + let mut rng = StdRng::seed_from_u64(7); + let mat: Vec>> = (0..4) + .map(|_| (0..8).map(|_| CyclotomicRing::random(&mut rng)).collect()) + .collect(); + + let flat = FlatMatrix::from_ring_matrix(&mat); + let view = flat.view::<64>(); + let sub = view.submatrix(2, 5); + + assert_eq!(sub.num_rows(), 2); + assert_eq!(sub.num_cols(), 5); + for (r, row) in mat.iter().enumerate().take(2) { + assert_eq!(sub.row(r), &row[..5]); + } + } +} diff --git a/src/protocol/commitment/utils/linear.rs b/src/protocol/commitment/utils/linear.rs new file mode 100644 index 00000000..3102fde6 --- /dev/null +++ b/src/protocol/commitment/utils/linear.rs @@ -0,0 +1,1505 @@ +//! Linear algebra helpers for ring commitment. + +#[cfg(target_arch = "aarch64")] +use crate::algebra::ntt::neon; +use crate::algebra::ntt::{MontCoeff, PrimeWidth}; +use crate::algebra::{ + CenteredMontLut, CrtNttParamSet, CyclotomicCrtNtt, CyclotomicRing, DigitMontLut, +}; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::{cfg_fold_reduce, cfg_into_iter, cfg_iter}; +use crate::{CanonicalField, FieldCore}; +use std::array::from_fn; +use std::mem::size_of; + +use super::crt_ntt::NttSlotCache; +use super::crt_ntt::{select_crt_ntt_params, ProtocolCrtNttParams}; + +#[inline(always)] +fn try_centered_i8(coeff: F, q: u128, half_q: u128) -> Option { + let canonical = coeff.to_canonical_u128(); + let centered = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + if (i8::MIN as i128..=i8::MAX as i128).contains(¢ered) { + Some(centered as i8) + } else { + None + } +} + +const FAST_I8_DIGIT_MIN: i8 = -8; +const FAST_I8_DIGIT_MAX: i8 = 7; + +pub(crate) fn try_centered_i8_cache_from_ring_coeffs( + coeffs: &[CyclotomicRing], +) -> Option> { + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + let mut out = Vec::with_capacity(coeffs.len()); + + for ring in coeffs { + let mut digits = [0i8; D]; + for (dst, coeff) in digits.iter_mut().zip(ring.coeffs.iter()) { + let centered = try_centered_i8(*coeff, q, half_q)?; + // The small-digit CRT+NTT fast path uses a fixed LUT for [-8, 7]. + // Larger centered coefficients must fall back to the generic path. + if !(FAST_I8_DIGIT_MIN..=FAST_I8_DIGIT_MAX).contains(¢ered) { + return None; + } + *dst = centered; + } + out.push(digits); + } + + Some(out) +} + +#[cfg(test)] +pub(crate) fn mat_vec_mul_unchecked( + mat: &[Vec>], + vec: &[CyclotomicRing], +) -> Vec> { + let mut out = Vec::with_capacity(mat.len()); + for row in mat { + debug_assert_eq!(row.len(), vec.len()); + let mut acc = CyclotomicRing::::zero(); + for (a, x) in row.iter().zip(vec.iter()) { + acc += *a * *x; + } + out.push(acc); + } + out +} + +#[inline] +fn accumulate_pointwise_product_into( + acc: &mut CyclotomicCrtNtt, + lhs: &CyclotomicCrtNtt, + rhs: &CyclotomicCrtNtt, + params: &CrtNttParamSet, +) { + acc.add_assign_pointwise_mul_with_params(lhs, rhs, params); +} + +fn precompute_dense_mat_ntt_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + mat: &[Vec>], + params: &CrtNttParamSet, +) -> Vec>> { + cfg_iter!(mat) + .map(|row| { + row.iter() + .map(|a| CyclotomicCrtNtt::from_ring_with_params(a, params)) + .collect() + }) + .collect() +} + +fn mat_vec_mul_dense_i8_many_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + mat: &[Vec>], + vecs: &[Vec<[i8; D]>], + params: &CrtNttParamSet, +) -> Vec>> { + let ntt_mat = precompute_dense_mat_ntt_with_params(mat, params); + let blocks: Vec<&[[i8; D]]> = vecs.iter().map(Vec::as_slice).collect(); + mat_vec_mul_digits_i8_with_params(&ntt_mat, &blocks, params) +} + +#[cfg(test)] +fn mat_vec_mul_dense_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + mat: &[Vec>], + vec: &[CyclotomicRing], + params: &CrtNttParamSet, +) -> Vec> { + let ntt_vec: Vec> = vec + .iter() + .map(|v| CyclotomicCrtNtt::from_ring_with_params(v, params)) + .collect(); + + mat.iter() + .map(|row| { + debug_assert_eq!(row.len(), ntt_vec.len()); + let mut acc = CyclotomicCrtNtt::::zero(); + for (a, x_ntt) in row.iter().zip(ntt_vec.iter()) { + let a_ntt = CyclotomicCrtNtt::from_ring_with_params(a, params); + accumulate_pointwise_product_into(&mut acc, &a_ntt, x_ntt, params); + } + acc.to_ring_with_params(params) + }) + .collect() +} + +#[cfg(test)] +fn mat_vec_mul_dense_many_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + mat: &[Vec>], + vecs: &[Vec>], + params: &CrtNttParamSet, +) -> Vec>> { + let ntt_mat = precompute_dense_mat_ntt_with_params(mat, params); + vecs.iter() + .map(|vec| { + let ntt_vec: Vec> = vec + .iter() + .map(|v| CyclotomicCrtNtt::from_ring_with_params(v, params)) + .collect(); + + ntt_mat + .iter() + .map(|row_ntt| { + debug_assert_eq!(row_ntt.len(), ntt_vec.len()); + let mut acc = CyclotomicCrtNtt::::zero(); + for (a_ntt, x_ntt) in row_ntt.iter().zip(ntt_vec.iter()) { + accumulate_pointwise_product_into(&mut acc, a_ntt, x_ntt, params); + } + acc.to_ring_with_params(params) + }) + .collect() + }) + .collect() +} + +#[cfg(test)] +pub(crate) fn mat_vec_mul_crt_ntt( + mat: &[Vec>], + vec: &[CyclotomicRing], +) -> Result>, HachiError> { + let params = select_crt_ntt_params::()?; + let out = match ¶ms { + ProtocolCrtNttParams::Q32(p) => mat_vec_mul_dense_with_params(mat, vec, p), + ProtocolCrtNttParams::Q64(p) => mat_vec_mul_dense_with_params(mat, vec, p), + ProtocolCrtNttParams::Q128(p) => mat_vec_mul_dense_with_params(mat, vec, p), + }; + Ok(out) +} + +#[cfg(test)] +pub(crate) fn mat_vec_mul_crt_ntt_many( + mat: &[Vec>], + vecs: &[Vec>], +) -> Result>>, HachiError> { + let params = select_crt_ntt_params::()?; + let out = match ¶ms { + ProtocolCrtNttParams::Q32(p) => mat_vec_mul_dense_many_with_params(mat, vecs, p), + ProtocolCrtNttParams::Q64(p) => mat_vec_mul_dense_many_with_params(mat, vecs, p), + ProtocolCrtNttParams::Q128(p) => mat_vec_mul_dense_many_with_params(mat, vecs, p), + }; + Ok(out) +} + +#[tracing::instrument(skip_all, name = "mat_vec_mul_crt_ntt_i8_many")] +pub(crate) fn mat_vec_mul_crt_ntt_i8_many( + mat: &[Vec>], + vecs: &[Vec<[i8; D]>], +) -> Result>>, HachiError> { + let params = select_crt_ntt_params::()?; + let out = match ¶ms { + ProtocolCrtNttParams::Q32(p) => mat_vec_mul_dense_i8_many_with_params(mat, vecs, p), + ProtocolCrtNttParams::Q64(p) => mat_vec_mul_dense_i8_many_with_params(mat, vecs, p), + ProtocolCrtNttParams::Q128(p) => mat_vec_mul_dense_i8_many_with_params(mat, vecs, p), + }; + Ok(out) +} + +fn mat_vec_mul_i8_single_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + row: &[CyclotomicRing], + vec: &[[i8; D]], + params: &CrtNttParamSet, +) -> CyclotomicRing { + let width = row.len(); + if width == 0 { + return CyclotomicRing::::zero(); + } + + let ntt_row: Vec> = cfg_iter!(row) + .map(|a| CyclotomicCrtNtt::from_ring_with_params(a, params)) + .collect(); + let lut = DigitMontLut::new(params); + let tw = (TARGET_L2_CACHE_BYTES / (K * D * size_of::())).max(1); + let num_tiles = width.div_ceil(tw); + + let final_acc: CyclotomicCrtNtt = cfg_fold_reduce!( + 0..num_tiles, + || CyclotomicCrtNtt::::zero(), + |mut acc: CyclotomicCrtNtt, tile_idx| { + let tile_start = tile_idx * tw; + let tile_end = (tile_start + tw).min(width); + for (j, digit) in vec[tile_start..tile_end].iter().enumerate() { + if is_zero_plane(digit) { + continue; + } + let ntt_d = CyclotomicCrtNtt::from_i8_with_lut(digit, params, &lut); + accumulate_pointwise_product_into( + &mut acc, + &ntt_row[tile_start + j], + &ntt_d, + params, + ); + } + acc + }, + |mut a: CyclotomicCrtNtt, b| { + add_ntt_into(&mut a, &b, params); + a + } + ); + final_acc.to_ring_with_params(params) +} + +#[tracing::instrument(skip_all, name = "mat_vec_mul_crt_ntt_i8_single")] +pub(crate) fn mat_vec_mul_crt_ntt_i8_single( + row: &[CyclotomicRing], + vec: &[[i8; D]], +) -> Result, HachiError> { + if row.len() != vec.len() { + return Err(HachiError::InvalidInput( + "single i8 mat-vec requires equal row/vector width".to_string(), + )); + } + let params = select_crt_ntt_params::()?; + let out = match ¶ms { + ProtocolCrtNttParams::Q32(p) => mat_vec_mul_i8_single_with_params(row, vec, p), + ProtocolCrtNttParams::Q64(p) => mat_vec_mul_i8_single_with_params(row, vec, p), + ProtocolCrtNttParams::Q128(p) => mat_vec_mul_i8_single_with_params(row, vec, p), + }; + Ok(out) +} + +fn mat_vec_mul_i8_labrador_cross_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + mat: &[Vec>], + vecs: &[Vec<[i8; D]>], + params: &CrtNttParamSet, +) -> Vec> { + let r = mat.len(); + let pair_count = r * (r + 1) / 2; + if r == 0 { + return Vec::new(); + } + let width = mat.first().map_or(0, Vec::len); + if width == 0 { + return vec![CyclotomicRing::::zero(); pair_count]; + } + + let ntt_mat = precompute_dense_mat_ntt_with_params(mat, params); + let lut = DigitMontLut::new(params); + let vecs_ntt: Vec>> = cfg_iter!(vecs) + .map(|vec| { + vec.iter() + .map(|digit| CyclotomicCrtNtt::from_i8_with_lut(digit, params, &lut)) + .collect() + }) + .collect(); + + let pairs: Vec<(usize, usize)> = (0..r).flat_map(|i| (i..r).map(move |j| (i, j))).collect(); + cfg_into_iter!(pairs) + .map(|(i, j)| { + let mut acc_ij = CyclotomicCrtNtt::::zero(); + if i == j { + for col in 0..width { + accumulate_pointwise_product_into( + &mut acc_ij, + &ntt_mat[i][col], + &vecs_ntt[i][col], + params, + ); + } + acc_ij.to_ring_with_params(params) + } else { + let mut acc_ji = CyclotomicCrtNtt::::zero(); + for col in 0..width { + accumulate_pointwise_product_into( + &mut acc_ij, + &ntt_mat[i][col], + &vecs_ntt[j][col], + params, + ); + accumulate_pointwise_product_into( + &mut acc_ji, + &ntt_mat[j][col], + &vecs_ntt[i][col], + params, + ); + } + let mut out = acc_ij.to_ring_with_params(params); + out += acc_ji.to_ring_with_params(params); + out + } + }) + .collect() +} + +#[tracing::instrument(skip_all, name = "mat_vec_mul_crt_ntt_i8_labrador_cross")] +pub(crate) fn mat_vec_mul_crt_ntt_i8_labrador_cross< + F: FieldCore + CanonicalField, + const D: usize, +>( + mat: &[Vec>], + vecs: &[Vec<[i8; D]>], +) -> Result>, HachiError> { + if mat.len() != vecs.len() { + return Err(HachiError::InvalidInput( + "labrador cross expects mat rows == vec block count".to_string(), + )); + } + let width = mat.first().map_or(0, Vec::len); + if mat.iter().any(|row| row.len() != width) { + return Err(HachiError::InvalidInput( + "labrador cross requires rectangular matrix".to_string(), + )); + } + if vecs.iter().any(|vec| vec.len() != width) { + return Err(HachiError::InvalidInput( + "labrador cross requires vec widths to match matrix width".to_string(), + )); + } + + let params = select_crt_ntt_params::()?; + let out = match ¶ms { + ProtocolCrtNttParams::Q32(p) => mat_vec_mul_i8_labrador_cross_with_params(mat, vecs, p), + ProtocolCrtNttParams::Q64(p) => mat_vec_mul_i8_labrador_cross_with_params(mat, vecs, p), + ProtocolCrtNttParams::Q128(p) => mat_vec_mul_i8_labrador_cross_with_params(mat, vecs, p), + }; + Ok(out) +} + +fn unreduced_quotient_ntt( + ntt_row: &[CyclotomicCrtNtt], + cyc_row: &[CyclotomicCrtNtt], + vec_neg: &[CyclotomicCrtNtt], + vec_cyc: &[CyclotomicCrtNtt], + params: &CrtNttParamSet, +) -> CyclotomicRing +where + F: FieldCore + CanonicalField, + W: PrimeWidth, +{ + let n = ntt_row.len().min(vec_neg.len()); + + let mut acc_neg = CyclotomicCrtNtt::::zero(); + let mut acc_cyc = CyclotomicCrtNtt::::zero(); + + for j in 0..n { + accumulate_pointwise_product_into(&mut acc_neg, &ntt_row[j], &vec_neg[j], params); + accumulate_pointwise_product_into(&mut acc_cyc, &cyc_row[j], &vec_cyc[j], params); + } + + let neg_ring: CyclotomicRing = acc_neg.to_ring_with_params(params); + let cyc_ring: CyclotomicRing = acc_cyc.to_ring_cyclic(params); + + let neg_coeffs = neg_ring.coefficients(); + let cyc_coeffs = cyc_ring.coefficients(); + let quotient: [F; D] = from_fn(|k| (cyc_coeffs[k] - neg_coeffs[k]) * F::TWO_INV); + CyclotomicRing::from_coefficients(quotient) +} + +macro_rules! dispatch_slot_quotient { + ($slot:expr, $vec:expr, $convert_neg:ident, $convert_cyc:ident, $quotient_fn:ident) => {{ + match $slot { + NttSlotCache::Q32 { + neg, + cyc, + params: p, + } => { + let v = $vec; + let n = neg.first().map_or(0, |r| r.len().min(v.len())); + let v_neg: Vec<_> = cfg_iter!(v[..n]) + .map(|x| CyclotomicCrtNtt::$convert_neg(x, p)) + .collect(); + let v_cyc: Vec<_> = cfg_iter!(v[..n]) + .map(|x| CyclotomicCrtNtt::$convert_cyc(x, p)) + .collect(); + cfg_into_iter!(0..neg.len()) + .map(|i| $quotient_fn(&neg[i], &cyc[i], &v_neg, &v_cyc, p)) + .collect() + } + NttSlotCache::Q64 { + neg, + cyc, + params: p, + } => { + let v = $vec; + let n = neg.first().map_or(0, |r| r.len().min(v.len())); + let v_neg: Vec<_> = cfg_iter!(v[..n]) + .map(|x| CyclotomicCrtNtt::$convert_neg(x, p)) + .collect(); + let v_cyc: Vec<_> = cfg_iter!(v[..n]) + .map(|x| CyclotomicCrtNtt::$convert_cyc(x, p)) + .collect(); + cfg_into_iter!(0..neg.len()) + .map(|i| $quotient_fn(&neg[i], &cyc[i], &v_neg, &v_cyc, p)) + .collect() + } + NttSlotCache::Q128 { + neg, + cyc, + params: p, + } => { + let v = $vec; + let n = neg.first().map_or(0, |r| r.len().min(v.len())); + let v_neg: Vec<_> = cfg_iter!(v[..n]) + .map(|x| CyclotomicCrtNtt::$convert_neg(x, p)) + .collect(); + let v_cyc: Vec<_> = cfg_iter!(v[..n]) + .map(|x| CyclotomicCrtNtt::$convert_cyc(x, p)) + .collect(); + cfg_into_iter!(0..neg.len()) + .map(|i| $quotient_fn(&neg[i], &cyc[i], &v_neg, &v_cyc, p)) + .collect() + } + } + }}; +} + +/// Compute unreduced quotients for matrix rows against a witness vector. +/// +/// For each row: `r_i = high_part(sum_j row_ij * vec_j) = (cyc - neg) / 2`. +/// Vec NTT conversions and matrix cyclic NTT are precomputed once (not per-row). +pub fn unreduced_quotient_rows_ntt_cached( + slot: &NttSlotCache, + vec: &[CyclotomicRing], +) -> Vec> { + dispatch_slot_quotient!( + slot, + vec, + from_ring_with_params, + from_ring_cyclic, + unreduced_quotient_ntt + ) +} + +/// Like [`unreduced_quotient_rows_ntt_cached`] but accepts centered i32 +/// coefficient rows instead of field-backed ring elements. +#[tracing::instrument(skip_all, name = "unreduced_quotient_rows_ntt_cached_centered_i32")] +pub fn unreduced_quotient_rows_ntt_cached_centered_i32< + F: FieldCore + CanonicalField, + const D: usize, +>( + slot: &NttSlotCache, + vec: &[[i32; D]], + max_abs: u32, +) -> Vec> { + match slot { + NttSlotCache::Q32 { + neg, + cyc, + params: p, + } => quotient_single_centered_i32_with_params(neg, cyc, vec, max_abs, p), + NttSlotCache::Q64 { + neg, + cyc, + params: p, + } => quotient_single_centered_i32_with_params(neg, cyc, vec, max_abs, p), + NttSlotCache::Q128 { + neg, + cyc, + params: p, + } => quotient_single_centered_i32_with_params(neg, cyc, vec, max_abs, p), + } +} + +macro_rules! dispatch_slot { + ($slot:expr, $func:ident $(, $arg:expr)*) => {{ + match $slot { + NttSlotCache::Q32 { neg, params: p, .. } => $func(neg, $($arg,)* p), + NttSlotCache::Q64 { neg, params: p, .. } => $func(neg, $($arg,)* p), + NttSlotCache::Q128 { neg, params: p, .. } => $func(neg, $($arg,)* p), + } + }}; +} + +/// Flatten a nested `Vec>` into a contiguous `Vec<[i8; D]>` using +/// bulk memcpy per block, avoiding element-by-element iteration. +pub fn flatten_i8_blocks(blocks: &[Vec<[i8; D]>]) -> Vec<[i8; D]> { + let total: usize = blocks.iter().map(|b| b.len()).sum(); + let mut flat = Vec::with_capacity(total); + for block in blocks { + flat.extend_from_slice(block); + } + flat +} + +/// Basis-decompose a block of ring elements into `block.len() * num_digits` gadget components. +pub fn decompose_block( + block: &[CyclotomicRing], + num_digits: usize, + log_basis: u32, +) -> Vec> { + let mut out = vec![CyclotomicRing::::zero(); block.len() * num_digits]; + for (i, coeff_vec) in block.iter().enumerate() { + coeff_vec.balanced_decompose_pow2_into( + &mut out[i * num_digits..(i + 1) * num_digits], + log_basis, + ); + } + out +} + +/// Decompose each ring element in `rows` into `num_digits` gadget components. +pub fn decompose_rows( + rows: &[CyclotomicRing], + num_digits: usize, + log_basis: u32, +) -> Vec> { + let mut out = vec![CyclotomicRing::::zero(); rows.len() * num_digits]; + for (i, row) in rows.iter().enumerate() { + row.balanced_decompose_pow2_into(&mut out[i * num_digits..(i + 1) * num_digits], log_basis); + } + out +} + +/// Decompose each ring element where the last digit carries the remainder. +/// +/// # Panics +/// +/// Panics if `delta == 0`. +pub fn decompose_rows_with_carry( + rows: &[CyclotomicRing], + delta: usize, + log_basis: u32, +) -> Vec> { + if rows.is_empty() { + return Vec::new(); + } + assert!(delta > 0, "levels must be positive"); + + let mut out = vec![CyclotomicRing::::zero(); rows.len() * delta]; + + #[cfg(feature = "parallel")] + out.par_chunks_mut(delta) + .zip(rows.par_iter()) + .for_each(|(dst_chunk, row)| { + row.balanced_decompose_pow2_with_carry_into(dst_chunk, log_basis) + }); + + #[cfg(not(feature = "parallel"))] + out.chunks_mut(delta) + .zip(rows.iter()) + .for_each(|(dst_chunk, row)| { + row.balanced_decompose_pow2_with_carry_into(dst_chunk, log_basis) + }); + + out +} + +/// Like [`decompose_block`] but outputs `[i8; D]` digit planes instead of ring elements. +pub fn decompose_block_i8( + block: &[CyclotomicRing], + num_digits: usize, + log_basis: u32, +) -> Vec<[i8; D]> { + let mut out = Vec::with_capacity(block.len() * num_digits); + for coeff_vec in block { + out.extend(coeff_vec.balanced_decompose_pow2_i8(num_digits, log_basis)); + } + out +} + +/// Like [`decompose_rows`] but outputs `[i8; D]` digit planes instead of ring elements. +pub fn decompose_rows_i8( + rows: &[CyclotomicRing], + num_digits: usize, + log_basis: u32, +) -> Vec<[i8; D]> { + let mut out = Vec::with_capacity(rows.len() * num_digits); + for row in rows { + out.extend(row.balanced_decompose_pow2_i8(num_digits, log_basis)); + } + out +} + +#[inline] +fn is_zero_plane(plane: &[i8; D]) -> bool { + plane.iter().all(|&d| d == 0) +} + +#[inline] +fn is_zero_centered_row(row: &[i32; D]) -> bool { + row.iter().all(|&d| d == 0) +} + +#[cfg(target_arch = "aarch64")] +const TARGET_L2_CACHE_BYTES: usize = 4 * 1024 * 1024; +#[cfg(target_arch = "x86_64")] +const TARGET_L2_CACHE_BYTES: usize = 1024 * 1024; +#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))] +const TARGET_L2_CACHE_BYTES: usize = 1024 * 1024; +const CENTERED_LUT_MAX_ABS: u32 = (1 << 16) - 1; + +#[inline] +#[allow(dead_code)] +fn add_ntt_into( + acc: &mut CyclotomicCrtNtt, + other: &CyclotomicCrtNtt, + params: &CrtNttParamSet, +) { + #[cfg(target_arch = "aarch64")] + if neon::use_neon_ntt() { + for k in 0..K { + let prime = params.primes[k]; + unsafe { + if size_of::() == size_of::() { + neon::add_reduce_i32( + acc.limbs[k].as_mut_ptr() as *mut i32, + other.limbs[k].as_ptr() as *const i32, + D, + prime.p.to_i64() as i32, + ); + } else { + neon::add_reduce_i16( + acc.limbs[k].as_mut_ptr() as *mut i16, + other.limbs[k].as_ptr() as *const i16, + D, + prime.p.to_i64() as i16, + ); + } + } + } + return; + } + + for k in 0..K { + let prime = params.primes[k]; + for d in 0..D { + let sum = + MontCoeff::from_raw(acc.limbs[k][d].raw().wrapping_add(other.limbs[k][d].raw())); + acc.limbs[k][d] = prime.reduce_range(sum); + } + } +} + +/// Column-tiled A*x across multiple blocks simultaneously. +/// +/// Each rayon thread owns one column tile of `ntt_mat` (sized to fit in L2 +/// cache) and iterates over all blocks, accumulating partial NTT results. +/// The matrix is loaded from DRAM exactly once. A final reduction sums +/// partial accumulators across tiles for each block. +/// +/// Accepts raw ring-coefficient slices per block. Decomposes to i8 digits +/// on-the-fly per tile to avoid materializing all digits at once. +/// Tile width is auto-computed from ring parameters and target L2 cache size. +#[tracing::instrument(skip_all, name = "mat_vec_mul_ntt_i8")] +pub fn mat_vec_mul_ntt_i8( + slot: &NttSlotCache, + blocks: &[&[CyclotomicRing]], + num_digits: usize, + log_basis: u32, +) -> Vec>> { + dispatch_slot!( + slot, + mat_vec_mul_i8_with_params, + blocks, + num_digits, + log_basis + ) +} + +/// Column-tiled A*x across multiple blocks of pre-decomposed i8 digit planes. +/// +/// This is the `num_digits_commit = 1` specialization of +/// [`mat_vec_mul_ntt_i8`]. It skips the `CyclotomicRing -> i8 digit plane` +/// decomposition entirely because the caller already holds each coefficient as a +/// balanced digit plane. +#[tracing::instrument(skip_all, name = "mat_vec_mul_ntt_digits_i8")] +pub fn mat_vec_mul_ntt_digits_i8( + slot: &NttSlotCache, + blocks: &[&[[i8; D]]], +) -> Vec>> { + dispatch_slot!(slot, mat_vec_mul_digits_i8_with_params, blocks) +} + +fn mat_vec_mul_digits_i8_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + ntt_mat: &[Vec>], + blocks: &[&[[i8; D]]], + params: &CrtNttParamSet, +) -> Vec>> { + let num_blocks = blocks.len(); + if num_blocks == 0 { + return vec![]; + } + let n_a = ntt_mat.len(); + let inner_width = ntt_mat.first().map_or(0, |row| row.len()); + if inner_width == 0 || n_a == 0 { + return vec![vec![CyclotomicRing::::zero(); n_a]; num_blocks]; + } + + let lut = DigitMontLut::new(params); + let tw = (TARGET_L2_CACHE_BYTES / (K * D * size_of::())).max(1); + let num_tiles = inner_width.div_ceil(tw); + + let final_accs: Vec>> = cfg_fold_reduce!( + 0..num_tiles, + || vec![vec![CyclotomicCrtNtt::::zero(); n_a]; num_blocks], + |mut accs: Vec>>, tile_idx| { + let tile_start = tile_idx * tw; + let tile_end = (tile_start + tw).min(inner_width); + + for block_idx in 0..num_blocks { + let block = blocks[block_idx]; + if tile_start >= block.len() { + continue; + } + let block_tile_end = tile_end.min(block.len()); + for (j, digit) in block[tile_start..block_tile_end].iter().enumerate() { + if is_zero_plane(digit) { + continue; + } + let ntt_d = CyclotomicCrtNtt::from_i8_with_lut(digit, params, &lut); + for (acc, mat_row) in accs[block_idx].iter_mut().zip(ntt_mat.iter()) { + accumulate_pointwise_product_into( + acc, + &mat_row[tile_start + j], + &ntt_d, + params, + ); + } + } + } + accs + }, + |mut a: Vec>>, b| { + for block_idx in 0..num_blocks { + for row in 0..n_a { + add_ntt_into(&mut a[block_idx][row], &b[block_idx][row], params); + } + } + a + } + ); + + cfg_into_iter!(final_accs) + .map(|row_accs| { + row_accs + .into_iter() + .map(|acc| acc.to_ring_with_params(params)) + .collect() + }) + .collect() +} + +fn mat_vec_mul_i8_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + ntt_mat: &[Vec>], + blocks: &[&[CyclotomicRing]], + num_digits: usize, + log_basis: u32, + params: &CrtNttParamSet, +) -> Vec>> { + let num_blocks = blocks.len(); + if num_blocks == 0 { + return vec![]; + } + let n_a = ntt_mat.len(); + let inner_width = ntt_mat.first().map_or(0, |row| row.len()); + if inner_width == 0 || n_a == 0 { + return vec![vec![CyclotomicRing::::zero(); n_a]; num_blocks]; + } + + let lut = DigitMontLut::new(params); + let tw = (TARGET_L2_CACHE_BYTES / (K * D * size_of::())).max(1); + let num_tiles = inner_width.div_ceil(tw); + + let final_accs: Vec>> = cfg_fold_reduce!( + 0..num_tiles, + || vec![vec![CyclotomicCrtNtt::::zero(); n_a]; num_blocks], + |mut accs: Vec>>, tile_idx| { + let tile_start = tile_idx * tw; + let tile_end = (tile_start + tw).min(inner_width); + let ring_start = tile_start / num_digits; + let ring_end = ((tile_end - 1) / num_digits) + 1; + let digit_offset = tile_start - ring_start * num_digits; + let tile_len = tile_end - tile_start; + + for block_idx in 0..num_blocks { + let block = blocks[block_idx]; + if ring_start >= block.len() { + continue; + } + let block_ring_end = ring_end.min(block.len()); + let partial_coeffs = &block[ring_start..block_ring_end]; + let all_digits = decompose_block_i8(partial_coeffs, num_digits, log_basis); + let available = all_digits.len().saturating_sub(digit_offset); + let n = tile_len.min(available); + + for (j, digit) in all_digits[digit_offset..digit_offset + n] + .iter() + .enumerate() + { + if is_zero_plane(digit) { + continue; + } + let ntt_d = CyclotomicCrtNtt::from_i8_with_lut(digit, params, &lut); + for (acc, mat_row) in accs[block_idx].iter_mut().zip(ntt_mat.iter()) { + accumulate_pointwise_product_into( + acc, + &mat_row[tile_start + j], + &ntt_d, + params, + ); + } + } + } + accs + }, + |mut a: Vec>>, b| { + for block_idx in 0..num_blocks { + for row in 0..n_a { + add_ntt_into(&mut a[block_idx][row], &b[block_idx][row], params); + } + } + a + } + ); + + cfg_into_iter!(final_accs) + .map(|row_accs| { + row_accs + .into_iter() + .map(|acc| acc.to_ring_with_params(params)) + .collect() + }) + .collect() +} + +/// Column-tiled mat-vec for a single pre-decomposed i8 digit vector. +/// +/// Same tiling strategy as [`mat_vec_mul_ntt_i8`] but for a single +/// input vector of i8 digit planes (already decomposed). Tiles the matrix +/// columns to keep each tile in L2, eliminating the full `ntt_vec` +/// materialization of the non-tiled path. +/// Tile width is auto-computed from ring parameters and target L2 cache size. +#[tracing::instrument(skip_all, name = "mat_vec_mul_ntt_single_i8")] +pub fn mat_vec_mul_ntt_single_i8( + slot: &NttSlotCache, + vec: &[[i8; D]], +) -> Vec> { + match slot { + NttSlotCache::Q32 { neg, params: p, .. } => mat_vec_mul_single_i8_with_params(neg, vec, p), + NttSlotCache::Q64 { neg, params: p, .. } => mat_vec_mul_single_i8_with_params(neg, vec, p), + NttSlotCache::Q128 { neg, params: p, .. } => mat_vec_mul_single_i8_with_params(neg, vec, p), + } +} + +/// Cyclic-domain variant of [`mat_vec_mul_ntt_single_i8`]. +#[tracing::instrument(skip_all, name = "mat_vec_mul_ntt_single_i8_cyclic")] +pub fn mat_vec_mul_ntt_single_i8_cyclic( + slot: &NttSlotCache, + vec: &[[i8; D]], +) -> Vec> { + match slot { + NttSlotCache::Q32 { cyc, params: p, .. } => { + mat_vec_mul_single_i8_cyclic_with_params(cyc, vec, p) + } + NttSlotCache::Q64 { cyc, params: p, .. } => { + mat_vec_mul_single_i8_cyclic_with_params(cyc, vec, p) + } + NttSlotCache::Q128 { cyc, params: p, .. } => { + mat_vec_mul_single_i8_cyclic_with_params(cyc, vec, p) + } + } +} + +fn mat_vec_mul_single_i8_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + ntt_mat: &[Vec>], + vec: &[[i8; D]], + params: &CrtNttParamSet, +) -> Vec> { + let n_a = ntt_mat.len(); + let inner_width = ntt_mat.first().map_or(0, |row| row.len()); + if inner_width == 0 || n_a == 0 { + return vec![CyclotomicRing::::zero(); n_a]; + } + + let lut = DigitMontLut::new(params); + let vec_len = vec.len().min(inner_width); + let tw = (TARGET_L2_CACHE_BYTES / (K * D * size_of::())).max(1); + let num_tiles = vec_len.div_ceil(tw); + + let final_accs: Vec> = cfg_fold_reduce!( + 0..num_tiles, + || vec![CyclotomicCrtNtt::::zero(); n_a], + |mut accs: Vec>, tile_idx| { + let tile_start = tile_idx * tw; + let tile_end = (tile_start + tw).min(vec_len); + for (j, digit) in vec[tile_start..tile_end].iter().enumerate() { + if is_zero_plane(digit) { + continue; + } + let ntt_d = CyclotomicCrtNtt::from_i8_with_lut(digit, params, &lut); + for (acc, mat_row) in accs.iter_mut().zip(ntt_mat.iter()) { + accumulate_pointwise_product_into( + acc, + &mat_row[tile_start + j], + &ntt_d, + params, + ); + } + } + accs + }, + |mut a: Vec>, b| { + for row in 0..n_a { + add_ntt_into(&mut a[row], &b[row], params); + } + a + } + ); + + final_accs + .into_iter() + .map(|acc| acc.to_ring_with_params(params)) + .collect() +} + +fn mat_vec_mul_single_i8_cyclic_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + ntt_mat: &[Vec>], + vec: &[[i8; D]], + params: &CrtNttParamSet, +) -> Vec> { + let n_a = ntt_mat.len(); + let inner_width = ntt_mat.first().map_or(0, |row| row.len()); + if inner_width == 0 || n_a == 0 { + return vec![CyclotomicRing::::zero(); n_a]; + } + + let lut = DigitMontLut::new(params); + let vec_len = vec.len().min(inner_width); + let tw = (TARGET_L2_CACHE_BYTES / (K * D * size_of::())).max(1); + let num_tiles = vec_len.div_ceil(tw); + + let final_accs: Vec> = cfg_fold_reduce!( + 0..num_tiles, + || vec![CyclotomicCrtNtt::::zero(); n_a], + |mut accs: Vec>, tile_idx| { + let tile_start = tile_idx * tw; + let tile_end = (tile_start + tw).min(vec_len); + for (j, digit) in vec[tile_start..tile_end].iter().enumerate() { + if is_zero_plane(digit) { + continue; + } + let ntt_d = CyclotomicCrtNtt::from_i8_cyclic_with_lut(digit, params, &lut); + for (acc, mat_row) in accs.iter_mut().zip(ntt_mat.iter()) { + accumulate_pointwise_product_into( + acc, + &mat_row[tile_start + j], + &ntt_d, + params, + ); + } + } + accs + }, + |mut a: Vec>, b| { + for row in 0..n_a { + add_ntt_into(&mut a[row], &b[row], params); + } + a + } + ); + + final_accs + .into_iter() + .map(|acc| acc.to_ring_cyclic(params)) + .collect() +} + +/// Like [`unreduced_quotient_rows_ntt_cached`] but accepts i8 digit planes +/// instead of ring elements, using direct i8 -> CRT+NTT conversion. +/// Column-tiled with zero-skip for all-zero digit planes. +#[tracing::instrument(skip_all, name = "unreduced_quotient_rows_ntt_cached_i8")] +pub fn unreduced_quotient_rows_ntt_cached_i8( + slot: &NttSlotCache, + vec: &[[i8; D]], +) -> Vec> { + match slot { + NttSlotCache::Q32 { + neg, + cyc, + params: p, + } => quotient_single_i8_with_params(neg, cyc, vec, p), + NttSlotCache::Q64 { + neg, + cyc, + params: p, + } => quotient_single_i8_with_params(neg, cyc, vec, p), + NttSlotCache::Q128 { + neg, + cyc, + params: p, + } => quotient_single_i8_with_params(neg, cyc, vec, p), + } +} + +fn quotient_single_i8_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + ntt_neg: &[Vec>], + ntt_cyc: &[Vec>], + vec: &[[i8; D]], + params: &CrtNttParamSet, +) -> Vec> { + let n_a = ntt_neg.len(); + let inner_width = ntt_neg.first().map_or(0, |row| row.len()); + if inner_width == 0 || n_a == 0 { + return vec![CyclotomicRing::::zero(); n_a]; + } + + let lut = DigitMontLut::new(params); + let vec_len = vec.len().min(inner_width); + let tw = (TARGET_L2_CACHE_BYTES / (K * D * size_of::())).max(1); + let num_tiles = vec_len.div_ceil(tw); + + let zero = CyclotomicCrtNtt::::zero(); + + let (final_neg, final_cyc): ( + Vec>, + Vec>, + ) = cfg_fold_reduce!( + 0..num_tiles, + || (vec![zero.clone(); n_a], vec![zero.clone(); n_a]), + |mut accs: ( + Vec>, + Vec> + ), + tile_idx| { + let tile_start = tile_idx * tw; + let tile_end = (tile_start + tw).min(vec_len); + for (j, digit) in vec[tile_start..tile_end].iter().enumerate() { + if is_zero_plane(digit) { + continue; + } + let ntt_d_neg = CyclotomicCrtNtt::from_i8_with_lut(digit, params, &lut); + let ntt_d_cyc = CyclotomicCrtNtt::from_i8_cyclic_with_lut(digit, params, &lut); + let col = tile_start + j; + for (row, (acc_neg, acc_cyc)) in + accs.0.iter_mut().zip(accs.1.iter_mut()).enumerate() + { + accumulate_pointwise_product_into( + acc_neg, + &ntt_neg[row][col], + &ntt_d_neg, + params, + ); + accumulate_pointwise_product_into( + acc_cyc, + &ntt_cyc[row][col], + &ntt_d_cyc, + params, + ); + } + } + accs + }, + |mut a: ( + Vec>, + Vec> + ), + b| { + for row in 0..n_a { + add_ntt_into(&mut a.0[row], &b.0[row], params); + add_ntt_into(&mut a.1[row], &b.1[row], params); + } + a + } + ); + + final_neg + .into_iter() + .zip(final_cyc) + .map(|(neg_acc, cyc_acc)| { + let neg_ring: CyclotomicRing = neg_acc.to_ring_with_params(params); + let cyc_ring: CyclotomicRing = cyc_acc.to_ring_cyclic(params); + let neg_c = neg_ring.coefficients(); + let cyc_c = cyc_ring.coefficients(); + let q: [F; D] = from_fn(|k| (cyc_c[k] - neg_c[k]) * F::TWO_INV); + CyclotomicRing::from_coefficients(q) + }) + .collect() +} + +fn quotient_single_centered_i32_with_params< + F: FieldCore + CanonicalField, + W: PrimeWidth, + const K: usize, + const D: usize, +>( + ntt_neg: &[Vec>], + ntt_cyc: &[Vec>], + vec: &[[i32; D]], + max_abs: u32, + params: &CrtNttParamSet, +) -> Vec> { + let n_a = ntt_neg.len(); + let inner_width = ntt_neg.first().map_or(0, |row| row.len()); + if inner_width == 0 || n_a == 0 { + return vec![CyclotomicRing::::zero(); n_a]; + } + + let vec_len = vec.len().min(inner_width); + let tw = (TARGET_L2_CACHE_BYTES / (K * D * size_of::())).max(1); + let num_tiles = vec_len.div_ceil(tw); + let zero = CyclotomicCrtNtt::::zero(); + let centered_lut = (max_abs <= CENTERED_LUT_MAX_ABS) + .then(|| CenteredMontLut::::new(params, max_abs as i32)); + + let (final_neg, final_cyc): ( + Vec>, + Vec>, + ) = cfg_fold_reduce!( + 0..num_tiles, + || (vec![zero.clone(); n_a], vec![zero.clone(); n_a]), + |mut accs: ( + Vec>, + Vec> + ), + tile_idx| { + let tile_start = tile_idx * tw; + let tile_end = (tile_start + tw).min(vec_len); + for (j, coeffs) in vec[tile_start..tile_end].iter().enumerate() { + if is_zero_centered_row(coeffs) { + continue; + } + let (ntt_d_neg, ntt_d_cyc) = if let Some(lut) = centered_lut.as_ref() { + CyclotomicCrtNtt::from_centered_i32_pair_with_lut(coeffs, params, lut) + } else { + CyclotomicCrtNtt::from_centered_i32_pair_with_params(coeffs, params) + }; + let col = tile_start + j; + for (row, (acc_neg, acc_cyc)) in + accs.0.iter_mut().zip(accs.1.iter_mut()).enumerate() + { + accumulate_pointwise_product_into( + acc_neg, + &ntt_neg[row][col], + &ntt_d_neg, + params, + ); + accumulate_pointwise_product_into( + acc_cyc, + &ntt_cyc[row][col], + &ntt_d_cyc, + params, + ); + } + } + accs + }, + |mut a: ( + Vec>, + Vec> + ), + b| { + for row in 0..n_a { + add_ntt_into(&mut a.0[row], &b.0[row], params); + add_ntt_into(&mut a.1[row], &b.1[row], params); + } + a + } + ); + + final_neg + .into_iter() + .zip(final_cyc) + .map(|(neg_acc, cyc_acc)| { + let neg_ring: CyclotomicRing = neg_acc.to_ring_with_params(params); + let cyc_ring: CyclotomicRing = cyc_acc.to_ring_cyclic(params); + let neg_c = neg_ring.coefficients(); + let cyc_c = cyc_ring.coefficients(); + let q: [F; D] = from_fn(|k| (cyc_c[k] - neg_c[k]) * F::TWO_INV); + CyclotomicRing::from_coefficients(q) + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::{ + mat_vec_mul_crt_ntt, mat_vec_mul_crt_ntt_i8_labrador_cross, mat_vec_mul_crt_ntt_i8_many, + mat_vec_mul_crt_ntt_i8_single, mat_vec_mul_crt_ntt_many, mat_vec_mul_digits_i8_with_params, + mat_vec_mul_i8_with_params, mat_vec_mul_unchecked, precompute_dense_mat_ntt_with_params, + }; + use crate::algebra::{CyclotomicRing, Fp64}; + use crate::protocol::commitment::utils::crt_ntt::{ + select_crt_ntt_params, ProtocolCrtNttParams, + }; + use crate::FromSmallInt; + + #[test] + fn dense_mat_vec_matches_schoolbook_q32_d64() { + type F = Fp64<4294967197>; + const D: usize = 64; + let mat: Vec>> = (0..3) + .map(|i| { + (0..4) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + F::from_u64((i as u64 * 10_000 + j as u64 * 100 + k as u64 + 1) % 97) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + let vec: Vec> = (0..4) + .map(|j| { + let coeffs = + std::array::from_fn(|k| F::from_u64((j as u64 * 50 + k as u64 + 3) % 89)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + + let schoolbook = mat_vec_mul_unchecked(&mat, &vec); + let crt_ntt = mat_vec_mul_crt_ntt(&mat, &vec).expect("Q32 dispatch should succeed"); + assert_eq!(schoolbook, crt_ntt); + } + + #[test] + fn dense_mat_vec_matches_schoolbook_q64_dispatch_for_large_d() { + type F = Fp64<4294967197>; + const D: usize = 128; + let mat: Vec>> = (0..2) + .map(|i| { + (0..2) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + F::from_u64((i as u64 * 20_000 + j as u64 * 300 + k as u64 + 7) % 113) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + let vec: Vec> = (0..2) + .map(|j| { + let coeffs = + std::array::from_fn(|k| F::from_u64((j as u64 * 70 + k as u64 + 11) % 101)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + + let schoolbook = mat_vec_mul_unchecked(&mat, &vec); + let crt_ntt = mat_vec_mul_crt_ntt(&mat, &vec).expect("Q64 dispatch should succeed"); + assert_eq!(schoolbook, crt_ntt); + } + + #[test] + fn dense_mat_vec_many_matches_individual_crt_ntt_q32_d64() { + type F = Fp64<4294967197>; + const D: usize = 64; + let mat: Vec>> = (0..3) + .map(|i| { + (0..4) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + F::from_u64((i as u64 * 10_000 + j as u64 * 100 + k as u64 + 1) % 97) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + + let vecs: Vec>> = (0..3) + .map(|seed| { + (0..4) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + F::from_u64((seed as u64 * 700 + j as u64 * 50 + k as u64 + 3) % 89) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + + let expected: Vec>> = vecs + .iter() + .map(|v| mat_vec_mul_crt_ntt(&mat, v).expect("single CRT+NTT mat-vec should succeed")) + .collect(); + + let got = + mat_vec_mul_crt_ntt_many(&mat, &vecs).expect("batched CRT+NTT mat-vec should succeed"); + assert_eq!(expected, got); + } + + #[test] + fn labrador_cross_i8_matches_generic_i8_many_pack_q32_d64() { + type F = Fp64<4294967197>; + const D: usize = 64; + let r = 4usize; + let w = 6usize; + + let mat: Vec>> = (0..r) + .map(|i| { + (0..w) + .map(|j| { + CyclotomicRing::from_coefficients(std::array::from_fn(|k| { + let raw = ((i as i64 * 13 + j as i64 * 7 + k as i64) % 9) - 4; + F::from_i64(raw) + })) + }) + .collect() + }) + .collect(); + let vecs: Vec> = (0..r) + .map(|i| { + (0..w) + .map(|j| std::array::from_fn(|k| ((i + j + k) % 7) as i8 - 3)) + .collect() + }) + .collect(); + + let cross = + mat_vec_mul_crt_ntt_i8_many(&mat, &vecs).expect("generic i8 CRT+NTT should succeed"); + let mut expected = Vec::with_capacity(r * (r + 1) / 2); + for i in 0..r { + expected.push(cross[i][i]); + for j in i + 1..r { + expected.push(cross[i][j] + cross[j][i]); + } + } + + let got = mat_vec_mul_crt_ntt_i8_labrador_cross(&mat, &vecs) + .expect("labrador cross i8 kernel should succeed"); + assert_eq!(got, expected); + } + + #[test] + fn single_i8_matvec_matches_generic_i8_many_q32_d64() { + type F = Fp64<4294967197>; + const D: usize = 64; + let width = 9usize; + let row: Vec> = (0..width) + .map(|j| { + CyclotomicRing::from_coefficients(std::array::from_fn(|k| { + let raw = ((11 * j as i64 + 5 * k as i64) % 11) - 5; + F::from_i64(raw) + })) + }) + .collect(); + let vec: Vec<[i8; D]> = (0..width) + .map(|j| std::array::from_fn(|k| ((3 * j + k) % 9) as i8 - 4)) + .collect(); + + let expected = + mat_vec_mul_crt_ntt_i8_many(std::slice::from_ref(&row), std::slice::from_ref(&vec)) + .expect("generic i8-many mat-vec should succeed")[0][0]; + let got = + mat_vec_mul_crt_ntt_i8_single(&row, &vec).expect("single i8 mat-vec should succeed"); + assert_eq!(got, expected); + } + + #[test] + fn mat_vec_mul_digits_i8_matches_num_digits_one_roundtrip() { + type F = Fp64<4294967197>; + const D: usize = 64; + let log_basis = 3; + + let mat: Vec>> = (0..3) + .map(|i| { + (0..6) + .map(|j| { + let coeffs = std::array::from_fn(|k| { + let raw = (i as i64 * 19 + j as i64 * 7 + k as i64) % 7; + F::from_i64(raw - 3) + }); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + + let digit_blocks: Vec> = vec![ + (0..6) + .map(|j| std::array::from_fn(|k| ((j + 2 * k) % 7) as i8 - 3)) + .collect(), + (0..4) + .map(|j| std::array::from_fn(|k| ((2 * j + k) % 7) as i8 - 3)) + .collect(), + vec![], + ]; + + let ring_blocks: Vec>> = digit_blocks + .iter() + .map(|block| { + block + .iter() + .map(|digit| { + let coeffs = std::array::from_fn(|k| F::from_i64(digit[k] as i64)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect(); + + let ring_block_slices: Vec<&[CyclotomicRing]> = + ring_blocks.iter().map(Vec::as_slice).collect(); + let digit_block_slices: Vec<&[[i8; D]]> = digit_blocks.iter().map(Vec::as_slice).collect(); + + match select_crt_ntt_params::().expect("CRT+NTT params should exist") { + ProtocolCrtNttParams::Q32(params) => { + let ntt_mat = precompute_dense_mat_ntt_with_params(&mat, ¶ms); + let via_roundtrip = + mat_vec_mul_i8_with_params(&ntt_mat, &ring_block_slices, 1, log_basis, ¶ms); + let direct = + mat_vec_mul_digits_i8_with_params(&ntt_mat, &digit_block_slices, ¶ms); + assert_eq!(via_roundtrip, direct); + } + _ => panic!("unexpected parameter family"), + } + } +} diff --git a/src/protocol/commitment/utils/math.rs b/src/protocol/commitment/utils/math.rs new file mode 100644 index 00000000..16cf8618 --- /dev/null +++ b/src/protocol/commitment/utils/math.rs @@ -0,0 +1,14 @@ +//! Small math helpers for commitment internals. + +use crate::error::HachiError; + +/// Compute `2^exp` with overflow checks. +/// +/// # Errors +/// +/// Returns `InvalidSetup` if `2^exp` does not fit in `usize`. +pub(in crate::protocol::commitment) fn checked_pow2(exp: usize) -> Result { + 1usize + .checked_shl(exp as u32) + .ok_or_else(|| HachiError::InvalidSetup(format!("2^{exp} does not fit usize"))) +} diff --git a/src/protocol/commitment/utils/matrix.rs b/src/protocol/commitment/utils/matrix.rs new file mode 100644 index 00000000..9be13b11 --- /dev/null +++ b/src/protocol/commitment/utils/matrix.rs @@ -0,0 +1,134 @@ +//! Matrix sampling helpers for setup. + +use crate::algebra::ring::CyclotomicRing; +use crate::{FieldCore, FieldSampling}; +use rand_core::{CryptoRng, RngCore}; +use sha3::digest::{ExtendableOutput, Update, XofReader}; +use sha3::Shake256; + +/// Public seed used to derive commitment matrices. +pub(crate) type PublicMatrixSeed = [u8; 32]; + +const PUBLIC_MATRIX_DOMAIN: &[u8] = b"hachi/commitment/public-matrix"; + +/// Fixed public seed for deterministic, reproducible setup. +pub(crate) fn sample_public_matrix_seed() -> PublicMatrixSeed { + let mut seed = [0u8; 32]; + seed[..8].copy_from_slice(&0xDEAD_BEEF_CAFE_BABEu64.to_le_bytes()); + seed +} + +/// Derive a public matrix from a seed using domain-separated SHAKE expansion. +/// +/// This follows the same high-level pattern used in NIST lattice specs: +/// derive deterministic public structure from a seed + indices, then sample +/// coefficients via rejection-sampling at the field layer. +/// +/// NOTE: Potential future hardening: +/// move toward stricter ML-KEM/ML-DSA-style byte layout and parsing rules +/// (fixed-format seed/index encoding and scheme-specific expansion details) +/// if we decide to maximize standards-shape interoperability. +pub(crate) fn derive_public_matrix( + rows: usize, + cols: usize, + seed: &PublicMatrixSeed, + matrix_label: &[u8], +) -> Vec>> { + (0..rows) + .map(|r| { + (0..cols) + .map(|c| { + let mut entry_rng = ShakeXofRng::new(seed, matrix_label, rows, cols, r, c); + CyclotomicRing::random(&mut entry_rng) + }) + .collect() + }) + .collect() +} + +struct ShakeXofRng { + reader: Box, +} + +impl ShakeXofRng { + // Dimensions (`rows`, `cols`) are intentionally excluded from the domain + // separator so that a matrix derived at one size is a prefix of the same + // matrix derived at a larger size. Each entry is uniquely identified by + // `(seed, matrix_label, row, col)`, which is sufficient for collision + // resistance while enabling setup reuse across poly/mega-poly layouts. + fn new( + seed: &PublicMatrixSeed, + matrix_label: &[u8], + _rows: usize, + _cols: usize, + row: usize, + col: usize, + ) -> Self { + let mut xof = Shake256::default(); + absorb_len_prefixed(&mut xof, b"domain", PUBLIC_MATRIX_DOMAIN); + absorb_len_prefixed(&mut xof, b"seed", seed); + absorb_len_prefixed(&mut xof, b"matrix", matrix_label); + absorb_len_prefixed(&mut xof, b"row", &(row as u64).to_le_bytes()); + absorb_len_prefixed(&mut xof, b"col", &(col as u64).to_le_bytes()); + Self { + reader: Box::new(xof.finalize_xof()), + } + } +} + +impl RngCore for ShakeXofRng { + fn next_u32(&mut self) -> u32 { + let mut buf = [0u8; 4]; + self.fill_bytes(&mut buf); + u32::from_le_bytes(buf) + } + + fn next_u64(&mut self) -> u64 { + let mut buf = [0u8; 8]; + self.fill_bytes(&mut buf); + u64::from_le_bytes(buf) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.reader.read(dest); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +impl CryptoRng for ShakeXofRng {} + +fn absorb_len_prefixed(xof: &mut Shake256, label: &[u8], data: &[u8]) { + xof.update(&(label.len() as u64).to_le_bytes()); + xof.update(label); + xof.update(&(data.len() as u64).to_le_bytes()); + xof.update(data); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Fp64; + + type F = Fp64<4294967197>; + const D: usize = 64; + + #[test] + fn matrix_derivation_is_deterministic_for_same_seed() { + let seed = [42u8; 32]; + let m1 = derive_public_matrix::(3, 5, &seed, b"A"); + let m2 = derive_public_matrix::(3, 5, &seed, b"A"); + assert_eq!(m1, m2); + } + + #[test] + fn matrix_derivation_domain_separates_labels() { + let seed = [7u8; 32]; + let a = derive_public_matrix::(2, 3, &seed, b"A"); + let b = derive_public_matrix::(2, 3, &seed, b"B"); + assert_ne!(a, b); + } +} diff --git a/src/protocol/commitment/utils/mod.rs b/src/protocol/commitment/utils/mod.rs new file mode 100644 index 00000000..6e768fb2 --- /dev/null +++ b/src/protocol/commitment/utils/mod.rs @@ -0,0 +1,9 @@ +//! Utility helpers for commitment internals. + +pub mod crt_ntt; +pub mod flat_matrix; +pub mod linear; +pub(crate) mod math; +pub(crate) mod matrix; +pub(crate) mod norm; +pub mod ntt_cache; diff --git a/src/protocol/commitment/utils/norm.rs b/src/protocol/commitment/utils/norm.rs new file mode 100644 index 00000000..5a2d4030 --- /dev/null +++ b/src/protocol/commitment/utils/norm.rs @@ -0,0 +1,51 @@ +//! Infinity norm utilities for ring elements over Z_q. + +use crate::algebra::ring::CyclotomicRing; +use crate::CanonicalField; + +/// Detect the field modulus from the canonical representation. +/// +/// Uses the identity: the canonical form of `−1` in `Z_q` is `q − 1`. +pub(crate) fn detect_field_modulus() -> u128 { + (-F::one()).to_canonical_u128() + 1 +} + +/// Centered absolute value of a field element. +/// +/// Maps canonical representation `v ∈ [0, q)` to `min(v, q − v)`. +#[inline] +#[allow(dead_code)] +pub(crate) fn centered_abs(x: F, modulus: u128) -> u128 { + let v = x.to_canonical_u128(); + let half = modulus / 2; + if v <= half { + v + } else { + modulus - v + } +} + +/// L∞ norm of a single ring element (maximum centered coefficient magnitude). +#[allow(dead_code)] +pub(crate) fn ring_inf_norm( + r: &CyclotomicRing, + modulus: u128, +) -> u128 { + r.coefficients() + .iter() + .map(|c| centered_abs(*c, modulus)) + .max() + .unwrap_or(0) +} + +/// L∞ norm of a vector of ring elements. +#[allow(dead_code)] +pub(crate) fn vec_inf_norm( + v: &[CyclotomicRing], + modulus: u128, +) -> u128 { + v.iter() + .map(|r| ring_inf_norm(r, modulus)) + .max() + .unwrap_or(0) +} diff --git a/src/protocol/commitment/utils/ntt_cache.rs b/src/protocol/commitment/utils/ntt_cache.rs new file mode 100644 index 00000000..6e29bd26 --- /dev/null +++ b/src/protocol/commitment/utils/ntt_cache.rs @@ -0,0 +1,105 @@ +//! Multi-D NTT cache management. +//! +//! Wraps per-D [`NttSlotCache`] bundles with lazy computation and memoization. +//! A single [`MultiDNttCaches`] can hold NTT caches for any subset of supported +//! ring dimensions, built on demand from a shared [`FlatMatrix`]. + +use super::crt_ntt::{build_ntt_slot, NttSlotCache}; +use super::flat_matrix::FlatMatrix; +use crate::error::HachiError; +use crate::{CanonicalField, FieldCore}; + +/// Per-matrix NTT caches for multiple ring dimensions. +/// +/// Each field is lazily populated by the `get_or_build_*` methods. +/// Fields use `Box>` to keep the struct's inline size +/// small: `NttSlotCache<1024>` alone is ~80 KB due to inline twiddle +/// arrays, so storing them unboxed would make this struct ~155 KB. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MultiDNttCaches { + /// Cache for D=64. + pub d64: Option>>, + /// Cache for D=128. + pub d128: Option>>, + /// Cache for D=256. + pub d256: Option>>, + /// Cache for D=512. + pub d512: Option>>, + /// Cache for D=1024. + pub d1024: Option>>, +} + +macro_rules! impl_get_or_build { + ($fn_name:ident, $field:ident, $d_val:expr) => { + /// Get (or build and memoize) the NTT cache for this ring dimension. + /// + /// # Errors + /// + /// Returns an error if no CRT+NTT parameter set matches the field and D. + pub fn $fn_name( + &mut self, + mat: &FlatMatrix, + ) -> Result<&NttSlotCache<$d_val>, HachiError> { + if self.$field.is_none() { + self.$field = Some(Box::new(build_ntt_slot(mat.view::<$d_val>())?)); + } + Ok(self.$field.as_deref().unwrap()) + } + }; +} + +impl MultiDNttCaches { + /// Empty cache set. + pub fn new() -> Self { + Self { + d64: None, + d128: None, + d256: None, + d512: None, + d1024: None, + } + } + + impl_get_or_build!(get_or_build_64, d64, 64); + impl_get_or_build!(get_or_build_128, d128, 128); + impl_get_or_build!(get_or_build_256, d256, 256); + impl_get_or_build!(get_or_build_512, d512, 512); + impl_get_or_build!(get_or_build_1024, d1024, 1024); + + /// Check if a cache for dimension `d` is already populated. + pub fn has(&self, d: usize) -> bool { + match d { + 64 => self.d64.is_some(), + 128 => self.d128.is_some(), + 256 => self.d256.is_some(), + 512 => self.d512.is_some(), + 1024 => self.d1024.is_some(), + _ => false, + } + } +} + +impl Default for MultiDNttCaches { + fn default() -> Self { + Self::new() + } +} + +/// Bundle of three multi-D NTT caches for the A, B, and D matrices. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[allow(non_snake_case)] +pub struct MultiDNttBundle { + /// NTT caches for the A matrix at various ring dimensions. + pub A: MultiDNttCaches, + /// NTT caches for the B matrix at various ring dimensions. + pub B: MultiDNttCaches, + /// NTT caches for the D matrix at various ring dimensions. + pub D_mat: MultiDNttCaches, +} + +impl MultiDNttBundle { + /// Empty bundle. + pub fn new() -> Self { + Self::default() + } +} diff --git a/src/protocol/commitment_scheme.rs b/src/protocol/commitment_scheme.rs new file mode 100644 index 00000000..d199c1cb --- /dev/null +++ b/src/protocol/commitment_scheme.rs @@ -0,0 +1,2295 @@ +//! Commitment scheme trait implementation. + +use crate::algebra::fields::wide::HasWide; +use crate::algebra::fields::HasUnreducedOps; +use crate::algebra::CyclotomicRing; +#[cfg(debug_assertions)] +use crate::algebra::SparseChallenge; +use crate::error::HachiError; +use crate::primitives::serialization::Valid; +use crate::protocol::commitment::utils::crt_ntt::NttSlotCache; +use crate::protocol::commitment::utils::linear::{flatten_i8_blocks, mat_vec_mul_ntt_single_i8}; +use crate::protocol::commitment::utils::ntt_cache::{MultiDNttBundle, MultiDNttCaches}; +use crate::protocol::commitment::{ + AppendToTranscript, CommitmentConfig, CommitmentScheme, HachiCommitmentCore, + HachiCommitmentLayout, HachiExpandedSetup, HachiLevelParams, HachiProverSetup, + HachiScheduleInputs, HachiVerifierSetup, RingCommitment, RingCommitmentScheme, +}; +use crate::protocol::hachi_poly_ops::{BalancedDigitPoly, HachiPolyOps}; +use crate::protocol::labrador_handoff::{labrador_handoff_prove, labrador_handoff_verify}; +#[cfg(debug_assertions)] +use crate::protocol::opening_point::RingOpeningPoint; +use crate::protocol::opening_point::{ + reduce_inner_opening_to_ring_element, ring_opening_point_from_field, BasisMode, +}; +use crate::protocol::proof::{ + FlatCommitmentHint, FlatRingVec, HachiCommitmentHint, HachiLevelProof, HachiProof, + HachiProofTail, LabradorTail, PackedDigits, +}; +#[cfg(any(test, debug_assertions))] +use crate::protocol::quadratic_equation::compute_m_a_reference; +use crate::protocol::quadratic_equation::QuadraticEquation; +#[cfg(debug_assertions)] +use crate::protocol::ring_switch::eval_ring_at; +#[cfg(debug_assertions)] +use crate::protocol::ring_switch::m_row_count; +use crate::protocol::ring_switch::{ + build_w_evals, commit_w, ring_switch_build_w, ring_switch_finalize, ring_switch_verifier, + w_ring_element_count, RingSwitchOutput, WCommitmentConfig, +}; +#[cfg(debug_assertions)] +use crate::protocol::sumcheck::eq_poly::EqPolynomial; +#[cfg(debug_assertions)] +use crate::protocol::sumcheck::hachi_stage1::{ + range_check_eval_from_s, HachiStage1Prover, HachiStage1Verifier, +}; +#[cfg(not(debug_assertions))] +use crate::protocol::sumcheck::hachi_stage1::{HachiStage1Prover, HachiStage1Verifier}; +use crate::protocol::sumcheck::hachi_stage2::{ + relation_claim_from_rows, HachiStage2Prover, HachiStage2Verifier, +}; +#[cfg(debug_assertions)] +use crate::protocol::sumcheck::multilinear_eval; +use crate::protocol::sumcheck::{ + prove_sumcheck, verify_sumcheck, SumcheckInstanceProver, SumcheckInstanceVerifier, +}; +use crate::protocol::transcript::labels::{ + ABSORB_COMMITMENT, ABSORB_EVALUATION_CLAIMS, ABSORB_SUMCHECK_S_CLAIM, CHALLENGE_SUMCHECK_BATCH, + CHALLENGE_SUMCHECK_ROUND, +}; +use crate::protocol::transcript::Transcript; +use crate::{dispatch_ring_dim, dispatch_with_d_ntt, dispatch_with_ntt}; +use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt}; +#[cfg(debug_assertions)] +use std::iter; +use std::marker::PhantomData; +use std::time::Instant; + +#[cfg(test)] +use crate::protocol::SmallTestCommitmentConfig; +#[cfg(test)] +use crate::{HachiDeserialize, HachiSerialize}; + +/// Minimum w vector length (in field elements) below which further folding +/// is not beneficial. When `w.len() <= MIN_W_LEN_FOR_FOLDING`, the prover +/// sends `w` directly instead of recursing. +const MIN_W_LEN_FOR_FOLDING: usize = 4096; + +/// Minimum shrink ratio (next_w / prev_w) below which further folding +/// stops being worthwhile. If the w vector doesn't shrink by at least +/// this factor, the overhead of another fold level outweighs the saving. +const MIN_SHRINK_RATIO: f64 = 0.5; + +/// End-to-end PCS wrapper, generic over ring degree `D` and config `Cfg`. +#[derive(Clone, Copy, Debug, Default)] +pub struct HachiCommitmentScheme { + _cfg: PhantomData, +} + +/// Output from a single prove level, needed to chain into the next level. +/// +/// D-agnostic: ring elements are erased into [`HachiLevelProof`] and +/// the commitment hint is stored as [`FlatCommitmentHint`]. +struct ProveLevelOutput { + level_proof: HachiLevelProof, + w: Vec, + w_hint: FlatCommitmentHint, + sumcheck_challenges: Vec, + num_u: usize, + num_l: usize, +} + +/// Prove one fold level: quad_eq -> ring_switch -> sumcheck. +/// +/// Generic over the commitment config so it works for both the original +/// polynomial (using `Cfg`) and recursive w-openings (using `WCommitmentConfig`). +type CommitFn<'a, F> = Box< + dyn FnOnce( + &[i8], + HachiScheduleInputs, + ) -> Result<(FlatRingVec, FlatCommitmentHint), HachiError> + + 'a, +>; + +#[cfg(debug_assertions)] +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn prove_level_diagnostic( + expanded: &HachiExpandedSetup, + opening_point: &RingOpeningPoint, + challenges: &[SparseChallenge], + rs: &RingSwitchOutput, + v: &[CyclotomicRing], + u: &[CyclotomicRing], + y_ring: &CyclotomicRing, + level_params: HachiLevelParams, + layout: HachiCommitmentLayout, + level: usize, +) where + F: FieldCore + CanonicalField + FieldSampling, +{ + let m_a = compute_m_a_reference::( + expanded, + opening_point, + challenges, + &rs.alpha, + level_params, + layout, + ) + .expect("compute_m_a diagnostic failed"); + + let x_len = 1usize << rs.num_u; + let d = D; + + let mut w_at_alpha = vec![F::zero(); x_len]; + for (x, w_at_alpha_x) in w_at_alpha.iter_mut().enumerate() { + let mut val = F::zero(); + for y in 0..d { + let idx = x + y * x_len; + if idx < rs.w_evals_compact.len() { + val += rs.alpha_evals_y[y] * F::from_i64(rs.w_evals_compact[idx] as i64); + } + } + *w_at_alpha_x = val; + } + + let num_rows = m_row_count(level_params); + let y_full: Vec = v + .iter() + .chain(u.iter()) + .chain(iter::once(y_ring)) + .map(|r| eval_ring_at(r, &rs.alpha)) + .collect(); + + tracing::debug!( + level, + num_rows, + x_len, + m_a_cols = m_a.first().map_or(0, |r| r.len()), + "per-row M*w=y diagnostic" + ); + for i in 0..num_rows { + let mw_i: F = m_a[i] + .iter() + .enumerate() + .fold(F::zero(), |acc, (x, &m_ix)| { + acc + m_ix * w_at_alpha.get(x).copied().unwrap_or(F::zero()) + }); + let y_i = if i < y_full.len() { + y_full[i] + } else { + F::zero() + }; + let residual = mw_i - y_i; + let row_name = match i { + _ if i < level_params.n_d => "D", + _ if i < level_params.n_d + level_params.n_b => "B", + _ if i == level_params.n_d + level_params.n_b => "bTw", + _ if i == level_params.n_d + level_params.n_b + 1 => "challenge_fold", + _ => "A", + }; + tracing::debug!( + row = i, + row_name, + matches = residual.is_zero(), + residual_is_zero = residual.is_zero(), + mw_is_zero = mw_i.is_zero(), + y_is_zero = y_i.is_zero(), + "diagnostic row" + ); + } + + let verifier_claim = relation_claim_from_rows::(&rs.tau1, rs.alpha, v, u, y_ring); + let x_mask = x_len - 1; + let mut prover_claim = F::zero(); + for (idx, &w) in rs.w_evals_compact.iter().enumerate() { + prover_claim += + F::from_i64(w as i64) * rs.alpha_evals_y[idx >> rs.num_u] * rs.m_evals_x[idx & x_mask]; + } + tracing::debug!( + level, + claims_match = (verifier_claim == prover_claim), + prover_is_zero = prover_claim.is_zero(), + verifier_is_zero = verifier_claim.is_zero(), + "relation_claim cross-check" + ); +} + +#[cfg(debug_assertions)] +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn prove_stage1_selfcheck( + tau0: &[F], + stage1_challenges: &[F], + s_claim: F, + b: usize, + final_claim: F, + level: usize, +) { + let eq_val = EqPolynomial::mle(tau0, stage1_challenges); + let oracle = eq_val * range_check_eval_from_s(s_claim, b); + if oracle != final_claim { + tracing::warn!( + level, + "PROVER stage-1 self-check FAILED: expected != final_claim" + ); + } else { + tracing::debug!(level, "PROVER stage-1 self-check OK"); + } +} + +#[cfg(debug_assertions)] +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn prove_stage2_selfcheck( + r_stage1: &[F], + sumcheck_challenges: &[F], + w_eval: F, + batching_coeff: F, + alpha_evals_y: &[F], + m_evals_x: &[F], + num_u: usize, + final_claim: F, + level: usize, +) { + let eq_val = EqPolynomial::mle(r_stage1, sumcheck_challenges); + let virtual_oracle = eq_val * w_eval * (w_eval + F::one()); + let (x_ch, y_ch) = sumcheck_challenges.split_at(num_u); + let alpha_val = multilinear_eval(alpha_evals_y, y_ch).unwrap(); + let m_val = multilinear_eval(m_evals_x, x_ch).unwrap(); + let relation_oracle = w_eval * alpha_val * m_val; + let prover_expected = batching_coeff * virtual_oracle + relation_oracle; + if prover_expected != final_claim { + tracing::warn!( + level, + "PROVER stage-2 self-check FAILED: expected != final_claim" + ); + } else { + tracing::debug!(level, "PROVER stage-2 self-check OK"); + } +} + +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn prove_one_level( + expanded: &HachiExpandedSetup, + ntt_a: &NttSlotCache, + ntt_b: &NttSlotCache, + ntt_d: &NttSlotCache, + commit_w_fn: CommitFn<'_, F>, + poly: &P, + max_num_vars: usize, + opening_point: &[F], + hint: HachiCommitmentHint, + transcript: &mut T, + commitment: &RingCommitment, + basis: BasisMode, + level: usize, + level_params: HachiLevelParams, + layout: HachiCommitmentLayout, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + HasUnreducedOps + HasWide, + T: Transcript, + Cfg: CommitmentConfig, + P: HachiPolyOps, +{ + { + let x: u8 = 0; + tracing::trace!( + stack_ptr = format_args!("{:#x}", &x as *const u8 as usize), + level, + "prove_one_level" + ); + } + let alpha = level_params.d.trailing_zeros() as usize; + if opening_point.len() < alpha { + return Err(HachiError::InvalidPointDimension { + expected: alpha, + actual: opening_point.len(), + }); + } + let target_num_vars = layout.m_vars + layout.r_vars + alpha; + let mut padded_point = opening_point.to_vec(); + padded_point.resize(target_num_vars, F::zero()); + let outer_point = &padded_point[alpha..]; + + let ring_opening_point = { + let _span = tracing::info_span!("ring_opening_point", level).entered(); + ring_opening_point_from_field::(outer_point, layout.r_vars, layout.m_vars, basis)? + }; + + let fold_scalars = &ring_opening_point.a; + let eval_outer_scalars = &ring_opening_point.b; + let (y_ring, w_folded) = { + let _span = tracing::info_span!( + "evaluate_and_fold", + level, + num_ring_elems = poly.num_ring_elems() + ) + .entered(); + poly.evaluate_and_fold(eval_outer_scalars, fold_scalars, layout.block_len) + }; + + commitment.append_to_transcript(ABSORB_COMMITMENT, transcript); + for pt in &padded_point { + transcript.append_field(ABSORB_EVALUATION_CLAIMS, pt); + } + transcript.append_serde(ABSORB_EVALUATION_CLAIMS, &y_ring); + + let mut quad_eq = Box::new(QuadraticEquation::::new_prover( + ntt_d, + ring_opening_point, + poly, + w_folded, + level_params, + hint, + transcript, + commitment, + &y_ring, + layout, + )?); + + let w = ring_switch_build_w::( + &mut quad_eq, + expanded, + ntt_a, + ntt_b, + ntt_d, + level_params, + layout, + )?; + let next_inputs = HachiScheduleInputs { + max_num_vars, + level: level + 1, + current_w_len: w.len(), + }; + + let (w_commitment_flat, w_hint_flat) = { + let _span = tracing::info_span!("commit_w_level", level).entered(); + commit_w_fn(&w, next_inputs)? + }; + + let rs = ring_switch_finalize::( + &quad_eq, + expanded, + transcript, + w, + w_commitment_flat, + w_hint_flat, + level_params, + layout, + )?; + + #[cfg(debug_assertions)] + prove_level_diagnostic::( + expanded, + quad_eq.opening_point(), + &quad_eq.challenges, + &rs, + &quad_eq.v, + &commitment.u, + &y_ring, + level_params, + layout, + level, + ); + + let relation_claim = + relation_claim_from_rows::(&rs.tau1, rs.alpha, &quad_eq.v, &commitment.u, &y_ring); + let RingSwitchOutput { + w, + w_commitment, + w_hint, + w_evals_compact, + live_x_cols, + m_evals_x, + alpha_evals_y, + num_u, + num_l, + tau0, + tau1: _, + b, + alpha: _, + } = rs; + #[cfg(debug_assertions)] + let alpha_evals_y_debug = alpha_evals_y.clone(); + #[cfg(debug_assertions)] + let m_evals_x_debug = m_evals_x.clone(); + let (stage1_sumcheck, r_stage1, s_claim) = { + let _sumcheck_span = tracing::info_span!("stage1_sumcheck").entered(); + let mut stage1_prover = + HachiStage1Prover::new(&w_evals_compact, &tau0, b, live_x_cols, num_u, num_l); + let (stage1_sumcheck, r_stage1, stage1_final_claim) = + prove_sumcheck::(&mut stage1_prover, transcript, |tr| { + tr.challenge_scalar(CHALLENGE_SUMCHECK_ROUND) + })?; + let s_claim = stage1_prover.final_s_claim(); + #[cfg(not(debug_assertions))] + let _ = stage1_final_claim; + + #[cfg(debug_assertions)] + prove_stage1_selfcheck(&tau0, &r_stage1, s_claim, b, stage1_final_claim, level); + + (stage1_sumcheck, r_stage1, s_claim) + }; + + transcript.append_serde(ABSORB_SUMCHECK_S_CLAIM, &s_claim); + let batching_coeff: F = transcript.challenge_scalar(CHALLENGE_SUMCHECK_BATCH); + let stage2_input_claim = batching_coeff * s_claim + relation_claim; + let (stage2_sumcheck, sumcheck_challenges, stage2_final_claim, w_eval) = { + let _sumcheck_span = tracing::info_span!("stage2_sumcheck").entered(); + let mut stage2_prover = HachiStage2Prover::new( + batching_coeff, + w_evals_compact, + &r_stage1, + s_claim, + alpha_evals_y, + m_evals_x, + live_x_cols, + num_u, + num_l, + relation_claim, + ); + debug_assert!(stage2_input_claim == SumcheckInstanceProver::input_claim(&stage2_prover)); + let (stage2_sumcheck, sumcheck_challenges, stage2_final_claim) = + prove_sumcheck::(&mut stage2_prover, transcript, |tr| { + tr.challenge_scalar(CHALLENGE_SUMCHECK_ROUND) + })?; + #[cfg(not(debug_assertions))] + let _ = stage2_final_claim; + + let w_eval = { + let _span = tracing::info_span!("multilinear_eval", level).entered(); + stage2_prover.final_w_eval() + }; + ( + stage2_sumcheck, + sumcheck_challenges, + stage2_final_claim, + w_eval, + ) + }; + + #[cfg(debug_assertions)] + prove_stage2_selfcheck( + &r_stage1, + &sumcheck_challenges, + w_eval, + batching_coeff, + &alpha_evals_y_debug, + &m_evals_x_debug, + num_u, + stage2_final_claim, + level, + ); + #[cfg(not(debug_assertions))] + let _ = stage2_final_claim; + + Ok(ProveLevelOutput { + level_proof: HachiLevelProof::new::( + y_ring, + quad_eq.v, + stage1_sumcheck, + s_claim, + stage2_sumcheck, + w_commitment, + w_eval, + ), + w, + w_hint, + sumcheck_challenges, + num_u, + num_l, + }) +} + +/// Whether the prover should stop folding and send `w` directly. +/// +/// `prev_w_len` is the polynomial length at the previous level (or the +/// original polynomial's field-element count for level 0). +fn should_stop_folding(w_len: usize, prev_w_len: usize) -> bool { + if w_len <= MIN_W_LEN_FOR_FOLDING { + return true; + } + let ratio = w_len as f64 / prev_w_len as f64; + ratio > MIN_SHRINK_RATIO +} + +/// Derive the opening point for the next fold level from the sumcheck +/// challenges of the current level. +/// +/// Sumcheck challenges are ordered `[x_0..x_{num_u-1}, y_0..y_{num_l-1}]` +/// where x selects ring elements and y selects coefficients. +/// The PCS opening point is `[inner, outer]` = `[y, x]`. +pub(crate) fn next_level_opening_point( + sumcheck_challenges: &[F], + num_u: usize, + num_l: usize, +) -> Vec { + let (x, y) = sumcheck_challenges.split_at(num_u); + debug_assert_eq!(y.len(), num_l); + let mut point = Vec::with_capacity(num_u + num_l); + point.extend_from_slice(y); + point.extend_from_slice(x); + point +} + +/// Dispatch a commit-w operation to the correct ring dimension. +/// +/// Each match arm builds NTT caches for the target D and calls `commit_w`. +/// `#[inline(never)]` isolates the match arms in their own stack frame, +/// preventing debug-mode stack bloat from monomorphized arms. +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn dispatch_commit( + commit_params: HachiLevelParams, + commit_ntt_bundle: &mut MultiDNttBundle, + expanded: &HachiExpandedSetup, + w: &[i8], +) -> Result<(FlatRingVec, FlatCommitmentHint), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, +{ + let commit_d = commit_params.d; + dispatch_with_ntt!( + commit_d, + commit_ntt_bundle, + expanded, + |D_COMMIT, ca, cb, _cd| { + let (wc, wh) = commit_w::>( + w, + ca, + cb, + commit_params, + )?; + Ok(( + FlatRingVec::from_commitment(&wc), + FlatCommitmentHint::from_typed(wh), + )) + } + ) +} + +/// Dispatch a prove-level operation to the correct ring dimension. +/// +/// Handles the fast-path (`level_d == D`) and the dynamic dispatch path. +/// `#[inline(never)]` isolates the monomorphized match arms in their own +/// stack frame. +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn dispatch_prove_level( + level_d: usize, + ntt_bundle: &mut MultiDNttBundle, + expanded: &HachiExpandedSetup, + setup_ntt_a: &NttSlotCache, + setup_ntt_b: &NttSlotCache, + setup_ntt_d: &NttSlotCache, + commit_ntt_bundle: &mut MultiDNttBundle, + max_num_vars: usize, + current_w: &[i8], + current_hint: &FlatCommitmentHint, + current_challenges: &[F], + current_num_u: usize, + current_num_l: usize, + last_w_commitment: &FlatRingVec, + last_w_eval: F, + transcript: &mut T, + level: usize, + level_params: HachiLevelParams, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + HasUnreducedOps + HasWide, + T: Transcript, + Cfg: CommitmentConfig, +{ + if level_d == D { + prove_subsequent_level::( + expanded, + setup_ntt_a, + setup_ntt_b, + setup_ntt_d, + commit_ntt_bundle, + max_num_vars, + current_w, + current_hint, + current_challenges, + current_num_u, + current_num_l, + last_w_commitment, + last_w_eval, + transcript, + level, + level_params, + ) + } else { + dispatch_with_ntt!( + level_d, + ntt_bundle, + expanded, + |D_LEVEL, ntt_a, ntt_b, ntt_d| { + prove_subsequent_level::( + expanded, + ntt_a, + ntt_b, + ntt_d, + commit_ntt_bundle, + max_num_vars, + current_w, + current_hint, + current_challenges, + current_num_u, + current_num_l, + last_w_commitment, + last_w_eval, + transcript, + level, + level_params, + ) + } + ) + } +} + +/// Dispatch a verify-level operation to the correct ring dimension. +/// +/// Each match arm converts the D-erased commitment to a typed one, +/// derives the w-commitment layout, and calls `verify_one_level`. +/// `#[inline(never)]` isolates the monomorphized match arms. +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn dispatch_verify_level( + level_d: usize, + level_proof: &HachiLevelProof, + setup: &HachiVerifierSetup, + transcript: &mut T, + opening_point: &[F], + opening: &F, + current_commitment: &FlatRingVec, + basis: BasisMode, + is_last: bool, + final_w: Option<&[F]>, + level_params: HachiLevelParams, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, + Cfg: CommitmentConfig, +{ + dispatch_ring_dim!(level_d, |D_LEVEL| { + let typed_commitment: RingCommitment = + current_commitment.try_to_ring_commitment()?; + let w_layout = + >::commitment_layout(opening_point.len())?; + verify_one_level::>( + level_proof, + setup, + transcript, + opening_point, + opening, + &typed_commitment, + basis, + is_last, + final_w, + level_params, + w_layout, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn dispatch_labrador_handoff_prove( + current_w: &[i8], + current_hint: &FlatCommitmentHint, + current_challenges: &[F], + current_num_u: usize, + current_num_l: usize, + current_commitment: &FlatRingVec, + setup: &HachiProverSetup, + handoff_ntt_d_cache: &mut MultiDNttCaches, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt + Valid, + T: Transcript, + Cfg: CommitmentConfig, +{ + let handoff_d = current_commitment.ring_dim(); + if current_hint.ring_dim() != handoff_d { + return Err(HachiError::InvalidInput(format!( + "handoff hint/commitment D mismatch: hint={}, commitment={handoff_d}", + current_hint.ring_dim() + ))); + } + + if handoff_d == D { + let typed_hint: HachiCommitmentHint = current_hint.to_typed(); + let typed_commitment: RingCommitment = current_commitment.to_ring_commitment(); + return labrador_handoff_prove::( + current_w, + &typed_hint, + &typed_commitment, + current_challenges, + current_num_u, + current_num_l, + &setup.expanded, + &setup.ntt_D, + transcript, + ); + } + + dispatch_with_d_ntt!( + handoff_d, + handoff_ntt_d_cache, + &setup.expanded, + |D_HANDOFF, ntt_d| { + let typed_hint: HachiCommitmentHint = current_hint.to_typed(); + let typed_commitment: RingCommitment = + current_commitment.to_ring_commitment(); + labrador_handoff_prove::( + current_w, + &typed_hint, + &typed_commitment, + current_challenges, + current_num_u, + current_num_l, + &setup.expanded, + ntt_d, + transcript, + ) + } + ) +} + +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn dispatch_labrador_handoff_verify( + tail: &LabradorTail, + opening_point: &[F], + opening: &F, + current_commitment: &FlatRingVec, + expanded_setup: &HachiExpandedSetup, + transcript: &mut T, +) -> Result<(), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt + Valid, + T: Transcript, + Cfg: CommitmentConfig, +{ + let handoff_d = current_commitment.ring_dim(); + if tail.v.ring_dim() != handoff_d + || tail.y_ring.ring_dim() != handoff_d + || tail.labrador_proof.levels.iter().any(|level| { + level.inner_opening_payload.ring_dim() != handoff_d + || level.linear_garbage_payload.ring_dim() != handoff_d + || level.jl_lift_residuals.ring_dim() != handoff_d + }) + || tail + .labrador_proof + .final_opening_witness + .rows + .iter() + .any(|row| row.ring_dim() != handoff_d) + { + return Err(HachiError::InvalidProof); + } + + dispatch_ring_dim!(handoff_d, |D_HANDOFF| { + let typed_commitment: RingCommitment = + current_commitment.try_to_ring_commitment()?; + labrador_handoff_verify::( + tail, + opening_point, + opening, + &typed_commitment, + expanded_setup, + transcript, + ) + }) +} + +/// Single subsequent (recursive) prove level, extracted so that the +/// dispatch match arms contain only a function call. +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn prove_subsequent_level( + expanded: &HachiExpandedSetup, + ntt_a: &NttSlotCache, + ntt_b: &NttSlotCache, + ntt_d: &NttSlotCache, + commit_ntt_bundle: &mut MultiDNttBundle, + max_num_vars: usize, + current_w: &[i8], + current_hint: &FlatCommitmentHint, + current_challenges: &[F], + current_num_u: usize, + current_num_l: usize, + last_w_commitment: &FlatRingVec, + #[cfg_attr(not(debug_assertions), allow(unused_variables))] last_w_eval: F, + transcript: &mut T, + level: usize, + level_params: HachiLevelParams, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + HasUnreducedOps + HasWide, + T: Transcript, + Cfg: CommitmentConfig, +{ + let w_poly = BalancedDigitPoly::::from_i8_digits(current_w)?; + let opening_point = next_level_opening_point(current_challenges, current_num_u, current_num_l); + + #[cfg(debug_assertions)] + { + let mut field_evals: Vec = current_w.iter().map(|&d| F::from_i8(d)).collect(); + field_evals.resize(w_poly.num_ring_elems() * D_LEVEL, F::zero()); + let direct_eval = multilinear_eval(&field_evals, &opening_point).unwrap(); + if last_w_eval != direct_eval { + tracing::error!( + level, + ring_elems = w_poly.num_ring_elems(), + field_len = field_evals.len(), + point_len = opening_point.len(), + "BUG: w_eval mismatch! prev_level w_eval != w_poly eval at opening_point" + ); + } else { + tracing::debug!(level, "w_eval consistency OK"); + } + } + + let w_commitment: RingCommitment = last_w_commitment.to_ring_commitment(); + let typed_hint: HachiCommitmentHint = current_hint.to_typed(); + + let commit_fn: CommitFn<'_, F> = Box::new( + |w: &[i8], + next_inputs: HachiScheduleInputs| + -> Result<(FlatRingVec, FlatCommitmentHint), HachiError> { + let next_params = Cfg::level_params(next_inputs); + if next_params.d == D_LEVEL { + let (wc, wh) = commit_w::>( + w, + ntt_a, + ntt_b, + next_params, + )?; + Ok(( + FlatRingVec::from_commitment(&wc), + FlatCommitmentHint::from_typed(wh), + )) + } else { + dispatch_commit::(next_params, commit_ntt_bundle, expanded, w) + } + }, + ); + + let w_layout = >::commitment_layout(opening_point.len())?; + prove_one_level::, _>( + expanded, + ntt_a, + ntt_b, + ntt_d, + commit_fn, + &w_poly, + max_num_vars, + &opening_point, + typed_hint, + transcript, + &w_commitment, + BasisMode::Lagrange, + level, + level_params, + w_layout, + ) +} + +impl CommitmentScheme for HachiCommitmentScheme +where + F: FieldCore + CanonicalField + FieldSampling + HasWide + HasUnreducedOps + Valid, + Cfg: CommitmentConfig, +{ + type ProverSetup = HachiProverSetup; + type VerifierSetup = HachiVerifierSetup; + type Commitment = RingCommitment; + type Proof = HachiProof; + type CommitHint = HachiCommitmentHint; + + #[tracing::instrument(skip_all, name = "HachiCommitmentScheme::setup_prover")] + fn setup_prover(max_num_vars: usize) -> Self::ProverSetup { + let (setup, _) = + >::setup(max_num_vars) + .expect("commitment setup failed"); + setup + } + + fn setup_verifier(setup: &Self::ProverSetup) -> Self::VerifierSetup { + HachiVerifierSetup { + expanded: setup.expanded.clone(), + } + } + + #[tracing::instrument(skip_all, name = "HachiCommitmentScheme::commit")] + fn commit>( + poly: &P, + setup: &Self::ProverSetup, + layout: &HachiCommitmentLayout, + ) -> Result<(Self::Commitment, Self::CommitHint), HachiError> { + setup.assert_layout_fits(layout); + let inner = poly.commit_inner_witness( + &setup.expanded.A, + &setup.ntt_A, + layout.block_len, + layout.num_digits_commit, + layout.num_digits_open, + layout.log_basis, + )?; + let inner_opening_digits_flat = flatten_i8_blocks(&inner.t_hat); + let mut u: Vec> = + mat_vec_mul_ntt_single_i8(&setup.ntt_B, &inner_opening_digits_flat); + let root_params = Cfg::level_params(HachiScheduleInputs { + max_num_vars: setup.expanded.seed.max_num_vars, + level: 0, + current_w_len: poly.num_ring_elems() * D, + }); + u.truncate(root_params.n_b); + let hint = HachiCommitmentHint::with_t(inner.t_hat, inner.t); + Ok((RingCommitment { u }, hint)) + } + + #[tracing::instrument(skip_all, name = "HachiCommitmentScheme::prove")] + fn prove, P: HachiPolyOps>( + setup: &Self::ProverSetup, + poly: &P, + opening_point: &[F], + hint: Self::CommitHint, + transcript: &mut T, + commitment: &Self::Commitment, + basis: BasisMode, + layout: &HachiCommitmentLayout, + ) -> Result { + let t_prove_total = Instant::now(); + let mut levels = Vec::new(); + + let mut ntt_bundle = MultiDNttBundle::new(); + let mut commit_ntt_bundle = MultiDNttBundle::new(); + let max_num_vars = setup.expanded.seed.max_num_vars; + let root_w_len = poly.num_ring_elems() * D; + let root_params = Cfg::level_params(HachiScheduleInputs { + max_num_vars, + level: 0, + current_w_len: root_w_len, + }); + + // Level 0: original polynomial with caller-provided layout. + // The w-commitment is produced at the next level's params, derived from + // public state once `w` has been built. + let commit_fn_0: CommitFn<'_, F> = Box::new( + |w: &[i8], + next_inputs: HachiScheduleInputs| + -> Result<(FlatRingVec, FlatCommitmentHint), HachiError> { + let next_params = Cfg::level_params(next_inputs); + if next_params.d == D { + let (wc, wh) = + commit_w::(w, &setup.ntt_A, &setup.ntt_B, next_params)?; + Ok(( + FlatRingVec::from_commitment(&wc), + FlatCommitmentHint::from_typed(wh), + )) + } else { + dispatch_commit::( + next_params, + &mut commit_ntt_bundle, + &setup.expanded, + w, + ) + } + }, + ); + let out = prove_one_level::( + &setup.expanded, + &setup.ntt_A, + &setup.ntt_B, + &setup.ntt_D, + commit_fn_0, + poly, + max_num_vars, + opening_point, + hint, + transcript, + commitment, + basis, + 0, + root_params, + *layout, + )?; + levels.push(out.level_proof); + + let mut prev_poly_len = poly.num_ring_elems() * D; + let mut current_w = out.w; + let mut current_hint = out.w_hint; + let mut current_challenges = out.sumcheck_challenges; + let mut current_num_u = out.num_u; + let mut current_num_l = out.num_l; + let mut level = 1usize; + + // Subsequent levels: recursive w-opening with WCommitmentConfig. + // Each level dispatches to the ring dimension from Cfg::d_at_level. + // The w-commitment is produced at the NEXT level's D. + while !should_stop_folding(current_w.len(), prev_poly_len) { + let level_params = Cfg::level_params(HachiScheduleInputs { + max_num_vars, + level, + current_w_len: current_w.len(), + }); + let level_d = level_params.d; + + let last_w_eval = levels.last().unwrap().stage2.next_w_eval; + let last_w_commitment = &levels.last().unwrap().stage2.next_w_commitment; + let out = dispatch_prove_level::( + level_d, + &mut ntt_bundle, + &setup.expanded, + &setup.ntt_A, + &setup.ntt_B, + &setup.ntt_D, + &mut commit_ntt_bundle, + max_num_vars, + ¤t_w, + ¤t_hint, + ¤t_challenges, + current_num_u, + current_num_l, + last_w_commitment, + last_w_eval, + transcript, + level, + level_params, + )?; + + levels.push(out.level_proof); + + prev_poly_len = current_w.len(); + current_w = out.w; + current_hint = out.w_hint; + current_challenges = out.sumcheck_challenges; + current_num_u = out.num_u; + current_num_l = out.num_l; + level += 1; + } + + tracing::info!( + levels = level, + elapsed_s = t_prove_total.elapsed().as_secs_f64(), + "hachi prove complete" + ); + + // let handoff_ring_dim = current_hint.ring_dim(); + let labrador_enabled = current_w.len() > Cfg::labrador_handoff_threshold() + // && handoff_ring_dim <= 64 + && std::env::var("HACHI_NO_LABRADOR").as_deref() != Ok("1"); + let final_w_basis = if level > 1 { + Cfg::w_log_basis() + } else { + Cfg::decomposition().log_basis + }; + + let tail = if labrador_enabled { + tracing::info!("labrador handoff started"); + dispatch_labrador_handoff_prove::( + ¤t_w, + ¤t_hint, + ¤t_challenges, + current_num_u, + current_num_l, + &levels.last().unwrap().stage2.next_w_commitment, + setup, + &mut commit_ntt_bundle.D_mat, + transcript, + )? + } else { + let final_w = PackedDigits::from_i8_digits(¤t_w, final_w_basis); + HachiProofTail::Direct(final_w) + }; + + Ok(HachiProof { levels, tail }) + } + + #[tracing::instrument(skip_all, name = "HachiCommitmentScheme::verify")] + fn verify>( + proof: &Self::Proof, + setup: &Self::VerifierSetup, + transcript: &mut T, + opening_point: &[F], + opening: &F, + commitment: &Self::Commitment, + basis: BasisMode, + layout: &HachiCommitmentLayout, + ) -> Result<(), HachiError> { + if proof.levels.is_empty() { + return Err(HachiError::InvalidProof); + } + let t_verify_hachi = Instant::now(); + + let num_levels = proof.levels.len(); + let has_handoff_tail = proof.has_handoff_tail(); + + let final_w_elems: Option> = match &proof.tail { + HachiProofTail::Direct(pw) => Some(pw.to_field_elems()), + HachiProofTail::Labrador(_) => None, + }; + + // State carried between levels. + // Commitment is D-erased so the loop can handle varying D per level. + let mut current_point = opening_point.to_vec(); + let mut current_opening = *opening; + let mut current_commitment = FlatRingVec::from_commitment(commitment); + let mut current_basis = basis; + let max_num_vars = setup.expanded.seed.max_num_vars; + let mut current_w_len = 1usize << max_num_vars; + + for (i, level_proof) in proof.levels.iter().enumerate() { + let is_last_hachi = i == num_levels - 1; + // With a handoff tail, the last Hachi level is NOT the + // final level -- verification continues in Labrador. + let is_last = is_last_hachi && !has_handoff_tail; + let level_params = Cfg::level_params(HachiScheduleInputs { + max_num_vars, + level: i, + current_w_len, + }); + let level_d = level_params.d; + let current_layout = if i == 0 { + *layout + } else { + >::commitment_layout(current_point.len())? + }; + if level_proof.level_d() != level_d || current_commitment.ring_dim() != level_d { + return Err(HachiError::InvalidProof); + } + tracing::debug!( + level = i, + is_last, + point_len = current_point.len(), + D = level_d, + "verify level" + ); + + let fw_ref = final_w_elems.as_deref(); + let challenges = if i == 0 { + let typed_commitment: RingCommitment = + current_commitment.try_to_ring_commitment()?; + verify_one_level::( + level_proof, + setup, + transcript, + ¤t_point, + ¤t_opening, + &typed_commitment, + current_basis, + is_last, + if is_last { fw_ref } else { None }, + level_params, + current_layout, + )? + } else { + dispatch_verify_level::( + level_d, + level_proof, + setup, + transcript, + ¤t_point, + ¤t_opening, + ¤t_commitment, + current_basis, + is_last, + if is_last { fw_ref } else { None }, + level_params, + )? + }; + + if !is_last { + let alpha_bits = level_d.trailing_zeros() as usize; + let num_l = alpha_bits; + let num_u = challenges.len() - num_l; + let next_w_len = w_ring_element_count::(level_params, current_layout) * level_d; + + if i + 1 < num_levels { + let next_level_d = Cfg::level_params(HachiScheduleInputs { + max_num_vars, + level: i + 1, + current_w_len: next_w_len, + }) + .d; + if level_proof.w_commit_d() != next_level_d { + return Err(HachiError::InvalidProof); + } + } + current_point = next_level_opening_point(&challenges, num_u, num_l); + current_opening = level_proof.stage2.next_w_eval; + current_commitment = level_proof.stage2.next_w_commitment.clone(); + current_basis = BasisMode::Lagrange; + current_w_len = next_w_len; + } + } + + tracing::info!( + levels = num_levels, + elapsed_s = t_verify_hachi.elapsed().as_secs_f64(), + "hachi verify complete" + ); + + match &proof.tail { + HachiProofTail::Labrador(ref tail) => { + dispatch_labrador_handoff_verify::( + tail, + ¤t_point, + ¤t_opening, + ¤t_commitment, + &setup.expanded, + transcript, + )?; + } + HachiProofTail::Direct(_) => {} + } + + Ok(()) + } + + fn protocol_name() -> &'static [u8] { + unimplemented!() + } +} + +/// Verify one fold level. +/// +/// At the final level, `final_w` is provided and the verifier checks w_val +/// from it directly. At intermediate levels, `level_proof.stage2.next_w_eval` is used. +/// +/// Returns the sumcheck challenges for chaining into the next level. +#[allow(clippy::too_many_arguments)] +#[inline(never)] +fn verify_one_level( + level_proof: &HachiLevelProof, + setup: &HachiVerifierSetup, + transcript: &mut T, + opening_point: &[F], + opening: &F, + commitment: &RingCommitment, + basis: BasisMode, + is_last: bool, + final_w: Option<&[F]>, + level_params: HachiLevelParams, + layout: HachiCommitmentLayout, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, + Cfg: CommitmentConfig, +{ + let y_ring: CyclotomicRing = level_proof.try_y_ring_typed()?; + let v_typed: Vec> = level_proof.try_v_typed()?; + + let alpha_bits = level_params.d.trailing_zeros() as usize; + if opening_point.len() < alpha_bits { + return Err(HachiError::InvalidSetup( + "opening point length underflow".to_string(), + )); + } + let target_num_vars = layout.m_vars + layout.r_vars + alpha_bits; + let mut padded_point = opening_point.to_vec(); + padded_point.resize(target_num_vars, F::zero()); + let inner_point = &padded_point[..alpha_bits]; + let reduced_opening_point = &padded_point[alpha_bits..]; + + commitment.append_to_transcript(ABSORB_COMMITMENT, transcript); + for pt in &padded_point { + transcript.append_field(ABSORB_EVALUATION_CLAIMS, pt); + } + transcript.append_serde(ABSORB_EVALUATION_CLAIMS, &y_ring); + + let v = reduce_inner_opening_to_ring_element::(inner_point, basis)?; + let d = F::from_u64(level_params.d as u64); + let trace_lhs = trace::(&(y_ring * v.sigma_m1())); + let trace_rhs = d * *opening; + if trace_lhs != trace_rhs { + return Err(HachiError::InvalidProof); + } + + let ring_opening_point = ring_opening_point_from_field::( + reduced_opening_point, + layout.r_vars, + layout.m_vars, + basis, + )?; + let quad_eq = Box::new(QuadraticEquation::::new_verifier( + ring_opening_point, + v_typed.clone(), + level_params, + transcript, + commitment, + &y_ring, + layout, + )?); + + let w_len = if is_last { + final_w.map_or(0, |fw| fw.len()) + } else { + w_ring_element_count::(level_params, layout) * D + }; + tracing::debug!(w_len, is_last, "verify ring_switch"); + + let rs = ring_switch_verifier::( + &quad_eq, + &setup.expanded, + w_len, + &level_proof.stage2.next_w_commitment, + transcript, + level_params, + layout, + )?; + + let stage1_verifier = + HachiStage1Verifier::new(rs.tau0.clone(), level_proof.stage1.s_claim, rs.b); + let r_stage1 = { + let _sumcheck_span = tracing::info_span!("stage1_sumcheck").entered(); + verify_sumcheck::( + &level_proof.stage1.sumcheck, + &stage1_verifier, + transcript, + |tr| tr.challenge_scalar(CHALLENGE_SUMCHECK_ROUND), + )? + }; + + transcript.append_serde(ABSORB_SUMCHECK_S_CLAIM, &level_proof.stage1.s_claim); + let batching_coeff: F = transcript.challenge_scalar(CHALLENGE_SUMCHECK_BATCH); + let relation_claim = + relation_claim_from_rows(&rs.tau1, rs.alpha, &v_typed, &commitment.u, &y_ring); + let stage2_input_claim = batching_coeff * level_proof.stage1.s_claim + relation_claim; + + let stage2_verifier = if is_last { + let fw = final_w.ok_or(HachiError::InvalidProof)?; + let (w_evals_full, _, _) = build_w_evals(fw, level_params.d)?; + HachiStage2Verifier::new_with_full_witness( + batching_coeff, + level_proof.stage1.s_claim, + w_evals_full, + r_stage1.clone(), + rs.alpha_evals_y, + rs.m_evals_x, + rs.tau1, + v_typed, + commitment.u.clone(), + y_ring, + rs.alpha, + rs.num_u, + rs.num_l, + ) + } else { + HachiStage2Verifier::new_with_claimed_w_eval( + batching_coeff, + level_proof.stage1.s_claim, + level_proof.stage2.next_w_eval, + r_stage1.clone(), + rs.alpha_evals_y, + rs.m_evals_x, + rs.tau1, + v_typed, + commitment.u.clone(), + y_ring, + rs.alpha, + rs.num_u, + rs.num_l, + ) + }; + if stage2_input_claim != SumcheckInstanceVerifier::input_claim(&stage2_verifier) { + return Err(HachiError::InvalidProof); + } + + let challenges = { + let _sumcheck_span = tracing::info_span!("stage2_sumcheck").entered(); + verify_sumcheck::( + &level_proof.stage2.sumcheck, + &stage2_verifier, + transcript, + |tr| tr.challenge_scalar(CHALLENGE_SUMCHECK_ROUND), + )? + }; + + Ok(challenges) +} + +fn trace(u: &CyclotomicRing) -> F { + let d = F::from_u64(D as u64); + u.coefficients()[0] * d +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::primitives::serialization::Compress; + use crate::protocol::commitment::CommitmentConfig; + use crate::protocol::hachi_poly_ops::DensePoly; + use crate::protocol::opening_point::{lagrange_weights, monomial_weights}; + use crate::protocol::transcript::Blake2bTranscript; + use crate::test_utils::F; + use crate::{CommitmentScheme, FromSmallInt}; + use std::sync::OnceLock; + + type Cfg = SmallTestCommitmentConfig; + const D: usize = Cfg::D; + type Scheme = HachiCommitmentScheme; + + fn make_dense_poly(num_vars: usize) -> (DensePoly, Vec) { + let len = 1usize << num_vars; + let evals: Vec = (0..len).map(|i| F::from_u64(i as u64)).collect(); + let poly = DensePoly::::from_field_evals(num_vars, &evals).unwrap(); + (poly, evals) + } + + fn make_verify_fixture( + num_vars: usize, + ) -> ( + HachiVerifierSetup, + RingCommitment, + HachiProof, + Vec, + F, + HachiCommitmentLayout, + ) { + let alpha = D.trailing_zeros() as usize; + let layout = Cfg::commitment_layout(num_vars).unwrap(); + let full_num_vars = layout.m_vars + layout.r_vars + alpha; + + let (poly, evals) = make_dense_poly(full_num_vars); + let setup = >::setup_prover(full_num_vars); + let verifier_setup = >::setup_verifier(&setup); + let (commitment, hint) = + >::commit(&poly, &setup, &layout).unwrap(); + + let opening_point: Vec = (0..full_num_vars) + .map(|i| F::from_u64((i + 2) as u64)) + .collect(); + let lw = lagrange_weights(&opening_point); + let opening: F = evals + .iter() + .zip(lw.iter()) + .fold(F::zero(), |a, (&c, &w)| a + c * w); + + let mut prover_transcript = Blake2bTranscript::::new(b"test/prove"); + let proof = >::prove( + &setup, + &poly, + &opening_point, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + ( + verifier_setup, + commitment, + proof, + opening_point, + opening, + layout, + ) + } + + fn serialize_uncompressed_proof(proof: &HachiProof) -> Vec { + let mut bytes = Vec::new(); + proof.serialize_uncompressed(&mut bytes).unwrap(); + bytes + } + + fn level0_next_w_commitment_ring_dim_offset(proof: &HachiProof) -> usize { + let level0 = &proof.levels[0]; + 4 + level0.y_ring.serialized_size(Compress::No) + + level0.v.serialized_size(Compress::No) + + level0.stage1.sumcheck.serialized_size(Compress::No) + + level0.stage1.s_claim.serialized_size(Compress::No) + + level0.stage2.sumcheck.serialized_size(Compress::No) + } + + #[test] + fn verify_passes_for_consistent_opening() { + let alpha = D.trailing_zeros() as usize; + let layout = Cfg::commitment_layout(16).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha; + + let (poly, evals) = make_dense_poly(num_vars); + + let setup = >::setup_prover(num_vars); + let verifier_setup = >::setup_verifier(&setup); + + let (commitment, hint) = + >::commit(&poly, &setup, &layout).unwrap(); + + let opening_point: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + let lw = lagrange_weights(&opening_point); + let opening: F = evals + .iter() + .zip(lw.iter()) + .fold(F::zero(), |a, (&c, &w)| a + c * w); + + let mut prover_transcript = Blake2bTranscript::::new(b"test/prove"); + let proof = >::prove( + &setup, + &poly, + &opening_point, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + let mut verifier_transcript = Blake2bTranscript::::new(b"test/prove"); + let result = >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ); + + assert!(result.is_ok()); + } + + #[test] + fn verify_rejects_wrong_opening() { + let alpha = D.trailing_zeros() as usize; + let layout = Cfg::commitment_layout(16).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha; + + let (poly, evals) = make_dense_poly(num_vars); + + let setup = >::setup_prover(num_vars); + let verifier_setup = >::setup_verifier(&setup); + + let (commitment, hint) = + >::commit(&poly, &setup, &layout).unwrap(); + + let opening_point: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + let lw = lagrange_weights(&opening_point); + let opening: F = evals + .iter() + .zip(lw.iter()) + .fold(F::zero(), |a, (&c, &w)| a + c * w); + + let mut prover_transcript = Blake2bTranscript::::new(b"test/prove"); + let proof = >::prove( + &setup, + &poly, + &opening_point, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + let wrong_opening = opening + F::one(); + let mut verifier_transcript = Blake2bTranscript::::new(b"test/prove"); + let result = >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &wrong_opening, + &commitment, + BasisMode::Lagrange, + &layout, + ); + + assert!( + result.is_err(), + "verify must reject an incorrect opening value" + ); + } + + #[test] + fn verify_rejects_malformed_y_ring_dimension_without_panicking() { + let (verifier_setup, commitment, proof, opening_point, opening, layout) = + make_verify_fixture(16); + let mut bytes = serialize_uncompressed_proof(&proof); + let bad_d = if D == 1 { 2 } else { 1 }; + bytes[4..8].copy_from_slice(&(bad_d as u32).to_le_bytes()); + let malformed = HachiProof::::deserialize_uncompressed_unchecked(&bytes[..]).unwrap(); + + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + let mut verifier_transcript = Blake2bTranscript::::new(b"test/prove"); + >::verify( + &malformed, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ) + })); + + assert!(matches!(result, Ok(Err(HachiError::InvalidProof)))); + } + + #[test] + fn verify_rejects_malformed_next_commitment_dimension_without_panicking() { + let HandoffFixture { + verifier_setup, + commitment, + proof, + opening_point, + opening, + layout, + } = handoff_fixture(); + let mut bytes = serialize_uncompressed_proof(&proof); + let offset = level0_next_w_commitment_ring_dim_offset(&proof); + let current_d = proof.levels[0].w_commit_d(); + let bad_d = if current_d == 1 { 2 } else { 1 }; + bytes[offset..offset + 4].copy_from_slice(&(bad_d as u32).to_le_bytes()); + let malformed = + HachiProof::::deserialize_uncompressed_unchecked(&bytes[..]).unwrap(); + + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + let mut verifier_transcript = + Blake2bTranscript::::new(HANDOFF_FIXTURE_LABEL); + >::verify( + &malformed, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ) + })); + + assert!(matches!(result, Ok(Err(HachiError::InvalidProof)))); + } + + #[test] + fn monomial_basis_prove_verify_round_trip() { + let alpha = D.trailing_zeros() as usize; + let layout = Cfg::commitment_layout(16).unwrap(); + let num_vars = layout.m_vars + layout.r_vars + alpha; + let len = 1usize << num_vars; + + let coeffs: Vec = (0..len).map(|i| F::from_u64(i as u64)).collect(); + let poly = DensePoly::::from_field_evals(num_vars, &coeffs).unwrap(); + + let setup = >::setup_prover(num_vars); + let verifier_setup = >::setup_verifier(&setup); + + let (commitment, hint) = + >::commit(&poly, &setup, &layout).unwrap(); + + let opening_point: Vec = (0..num_vars).map(|i| F::from_u64((i + 2) as u64)).collect(); + + let mw = monomial_weights(&opening_point); + let opening: F = coeffs + .iter() + .zip(mw.iter()) + .fold(F::zero(), |acc, (&c, &w)| acc + c * w); + + let mut prover_transcript = Blake2bTranscript::::new(b"test/monomial"); + let proof = >::prove( + &setup, + &poly, + &opening_point, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Monomial, + &layout, + ) + .unwrap(); + + let mut verifier_transcript = Blake2bTranscript::::new(b"test/monomial"); + let result = >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Monomial, + &layout, + ); + + assert!( + result.is_ok(), + "monomial-basis proof should verify: {result:?}" + ); + } + + /// A config identical to `DynamicSmallTestCommitmentConfig` but with a + /// handoff threshold of 0 (always hand off to Labrador). + #[derive(Clone, Copy, Debug, Default)] + struct HandoffTestConfig; + + impl CommitmentConfig for HandoffTestConfig { + const D: usize = 64; + const N_A: usize = 8; + const N_B: usize = 4; + const N_D: usize = 4; + const CHALLENGE_WEIGHT: usize = 3; + + fn decomposition() -> crate::protocol::commitment::DecompositionParams { + crate::protocol::commitment::DecompositionParams { + log_basis: 3, + log_commit_bound: 32, + log_open_bound: Some(128), + } + } + + fn commitment_layout( + max_num_vars: usize, + ) -> Result + { + let alpha = Self::D.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + crate::error::HachiError::InvalidSetup( + "max_num_vars is smaller than alpha".to_string(), + ) + })?; + if reduced_vars == 0 { + return Err(crate::error::HachiError::InvalidSetup( + "need at least 1 reduced variable".to_string(), + )); + } + let m_vars = reduced_vars.div_ceil(2); + let r_vars = reduced_vars - m_vars; + crate::protocol::commitment::HachiCommitmentLayout::new::( + m_vars, + r_vars, + &Self::decomposition(), + ) + } + + fn labrador_handoff_threshold() -> usize { + 0 + } + } + + type HandoffField = crate::algebra::Fp128<0xfffffffffffffffffffffffffffffeed>; + type HandoffScheme = HachiCommitmentScheme<{ HandoffTestConfig::D }, HandoffTestConfig>; + const HANDOFF_FIXTURE_LABEL: &[u8] = b"test/labrador-tail-fixture"; + const HANDOFF_SPLICE_LABEL: &[u8] = b"test/labrador-tail-splice"; + + #[derive(Clone)] + struct HandoffFixture { + verifier_setup: HachiVerifierSetup, + layout: HachiCommitmentLayout, + commitment: RingCommitment, + opening_point: Vec, + opening: HandoffField, + proof: HachiProof, + } + + #[derive(Clone, Copy, Debug, Default)] + struct VariableDHandoffTestConfig; + + impl CommitmentConfig for VariableDHandoffTestConfig { + const D: usize = 256; + const N_A: usize = 1; + const N_B: usize = 1; + const N_D: usize = 1; + const CHALLENGE_WEIGHT: usize = 23; + + fn decomposition() -> crate::protocol::commitment::DecompositionParams { + crate::protocol::commitment::Fp128HalvingDCommitmentConfig::decomposition() + } + + fn commitment_layout( + max_num_vars: usize, + ) -> Result + { + let alpha = Self::D.trailing_zeros() as usize; + let reduced_vars = max_num_vars.checked_sub(alpha).ok_or_else(|| { + crate::error::HachiError::InvalidSetup( + "max_num_vars is smaller than alpha".to_string(), + ) + })?; + if reduced_vars == 0 { + return Err(crate::error::HachiError::InvalidSetup( + "max_num_vars must leave at least one outer variable".to_string(), + )); + } + let m_vars = reduced_vars.div_ceil(2); + let r_vars = reduced_vars - m_vars; + crate::protocol::commitment::HachiCommitmentLayout::new::( + m_vars, + r_vars, + &Self::decomposition(), + ) + } + + fn d_at_level(level: usize, _w_num_vars: usize) -> usize { + match level { + 0 => 256, + _ => 128, + } + } + + fn n_a_at_level(level: usize) -> usize { + match level { + 0 => 1, + _ => 2, + } + } + + fn challenge_weight_for_ring_dim(d: usize) -> usize { + match d { + 256 => 23, + 128 => 31, + _ => panic!("VariableDHandoffTestConfig: unsupported ring dim {d}"), + } + } + + fn labrador_handoff_threshold() -> usize { + 0 + } + } + + fn purge_test_setup_cache(_max_num_vars: usize) { + #[cfg(feature = "disk-persistence")] + { + let cache_dir = std::env::var("LOCALAPPDATA") + .map(std::path::PathBuf::from) + .or_else(|_| { + std::env::var("HOME").map(|home| { + let mut p = std::path::PathBuf::from(&home); + if p.join("Library/Caches").exists() { + p.push("Library/Caches"); + } else { + p.push(".cache"); + } + p + }) + }); + if let Ok(mut path) = cache_dir { + path.push("hachi"); + path.push(format!("hachi_{_max_num_vars}.setup")); + let _ = std::fs::remove_file(&path); + } + } + } + + fn make_handoff_fixture(eval_offset: u64, transcript_label: &[u8]) -> HandoffFixture { + const MAX_NUM_VARS: usize = 11; + const D: usize = HandoffTestConfig::D; + + let layout = HandoffTestConfig::commitment_layout(MAX_NUM_VARS).unwrap(); + let alpha = D.trailing_zeros() as usize; + let num_vars = layout.m_vars + layout.r_vars + alpha; + + purge_test_setup_cache(num_vars); + + let len = 1usize << num_vars; + let evals: Vec = (0..len) + .map(|i| HandoffField::from_u64(i as u64 + eval_offset)) + .collect(); + let poly = DensePoly::::from_field_evals(num_vars, &evals).unwrap(); + + let setup = >::setup_prover(num_vars); + let verifier_setup = + >::setup_verifier(&setup); + + let (commitment, hint) = + >::commit(&poly, &setup, &layout) + .unwrap(); + + let opening_point: Vec = (0..num_vars) + .map(|i| HandoffField::from_u64((i + 2) as u64)) + .collect(); + let lw = lagrange_weights(&opening_point); + let opening = evals + .iter() + .zip(lw.iter()) + .fold(HandoffField::zero(), |a, (&c, &w)| a + c * w); + + let mut prover_transcript = Blake2bTranscript::::new(transcript_label); + let proof = >::prove( + &setup, + &poly, + &opening_point, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + HandoffFixture { + verifier_setup, + layout, + commitment, + opening_point, + opening, + proof, + } + } + + fn handoff_fixture() -> HandoffFixture { + static FIXTURE: OnceLock = OnceLock::new(); + FIXTURE + .get_or_init(|| make_handoff_fixture(0, HANDOFF_FIXTURE_LABEL)) + .clone() + } + + fn handoff_splice_fixture_a() -> HandoffFixture { + static FIXTURE: OnceLock = OnceLock::new(); + FIXTURE + .get_or_init(|| make_handoff_fixture(0, HANDOFF_SPLICE_LABEL)) + .clone() + } + + fn handoff_splice_fixture_b() -> HandoffFixture { + static FIXTURE: OnceLock = OnceLock::new(); + FIXTURE + .get_or_init(|| make_handoff_fixture(17, HANDOFF_SPLICE_LABEL)) + .clone() + } + + fn mutate_labrador_tail_fixture( + mutator: impl FnOnce(&mut LabradorTail), + ) -> Option { + let mut fixture = handoff_fixture(); + let HachiProofTail::Labrador(tail) = &mut fixture.proof.tail else { + return None; + }; + mutator(tail); + Some(fixture) + } + + #[test] + fn labrador_tail_prove_verify_round_trip() { + let HandoffFixture { + verifier_setup, + layout, + commitment, + opening_point, + opening, + proof, + } = handoff_fixture(); + + let mut verifier_transcript = Blake2bTranscript::::new(HANDOFF_FIXTURE_LABEL); + let result = + >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ); + + assert!(result.is_ok(), "handoff proof should verify: {result:?}"); + } + + #[test] + fn labrador_tail_serialization_round_trip() { + let HandoffFixture { + verifier_setup, + layout, + commitment, + opening_point, + opening, + proof, + } = handoff_fixture(); + + let mut bytes = Vec::new(); + proof.serialize_uncompressed(&mut bytes).unwrap(); + assert_eq!(bytes.len(), proof.size()); + assert_eq!(bytes.len(), proof.serialized_size(Compress::No)); + + let decoded = HachiProof::::deserialize_uncompressed(&bytes[..]).unwrap(); + assert_eq!(decoded, proof); + + let mut verifier_transcript = Blake2bTranscript::::new(HANDOFF_FIXTURE_LABEL); + let result = + >::verify( + &decoded, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ); + assert!( + result.is_ok(), + "round-tripped proof should verify: {result:?}" + ); + } + + #[test] + fn labrador_tail_rejects_spliced_or_mutated_payloads() { + let HandoffFixture { + verifier_setup, + layout, + commitment, + opening_point, + opening, + mut proof, + } = handoff_splice_fixture_a(); + let proof_b = handoff_splice_fixture_b().proof; + + proof.tail = proof_b.tail.clone(); + let mut verifier_transcript = Blake2bTranscript::::new(HANDOFF_SPLICE_LABEL); + let result = + >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ); + assert!(result.is_err(), "spliced handoff tail must be rejected"); + + let Some(HandoffFixture { + verifier_setup, + layout, + commitment, + opening_point, + opening, + proof, + }) = mutate_labrador_tail_fixture(|tail| { + let mut y_ring = tail.y_ring.to_single::<{ HandoffTestConfig::D }>(); + y_ring.coefficients_mut()[0] += HandoffField::one(); + tail.y_ring = FlatRingVec::from_single(&y_ring); + }) + else { + return; + }; + + let mut verifier_transcript = Blake2bTranscript::::new(HANDOFF_FIXTURE_LABEL); + let result = + >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ); + assert!(result.is_err(), "modified y_ring must be rejected"); + } + + #[test] + fn labrador_tail_rejects_malformed_tail_metadata() { + let Some(HandoffFixture { + verifier_setup, + layout, + commitment, + opening_point, + opening, + proof, + }) = mutate_labrador_tail_fixture(|tail| { + let last_level = tail + .labrador_proof + .levels + .last_mut() + .expect("tail proof should contain a Labrador level"); + last_level.config.tail = false; + }) + else { + return; + }; + + let mut verifier_transcript = Blake2bTranscript::::new(HANDOFF_FIXTURE_LABEL); + let result = + >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ); + assert!(result.is_err(), "tail/config mismatch must be rejected"); + + let Some(HandoffFixture { + verifier_setup, + layout, + commitment, + opening_point, + opening, + proof, + }) = mutate_labrador_tail_fixture(|tail| { + let last_level = tail + .labrador_proof + .levels + .last_mut() + .expect("tail proof should contain a Labrador level"); + last_level.jl_nonce = + crate::protocol::labrador::guardrails::LABRADOR_MAX_JL_NONCE_RETRIES + 1; + }) + else { + return; + }; + + let mut verifier_transcript = Blake2bTranscript::::new(HANDOFF_FIXTURE_LABEL); + let result = + >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ); + assert!(result.is_err(), "oversized JL nonce must be rejected"); + + let Some(HandoffFixture { + verifier_setup, + layout, + commitment, + opening_point, + opening, + proof, + }) = mutate_labrador_tail_fixture(|tail| { + let last_level = tail + .labrador_proof + .levels + .last_mut() + .expect("tail proof should contain a Labrador level"); + last_level.virtual_row_len = 1; + }) + else { + return; + }; + + let mut verifier_transcript = Blake2bTranscript::::new(HANDOFF_FIXTURE_LABEL); + let result = + >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ); + assert!(result.is_err(), "lossy reshape metadata must be rejected"); + } + + #[test] + fn variable_d_handoff_uses_current_commitment_dimension() { + std::thread::Builder::new() + .stack_size(256 * 1024 * 1024) + .spawn(|| { + type VarScheme = HachiCommitmentScheme< + { VariableDHandoffTestConfig::D }, + VariableDHandoffTestConfig, + >; + type GF = HandoffField; + + const MAX_NUM_VARS: usize = 10; + let layout = VariableDHandoffTestConfig::commitment_layout(MAX_NUM_VARS).unwrap(); + let alpha = VariableDHandoffTestConfig::D.trailing_zeros() as usize; + let num_vars = layout.m_vars + layout.r_vars + alpha; + purge_test_setup_cache(num_vars); + + let len = 1usize << num_vars; + let evals: Vec = (0..len).map(|i| GF::from_u64(i as u64)).collect(); + let poly = DensePoly::::from_field_evals( + num_vars, &evals, + ) + .unwrap(); + + let setup = >::setup_prover(num_vars); + let (commitment, hint) = >::commit(&poly, &setup, &layout) + .unwrap(); + + let opening_point: Vec = (0..num_vars) + .map(|i| GF::from_u64((i + 2) as u64)) + .collect(); + let lw = lagrange_weights(&opening_point); + let _opening = evals + .iter() + .zip(lw.iter()) + .fold(GF::zero(), |a, (&c, &w)| a + c * w); + + let mut prover_transcript = + Blake2bTranscript::::new(b"test/variable-d-labrador-tail"); + let proof = + >::prove( + &setup, + &poly, + &opening_point, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + let verifier_setup = >::setup_verifier(&setup); + let opening = evals + .iter() + .zip(lw.iter()) + .fold(GF::zero(), |a, (&c, &w)| a + c * w); + let mut verifier_transcript = + Blake2bTranscript::::new(b"test/variable-d-labrador-tail"); + >::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + let carried_d = proof + .levels + .last() + .expect("expected at least one Hachi level") + .w_commit_d(); + if let HachiProofTail::Labrador(tail) = &proof.tail { + assert_eq!(tail.v.ring_dim(), carried_d); + assert_ne!(tail.v.ring_dim(), 64); + } + }) + .expect("failed to spawn variable-D handoff test") + .join() + .expect("variable-D handoff test panicked"); + } +} diff --git a/src/protocol/dispatch.rs b/src/protocol/dispatch.rs new file mode 100644 index 00000000..9fb22850 --- /dev/null +++ b/src/protocol/dispatch.rs @@ -0,0 +1,183 @@ +//! Runtime-to-const-generic dispatch for ring dimension D. +//! +//! The supported D values (all powers of 2 that admit a CRT+NTT decomposition) +//! are: 64, 128, 256, 512, 1024. + +/// Bridge a runtime `d: usize` to a const-generic `D` context. +/// +/// Calls `$body` with the matched const `D`. Inside `$body`, `D` is available +/// as a const generic parameter (via the generated function). +/// +/// # Supported dimensions +/// +/// 64, 128, 256, 512, 1024. +/// +/// # Panics +/// +/// Panics at runtime if `d` is not one of the supported values. +/// +/// # Examples +/// +/// ``` +/// use hachi_pcs::dispatch_ring_dim; +/// let ring_dim: usize = 256; +/// let result = dispatch_ring_dim!(ring_dim, |D| D * 2); +/// assert_eq!(result, 512); +/// ``` +#[macro_export] +macro_rules! dispatch_ring_dim { + ($d:expr, |$D:ident| $body:expr) => {{ + let __d = $d; + match __d { + 64 => { + const $D: usize = 64; + $body + } + 128 => { + const $D: usize = 128; + $body + } + 256 => { + const $D: usize = 256; + $body + } + 512 => { + const $D: usize = 512; + $body + } + 1024 => { + const $D: usize = 1024; + $body + } + _ => panic!("unsupported ring dimension: {__d}"), + } + }}; +} + +/// Like [`dispatch_ring_dim!`] but also lazily builds NTT caches for the +/// matched ring dimension from a [`crate::protocol::commitment::utils::ntt_cache::MultiDNttBundle`] and +/// [`crate::protocol::commitment::HachiExpandedSetup`]. +/// +/// Inside the body, `$D` is a const ring dimension and `$ntt_a`, `$ntt_b`, +/// `$ntt_d` are `&NttSlotCache` references. +/// +/// # Panics +/// +/// Panics at runtime if `d` is not one of the supported values. +#[macro_export] +macro_rules! dispatch_with_ntt { + ($d:expr, $ntt:expr, $expanded:expr, + |$D:ident, $ntt_a:ident, $ntt_b:ident, $ntt_d:ident| $body:expr) => {{ + let __d = $d; + match __d { + 64 => { + const $D: usize = 64; + let $ntt_a = ($ntt).A.get_or_build_64(&($expanded).A)?; + let $ntt_b = ($ntt).B.get_or_build_64(&($expanded).B)?; + let $ntt_d = ($ntt).D_mat.get_or_build_64(&($expanded).D_mat)?; + $body + } + 128 => { + const $D: usize = 128; + let $ntt_a = ($ntt).A.get_or_build_128(&($expanded).A)?; + let $ntt_b = ($ntt).B.get_or_build_128(&($expanded).B)?; + let $ntt_d = ($ntt).D_mat.get_or_build_128(&($expanded).D_mat)?; + $body + } + 256 => { + const $D: usize = 256; + let $ntt_a = ($ntt).A.get_or_build_256(&($expanded).A)?; + let $ntt_b = ($ntt).B.get_or_build_256(&($expanded).B)?; + let $ntt_d = ($ntt).D_mat.get_or_build_256(&($expanded).D_mat)?; + $body + } + 512 => { + const $D: usize = 512; + let $ntt_a = ($ntt).A.get_or_build_512(&($expanded).A)?; + let $ntt_b = ($ntt).B.get_or_build_512(&($expanded).B)?; + let $ntt_d = ($ntt).D_mat.get_or_build_512(&($expanded).D_mat)?; + $body + } + 1024 => { + const $D: usize = 1024; + let $ntt_a = ($ntt).A.get_or_build_1024(&($expanded).A)?; + let $ntt_b = ($ntt).B.get_or_build_1024(&($expanded).B)?; + let $ntt_d = ($ntt).D_mat.get_or_build_1024(&($expanded).D_mat)?; + $body + } + _ => panic!("unsupported ring dimension: {__d}"), + } + }}; +} + +/// Like [`dispatch_ring_dim!`] but lazily builds only the `D_mat` NTT cache for +/// the matched ring dimension from a +/// [`crate::protocol::commitment::utils::ntt_cache::MultiDNttCaches`] and +/// [`crate::protocol::commitment::HachiExpandedSetup`]. +/// +/// Inside the body, `$D` is a const ring dimension and `$ntt_d` is a +/// `&NttSlotCache` reference. +/// +/// # Panics +/// +/// Panics at runtime if `d` is not one of the supported values. +#[macro_export] +macro_rules! dispatch_with_d_ntt { + ($d:expr, $ntt_d_cache:expr, $expanded:expr, |$D:ident, $ntt_d:ident| $body:expr) => {{ + let __d = $d; + match __d { + 64 => { + const $D: usize = 64; + let $ntt_d = ($ntt_d_cache).get_or_build_64(&($expanded).D_mat)?; + $body + } + 128 => { + const $D: usize = 128; + let $ntt_d = ($ntt_d_cache).get_or_build_128(&($expanded).D_mat)?; + $body + } + 256 => { + const $D: usize = 256; + let $ntt_d = ($ntt_d_cache).get_or_build_256(&($expanded).D_mat)?; + $body + } + 512 => { + const $D: usize = 512; + let $ntt_d = ($ntt_d_cache).get_or_build_512(&($expanded).D_mat)?; + $body + } + 1024 => { + const $D: usize = 1024; + let $ntt_d = ($ntt_d_cache).get_or_build_1024(&($expanded).D_mat)?; + $body + } + _ => panic!("unsupported ring dimension: {__d}"), + } + }}; +} + +/// The set of supported ring dimensions for [`dispatch_ring_dim!`]. +pub const SUPPORTED_RING_DIMS: &[usize] = &[64, 128, 256, 512, 1024]; + +/// Returns true if `d` is one of the [`SUPPORTED_RING_DIMS`]. +#[inline] +pub fn is_supported_ring_dim(d: usize) -> bool { + SUPPORTED_RING_DIMS.contains(&d) +} + +#[cfg(test)] +mod tests { + #[test] + fn dispatch_ring_dim_basic() { + for &d in super::SUPPORTED_RING_DIMS { + let result = dispatch_ring_dim!(d, |D| D); + assert_eq!(result, d); + } + } + + #[test] + #[should_panic(expected = "unsupported ring dimension")] + fn dispatch_ring_dim_unsupported_panics() { + let _ = dispatch_ring_dim!(42, |D| D); + } +} diff --git a/src/protocol/hachi_poly_ops/decompose_fold_neon.rs b/src/protocol/hachi_poly_ops/decompose_fold_neon.rs new file mode 100644 index 00000000..0ee84394 --- /dev/null +++ b/src/protocol/hachi_poly_ops/decompose_fold_neon.rs @@ -0,0 +1,143 @@ +//! AArch64 NEON kernel for sparse-multiply-accumulate in `decompose_fold`. +//! +//! Rotates an i8 digit plane by each challenge position and accumulates +//! into an i32 accumulator using widening add/sub (`SADDW` / `SSUBW`). + +use std::arch::aarch64::*; + +/// NEON sparse-multiply-accumulate. +/// +/// For each challenge term `(pos, coeff)`, rotates the `digit_plane` by `pos` +/// positions in the negacyclic ring (X^D + 1) and adds or subtracts the +/// widened i8 values into the i32 `acc`. +/// +/// # Safety +/// +/// - `digit_plane` must point to at least `d` valid i8 values. +/// - `acc` must point to at least `d` valid i32 values. +/// - `d` must be a multiple of 16. +#[target_feature(enable = "neon")] +pub(super) unsafe fn sparse_mul_acc_neon( + digit_plane: *const i8, + acc: *mut i32, + d: usize, + positions: &[u32], + coeffs: &[i16], +) { + debug_assert!(d % 16 == 0); + + for (&pos, &coeff) in positions.iter().zip(coeffs.iter()) { + let p = pos as usize; + let split = d - p; + + if coeff > 0 { + acc_rotated_add(digit_plane, acc, d, p, split); + } else { + acc_rotated_sub(digit_plane, acc, d, p, split); + } + } +} + +/// Add rotated digit plane: acc[i+p] += digits[i] for i in [0, split), +/// acc[i-split] -= digits[i] for i in [split, D) (negacyclic wrap). +#[inline(always)] +unsafe fn acc_rotated_add(digits: *const i8, acc: *mut i32, d: usize, p: usize, split: usize) { + // First segment: digits[0..split] -> acc[p..D], ADD + acc_segment_add(digits, acc.add(p), split); + // Second segment: digits[split..D] -> acc[0..p], SUB (negacyclic) + if p > 0 { + acc_segment_sub(digits.add(split), acc, p); + } + let _ = d; +} + +/// Sub rotated digit plane: acc[i+p] -= digits[i] for i in [0, split), +/// acc[i-split] += digits[i] for i in [split, D) (negacyclic wrap). +#[inline(always)] +unsafe fn acc_rotated_sub(digits: *const i8, acc: *mut i32, d: usize, p: usize, split: usize) { + // First segment: digits[0..split] -> acc[p..D], SUB + acc_segment_sub(digits, acc.add(p), split); + // Second segment: digits[split..D] -> acc[0..p], ADD (negacyclic) + if p > 0 { + acc_segment_add(digits.add(split), acc, p); + } + let _ = d; +} + +/// Widen i8 source values to i32 and ADD into accumulator. +/// Handles arbitrary length (processes 16 at a time, then remainder). +#[inline(always)] +unsafe fn acc_segment_add(src: *const i8, dst: *mut i32, len: usize) { + let chunks = len / 16; + let rem = len % 16; + + for i in 0..chunks { + let offset = i * 16; + let v = vld1q_s8(src.add(offset)); + + let lo8 = vget_low_s8(v); + let hi8 = vget_high_s8(v); + let lo16 = vmovl_s8(lo8); + let hi16 = vmovl_s8(hi8); + + let s0 = vmovl_s16(vget_low_s16(lo16)); + let s1 = vmovl_s16(vget_high_s16(lo16)); + let s2 = vmovl_s16(vget_low_s16(hi16)); + let s3 = vmovl_s16(vget_high_s16(hi16)); + + let d0 = vld1q_s32(dst.add(offset)); + let d1 = vld1q_s32(dst.add(offset + 4)); + let d2 = vld1q_s32(dst.add(offset + 8)); + let d3 = vld1q_s32(dst.add(offset + 12)); + + vst1q_s32(dst.add(offset), vaddq_s32(d0, s0)); + vst1q_s32(dst.add(offset + 4), vaddq_s32(d1, s1)); + vst1q_s32(dst.add(offset + 8), vaddq_s32(d2, s2)); + vst1q_s32(dst.add(offset + 12), vaddq_s32(d3, s3)); + } + + let base = chunks * 16; + for i in 0..rem { + let val = *src.add(base + i) as i32; + *dst.add(base + i) += val; + } +} + +/// Widen i8 source values to i32 and SUB from accumulator. +/// Handles arbitrary length (processes 16 at a time, then remainder). +#[inline(always)] +unsafe fn acc_segment_sub(src: *const i8, dst: *mut i32, len: usize) { + let chunks = len / 16; + let rem = len % 16; + + for i in 0..chunks { + let offset = i * 16; + let v = vld1q_s8(src.add(offset)); + + let lo8 = vget_low_s8(v); + let hi8 = vget_high_s8(v); + let lo16 = vmovl_s8(lo8); + let hi16 = vmovl_s8(hi8); + + let s0 = vmovl_s16(vget_low_s16(lo16)); + let s1 = vmovl_s16(vget_high_s16(lo16)); + let s2 = vmovl_s16(vget_low_s16(hi16)); + let s3 = vmovl_s16(vget_high_s16(hi16)); + + let d0 = vld1q_s32(dst.add(offset)); + let d1 = vld1q_s32(dst.add(offset + 4)); + let d2 = vld1q_s32(dst.add(offset + 8)); + let d3 = vld1q_s32(dst.add(offset + 12)); + + vst1q_s32(dst.add(offset), vsubq_s32(d0, s0)); + vst1q_s32(dst.add(offset + 4), vsubq_s32(d1, s1)); + vst1q_s32(dst.add(offset + 8), vsubq_s32(d2, s2)); + vst1q_s32(dst.add(offset + 12), vsubq_s32(d3, s3)); + } + + let base = chunks * 16; + for i in 0..rem { + let val = *src.add(base + i) as i32; + *dst.add(base + i) -= val; + } +} diff --git a/src/protocol/hachi_poly_ops/mod.rs b/src/protocol/hachi_poly_ops/mod.rs new file mode 100644 index 00000000..fbb1f5e8 --- /dev/null +++ b/src/protocol/hachi_poly_ops/mod.rs @@ -0,0 +1,1508 @@ +//! Operation-centric polynomial trait for the Hachi commitment scheme. +//! +//! [`HachiPolyOps`] exposes the four operations the Hachi commit/prove paths +//! need from a polynomial, rather than raw coefficient access. Each +//! implementation handles every operation in its own optimal way: +//! +//! - [`DensePoly`] — standard dense algorithms (decompose + NTT matvec). +//! - [`OneHotPoly`] — sparse monomial tricks, avoids all inner ring +//! multiplications. +//! +//! # Extensibility +//! +//! This trait is coupled to power-of-2 cyclotomic rings +//! ([`CyclotomicRing`]). When non-power-of-2 rings are added, the trait +//! signature will change. Additional operation methods may be added as the +//! protocol evolves. + +use crate::algebra::fields::wide::HasWide; +use crate::algebra::ring::sparse_challenge::SparseChallenge; +use crate::algebra::CyclotomicRing; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::commitment::onehot::{ + inner_ajtai_onehot_wide, map_onehot_to_sparse_blocks, SparseBlockEntry, +}; +use crate::protocol::commitment::utils::crt_ntt::NttSlotCache; +use crate::protocol::commitment::utils::flat_matrix::FlatMatrix; +use crate::protocol::commitment::utils::linear::{ + decompose_rows_i8, mat_vec_mul_ntt_digits_i8, mat_vec_mul_ntt_i8, +}; +use crate::{cfg_fold_reduce, cfg_into_iter, cfg_iter, CanonicalField, FieldCore}; +use std::array::from_fn; +use std::marker::PhantomData; + +#[cfg(target_arch = "aarch64")] +use crate::algebra::ntt::neon; + +#[cfg(target_arch = "aarch64")] +mod decompose_fold_neon; + +/// Precomputed constants for balanced base-b decomposition. +struct DecomposeParams { + half_q: u128, + q: u128, + mask: i128, + half_b: i128, + b_val: i128, + log_basis: u32, +} + +/// Decompose all D coefficients of a ring element into balanced base-b digits, +/// storing results in digit-major order for subsequent SIMD scatter. +/// +/// Uses K=3 interleaved carry chains to saturate ALU throughput (3x ILP gain +/// over processing one coefficient at a time on out-of-order cores). +/// +/// `digit_buf` is `[num_digits][D]` in i8, OVERWRITTEN (not accumulated). +#[inline(never)] +fn decompose_ring_interleaved( + ring: &CyclotomicRing, + digit_buf: &mut [Vec], + num_digits: usize, + p: &DecomposeParams, +) { + let bulk_end = D - (D % 3); + + for base in (0..bulk_end).step_by(3) { + let mut c0 = to_signed(ring.coeffs[base].to_canonical_u128(), p); + let mut c1 = to_signed(ring.coeffs[base + 1].to_canonical_u128(), p); + let mut c2 = to_signed(ring.coeffs[base + 2].to_canonical_u128(), p); + + for plane in digit_buf.iter_mut().take(num_digits) { + let d0 = extract_balanced_digit(&mut c0, p); + let d1 = extract_balanced_digit(&mut c1, p); + let d2 = extract_balanced_digit(&mut c2, p); + plane[base] = d0 as i8; + plane[base + 1] = d1 as i8; + plane[base + 2] = d2 as i8; + } + } + + for idx in bulk_end..D { + let mut c = to_signed(ring.coeffs[idx].to_canonical_u128(), p); + for plane in digit_buf.iter_mut().take(num_digits) { + plane[idx] = extract_balanced_digit(&mut c, p) as i8; + } + } +} + +#[inline(never)] +fn decompose_ring_single_digit( + ring: &CyclotomicRing, + digit_plane: &mut [i8; D], + p: &DecomposeParams, +) { + for (dst, coeff) in digit_plane.iter_mut().zip(ring.coeffs.iter()) { + let centered = to_signed(coeff.to_canonical_u128(), p); + debug_assert!( + centered >= -(1i128 << (p.log_basis - 1)) && centered < (1i128 << (p.log_basis - 1)) + ); + *dst = centered as i8; + } +} + +#[inline(always)] +fn to_signed(canonical: u128, p: &DecomposeParams) -> i128 { + if canonical > p.half_q { + -((p.q - canonical) as i128) + } else { + canonical as i128 + } +} + +#[inline(always)] +fn try_centered_i8(coeff: F, q: u128, half_q: u128) -> Option { + let canonical = coeff.to_canonical_u128(); + let centered = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + if (i8::MIN as i128..=i8::MAX as i128).contains(¢ered) { + Some(centered as i8) + } else { + None + } +} + +fn try_small_i8_cache_from_ring_coeffs( + coeffs: &[CyclotomicRing], +) -> Option> { + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + let mut out = Vec::with_capacity(coeffs.len()); + + for ring in coeffs { + let mut digits = [0i8; D]; + for (dst, coeff) in digits.iter_mut().zip(ring.coeffs.iter()) { + *dst = try_centered_i8(*coeff, q, half_q)?; + } + out.push(digits); + } + + Some(out) +} + +#[inline(always)] +fn extract_balanced_digit(c: &mut i128, p: &DecomposeParams) -> i32 { + let d = *c & p.mask; + let balanced = if d >= p.half_b { d - p.b_val } else { d }; + *c = (*c - balanced) >> p.log_basis; + balanced as i32 +} + +/// Scalar sparse-multiply-accumulate: accumulate `challenge * digit_plane` +/// into `acc` using the rotate-and-add formulation. +/// +/// `digit_plane` is `[i8; D]`, `acc` is `[i32; D]`. +/// Each challenge term rotates the digit plane and adds/subtracts contiguously. +fn sparse_mul_acc_scalar( + digit_plane: &[i8], + challenge: &SparseChallenge, + acc: &mut [i32; D], +) { + for (&pos, &coeff) in challenge.positions.iter().zip(challenge.coeffs.iter()) { + let p = pos as usize; + let split = D - p; + if coeff > 0 { + for i in 0..split { + acc[i + p] += digit_plane[i] as i32; + } + for i in split..D { + acc[i - split] -= digit_plane[i] as i32; + } + } else { + for i in 0..split { + acc[i + p] -= digit_plane[i] as i32; + } + for i in split..D { + acc[i - split] += digit_plane[i] as i32; + } + } + } +} + +/// Dispatch to NEON or scalar sparse-multiply-accumulate. +#[inline(always)] +fn sparse_mul_acc( + digit_plane: &[i8], + challenge: &SparseChallenge, + acc: &mut [i32; D], +) { + #[cfg(target_arch = "aarch64")] + { + if neon::use_neon_ntt() { + unsafe { + decompose_fold_neon::sparse_mul_acc_neon( + digit_plane.as_ptr(), + acc.as_mut_ptr(), + D, + &challenge.positions, + &challenge.coeffs, + ); + } + return; + } + } + sparse_mul_acc_scalar::(digit_plane, challenge, acc); +} + +#[inline(always)] +fn accum_onehot_coeff( + acc: &mut [i32; D], + coeff_idx: usize, + challenge: &SparseChallenge, +) { + debug_assert!(coeff_idx < D); + for (&pos, &coeff) in challenge.positions.iter().zip(challenge.coeffs.iter()) { + let target = coeff_idx + pos as usize; + if target < D { + acc[target] += coeff as i32; + } else { + acc[target - D] -= coeff as i32; + } + } +} + +#[inline(always)] +fn accum_onehot_entry( + acc: &mut [i32; D], + entry: &SparseBlockEntry, + challenge: &SparseChallenge, +) { + for &coeff_idx in &entry.nonzero_coeffs { + accum_onehot_coeff::(acc, coeff_idx, challenge); + } +} + +fn signed_accum_to_ring( + coeff_accum: [i32; D], + modulus: u128, +) -> CyclotomicRing { + let coeffs = from_fn(|k| { + let v = coeff_accum[k]; + if v >= 0 { + F::from_canonical_u128_reduced(v as u128) + } else { + F::from_canonical_u128_reduced(modulus - ((-v) as u128)) + } + }); + CyclotomicRing::from_coefficients(coeffs) +} + +fn build_decompose_fold_witness( + centered_coeffs: Vec<[i32; D]>, + modulus: u128, +) -> DecomposeFoldWitness { + let centered_inf_norm = centered_coeffs + .iter() + .flat_map(|row| row.iter()) + .map(|coeff| coeff.unsigned_abs()) + .max() + .unwrap_or(0); + let z_pre = cfg_iter!(centered_coeffs) + .map(|coeff_accum| signed_accum_to_ring::(*coeff_accum, modulus)) + .collect(); + DecomposeFoldWitness { + z_pre, + centered_coeffs, + centered_inf_norm, + } +} + +fn recompose_commit_inner_blocks( + t_hat_blocks: &[Vec<[i8; D]>], + num_digits_open: usize, + log_basis: u32, +) -> Result>>, HachiError> { + if num_digits_open == 0 { + return Err(HachiError::InvalidSetup( + "num_digits_open must be nonzero when recomposing commit witness".to_string(), + )); + } + t_hat_blocks + .iter() + .map(|block| { + if block.len() % num_digits_open != 0 { + return Err(HachiError::InvalidSetup(format!( + "t_hat block has {} planes, expected a multiple of num_digits_open={num_digits_open}", + block.len() + ))); + } + Ok(block + .chunks(num_digits_open) + .map(|digits| CyclotomicRing::gadget_recompose_pow2_i8(digits, log_basis)) + .collect()) + }) + .collect() +} + +/// Prover-side output of the decompose + challenge-fold step. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DecomposeFoldWitness { + /// Folded witness rows in ring form. + pub z_pre: Vec>, + /// Centered integer coefficients for each `z_pre` row. + pub centered_coeffs: Vec<[i32; D]>, + /// Infinity norm of `centered_coeffs`. + pub centered_inf_norm: u32, +} + +/// Prover-side output of the inner Ajtai commit step. +pub struct CommitInnerWitness { + /// Undecomposed `t_i = A * s_i` rows, grouped by block. + pub t: Vec>>, + /// Decomposed `t_hat_i = G^{-1}(t_i)` rows, grouped by block. + pub t_hat: Vec>, +} + +/// Operations the Hachi commitment scheme needs from a polynomial. +/// +/// The four methods correspond to the four places in commit/prove that consume +/// polynomial data. Implementations decide *how* to carry out each operation +/// (dense decompose + NTT, sparse monomial tricks, streaming, etc.). +pub trait HachiPolyOps: Clone + Send + Sync { + /// Per-polynomial cache type for the A-matrix commit path. + /// + /// `DensePoly` uses `NttSlotCache` (CRT+NTT of A for dense mat-vec). + /// `OneHotPoly` uses `()` (one-hot commit bypasses NTT entirely). + type CommitCache: Send + Sync; + + /// Total number of ring elements in the polynomial. + fn num_ring_elems(&self) -> usize; + + /// **Op 1 — prove: ring-space evaluation.** + /// + /// Computes the global weighted sum `y = Σᵢ scalars[i] · self[i]`. + /// + /// `scalars` has length >= `num_ring_elems`; excess entries are ignored. + fn evaluate_ring(&self, scalars: &[F]) -> CyclotomicRing; + + /// **Op 2 — prove: per-block fold.** + /// + /// For each contiguous block of `block_len` ring elements, computes + /// `Σⱼ scalars[j] · self[i·block_len + j]`. + /// + /// Returns one ring element per block (total `ceil(num_ring_elems / block_len)`). + /// `scalars` has length `block_len`. + fn fold_blocks(&self, scalars: &[F], block_len: usize) -> Vec>; + + /// Fused fold + evaluation in a single pass over the polynomial. + /// + /// `eval_outer_scalars` is the per-block weight vector `b` (size `num_blocks`). + /// `fold_scalars` is the per-element-in-block weight vector `a` (size `block_len`). + /// + /// The full evaluation scalars factor as `outer_weights[i*block_len + j] = b[i] * a[j]`, + /// so `eval = Σ_i b[i] * fold(a)[i]` — derived from the fold result without + /// materializing the full `2^(m_vars + r_vars)` weight vector. + fn evaluate_and_fold( + &self, + eval_outer_scalars: &[F], + fold_scalars: &[F], + block_len: usize, + ) -> (CyclotomicRing, Vec>) { + let folded = self.fold_blocks(fold_scalars, block_len); + let eval = folded + .iter() + .zip(eval_outer_scalars.iter()) + .fold(CyclotomicRing::::zero(), |acc, (f_i, s_i)| { + acc + f_i.scale(s_i) + }); + (eval, folded) + } + + /// **Op 3 — prove: decompose + challenge-fold.** + /// + /// For each block of `block_len` ring elements: + /// 1. Decompose: `sᵢ = G⁻¹(blockᵢ)` via `balanced_decompose_pow2(num_digits, log_basis)`. + /// 2. Accumulate: `z += cᵢ ⊗ sᵢ` (sparse challenge multiplication). + /// + /// Returns the folded witness `z_pre` of length `block_len · num_digits` + /// together with centered coefficient rows that later prover steps can reuse. + fn decompose_fold( + &self, + challenges: &[SparseChallenge], + block_len: usize, + num_digits: usize, + log_basis: u32, + ) -> DecomposeFoldWitness; + + /// **Op 4 — commit: per-block inner Ajtai.** + /// + /// For each block of `block_len` ring elements: + /// 1. `sᵢ = G⁻¹(blockᵢ)` with `num_digits_commit` levels. + /// 2. `tᵢ = A · sᵢ` (matrix-vector multiply via NTT cache or sparse path). + /// 3. `t̂ᵢ = G⁻¹(tᵢ)` with `num_digits_open` levels (t has full-field + /// coefficients regardless of s's digit count). + /// + /// Returns one `t̂ᵢ` vector per block as `[i8; D]` digit planes. + /// + /// # Errors + /// + /// Returns an error if the cached matrix-vector multiply fails. + fn commit_inner( + &self, + a_matrix: &FlatMatrix, + ntt_a: &NttSlotCache, + block_len: usize, + num_digits_commit: usize, + num_digits_open: usize, + log_basis: u32, + ) -> Result>, HachiError>; + + /// Like [`commit_inner`](Self::commit_inner), but also preserves the + /// undecomposed `t_i` rows for prover-side consumers that would otherwise + /// need to recompose `t_hat`. + /// + /// # Errors + /// + /// Returns an error if [`commit_inner`](Self::commit_inner) fails or if the + /// resulting `t_hat` blocks cannot be recomposed into full `t_i` rows. + fn commit_inner_witness( + &self, + a_matrix: &FlatMatrix, + ntt_a: &NttSlotCache, + block_len: usize, + num_digits_commit: usize, + num_digits_open: usize, + log_basis: u32, + ) -> Result, HachiError> + where + F: CanonicalField, + { + let t_hat = self.commit_inner( + a_matrix, + ntt_a, + block_len, + num_digits_commit, + num_digits_open, + log_basis, + )?; + let t = recompose_commit_inner_blocks::(&t_hat, num_digits_open, log_basis)?; + Ok(CommitInnerWitness { t, t_hat }) + } +} + +/// Dense polynomial: all ring coefficients materialized in memory. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DensePoly { + /// Ring coefficients in sequential block order. + pub coeffs: Vec>, + small_i8_coeffs: Option>, +} + +impl DensePoly { + /// Pack field-element evaluations into ring elements. + /// + /// The first `α = log₂(D)` variables become coefficient slots within each + /// ring element; the remaining variables index ring elements. + /// + /// # Errors + /// + /// Returns an error if `D` is not a power of two, `num_vars < log₂(D)`, or + /// `evals.len() != 2^num_vars`. + pub fn from_field_evals(num_vars: usize, evals: &[F]) -> Result { + if D == 0 || !D.is_power_of_two() { + return Err(HachiError::InvalidInput(format!( + "ring degree D={D} is not a power of two" + ))); + } + let alpha = D.trailing_zeros() as usize; + if num_vars < alpha { + return Err(HachiError::InvalidInput(format!( + "num_vars {num_vars} is smaller than alpha {alpha}" + ))); + } + let expected_len = 1usize + .checked_shl(num_vars as u32) + .ok_or_else(|| HachiError::InvalidInput(format!("2^{num_vars} does not fit usize")))?; + if evals.len() != expected_len { + return Err(HachiError::InvalidSize { + expected: expected_len, + actual: evals.len(), + }); + } + + let outer_len = expected_len / D; + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + let mut coeffs = Vec::with_capacity(outer_len); + let mut small_i8_coeffs = Vec::with_capacity(outer_len); + let mut all_small_i8 = true; + + for i in 0..outer_len { + let slice = &evals[i * D..(i + 1) * D]; + coeffs.push(CyclotomicRing::from_slice(slice)); + + if all_small_i8 { + let mut digits = [0i8; D]; + for (dst, coeff) in digits.iter_mut().zip(slice.iter()) { + if let Some(centered) = try_centered_i8(*coeff, q, half_q) { + *dst = centered; + } else { + all_small_i8 = false; + break; + } + } + if all_small_i8 { + small_i8_coeffs.push(digits); + } + } + } + + Ok(Self { + coeffs, + small_i8_coeffs: all_small_i8.then_some(small_i8_coeffs), + }) + } + + /// Wrap an existing vector of ring elements. + pub fn from_ring_coeffs(coeffs: Vec>) -> Self { + let small_i8_coeffs = try_small_i8_cache_from_ring_coeffs(&coeffs); + Self { + coeffs, + small_i8_coeffs, + } + } +} + +impl HachiPolyOps for DensePoly +where + F: FieldCore + CanonicalField, +{ + type CommitCache = NttSlotCache; + + fn num_ring_elems(&self) -> usize { + self.coeffs.len() + } + + fn evaluate_ring(&self, scalars: &[F]) -> CyclotomicRing { + #[cfg(feature = "parallel")] + { + self.coeffs + .par_iter() + .zip(scalars.par_iter()) + .fold( + || CyclotomicRing::::zero(), + |acc, (f_i, w_i)| acc + f_i.scale(w_i), + ) + .reduce(|| CyclotomicRing::::zero(), |a, b| a + b) + } + #[cfg(not(feature = "parallel"))] + { + self.coeffs + .iter() + .zip(scalars.iter()) + .fold(CyclotomicRing::::zero(), |acc, (f_i, w_i)| { + acc + f_i.scale(w_i) + }) + } + } + + fn fold_blocks(&self, scalars: &[F], block_len: usize) -> Vec> { + let n = self.coeffs.len(); + let num_blocks = n.div_ceil(block_len); + cfg_into_iter!(0..num_blocks) + .map(|i| { + let start = i * block_len; + let end = (start + block_len).min(n); + let block = &self.coeffs[start..end]; + let mut acc = CyclotomicRing::::zero(); + for (b_j, &a_j) in block.iter().zip(scalars.iter()) { + acc += b_j.scale(&a_j); + } + acc + }) + .collect() + } + + #[tracing::instrument(skip_all, name = "DensePoly::decompose_fold")] + fn decompose_fold( + &self, + challenges: &[SparseChallenge], + block_len: usize, + num_digits: usize, + log_basis: u32, + ) -> DecomposeFoldWitness { + let n = self.coeffs.len(); + let coeffs = &self.coeffs; + + let q = (-F::one()).to_canonical_u128() + 1; + let params = DecomposeParams { + half_q: q / 2, + q, + mask: (1i128 << log_basis) - 1, + half_b: 1i128 << (log_basis - 1), + b_val: 1i128 << log_basis, + log_basis, + }; + + // Single-digit dense configs (e.g. logbasis) can skip the generic + // multi-digit decomposition buffers and accumulate one centered digit + // plane per ring element directly. + if num_digits == 1 { + if let Some(small_coeffs) = &self.small_i8_coeffs { + let coeff_accum: Vec<[i32; D]> = { + let _span = + tracing::info_span!("dense_single_digit_cached_accumulate").entered(); + cfg_into_iter!(0..block_len) + .map(|elem_idx| { + let mut z_local = [0i32; D]; + + for (block_idx, c_i) in challenges.iter().enumerate() { + let global_idx = block_idx * block_len + elem_idx; + if global_idx >= small_coeffs.len() { + continue; + } + sparse_mul_acc::(&small_coeffs[global_idx], c_i, &mut z_local); + } + + z_local + }) + .collect() + }; + + let _span = tracing::info_span!("dense_single_digit_convert").entered(); + return build_decompose_fold_witness::(coeff_accum, params.q); + } + + let coeff_accum: Vec<[i32; D]> = { + let _span = tracing::info_span!("dense_single_digit_accumulate").entered(); + cfg_into_iter!(0..block_len) + .map(|elem_idx| { + let mut z_local = [0i32; D]; + let mut digit_plane = [0i8; D]; + + for (block_idx, c_i) in challenges.iter().enumerate() { + let global_idx = block_idx * block_len + elem_idx; + if global_idx >= n { + continue; + } + let ring = &coeffs[global_idx]; + decompose_ring_single_digit::(ring, &mut digit_plane, ¶ms); + sparse_mul_acc::(&digit_plane, c_i, &mut z_local); + } + + z_local + }) + .collect() + }; + + let _span = tracing::info_span!("dense_single_digit_convert").entered(); + return build_decompose_fold_witness::(coeff_accum, params.q); + } + + // Two-phase approach: decompose ring element coefficients into i8 digit + // planes, then scatter via sparse polynomial multiply. + let z_chunks: Vec> = { + let _span = tracing::info_span!("dense_multi_digit_accumulate").entered(); + cfg_into_iter!(0..block_len) + .map(|elem_idx| { + let mut z_local: Vec<[i32; D]> = vec![[0i32; D]; num_digits]; + let mut digit_buf: Vec> = vec![vec![0i8; D]; num_digits]; + + for (block_idx, c_i) in challenges.iter().enumerate() { + let global_idx = block_idx * block_len + elem_idx; + if global_idx >= n { + continue; + } + let ring = &coeffs[global_idx]; + decompose_ring_interleaved::( + ring, + &mut digit_buf, + num_digits, + ¶ms, + ); + + for digit in 0..num_digits { + sparse_mul_acc::(&digit_buf[digit], c_i, &mut z_local[digit]); + } + } + + z_local + }) + .collect() + }; + + let _span = tracing::info_span!("dense_multi_digit_convert").entered(); + let mut centered_coeffs = Vec::with_capacity(block_len * num_digits); + for chunk in z_chunks { + centered_coeffs.extend(chunk); + } + build_decompose_fold_witness::(centered_coeffs, params.q) + } + + #[tracing::instrument(skip_all, name = "DensePoly::commit_inner")] + fn commit_inner( + &self, + _a_matrix: &FlatMatrix, + ntt_a: &NttSlotCache, + block_len: usize, + num_digits_commit: usize, + num_digits_open: usize, + log_basis: u32, + ) -> Result>, HachiError> { + let n = self.coeffs.len(); + let num_blocks = n.div_ceil(block_len); + + let block_slices: Vec<&[CyclotomicRing]> = (0..num_blocks) + .map(|i| { + let start = i * block_len; + if start >= n { + &[] as &[CyclotomicRing] + } else { + &self.coeffs[start..(start + block_len).min(n)] + } + }) + .collect(); + + let t_all = mat_vec_mul_ntt_i8(ntt_a, &block_slices, num_digits_commit, log_basis); + + let results: Vec> = cfg_into_iter!(t_all) + .map(|t_i| decompose_rows_i8(&t_i, num_digits_open, log_basis)) + .collect(); + + Ok(results) + } + + fn commit_inner_witness( + &self, + _a_matrix: &FlatMatrix, + ntt_a: &NttSlotCache, + block_len: usize, + num_digits_commit: usize, + num_digits_open: usize, + log_basis: u32, + ) -> Result, HachiError> { + let n = self.coeffs.len(); + let num_blocks = n.div_ceil(block_len); + + let block_slices: Vec<&[CyclotomicRing]> = (0..num_blocks) + .map(|i| { + let start = i * block_len; + if start >= n { + &[] as &[CyclotomicRing] + } else { + &self.coeffs[start..(start + block_len).min(n)] + } + }) + .collect(); + + let t = mat_vec_mul_ntt_i8(ntt_a, &block_slices, num_digits_commit, log_basis); + let t_hat = cfg_iter!(t) + .map(|t_i| decompose_rows_i8(t_i, num_digits_open, log_basis)) + .collect(); + Ok(CommitInnerWitness { t, t_hat }) + } +} + +/// Ring polynomial whose coefficients are already balanced base-`2^log_basis` +/// digits. +/// +/// This is the recursive `w` witness used by Hachi's later prove levels. Unlike +/// [`DensePoly`], it can skip the `i8 -> field -> dense ring` round-trip and +/// operate on the digit planes directly. +#[derive(Debug, Clone)] +pub(crate) struct BalancedDigitPoly<'a, F: FieldCore, const D: usize> { + coeffs: &'a [[i8; D]], + padded_ring_elems: usize, + _marker: PhantomData, +} + +impl<'a, F: FieldCore, const D: usize> BalancedDigitPoly<'a, F, D> { + /// Wrap a flat digit vector laid out as consecutive ring coefficients. + pub(crate) fn from_i8_digits(digits: &'a [i8]) -> Result { + let (coeffs, remainder) = digits.as_chunks::(); + if !remainder.is_empty() { + return Err(HachiError::InvalidSize { + expected: D, + actual: digits.len(), + }); + } + + Ok(Self { + coeffs, + padded_ring_elems: coeffs.len().next_power_of_two().max(1), + _marker: PhantomData, + }) + } + + #[inline] + fn block_slice(&self, block_idx: usize, block_len: usize) -> &'a [[i8; D]] { + let start = block_idx * block_len; + if start >= self.coeffs.len() { + &[] + } else { + &self.coeffs[start..(start + block_len).min(self.coeffs.len())] + } + } +} + +impl<'a, F, const D: usize> HachiPolyOps for BalancedDigitPoly<'a, F, D> +where + F: FieldCore + CanonicalField, +{ + type CommitCache = NttSlotCache; + + fn num_ring_elems(&self) -> usize { + self.padded_ring_elems + } + + fn evaluate_ring(&self, scalars: &[F]) -> CyclotomicRing { + let total = cfg_fold_reduce!( + 0..self.coeffs.len().min(scalars.len()), + || [F::zero(); D], + |mut acc: [F; D], idx| { + let scalar = scalars[idx]; + let digit = &self.coeffs[idx]; + for (coeff, &d) in acc.iter_mut().zip(digit.iter()) { + if d != 0 { + *coeff += scalar * F::from_i8(d); + } + } + acc + }, + |mut a: [F; D], b: [F; D]| { + for (a_coeff, b_coeff) in a.iter_mut().zip(b.iter()) { + *a_coeff += *b_coeff; + } + a + } + ); + CyclotomicRing::from_coefficients(total) + } + + fn fold_blocks(&self, scalars: &[F], block_len: usize) -> Vec> { + let num_blocks = self.num_ring_elems().div_ceil(block_len); + cfg_into_iter!(0..num_blocks) + .map(|block_idx| { + let mut acc = [F::zero(); D]; + for (ring, &scalar) in self + .block_slice(block_idx, block_len) + .iter() + .zip(scalars.iter()) + { + for (coeff, &d) in acc.iter_mut().zip(ring.iter()) { + if d != 0 { + *coeff += scalar * F::from_i8(d); + } + } + } + CyclotomicRing::from_coefficients(acc) + }) + .collect() + } + + fn decompose_fold( + &self, + challenges: &[SparseChallenge], + block_len: usize, + num_digits: usize, + _log_basis: u32, + ) -> DecomposeFoldWitness { + let inner_width = block_len * num_digits; + let num_blocks = self.num_ring_elems().div_ceil(block_len); + + let q = (-F::one()).to_canonical_u128() + 1; + let coeff_accum = cfg_fold_reduce!( + 0..challenges.len().min(num_blocks), + || vec![[0i32; D]; inner_width], + |mut z_local: Vec<[i32; D]>, block_idx| { + let challenge = &challenges[block_idx]; + for (elem_idx, digit_plane) in + self.block_slice(block_idx, block_len).iter().enumerate() + { + sparse_mul_acc::( + digit_plane, + challenge, + &mut z_local[elem_idx * num_digits], + ); + } + z_local + }, + |mut a: Vec<[i32; D]>, b: Vec<[i32; D]>| { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + for (a_coeff, b_coeff) in ai.iter_mut().zip(bi.iter()) { + *a_coeff += *b_coeff; + } + } + a + } + ); + build_decompose_fold_witness::(coeff_accum, q) + } + + fn commit_inner( + &self, + _a_matrix: &FlatMatrix, + ntt_a: &NttSlotCache, + block_len: usize, + num_digits_commit: usize, + num_digits_open: usize, + log_basis: u32, + ) -> Result>, HachiError> { + let num_blocks = self.num_ring_elems().div_ceil(block_len); + let coeff_len = self.coeffs.len(); + + let t_all = if num_digits_commit == 1 { + let block_slices: Vec<&[[i8; D]]> = (0..num_blocks) + .map(|block_idx| self.block_slice(block_idx, block_len)) + .collect(); + mat_vec_mul_ntt_digits_i8(ntt_a, &block_slices) + } else { + let ring_elems: Vec> = self + .coeffs + .iter() + .map(|digit| { + let coeffs = from_fn(|k| F::from_i8(digit[k])); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let block_slices: Vec<&[CyclotomicRing]> = (0..num_blocks) + .map(|block_idx| { + let start = block_idx * block_len; + if start >= coeff_len { + &[] as &[CyclotomicRing] + } else { + &ring_elems[start..(start + block_len).min(coeff_len)] + } + }) + .collect(); + mat_vec_mul_ntt_i8(ntt_a, &block_slices, num_digits_commit, log_basis) + }; + + let results = cfg_into_iter!(t_all) + .map(|t_i| decompose_rows_i8(&t_i, num_digits_open, log_basis)) + .collect(); + Ok(results) + } + + fn commit_inner_witness( + &self, + _a_matrix: &FlatMatrix, + ntt_a: &NttSlotCache, + block_len: usize, + num_digits_commit: usize, + num_digits_open: usize, + log_basis: u32, + ) -> Result, HachiError> { + let num_blocks = self.num_ring_elems().div_ceil(block_len); + let coeff_len = self.coeffs.len(); + + let t = if num_digits_commit == 1 { + let block_slices: Vec<&[[i8; D]]> = (0..num_blocks) + .map(|block_idx| self.block_slice(block_idx, block_len)) + .collect(); + mat_vec_mul_ntt_digits_i8(ntt_a, &block_slices) + } else { + let ring_elems: Vec> = self + .coeffs + .iter() + .map(|digit| { + let coeffs = from_fn(|k| F::from_i8(digit[k])); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let block_slices: Vec<&[CyclotomicRing]> = (0..num_blocks) + .map(|block_idx| { + let start = block_idx * block_len; + if start >= coeff_len { + &[] as &[CyclotomicRing] + } else { + &ring_elems[start..(start + block_len).min(coeff_len)] + } + }) + .collect(); + mat_vec_mul_ntt_i8(ntt_a, &block_slices, num_digits_commit, log_basis) + }; + + let t_hat = cfg_iter!(t) + .map(|t_i| decompose_rows_i8(t_i, num_digits_open, log_basis)) + .collect(); + Ok(CommitInnerWitness { t, t_hat }) + } +} + +/// Types usable as one-hot position indices. +/// +/// Implemented for `u8`, `u16`, `u32`, and `usize`. +pub trait OneHotIndex: Copy + Send + Sync + std::fmt::Debug + 'static { + /// Convert to `usize` for indexing. + fn as_usize(self) -> usize; +} + +impl OneHotIndex for u8 { + #[inline] + fn as_usize(self) -> usize { + self as usize + } +} + +impl OneHotIndex for u16 { + #[inline] + fn as_usize(self) -> usize { + self as usize + } +} + +impl OneHotIndex for u32 { + #[inline] + fn as_usize(self) -> usize { + self as usize + } +} + +impl OneHotIndex for usize { + #[inline] + fn as_usize(self) -> usize { + self + } +} + +/// One-hot polynomial: sparse witness with at most one nonzero field element +/// per chunk of size `onehot_k`. +/// +/// Exploits sparsity in all four operations, avoiding inner ring +/// multiplications during commit and decomposing only nonzero monomials. +/// +/// Generic over `I`: the index type stored per chunk. Use `u8` when +/// `onehot_k <= 256` to cut per-entry memory from 16 bytes to 2 bytes. +#[derive(Debug, Clone)] +pub struct OneHotPoly { + onehot_k: usize, + indices: Vec>, + m_vars: usize, + sparse_blocks: Vec>, + _marker: PhantomData, +} + +impl OneHotPoly { + /// Build a one-hot polynomial from chunk size and hot-position indices. + /// + /// `indices[c]` is the hot position in chunk `c` (`None` for all-zero chunks). + /// + /// # Errors + /// + /// Returns an error if dimensions are inconsistent or any index is out of range. + pub fn new( + onehot_k: usize, + indices: Vec>, + r_vars: usize, + m_vars: usize, + ) -> Result { + let sparse_blocks = map_onehot_to_sparse_blocks(onehot_k, &indices, r_vars, m_vars, D)?; + Ok(Self { + onehot_k, + indices, + m_vars, + sparse_blocks, + _marker: PhantomData, + }) + } + + fn total_ring_elems(&self) -> usize { + let total_field = self.indices.len() * self.onehot_k; + total_field / D + } + + fn decompose_fold_regular_onehot( + &self, + challenges: &[SparseChallenge], + block_len: usize, + ) -> DecomposeFoldWitness + where + F: CanonicalField, + { + let num_blocks = challenges.len().min(self.sparse_blocks.len()); + let modulus = (-F::one()).to_canonical_u128() + 1; + let indices = &self.indices; + debug_assert_eq!(indices.len(), self.total_ring_elems()); + + let coeff_accum: Vec<[i32; D]> = { + let _span = tracing::info_span!("onehot_regular_accumulate").entered(); + cfg_into_iter!(0..block_len) + .map(|elem_idx| { + let mut coeffs = [0i32; D]; + let mut ring_idx = elem_idx; + for challenge in challenges.iter().take(num_blocks) { + if let Some(hot_idx) = indices[ring_idx] { + accum_onehot_coeff::(&mut coeffs, hot_idx.as_usize(), challenge); + } + ring_idx += block_len; + } + coeffs + }) + .collect() + }; + + let _span = tracing::info_span!("onehot_regular_convert").entered(); + build_decompose_fold_witness::(coeff_accum, modulus) + } + + fn decompose_fold_sparse_onehot( + &self, + challenges: &[SparseChallenge], + block_len: usize, + num_digits: usize, + ) -> DecomposeFoldWitness + where + F: CanonicalField, + { + let inner_width = block_len * num_digits; + let num_blocks = challenges.len().min(self.sparse_blocks.len()); + let modulus = (-F::one()).to_canonical_u128() + 1; + + let coeff_accum = { + let _span = tracing::info_span!("onehot_sparse_accumulate").entered(); + cfg_fold_reduce!( + 0..num_blocks, + || vec![[0i32; D]; inner_width], + |mut z_local: Vec<[i32; D]>, block_idx: usize| { + let challenge = &challenges[block_idx]; + for entry in &self.sparse_blocks[block_idx] { + let z_coeffs = &mut z_local[entry.pos_in_block * num_digits]; + accum_onehot_entry::(z_coeffs, entry, challenge); + } + z_local + }, + |mut a: Vec<[i32; D]>, b: Vec<[i32; D]>| { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + for (a_coeff, b_coeff) in ai.iter_mut().zip(bi.iter()) { + *a_coeff += *b_coeff; + } + } + a + } + ) + }; + + let _span = tracing::info_span!("onehot_sparse_convert").entered(); + build_decompose_fold_witness::(coeff_accum, modulus) + } +} + +impl HachiPolyOps for OneHotPoly +where + F: FieldCore + CanonicalField + HasWide, +{ + type CommitCache = NttSlotCache; + + fn num_ring_elems(&self) -> usize { + self.total_ring_elems() + } + + fn evaluate_ring(&self, scalars: &[F]) -> CyclotomicRing { + let block_len = 1usize << self.m_vars; + cfg_fold_reduce!( + 0..self.sparse_blocks.len(), + || CyclotomicRing::::zero(), + |mut acc: CyclotomicRing, block_idx: usize| { + let block_offset = block_idx * block_len; + for entry in &self.sparse_blocks[block_idx] { + let ring_idx = block_offset + entry.pos_in_block; + if ring_idx < scalars.len() { + let s = scalars[ring_idx]; + for &ci in &entry.nonzero_coeffs { + acc.coeffs[ci] += s; + } + } + } + acc + }, + |a, b| a + b + ) + } + + fn fold_blocks(&self, scalars: &[F], block_len: usize) -> Vec> { + cfg_iter!(self.sparse_blocks) + .map(|entries| { + let mut coeffs_acc = [F::zero(); D]; + for entry in entries { + if entry.pos_in_block < scalars.len() && entry.pos_in_block < block_len { + let s = scalars[entry.pos_in_block]; + for &ci in &entry.nonzero_coeffs { + coeffs_acc[ci] += s; + } + } + } + CyclotomicRing::from_coefficients(coeffs_acc) + }) + .collect() + } + + #[tracing::instrument(skip_all, name = "OneHotPoly::decompose_fold")] + fn decompose_fold( + &self, + challenges: &[SparseChallenge], + block_len: usize, + num_digits: usize, + _log_basis: u32, + ) -> DecomposeFoldWitness { + // In the common regular one-hot case used by the large onehot profile, + // each chunk is exactly one ring element with one hot coefficient. + // Build each output ring independently instead of reducing full z + // vectors across blocks. + if num_digits == 1 && self.onehot_k == D { + self.decompose_fold_regular_onehot(challenges, block_len) + } else { + self.decompose_fold_sparse_onehot(challenges, block_len, num_digits) + } + } + + #[tracing::instrument(skip_all, name = "OneHotPoly::commit_inner")] + fn commit_inner( + &self, + a_matrix: &FlatMatrix, + _ntt_a: &NttSlotCache, + block_len: usize, + num_digits_commit: usize, + num_digits_open: usize, + log_basis: u32, + ) -> Result>, HachiError> { + let a_view = a_matrix.view::(); + let n_a = a_view.num_rows(); + let zero_block_len = n_a.checked_mul(num_digits_open).unwrap(); + + let t_hat_all: Vec> = cfg_iter!(self.sparse_blocks) + .map(|block_entries| { + if block_entries.is_empty() { + vec![[0i8; D]; zero_block_len] + } else { + let t_i = inner_ajtai_onehot_wide( + &a_view, + block_entries, + block_len, + num_digits_commit, + ); + decompose_rows_i8(&t_i, num_digits_open, log_basis) + } + }) + .collect(); + + Ok(t_hat_all) + } + + fn commit_inner_witness( + &self, + a_matrix: &FlatMatrix, + _ntt_a: &NttSlotCache, + block_len: usize, + num_digits_commit: usize, + num_digits_open: usize, + log_basis: u32, + ) -> Result, HachiError> { + let a_view = a_matrix.view::(); + let n_a = a_view.num_rows(); + let zero_block_len = n_a.checked_mul(num_digits_open).unwrap(); + + let per_block = cfg_iter!(self.sparse_blocks) + .map(|block_entries| { + if block_entries.is_empty() { + ( + vec![CyclotomicRing::::zero(); n_a], + vec![[0i8; D]; zero_block_len], + ) + } else { + let t_i = inner_ajtai_onehot_wide( + &a_view, + block_entries, + block_len, + num_digits_commit, + ); + let t_hat_i = decompose_rows_i8(&t_i, num_digits_open, log_basis); + (t_i, t_hat_i) + } + }) + .collect::>(); + let (t, t_hat): (Vec<_>, Vec<_>) = per_block.into_iter().unzip(); + Ok(CommitInnerWitness { t, t_hat }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::commitment::{ + CommitmentConfig, HachiCommitmentCore, HachiScheduleInputs, RingCommitmentScheme, + }; + use crate::protocol::ring_switch::w_commitment_layout; + use crate::test_utils::{TinyConfig, D as TestD, F as TestF}; + use crate::FromSmallInt; + + #[test] + fn dense_poly_from_field_evals_roundtrip() { + let num_vars = 10; + let len = 1usize << num_vars; + let evals: Vec = (0..len).map(|i| TestF::from_u64(i as u64)).collect(); + let poly = DensePoly::::from_field_evals(num_vars, &evals).unwrap(); + assert_eq!(poly.num_ring_elems(), len / TestD); + } + + #[test] + fn dense_commit_inner_matches_ring_commit() { + let (setup, _) = + >::setup(16) + .unwrap(); + let layout = setup.layout(); + let num_ring = layout.num_blocks * layout.block_len; + let evals: Vec = (0..num_ring * TestD) + .map(|i| TestF::from_u64(i as u64)) + .collect(); + + let alpha = TestD.trailing_zeros() as usize; + let num_vars = alpha + layout.m_vars + layout.r_vars; + let poly = DensePoly::::from_field_evals(num_vars, &evals).unwrap(); + + let t_hat_poly = poly + .commit_inner( + &setup.expanded.A, + &setup.ntt_A, + layout.block_len, + layout.num_digits_commit, + layout.num_digits_open, + layout.log_basis, + ) + .unwrap(); + + let w = + >::commit_coeffs( + &poly.coeffs, + &setup, + ) + .unwrap(); + + assert_eq!(t_hat_poly, w.t_hat); + } + + #[test] + fn onehot_commit_inner_matches_ring_commit_onehot() { + let (setup, _) = + >::setup(16) + .unwrap(); + let layout = setup.layout(); + let total_ring = layout.num_blocks * layout.block_len; + let onehot_k = TestD; + let num_chunks = total_ring; + let indices: Vec> = (0..num_chunks).map(|i| Some(i % onehot_k)).collect(); + + let poly = OneHotPoly::::new( + onehot_k, + indices.clone(), + layout.r_vars, + layout.m_vars, + ) + .unwrap(); + + let t_hat_poly = poly + .commit_inner( + &setup.expanded.A, + &setup.ntt_A, + layout.block_len, + layout.num_digits_commit, + layout.num_digits_open, + layout.log_basis, + ) + .unwrap(); + + let w = + >::commit_onehot( + onehot_k, &indices, &setup, + ) + .unwrap(); + + assert_eq!(t_hat_poly, w.t_hat); + } + + #[test] + fn onehot_decompose_fold_matches_dense_regular_onehot() { + let (setup, _) = + >::setup(16) + .unwrap(); + let layout = setup.layout(); + let total_ring = layout.num_blocks * layout.block_len; + let onehot_k = TestD; + let indices: Vec> = (0..total_ring) + .map(|i| (i % 11 != 0).then_some((i * 7 + 3) % onehot_k)) + .collect(); + + let poly = OneHotPoly::::new( + onehot_k, + indices.clone(), + layout.r_vars, + layout.m_vars, + ) + .unwrap(); + + let mut evals = vec![TestF::zero(); total_ring * onehot_k]; + for (chunk_idx, hot_idx) in indices.into_iter().enumerate() { + if let Some(hot_idx) = hot_idx { + evals[chunk_idx * onehot_k + hot_idx] = TestF::from_u64(1); + } + } + + let alpha = TestD.trailing_zeros() as usize; + let num_vars = alpha + layout.m_vars + layout.r_vars; + let dense = DensePoly::::from_field_evals(num_vars, &evals).unwrap(); + let challenges: Vec = (0..layout.num_blocks) + .map(|i| SparseChallenge { + positions: vec![ + 0u32, + ((i * 5 + 1) % TestD) as u32, + ((i * 9 + 2) % TestD) as u32, + ], + coeffs: vec![1, -1, 1], + }) + .collect(); + + let got = poly.decompose_fold(&challenges, layout.block_len, 1, layout.log_basis); + let expected = dense.decompose_fold(&challenges, layout.block_len, 1, layout.log_basis); + assert_eq!(got.z_pre, expected.z_pre); + assert_eq!(got.centered_coeffs, expected.centered_coeffs); + assert_eq!(got.centered_inf_norm, expected.centered_inf_norm); + } + + #[test] + fn balanced_digit_poly_matches_dense_recursive_w_ops() { + let log_basis = TinyConfig::decomposition().log_basis; + let digits: Vec = (0..(3 * TestD)).map(|i| (i % 7) as i8 - 3).collect(); + let field_evals: Vec = digits.iter().map(|&d| TestF::from_i64(d as i64)).collect(); + let total_coeffs = digits.len().next_power_of_two().max(TestD); + let mut padded = field_evals.clone(); + padded.resize(total_coeffs, TestF::zero()); + + let dense = DensePoly::::from_field_evals( + total_coeffs.trailing_zeros() as usize, + &padded, + ) + .unwrap(); + let digit_poly = BalancedDigitPoly::::from_i8_digits(&digits).unwrap(); + + assert_eq!(digit_poly.num_ring_elems(), dense.num_ring_elems()); + + let eval_scalars: Vec = (0..digit_poly.num_ring_elems()) + .map(|i| TestF::from_u64((i + 2) as u64)) + .collect(); + assert_eq!( + digit_poly.evaluate_ring(&eval_scalars), + dense.evaluate_ring(&eval_scalars) + ); + + let block_len = 2; + let fold_scalars: Vec = (0..block_len) + .map(|i| TestF::from_u64((i + 5) as u64)) + .collect(); + assert_eq!( + digit_poly.fold_blocks(&fold_scalars, block_len), + dense.fold_blocks(&fold_scalars, block_len) + ); + + let num_blocks = digit_poly.num_ring_elems().div_ceil(block_len); + let challenges: Vec = (0..num_blocks) + .map(|i| SparseChallenge { + positions: vec![0u32, ((i + 3) % TestD) as u32], + coeffs: vec![1, -1], + }) + .collect(); + let got = digit_poly.decompose_fold(&challenges, block_len, 1, log_basis); + let expected = dense.decompose_fold(&challenges, block_len, 1, log_basis); + assert_eq!(got.z_pre, expected.z_pre); + assert_eq!(got.centered_coeffs, expected.centered_coeffs); + assert_eq!(got.centered_inf_norm, expected.centered_inf_norm); + + let (setup, _) = + >::setup(16) + .unwrap(); + let layout = setup.layout(); + let level_params = TinyConfig::level_params(HachiScheduleInputs { + max_num_vars: setup.expanded.seed.max_num_vars, + level: 0, + current_w_len: layout.num_blocks * layout.block_len * TestD, + }); + let w_layout = + w_commitment_layout::(level_params, layout).unwrap(); + let digit_commit = digit_poly + .commit_inner( + &setup.expanded.A, + &setup.ntt_A, + w_layout.block_len, + w_layout.num_digits_commit, + w_layout.num_digits_open, + w_layout.log_basis, + ) + .unwrap(); + let dense_commit = dense + .commit_inner( + &setup.expanded.A, + &setup.ntt_A, + w_layout.block_len, + w_layout.num_digits_commit, + w_layout.num_digits_open, + w_layout.log_basis, + ) + .unwrap(); + + assert_eq!(digit_commit, dense_commit); + } +} diff --git a/src/protocol/labrador/aggregation.rs b/src/protocol/labrador/aggregation.rs new file mode 100644 index 00000000..bea0db22 --- /dev/null +++ b/src/protocol/labrador/aggregation.rs @@ -0,0 +1,960 @@ +//! Constraint aggregation for the Labrador protocol. +//! +//! Implements the "Aggregating" step from Section 5.2 of the LaBRADOR paper: +//! the 256 JL projection constraints and the existing statement constraints +//! are folded into a single aggregated constraint (φ_i, b) via random challenges. +//! +//! # Paper reference +//! +//! The JL constraints are first collapsed into ⌈128/log q⌉ functions using +//! 256 independent scalar collapse challenges per lift. Each collapsed +//! function produces a polynomial b''(k) whose constant term the verifier +//! can check. The prover sends b''(k) with the constant term zeroed out. +//! Note: If the field size is 128-bit or larger, instead of aggregating constraints +//! with random ring elements, we aggregate them with random field elements. +//! This diverges from the paper's protocol. For smaller fields, we precisely +//! follow the paper protocol. + +use crate::algebra::ring::CyclotomicRing; +use crate::algebra::SparseChallenge; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::commitment::utils::linear::{ + mat_vec_mul_crt_ntt_i8_single, try_centered_i8_cache_from_ring_coeffs, +}; +use crate::protocol::labrador::config::jl_lifts; +use crate::protocol::labrador::constraints::{pair_index, LabradorConstraint, NextWitnessLayout}; +use crate::protocol::labrador::johnson_lindenstrauss::{ + restore_constant_term, zero_constant_term_for_proof, LabradorJlMatrix, +}; +use crate::protocol::labrador::types::{ + LabradorReducedConstraintPlan, LabradorStatement, LabradorWitness, +}; +use crate::protocol::labrador::utils::pow2_field; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::{challenge_ring_element, Transcript}; +use crate::{CanonicalField, FieldCore, FromSmallInt}; + +#[derive(Clone, Copy)] +enum AggregationRandomness { + /// Field is 128-bit or larger: aggregate with random field elements. + Scalar(F), + /// Field is smaller than 128-bit: aggregate with random ring elements. + Ring(CyclotomicRing), +} + +#[inline] +/// Whether constraint aggregation may safely replace ring randomness with scalar randomness. +/// +/// Security note: for prime moduli with bit-length greater than 128, we can +/// replace a ring-element challenge with a scalar field challenge and still keep +/// the claimed security level for the aggregation step. +pub(crate) fn safe_to_use_scalar_randomness() -> bool { + let modulus = (-F::one()).to_canonical_u128() + 1; + let bits = u128::BITS - modulus.leading_zeros(); + bits == 128 +} + +#[inline] +fn sample_aggregation_randomness( + transcript: &mut T, +) -> AggregationRandomness +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + if safe_to_use_scalar_randomness::() { + AggregationRandomness::Scalar( + transcript.challenge_scalar(labels::CHALLENGE_LABRADOR_AGGREGATION), + ) + } else { + AggregationRandomness::Ring(challenge_ring_element( + transcript, + labels::CHALLENGE_LABRADOR_AGGREGATION, + )) + } +} + +type AggregatedConstraintSystem = + (Vec>>, CyclotomicRing); + +#[cfg(feature = "parallel")] +const STATEMENT_ROW_CHUNK_LEN: usize = 256; +const SPARSE_RING_MUL_MAX_WEIGHT: usize = 48; + +/// Inner product of two ring-element slices. +pub(crate) fn dot_product( + lhs: &[CyclotomicRing], + rhs: &[CyclotomicRing], +) -> CyclotomicRing { + let len = lhs.len().min(rhs.len()); + cfg_fold_reduce!( + (0..len), + || CyclotomicRing::::zero(), + |acc, i| acc + lhs[i] * rhs[i], + |a, b| a + b + ) +} + +/// Element-wise accumulate a flat `other` view into row-structured `acc`. +#[tracing::instrument(skip_all, name = "labrador::add_phi_flat_in_place")] +pub(crate) fn add_phi_flat_in_place( + acc: &mut [Vec>], + other_flat: &[CyclotomicRing], +) -> Result<(), HachiError> { + let mut ranges = Vec::with_capacity(acc.len()); + let mut cursor = 0usize; + for row in acc.iter() { + let start = cursor; + cursor += row.len(); + ranges.push((start, cursor)); + } + if cursor != other_flat.len() { + return Err(HachiError::InvalidInput( + "flat phi length mismatch".to_string(), + )); + } + + cfg_iter_mut!(acc) + .zip(cfg_iter!(ranges)) + .for_each(|(row_acc, &(start, end))| { + for (dst, src) in row_acc.iter_mut().zip(other_flat[start..end].iter()) { + *dst += *src; + } + }); + Ok(()) +} + +#[inline] +fn scalar_to_ring(scalar: F) -> CyclotomicRing { + let mut coeffs = [F::zero(); D]; + coeffs[0] = scalar; + CyclotomicRing::from_coefficients(coeffs) +} + +#[inline] +fn mul_accumulate_with_alpha( + alpha: &AggregationRandomness, + coeff: &CyclotomicRing, + dst: &mut CyclotomicRing, +) { + match alpha { + AggregationRandomness::Scalar(alpha_scalar) => { + if alpha_scalar.is_zero() { + return; + } + for (dst_coeff, src_coeff) in + dst.coefficients_mut().iter_mut().zip(coeff.coefficients()) + { + *dst_coeff += *src_coeff * *alpha_scalar; + } + } + AggregationRandomness::Ring(alpha_ring) => { + mul_accumulate_term_coeff(alpha_ring, coeff, dst) + } + } +} + +#[inline] +fn accumulate_rhs_with_alpha( + alpha: &AggregationRandomness, + target: &CyclotomicRing, + acc: &mut CyclotomicRing, +) { + match alpha { + AggregationRandomness::Scalar(alpha_scalar) => *acc += target.scale(alpha_scalar), + AggregationRandomness::Ring(alpha_ring) => alpha_ring.mul_accumulate_into(target, acc), + } +} + +#[inline] +fn alpha_mul_sparse( + alpha: &AggregationRandomness, + challenge: &SparseChallenge, +) -> CyclotomicRing { + match alpha { + AggregationRandomness::Scalar(alpha_scalar) => { + let mut coeffs = [F::zero(); D]; + for (&pos, &coeff) in challenge.positions.iter().zip(challenge.coeffs.iter()) { + coeffs[pos as usize] += *alpha_scalar * F::from_i64(coeff as i64); + } + CyclotomicRing::from_coefficients(coeffs) + } + AggregationRandomness::Ring(alpha_ring) => alpha_ring.mul_by_sparse(challenge), + } +} + +#[inline] +fn alpha_scale_to_ring( + alpha: &AggregationRandomness, + scale: F, +) -> CyclotomicRing { + match alpha { + AggregationRandomness::Scalar(alpha_scalar) => { + scalar_to_ring::(*alpha_scalar * scale) + } + AggregationRandomness::Ring(alpha_ring) => alpha_ring.scale(&scale), + } +} + +/// Sample 256 scalar collapse challenges (legacy schedule). +fn sample_jl_collapse_challenge(transcript: &mut T) -> [F; 256] +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + std::array::from_fn(|_| transcript.challenge_scalar(labels::CHALLENGE_LABRADOR_JL_COLLAPSE)) +} + +/// Collapse JL projection coordinates with signed challenge weights directly in +/// the field, avoiding host-integer saturation. +fn collapse_to_field(projection: &[i64; 256], alpha: &[F; 256]) -> F +where + F: FieldCore + FromSmallInt, +{ + projection + .iter() + .zip(alpha.iter()) + .fold(F::zero(), |acc, (&p, &a)| acc + a * F::from_i64(p)) +} + +fn validate_matrix_cols(matrix: &LabradorJlMatrix, cols: usize) -> Result<(), HachiError> { + if !matrix.is_well_formed() || matrix.cols() != cols { + return Err(HachiError::InvalidInput( + "JL matrix row length mismatch".to_string(), + )); + } + Ok(()) +} + +#[inline] +fn build_four_russians_lookup_field( + alpha0: F, + alpha1: F, + alpha2: F, + alpha3: F, +) -> [F; 256] { + let mut lookup = [F::zero(); 256]; + for packed in 0u16..256 { + let packed = packed as u8; + let pair0 = packed & 0b11; + let pair1 = (packed >> 2) & 0b11; + let pair2 = (packed >> 4) & 0b11; + let pair3 = (packed >> 6) & 0b11; + let mut acc = F::zero(); + match pair0 { + 0b00 => acc -= alpha0, + 0b11 => acc += alpha0, + _ => {} + } + match pair1 { + 0b00 => acc -= alpha1, + 0b11 => acc += alpha1, + _ => {} + } + match pair2 { + 0b00 => acc -= alpha2, + 0b11 => acc += alpha2, + _ => {} + } + match pair3 { + 0b00 => acc -= alpha3, + 0b11 => acc += alpha3, + _ => {} + } + lookup[packed as usize] = acc; + } + lookup +} + +#[inline] +fn accumulate_field_weight_contribution( + coeffs: &mut [F], + local_idx: usize, + contribution: F, +) { + if contribution.is_zero() { + return; + } + if local_idx == 0 { + coeffs[0] += contribution; + } else { + coeffs[D - local_idx] -= contribution; + } +} + +type JlGroupRows<'a> = [&'a [u8]; 4]; + +#[inline] +fn apply_four_russians_group4_to_elem( + elem: &mut CyclotomicRing, + elem_idx: usize, + rows: JlGroupRows<'_>, + lookup: &[F; 256], + bytes_per_ring: usize, +) { + let start = elem_idx * bytes_per_ring; + let end = start + bytes_per_ring; + let sign_bytes0 = &rows[0][start..end]; + let sign_bytes1 = &rows[1][start..end]; + let sign_bytes2 = &rows[2][start..end]; + let sign_bytes3 = &rows[3][start..end]; + + let coeffs = elem.coefficients_mut(); + let mut local_idx = 0usize; + + for (((&byte0, &byte1), &byte2), &byte3) in sign_bytes0 + .iter() + .zip(sign_bytes1.iter()) + .zip(sign_bytes2.iter()) + .zip(sign_bytes3.iter()) + { + let packed0 = + (byte0 & 0b11) | ((byte1 & 0b11) << 2) | ((byte2 & 0b11) << 4) | ((byte3 & 0b11) << 6); + let packed1 = ((byte0 >> 2) & 0b11) + | (((byte1 >> 2) & 0b11) << 2) + | (((byte2 >> 2) & 0b11) << 4) + | (((byte3 >> 2) & 0b11) << 6); + let packed2 = ((byte0 >> 4) & 0b11) + | (((byte1 >> 4) & 0b11) << 2) + | (((byte2 >> 4) & 0b11) << 4) + | (((byte3 >> 4) & 0b11) << 6); + let packed3 = ((byte0 >> 6) & 0b11) + | (((byte1 >> 6) & 0b11) << 2) + | (((byte2 >> 6) & 0b11) << 4) + | (((byte3 >> 6) & 0b11) << 6); + + accumulate_field_weight_contribution::(coeffs, local_idx, lookup[packed0 as usize]); + accumulate_field_weight_contribution::( + coeffs, + local_idx + 1, + lookup[packed1 as usize], + ); + accumulate_field_weight_contribution::( + coeffs, + local_idx + 2, + lookup[packed2 as usize], + ); + accumulate_field_weight_contribution::( + coeffs, + local_idx + 3, + lookup[packed3 as usize], + ); + local_idx += 4; + } +} + +/// Collapse 256 JL rows × omega into JL phi coefficients using +/// field arithmetic only. +/// +/// # Errors +/// +/// Returns [`HachiError::InvalidInput`] if the matrix dimensions are invalid. +/// +/// # Panics +/// +/// Panics if `D` is not divisible by 4, which is required by the Four Russians +/// JL collapse implementation. +#[tracing::instrument(skip_all, name = "labrador::aggregate_jl_contraints_one_lift")] +pub fn aggregate_jl_contraints_one_lift( + matrix: &LabradorJlMatrix, + omega: &[F; 256], +) -> Result>, HachiError> { + let cols = matrix.cols(); + validate_matrix_cols(matrix, cols)?; + if D % 4 != 0 { + panic!("Four Russians field collapse requires D divisible by 4, got D={D}"); + } + let num_elems = cols / D; + let bytes_per_ring = D / 4; + debug_assert_eq!(omega.len() % 4, 0); + let group_rows: Vec> = (0..omega.len()) + .step_by(4) + .map(|group_start| { + [ + matrix.packed_rows[group_start].as_slice(), + matrix.packed_rows[group_start + 1].as_slice(), + matrix.packed_rows[group_start + 2].as_slice(), + matrix.packed_rows[group_start + 3].as_slice(), + ] + }) + .collect(); + let lookups: Vec<[F; 256]> = (0..omega.len()) + .step_by(4) + .map(|group_start| { + build_four_russians_lookup_field( + omega[group_start], + omega[group_start + 1], + omega[group_start + 2], + omega[group_start + 3], + ) + }) + .collect(); + + Ok(cfg_into_iter!(0..num_elems) + .map(|elem_idx| { + let mut elem = CyclotomicRing::::zero(); + for (rows, lookup) in group_rows.iter().zip(lookups.iter()) { + apply_four_russians_group4_to_elem::( + &mut elem, + elem_idx, + *rows, + lookup, + bytes_per_ring, + ); + } + elem + }) + .collect()) +} + +/// Pre-flattened witness layout used during JL aggregation. +struct FlatWitness { + rings: Vec>, +} + +impl FlatWitness { + #[tracing::instrument(skip_all, name = "labrador::flat_witness_new")] + fn new(witness: &LabradorWitness) -> Self { + let mut rings = Vec::new(); + for row in witness.rows() { + rings.extend(row.iter().copied()); + } + Self { rings } + } +} + +fn accumulate_phi_flat( + phi_total_flat: &mut [CyclotomicRing], + phi_flat: &[CyclotomicRing], + beta: &AggregationRandomness, +) { + debug_assert_eq!(phi_total_flat.len(), phi_flat.len()); + match beta { + AggregationRandomness::Scalar(s) => { + if s.is_zero() { + return; + } + cfg_iter_mut!(phi_total_flat) + .zip(cfg_iter!(phi_flat)) + .for_each(|(dst, src)| { + for (dst_coeff, src_coeff) in + dst.coefficients_mut().iter_mut().zip(src.coefficients()) + { + *dst_coeff += *src_coeff * *s; + } + }); + } + AggregationRandomness::Ring(r) => { + cfg_iter_mut!(phi_total_flat) + .zip(cfg_iter!(phi_flat)) + .for_each(|(dst, src)| mul_accumulate_term_coeff(r, src, dst)); + } + } +} + +/// Aggregate JL projection constraints on the prover side. +/// +/// For each of the ⌈128/log q⌉ lifts: +/// 1. Sample ω_j^(k) ∈ Z_q for `j = 1..256`. +/// 2. Collapse the JL matrix rows → φ^''(k) ring-element vector. +/// 3. Compute b^''(k) = ⟨φ^''(k), s⟩ and verify its constant term. +/// 4. Transmit b^''(k) (constant term zeroed) and absorb into transcript. +/// 5. If `q` is 128-bit, use scalar β_k; otherwise use ring β_k. +/// +/// Returns `(phi_total_flat, jl_lift_residuals)` where +/// `phi_total_flat` is flattened in row-major order and +/// `jl_lift_residuals` holds the transmitted polynomials. +#[allow(clippy::type_complexity)] +#[tracing::instrument(skip_all, name = "labrador::aggregate_jl_constraints_prover")] +pub(crate) fn aggregate_jl_constraints_prover( + witness: &LabradorWitness, + matrix: &LabradorJlMatrix, + transcript: &mut T, +) -> Result<(Vec>, Vec>), HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + let flat_witness = FlatWitness::new(witness); + let flat_witness_i8 = try_centered_i8_cache_from_ring_coeffs(&flat_witness.rings); + + let mut phi_total_flat = vec![CyclotomicRing::::zero(); flat_witness.rings.len()]; + let lifts = jl_lifts::(); + let mut jl_lift_residuals = Vec::with_capacity(lifts); + + for _ in 0..lifts { + let omega = sample_jl_collapse_challenge::(transcript); + let phi_flat = aggregate_jl_contraints_one_lift::(matrix, &omega)?; + let b_full = if let Some(witness_i8) = flat_witness_i8.as_ref() { + mat_vec_mul_crt_ntt_i8_single(&phi_flat, witness_i8) + .ok() + .unwrap_or_else(|| dot_product(&phi_flat, &flat_witness.rings)) + } else { + dot_product(&phi_flat, &flat_witness.rings) + }; + + let (b_tx, _c0) = zero_constant_term_for_proof(b_full); + jl_lift_residuals.push(b_tx); + transcript.append_serde(labels::ABSORB_LABRADOR_JL_LIFT_RESIDUALS, &b_tx); + + let beta = sample_aggregation_randomness::(transcript); + accumulate_phi_flat(&mut phi_total_flat, &phi_flat, &beta); + } + + Ok((phi_total_flat, jl_lift_residuals)) +} + +/// Aggregate JL projection constraints on the verifier side. +/// +/// Same transcript flow as the prover variant, but reconstructs the full +/// polynomial b^''(k) by restoring the constant term from the projection +/// and the transmitted `jl_lift_residuals[k]`. Returns a flattened `phi_total`. +#[allow(clippy::type_complexity)] +#[tracing::instrument(skip_all, name = "labrador::aggregate_jl_constraints_verifier")] +pub(crate) fn aggregate_jl_constraints_verifier( + row_lengths: &[usize], + jl_projection: &[i64; 256], + matrix: &LabradorJlMatrix, + jl_lift_residuals: &[CyclotomicRing], + transcript: &mut T, +) -> Result<(Vec>, CyclotomicRing), HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + let lifts = jl_lifts::(); + if jl_lift_residuals.len() != lifts { + return Err(HachiError::InvalidProof); + } + let total_phi_elems: usize = row_lengths.iter().sum(); + let mut phi_total_flat = vec![CyclotomicRing::::zero(); total_phi_elems]; + let mut aggregated_rhs = CyclotomicRing::::zero(); + + for jl_lift_residual in jl_lift_residuals.iter() { + let omega = sample_jl_collapse_challenge::(transcript); + let phi_flat = aggregate_jl_contraints_one_lift::(matrix, &omega)?; + let b_full = restore_constant_term( + *jl_lift_residual, + collapse_to_field::(jl_projection, &omega), + ); + transcript.append_serde(labels::ABSORB_LABRADOR_JL_LIFT_RESIDUALS, jl_lift_residual); + let beta = sample_aggregation_randomness::(transcript); + accumulate_rhs_with_alpha(&beta, &b_full, &mut aggregated_rhs); + accumulate_phi_flat(&mut phi_total_flat, &phi_flat, &beta); + } + + Ok((phi_total_flat, aggregated_rhs)) +} + +#[inline] +fn mul_accumulate_term_coeff( + alpha: &CyclotomicRing, + coeff: &CyclotomicRing, + dst: &mut CyclotomicRing, +) { + if coeff.hamming_weight() <= SPARSE_RING_MUL_MAX_WEIGHT { + alpha.mul_accumulate_sparse_rhs_into(coeff, dst); + } else { + alpha.mul_accumulate_into(coeff, dst); + } +} + +fn accumulate_scaled_row( + dst: &mut [CyclotomicRing], + src: &[CyclotomicRing], + alpha: &AggregationRandomness, + scale: F, +) { + debug_assert_eq!(dst.len(), src.len()); + match alpha { + AggregationRandomness::Scalar(alpha_scalar) => { + let scaled_alpha = *alpha_scalar * scale; + if scaled_alpha.is_zero() { + return; + } + cfg_iter_mut!(dst) + .zip(cfg_iter!(src)) + .for_each(|(dst_elem, src_elem)| { + for (dst_coeff, src_coeff) in dst_elem + .coefficients_mut() + .iter_mut() + .zip(src_elem.coefficients()) + { + *dst_coeff += *src_coeff * scaled_alpha; + } + }); + } + AggregationRandomness::Ring(alpha_ring) => { + let scaled_alpha = alpha_ring.scale(&scale); + cfg_iter_mut!(dst) + .zip(cfg_iter!(src)) + .for_each(|(dst_elem, src_elem)| { + mul_accumulate_term_coeff(&scaled_alpha, src_elem, dst_elem) + }); + } + } +} + +fn accumulate_statement_row_work( + row: &mut [CyclotomicRing], + work: &[(usize, usize)], + constraints: &[LabradorConstraint], + alphas: &[AggregationRandomness], +) { + #[cfg(feature = "parallel")] + row.par_chunks_mut(STATEMENT_ROW_CHUNK_LEN) + .enumerate() + .for_each(|(chunk_idx, chunk)| { + let chunk_start = chunk_idx * STATEMENT_ROW_CHUNK_LEN; + let chunk_end = chunk_start + chunk.len(); + for &(ci, ti) in work { + let term = &constraints[ci].terms[ti]; + let term_end = term.offset + term.coefficients.len(); + let start = chunk_start.max(term.offset); + let end = chunk_end.min(term_end); + if start >= end { + continue; + } + let alpha = &alphas[ci]; + let src = &term.coefficients[start - term.offset..end - term.offset]; + let dst = &mut chunk[start - chunk_start..end - chunk_start]; + for (dst_elem, src_elem) in dst.iter_mut().zip(src.iter()) { + mul_accumulate_with_alpha(alpha, src_elem, dst_elem); + } + } + }); + + #[cfg(not(feature = "parallel"))] + for &(ci, ti) in work { + let term = &constraints[ci].terms[ti]; + let alpha = &alphas[ci]; + for (dst_elem, src_elem) in row[term.offset..term.offset + term.coefficients.len()] + .iter_mut() + .zip(term.coefficients.iter()) + { + mul_accumulate_with_alpha(alpha, src_elem, dst_elem); + } + } +} + +#[tracing::instrument(skip_all, name = "labrador::aggregate_reduced_statement_constraints")] +fn aggregate_reduced_statement_constraints( + statement: &LabradorStatement, + plan: &LabradorReducedConstraintPlan, + row_lengths: &[usize], + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + let layout = NextWitnessLayout::new(plan.row_count, &plan.config); + if row_lengths.len() != layout.num_rows() { + return Err(HachiError::InvalidInput( + "reduced statement row count mismatch".to_string(), + )); + } + if row_lengths + .iter() + .take(plan.config.witness_digit_parts) + .any(|&len| len != plan.max_len) + || row_lengths[layout.aux_row] != layout.aux_row_len() + { + return Err(HachiError::InvalidInput( + "reduced statement row layout mismatch".to_string(), + )); + } + if statement.inner_opening_payload.len() != plan.config.outer_commit_rank + || statement.linear_garbage_payload.len() != plan.config.outer_commit_rank + { + return Err(HachiError::InvalidInput( + "reduced statement payload length mismatch".to_string(), + )); + } + + let pow_b: Vec = (0..plan.config.witness_digit_parts) + .map(|idx| pow2_field::(plan.config.witness_digit_bits * idx)) + .collect(); + let pow_bu: Vec = (0..plan.config.aux_digit_parts) + .map(|idx| pow2_field::(plan.config.aux_digit_bits * idx)) + .collect(); + + let mut phi_total: Vec>> = row_lengths + .iter() + .map(|&len| vec![CyclotomicRing::zero(); len]) + .collect(); + let (z_rows, aux_rows) = phi_total.split_at_mut(plan.config.witness_digit_parts); + let aux_row = aux_rows.first_mut().ok_or_else(|| { + HachiError::InvalidInput("missing auxiliary row in reduced statement".to_string()) + })?; + let inner_opening_start = layout.inner_opening_digits_range().start; + let linear_garbage_start = layout.linear_garbage_digits_range().start; + let mut aggregated_rhs = CyclotomicRing::::zero(); + + for (b_row, target) in plan + .setup + .b_mat + .iter() + .zip(statement.inner_opening_payload.iter()) + { + let alpha = sample_aggregation_randomness::(transcript); + accumulate_rhs_with_alpha(&alpha, target, &mut aggregated_rhs); + let dst = &mut aux_row[inner_opening_start..linear_garbage_start]; + for (dst, src) in dst.iter_mut().zip(b_row.iter()) { + mul_accumulate_with_alpha(&alpha, src, dst); + } + } + + for (d_row, target) in plan + .setup + .d_mat + .iter() + .zip(statement.linear_garbage_payload.iter()) + { + let alpha = sample_aggregation_randomness::(transcript); + accumulate_rhs_with_alpha(&alpha, target, &mut aggregated_rhs); + let dst = &mut aux_row[linear_garbage_start..]; + for (dst, src) in dst.iter_mut().zip(d_row.iter()) { + mul_accumulate_with_alpha(&alpha, src, dst); + } + } + + for output_idx in 0..plan.config.inner_commit_rank { + let alpha = sample_aggregation_randomness::(transcript); + let a_row = &plan.setup.a_mat[output_idx]; + for (part_idx, &scale) in pow_b.iter().enumerate() { + accumulate_scaled_row(&mut z_rows[part_idx], a_row, &alpha, scale); + } + + for (row_idx, challenge) in plan.challenges.iter().enumerate() { + let base = alpha_mul_sparse(&alpha, challenge); + for (part_idx, &scale) in pow_bu.iter().enumerate() { + let idx = inner_opening_start + + row_idx * plan.config.inner_commit_rank * plan.config.aux_digit_parts + + output_idx * plan.config.aux_digit_parts + + part_idx; + aux_row[idx] -= base.scale(&scale); + } + } + } + + let alpha_lg = sample_aggregation_randomness::(transcript); + for (part_idx, &scale) in pow_b.iter().enumerate() { + accumulate_scaled_row(&mut z_rows[part_idx], &plan.amortized_phi, &alpha_lg, scale); + } + for i in 0..plan.challenges.len() { + for j in i..plan.challenges.len() { + let mut base = alpha_mul_sparse(&alpha_lg, &plan.challenges[i]); + base = base.mul_by_sparse(&plan.challenges[j]); + let pair = pair_index(i, j, plan.challenges.len()); + for (part_idx, &scale) in pow_bu.iter().enumerate() { + let idx = linear_garbage_start + pair * plan.config.aux_digit_parts + part_idx; + aux_row[idx] -= base.scale(&scale); + } + } + } + + let alpha_diag = sample_aggregation_randomness::(transcript); + accumulate_rhs_with_alpha(&alpha_diag, &plan.aggregated_rhs, &mut aggregated_rhs); + for i in 0..plan.row_count { + let pair = pair_index(i, i, plan.row_count); + for (part_idx, &scale) in pow_bu.iter().enumerate() { + let idx = linear_garbage_start + pair * plan.config.aux_digit_parts + part_idx; + aux_row[idx] += alpha_scale_to_ring(&alpha_diag, scale); + } + } + + Ok((phi_total, aggregated_rhs)) +} + +pub(crate) fn aggregate_statement( + statement: &LabradorStatement, + row_lengths: &[usize], + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + if let Some(plan) = statement.reduced_constraints.as_deref() { + aggregate_reduced_statement_constraints(statement, plan, row_lengths, transcript) + } else { + aggregate_statement_constraints(&statement.constraints, row_lengths, transcript) + } +} + +/// Fold statement constraints into aggregated (φ, b) using transcript challenges. +/// +/// Each scalar constraint is folded with one fresh dense challenge α: its +/// coefficient terms are fused-accumulated into `phi_total`, while `α · target` +/// is accumulated into `aggregated_rhs`. +#[allow(clippy::type_complexity)] +#[tracing::instrument(skip_all, name = "labrador::aggregate_statement_constraints")] +pub(crate) fn aggregate_statement_constraints( + constraints: &[LabradorConstraint], + row_lengths: &[usize], + transcript: &mut T, +) -> Result<(Vec>>, CyclotomicRing), HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + if constraints.is_empty() { + let phi_total: Vec>> = row_lengths + .iter() + .map(|&len| vec![CyclotomicRing::zero(); len]) + .collect(); + return Ok((phi_total, CyclotomicRing::zero())); + } + + let num_rows = row_lengths.len(); + + // Phase 1: sample all challenges sequentially (Fiat-Shamir ordering). + let alphas: Vec> = constraints + .iter() + .map(|_| sample_aggregation_randomness::(transcript)) + .collect(); + + // Phase 2: validate bounds (cheap, allows early `?` return). + for cnst in constraints { + for term in &cnst.terms { + if term.row >= num_rows { + return Err(HachiError::InvalidInput( + "constraint row index out of bounds".to_string(), + )); + } + if term.offset + term.coefficients.len() > row_lengths[term.row] { + return Err(HachiError::InvalidInput( + "constraint term exceeds row length".to_string(), + )); + } + } + } + + // Phase 3: aggregated_rhs — parallel fold-reduce over constraints. + let aggregated_rhs = cfg_fold_reduce!( + (0..constraints.len()), + || CyclotomicRing::::zero(), + |mut acc, i| { + accumulate_rhs_with_alpha(&alphas[i], &constraints[i].target, &mut acc); + acc + }, + |mut a, b| { + a += b; + a + } + ); + + // Phase 4: phi_total — group work by target row, then parallel over rows. + let mut row_work: Vec> = vec![Vec::new(); num_rows]; + for (ci, cnst) in constraints.iter().enumerate() { + for (ti, term) in cnst.terms.iter().enumerate() { + row_work[term.row].push((ci, ti)); + } + } + + let phi_total: Vec>> = cfg_into_iter!(row_work) + .zip(cfg_iter!(row_lengths).copied()) + .map(|(work, len)| { + let mut row = vec![CyclotomicRing::::zero(); len]; + accumulate_statement_row_work(&mut row, &work, constraints, &alphas); + row + }) + .collect(); + + Ok((phi_total, aggregated_rhs)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Prime128M13M4P0; + use crate::algebra::{Pow2Offset32Field, Pow2Offset64Field}; + use crate::protocol::transcript::labels::DOMAIN_LABRADOR_RECURSION; + use crate::protocol::transcript::Blake2bTranscript; + + const D: usize = 64; + const TEST_RING_ELEMS: usize = 16; + const TEST_COLS: usize = TEST_RING_ELEMS * D; + + #[test] + fn safe_to_use_scalar_randomness_only_for_128_bit_fields() { + assert!( + !safe_to_use_scalar_randomness::(), + "32-bit field must use legacy JL aggregation schedule" + ); + assert!( + !safe_to_use_scalar_randomness::(), + "64-bit field must use legacy JL aggregation schedule" + ); + assert!( + safe_to_use_scalar_randomness::(), + "128-bit field should use scalar JL aggregation schedule" + ); + } + + fn assert_aggregate_jl_contraints_one_lift_matches_naive< + F: FieldCore + CanonicalField + FromSmallInt, + >( + matrix: &LabradorJlMatrix, + omega: [F; 256], + ) { + let cols = matrix.cols(); + assert_eq!(cols, TEST_COLS); + + let got = aggregate_jl_contraints_one_lift::(matrix, &omega).unwrap(); + let mut expected = vec![CyclotomicRing::::zero(); cols / D]; + // Naive paper-style reference: + // φ''_i = Σ_j ω_j · σ_{-1}(π_i^(j)) + for (elem_idx, elem) in expected.iter_mut().enumerate() { + for (row_idx, &alpha) in omega.iter().enumerate() { + let row = &matrix.packed_rows[row_idx]; + let pi = std::array::from_fn(|local_idx| { + let col_idx = elem_idx * D + local_idx; + let shift = (col_idx & 0b11) << 1; + let pair = (row[col_idx >> 2] >> shift) & 0b11; + let sign = match pair { + 0b00 => -1i64, + 0b11 => 1i64, + _ => 0i64, + }; + F::from_i64(sign) + }); + *elem += CyclotomicRing::::from_coefficients(pi) + .sigma_m1() + .scale(&alpha); + } + } + + assert!(got == expected); + } + + #[test] + fn aggregate_jl_contraints_one_lift_matches_naive_fp32() { + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let matrix = + LabradorJlMatrix::generate::(&mut transcript, TEST_COLS).unwrap(); + let omega = sample_jl_collapse_challenge::(&mut transcript); + assert_aggregate_jl_contraints_one_lift_matches_naive::(&matrix, omega); + } + + #[test] + fn aggregate_jl_contraints_one_lift_matches_naive_fp64() { + type F64 = Pow2Offset64Field; + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let matrix = LabradorJlMatrix::generate::(&mut transcript, TEST_COLS).unwrap(); + let omega = sample_jl_collapse_challenge::(&mut transcript); + assert_aggregate_jl_contraints_one_lift_matches_naive::(&matrix, omega); + } + + #[test] + fn aggregate_jl_contraints_one_lift_matches_naive_fp128() { + type F128 = Prime128M13M4P0; + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let matrix = LabradorJlMatrix::generate::(&mut transcript, TEST_COLS).unwrap(); + let omega = sample_jl_collapse_challenge::(&mut transcript); + assert_aggregate_jl_contraints_one_lift_matches_naive::(&matrix, omega); + } +} diff --git a/src/protocol/labrador/challenge.rs b/src/protocol/labrador/challenge.rs new file mode 100644 index 00000000..39f895ac --- /dev/null +++ b/src/protocol/labrador/challenge.rs @@ -0,0 +1,457 @@ +//! Labrador challenge sampler (C-parity oriented). +//! +//! This ports the `polyvec_challenge` rejection sampler from the C reference. + +use crate::algebra::ring::CyclotomicRing; +use crate::algebra::SparseChallenge; +use crate::error::HachiError; +use crate::protocol::labrador::guardrails::{ + checked_add, checked_mul, ensure_power_of_two, ensure_temp_allocation_limit, + LABRADOR_MAX_CHALLENGE_POLYS, +}; +use crate::{CanonicalField, FieldCore, FromSmallInt}; +use sha3::digest::{ExtendableOutput, Update, XofReader}; +use sha3::Shake128; +use std::sync::OnceLock; + +/// Number of `±1` coefficients in a challenge polynomial. +pub const LABRADOR_TAU1: usize = 32; +/// Number of `±2` coefficients in a challenge polynomial. +pub const LABRADOR_TAU2: usize = 8; +/// Operator norm bound used by C's challenge rejection sampler. +pub const LABRADOR_CHALLENGE_OPNORM_BOUND: f64 = 14.0; +const LABRADOR_CHALLENGE_OPNORM_BOUND_SQ: f64 = + LABRADOR_CHALLENGE_OPNORM_BOUND * LABRADOR_CHALLENGE_OPNORM_BOUND; + +const SHAKE128_RATE: usize = 168; +const SINGLE_CHALLENGE_BLOCKS: usize = 2; +const SINGLE_CHALLENGE_BLOCK_BYTES: usize = SINGLE_CHALLENGE_BLOCKS * SHAKE128_RATE; + +/// Sample Labrador challenge polynomials as signed coefficient arrays. +/// +/// The output follows C `polyvec_challenge`: each polynomial has exactly +/// `LABRADOR_TAU1` coefficients in `{±1}`, `LABRADOR_TAU2` coefficients in +/// `{±2}`, all other coefficients 0, and must satisfy operator-norm bound. +/// +/// # Errors +/// +/// Returns an error if ring parameters are incompatible with the C algorithm. +pub fn sample_labrador_challenge_coeffs( + len: usize, + seed: &[u8; 16], + stream_id: u64, +) -> Result, HachiError> { + validate_challenge_params::()?; + if len > LABRADOR_MAX_CHALLENGE_POLYS { + return Err(HachiError::InvalidInput(format!( + "requested too many challenge polynomials: {len} (max {LABRADOR_MAX_CHALLENGE_POLYS})" + ))); + } + + let mut xof = Shake128::default(); + xof.update(seed); + xof.update(&stream_id.to_le_bytes()); + let mut reader = xof.finalize_xof(); + + let mut out = Vec::with_capacity(len); + let mut remaining = len; + + if remaining == 1 { + while remaining > 0 { + let mut buf = [0u8; SINGLE_CHALLENGE_BLOCK_BYTES]; + reader.read(&mut buf); + let produced = consume_challenge_buffer::(&mut out, remaining, &buf); + remaining -= produced; + } + return Ok(out); + } + + while remaining >= 10 { + let bytes = checked_mul(17, SHAKE128_RATE, "challenge block bytes")?; + ensure_temp_allocation_limit(bytes, "challenge sampler")?; + let mut buf = vec![0u8; bytes]; + reader.read(&mut buf); + let produced = consume_challenge_buffer::(&mut out, 10, &buf); + remaining -= produced; + } + + while remaining > 0 { + let scaled = checked_mul(remaining, 17, "scaled tail blocks numerator")?; + let scaled = checked_add(scaled, 9, "tail blocks numerator rounding")?; + let blocks = scaled / 10; + let bytes = checked_mul(blocks, SHAKE128_RATE, "tail block bytes")?; + ensure_temp_allocation_limit(bytes, "challenge sampler tail")?; + let mut buf = vec![0u8; bytes]; + reader.read(&mut buf); + let produced = consume_challenge_buffer::(&mut out, remaining, &buf); + remaining -= produced; + } + + Ok(out) +} + +/// Sample Labrador challenge polynomials as dense ring elements. +/// +/// # Errors +/// +/// Returns an error if parameter checks fail. +pub fn sample_labrador_challenges( + len: usize, + seed: &[u8; 16], + stream_id: u64, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, +{ + let coeffs = sample_labrador_challenge_coeffs::(len, seed, stream_id)?; + Ok(coeffs + .into_iter() + .map(|poly| { + CyclotomicRing::from_coefficients(std::array::from_fn(|i| F::from_i64(poly[i] as i64))) + }) + .collect()) +} + +/// Sample Labrador challenge polynomials as sparse ring elements. +/// +/// # Errors +/// +/// Returns an error if parameter checks fail. +pub fn sample_labrador_sparse_challenges( + len: usize, + seed: &[u8; 16], + stream_id: u64, +) -> Result, HachiError> { + Ok(sample_labrador_challenge_coeffs::(len, seed, stream_id)? + .into_iter() + .map(|poly| { + let mut positions = Vec::with_capacity(LABRADOR_TAU1 + LABRADOR_TAU2); + let mut coeffs = Vec::with_capacity(LABRADOR_TAU1 + LABRADOR_TAU2); + for (idx, coeff) in poly.into_iter().enumerate() { + if coeff != 0 { + positions.push(idx as u32); + coeffs.push(coeff); + } + } + SparseChallenge { positions, coeffs } + }) + .collect()) +} + +fn validate_challenge_params() -> Result<(), HachiError> { + ensure_power_of_two(D, "challenge sampler degree D")?; + if D > 256 { + return Err(HachiError::InvalidInput(format!( + "challenge sampler expects D <= 256, got {D}" + ))); + } + if LABRADOR_TAU1 + LABRADOR_TAU2 > D { + return Err(HachiError::InvalidInput(format!( + "tau1 + tau2 exceeds ring degree: {LABRADOR_TAU1} + {LABRADOR_TAU2} > {D}" + ))); + } + Ok(()) +} + +fn consume_challenge_buffer( + out: &mut Vec<[i16; D]>, + target_len: usize, + buf: &[u8], +) -> usize { + let sign_bytes = (LABRADOR_TAU1 + LABRADOR_TAU2).div_ceil(8); + let min_bytes = LABRADOR_TAU1 + LABRADOR_TAU2 + sign_bytes; + let mut produced = 0usize; + let mut cursor = 0usize; + + while produced < target_len && cursor <= buf.len().saturating_sub(min_bytes) { + let mut signs = 0u64; + for k in 0..sign_bytes { + signs |= (buf[cursor] as u64) << (8 * k); + cursor += 1; + } + + let mut poly = [0i16; D]; + let mut k = D - LABRADOR_TAU1 - LABRADOR_TAU2; + while k < D && cursor < buf.len() { + let b = (buf[cursor] as usize) & (D - 1); + cursor += 1; + if b <= k { + poly[k] = poly[b]; + let mut value = if k < D - LABRADOR_TAU2 { 1 } else { 2 }; + if (signs & 1) == 1 { + value = -value; + } + poly[b] = value; + signs >>= 1; + k += 1; + } + } + + if k == D + && challenge_operator_norm_with_bound::(&poly, LABRADOR_CHALLENGE_OPNORM_BOUND_SQ) + { + out.push(poly); + produced += 1; + } + } + + produced +} + +struct ChallengeOpNormTable { + cos: Vec, + sin: Vec, +} + +fn build_challenge_opnorm_table(d: usize) -> ChallengeOpNormTable { + let mut cos = Vec::with_capacity(d * d); + let mut sin = Vec::with_capacity(d * d); + let d_f = d as f64; + for i in 0..d { + let theta = ((2 * i + 1) as f64) * std::f64::consts::PI / d_f; + for j in 0..d { + let angle = theta * (j as f64); + cos.push(angle.cos()); + sin.push(angle.sin()); + } + } + ChallengeOpNormTable { cos, sin } +} + +fn challenge_opnorm_table() -> &'static ChallengeOpNormTable { + match D { + 1 => { + static TABLE: OnceLock = OnceLock::new(); + TABLE.get_or_init(|| build_challenge_opnorm_table(1)) + } + 2 => { + static TABLE: OnceLock = OnceLock::new(); + TABLE.get_or_init(|| build_challenge_opnorm_table(2)) + } + 4 => { + static TABLE: OnceLock = OnceLock::new(); + TABLE.get_or_init(|| build_challenge_opnorm_table(4)) + } + 8 => { + static TABLE: OnceLock = OnceLock::new(); + TABLE.get_or_init(|| build_challenge_opnorm_table(8)) + } + 16 => { + static TABLE: OnceLock = OnceLock::new(); + TABLE.get_or_init(|| build_challenge_opnorm_table(16)) + } + 32 => { + static TABLE: OnceLock = OnceLock::new(); + TABLE.get_or_init(|| build_challenge_opnorm_table(32)) + } + 64 => { + static TABLE: OnceLock = OnceLock::new(); + TABLE.get_or_init(|| build_challenge_opnorm_table(64)) + } + 128 => { + static TABLE: OnceLock = OnceLock::new(); + TABLE.get_or_init(|| build_challenge_opnorm_table(128)) + } + 256 => { + static TABLE: OnceLock = OnceLock::new(); + TABLE.get_or_init(|| build_challenge_opnorm_table(256)) + } + _ => panic!("unsupported challenge sampler degree {D}"), + } +} + +#[cfg(test)] +fn challenge_operator_norm_dense_reference(coeffs: &[i16; D]) -> f64 { + let mut max_norm = 0.0f64; + let d_f = D as f64; + for i in 0..D { + let theta = ((2 * i + 1) as f64) * std::f64::consts::PI / d_f; + let mut re = 0.0f64; + let mut im = 0.0f64; + for (j, &coeff) in coeffs.iter().enumerate() { + let angle = theta * (j as f64); + let c = coeff as f64; + re += c * angle.cos(); + im += c * angle.sin(); + } + let norm = (re * re + im * im).sqrt(); + if norm > max_norm { + max_norm = norm; + } + } + max_norm +} + +#[cfg(test)] +fn challenge_operator_norm(coeffs: &[i16; D]) -> f64 { + let table = challenge_opnorm_table::(); + let mut support_idx = [0usize; LABRADOR_TAU1 + LABRADOR_TAU2]; + let mut support_coeff = [0.0f64; LABRADOR_TAU1 + LABRADOR_TAU2]; + let mut support_len = 0usize; + for (idx, &coeff) in coeffs.iter().enumerate() { + if coeff == 0 { + continue; + } + if support_len == support_idx.len() { + #[cfg(test)] + { + return challenge_operator_norm_dense_reference(coeffs); + } + #[cfg(not(test))] + { + panic!("challenge support exceeded expected sparsity"); + } + } + support_idx[support_len] = idx; + support_coeff[support_len] = coeff as f64; + support_len += 1; + } + + let mut max_norm = 0.0f64; + for i in 0..D { + let row_base = i * D; + let mut re = 0.0f64; + let mut im = 0.0f64; + for idx in 0..support_len { + let coeff = support_coeff[idx]; + let col = support_idx[idx]; + re += coeff * table.cos[row_base + col]; + im += coeff * table.sin[row_base + col]; + } + let norm = (re * re + im * im).sqrt(); + if norm > max_norm { + max_norm = norm; + } + } + max_norm +} + +fn challenge_operator_norm_with_bound(coeffs: &[i16; D], bound_sq: f64) -> bool { + let table = challenge_opnorm_table::(); + let mut support_idx = [0usize; LABRADOR_TAU1 + LABRADOR_TAU2]; + let mut support_coeff = [0.0f64; LABRADOR_TAU1 + LABRADOR_TAU2]; + let mut support_len = 0usize; + for (idx, &coeff) in coeffs.iter().enumerate() { + if coeff == 0 { + continue; + } + if support_len == support_idx.len() { + #[cfg(test)] + { + let norm = challenge_operator_norm_dense_reference(coeffs); + return norm * norm <= bound_sq; + } + #[cfg(not(test))] + { + panic!("challenge support exceeded expected sparsity"); + } + } + support_idx[support_len] = idx; + support_coeff[support_len] = coeff as f64; + support_len += 1; + } + + for i in 0..D { + let row_base = i * D; + let mut re = 0.0f64; + let mut im = 0.0f64; + for idx in 0..support_len { + let coeff = support_coeff[idx]; + let col = support_idx[idx]; + re += coeff * table.cos[row_base + col]; + im += coeff * table.sin[row_base + col]; + } + if re * re + im * im > bound_sq { + return false; + } + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp32; + + type F = Fp32<4294967197>; + const D: usize = 64; + + // Fixed test seeds and stream IDs for deterministic replay. + const TEST_SEED_A: [u8; 16] = [7u8; 16]; + const TEST_SEED_B: [u8; 16] = [11u8; 16]; + const TEST_SEED_C: [u8; 16] = [5u8; 16]; + const TEST_STREAM_ID_A: u64 = 9; + const TEST_STREAM_ID_B: u64 = 17; + const TEST_STREAM_ID_C: u64 = 4; + const TEST_STREAM_ID_REF: u64 = 7; + + #[test] + fn challenge_sampler_is_deterministic() { + let c1 = sample_labrador_challenge_coeffs::(3, &TEST_SEED_A, TEST_STREAM_ID_A).unwrap(); + let c2 = sample_labrador_challenge_coeffs::(3, &TEST_SEED_A, TEST_STREAM_ID_A).unwrap(); + assert_eq!(c1, c2); + } + + #[test] + fn challenge_sampler_obeys_operator_norm_bound() { + let samples = + sample_labrador_challenge_coeffs::(8, &TEST_SEED_B, TEST_STREAM_ID_B).unwrap(); + assert_eq!(samples.len(), 8); + for poly in &samples { + assert!(challenge_operator_norm(poly) <= LABRADOR_CHALLENGE_OPNORM_BOUND); + } + } + + #[test] + fn challenge_sampler_supports_dense_ring_conversion() { + let dense = sample_labrador_challenges::(2, &TEST_SEED_C, TEST_STREAM_ID_C).unwrap(); + assert_eq!(dense.len(), 2); + } + + #[test] + fn challenge_sampler_matches_transliterated_reference_vector() { + // Captured from the C-reference algorithm semantics (`polyvec_challenge`) + // for seed = [0,1,2,...,15], stream ID = 7, len = 1. + let seed: [u8; 16] = std::array::from_fn(|i| i as u8); + let coeffs = sample_labrador_challenge_coeffs::(1, &seed, TEST_STREAM_ID_REF).unwrap(); + let got = coeffs[0]; + let expected: [i16; D] = [ + 1, 1, 0, 1, 0, 0, 2, -1, 0, 0, 2, 1, 1, -1, -1, 1, -2, 0, 1, 0, -1, -1, 1, 0, 1, -1, 1, + 1, 0, -1, 0, -1, 2, 1, 1, -1, -2, 0, 0, 1, 0, 0, 1, 1, -2, 1, 0, 0, 0, 0, 0, 0, 1, 0, + -1, -1, 2, -1, 0, 1, -2, 1, 0, 0, + ]; + assert_eq!(got, expected); + } + + #[test] + fn sparse_operator_norm_matches_dense_reference() { + for stream_id in [1u64, 3, 7, 9, 17, 29] { + let polys = + sample_labrador_challenge_coeffs::(6, &TEST_SEED_A, stream_id).expect("sample"); + for poly in polys { + let sparse = challenge_operator_norm(&poly); + let dense = challenge_operator_norm_dense_reference(&poly); + assert_eq!(sparse.to_bits(), dense.to_bits()); + assert_eq!( + sparse <= LABRADOR_CHALLENGE_OPNORM_BOUND, + dense <= LABRADOR_CHALLENGE_OPNORM_BOUND + ); + } + } + } + + #[test] + fn operator_norm_bound_check_matches_full_norm() { + for stream_id in [2u64, 5, 11, 19] { + let polys = + sample_labrador_challenge_coeffs::(6, &TEST_SEED_B, stream_id).expect("sample"); + for poly in polys { + assert_eq!( + challenge_operator_norm(&poly) <= LABRADOR_CHALLENGE_OPNORM_BOUND, + challenge_operator_norm_with_bound(&poly, LABRADOR_CHALLENGE_OPNORM_BOUND_SQ) + ); + } + } + } +} diff --git a/src/protocol/labrador/comkey.rs b/src/protocol/labrador/comkey.rs new file mode 100644 index 00000000..c4b6f16e --- /dev/null +++ b/src/protocol/labrador/comkey.rs @@ -0,0 +1,95 @@ +//! Prefix-stable extendable commitment-key derivation for Labrador. +//! +//! Unlike setup matrices that bind full `(rows, cols)` shape, this derivation +//! binds only `(matrix_label, row, col)` so extending dimensions preserves the +//! previously derived prefix exactly. + +use blake2::digest::consts::U32; +use blake2::digest::Digest; +use blake2::Blake2b; + +use crate::algebra::ring::CyclotomicRing; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::prg::{MatrixPrgContext, Shake256Backend}; +use crate::{FieldCore, FieldSampling}; + +/// Public seed used to derive extendable Labrador commitment keys. +pub type LabradorComKeySeed = [u8; 32]; + +/// Derive a Labrador commitment-key seed from the Hachi public-matrix seed. +/// +/// Uses domain-separated BLAKE2b-256 so that the Labrador key space is +/// independent of the Hachi commitment-matrix key space. +pub fn derive_labrador_comkey_seed(hachi_public_matrix_seed: &[u8; 32]) -> LabradorComKeySeed { + let mut hasher = Blake2b::::new(); + hasher.update(b"hachi/labrador/comkey-seed"); + hasher.update(hachi_public_matrix_seed); + let hash = hasher.finalize(); + let mut seed = [0u8; 32]; + seed.copy_from_slice(&hash); + seed +} + +/// Derive a prefix-stable matrix for Labrador commitment keys. +/// +/// Prefix-stable means: if `M_small = derive(rows, cols)` and +/// `M_large = derive(rows2, cols2)` with `rows2 >= rows`, `cols2 >= cols`, +/// then `M_large[r][c] == M_small[r][c]` for all `r < rows`, `c < cols`. +pub fn derive_extendable_comkey_matrix( + rows: usize, + cols: usize, + seed: &LabradorComKeySeed, + matrix_label: &[u8], +) -> Vec>> { + use crate::protocol::prg::MatrixPrgBackend; + + cfg_into_iter!(0..rows) + .map(|r| { + (0..cols) + .map(|c| { + let context = MatrixPrgContext { + seed, + matrix_label, + rows: 0, + cols: 0, + row: r, + col: c, + }; + let mut rng = Shake256Backend.entry_rng(&context); + CyclotomicRing::random(&mut rng) + }) + .collect() + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Fp64; + + type F = Fp64<4294967197>; + const D: usize = 64; + + #[test] + fn extendable_derivation_has_prefix_stability() { + let seed = [19u8; 32]; + let small = derive_extendable_comkey_matrix::(3, 4, &seed, b"comkey/A"); + let large = derive_extendable_comkey_matrix::(5, 7, &seed, b"comkey/A"); + + for r in 0..3 { + for c in 0..4 { + assert_eq!(small[r][c], large[r][c]); + } + } + } + + #[test] + fn extendable_derivation_domain_separates_labels() { + let seed = [7u8; 32]; + let a = derive_extendable_comkey_matrix::(2, 3, &seed, b"comkey/A"); + let b = derive_extendable_comkey_matrix::(2, 3, &seed, b"comkey/B"); + assert_ne!(a, b); + } +} diff --git a/src/protocol/labrador/commit.rs b/src/protocol/labrador/commit.rs new file mode 100644 index 00000000..2afbe137 --- /dev/null +++ b/src/protocol/labrador/commit.rs @@ -0,0 +1,480 @@ +//! Two-tier commitment helpers for Labrador. + +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::commitment::utils::linear::{ + decompose_rows_with_carry, mat_vec_mul_crt_ntt_i8_many, +}; +use crate::protocol::labrador::comkey::{derive_extendable_comkey_matrix, LabradorComKeySeed}; +use crate::protocol::labrador::types::{LabradorReductionConfig, LabradorWitness}; +use crate::protocol::labrador::utils::{mat_vec_mul, try_centered_i8_rows}; +use crate::{cfg_iter, CanonicalField, FieldCore, FieldSampling}; + +/// Commitment artifacts needed by downstream Labrador flows. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorCommitmentArtifacts { + /// Per-row inner commitments. + pub u_inner: Vec>>, + /// Opening-side payload (formerly `u1`). + pub inner_opening_payload: Vec>, + /// Linear-garbage-side payload (formerly `u2`). + pub linear_garbage_payload: Vec>, + /// Decomposed witness rows. + pub decomposed_witness: Vec>>, + /// Decomposed inner commitments. + pub decomposed_inner: Vec>>, + /// Linear garbage terms `h_{ij}` (always present in linear-only mode). + pub linear_garbage: Vec>, +} + +/// Commit witness rows in linear-only Labrador mode. +/// +/// # Errors +/// +/// Returns an error if dimensions/config are invalid. +pub fn commit_linear_only( + witness: &LabradorWitness, + config: &LabradorReductionConfig, + comkey_seed: &LabradorComKeySeed, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, +{ + if witness.rows().is_empty() { + return Err(HachiError::InvalidInput( + "cannot commit empty Labrador witness".to_string(), + )); + } + if config.aux_digit_parts == 0 || config.aux_digit_bits == 0 || config.inner_commit_rank == 0 { + return Err(HachiError::InvalidInput( + "invalid Labrador commitment config".to_string(), + )); + } + + #[allow(clippy::type_complexity)] + let per_row: Vec<( + Vec>, + Vec>, + Vec>, + )> = cfg_iter!(witness.rows()) + .map(|row| { + let a = derive_extendable_comkey_matrix::( + config.inner_commit_rank, + row.len(), + comkey_seed, + b"labrador/comkey/A", + ); + let t = mat_vec_mul(&a, row); + let inner_opening_digits = + decompose_rows_with_carry(&t, config.aux_digit_parts, config.aux_digit_bits as u32); + let witness_digits = decompose_rows_with_carry( + row, + config.witness_digit_parts, + config.witness_digit_bits as u32, + ); + (t, inner_opening_digits, witness_digits) + }) + .collect(); + + let mut u_inner = Vec::with_capacity(per_row.len()); + let mut decomposed_inner = Vec::with_capacity(per_row.len()); + let mut decomposed_witness = Vec::with_capacity(per_row.len()); + for (t, inner_opening_digits, witness_digits) in per_row { + if t.is_empty() { + return Err(HachiError::InvalidInput( + "inner commitment row produced empty vector".to_string(), + )); + } + u_inner.push(t); + decomposed_inner.push(inner_opening_digits); + decomposed_witness.push(witness_digits); + } + + let mut inner_opening_digits_flat = Vec::new(); + for inner_opening_digits in &decomposed_inner { + inner_opening_digits_flat.extend(inner_opening_digits.iter().copied()); + } + + let inner_opening_payload = if config.tail || config.outer_commit_rank == 0 { + u_inner.iter().flat_map(|v| v.iter().copied()).collect() + } else { + let b = derive_extendable_comkey_matrix::( + config.outer_commit_rank, + inner_opening_digits_flat.len(), + comkey_seed, + b"labrador/comkey/B", + ); + mat_vec_mul(&b, &inner_opening_digits_flat) + }; + + let linear_garbage = build_linear_garbage(witness); + let linear_garbage_payload = if config.tail || config.outer_commit_rank == 0 { + linear_garbage.clone() + } else { + let b2 = derive_extendable_comkey_matrix::( + config.outer_commit_rank, + linear_garbage.len(), + comkey_seed, + b"labrador/comkey/U2", + ); + mat_vec_mul(&b2, &linear_garbage) + }; + + Ok(LabradorCommitmentArtifacts { + u_inner, + inner_opening_payload, + linear_garbage_payload, + decomposed_witness, + decomposed_inner, + linear_garbage, + }) +} + +type RingVec = Vec>; +type TwoTierResult = Result<(RingVec, RingVec), HachiError>; +pub(crate) const OUTER_NTT_LOG_BASIS: u32 = 4; + +fn max_centered_coeff_bits( + rows: &[CyclotomicRing], +) -> usize { + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + let mut max_abs = 0u128; + + for row in rows { + for coeff in row.coeffs.iter() { + let canonical = coeff.to_canonical_u128(); + let signed = if canonical > half_q { + -((q - canonical) as i128) + } else { + canonical as i128 + }; + let abs = signed.unsigned_abs(); + if abs > max_abs { + max_abs = abs; + } + } + } + + if max_abs == 0 { + 1 + } else { + (u128::BITS - max_abs.leading_zeros()) as usize + } +} + +pub(crate) fn outer_ntt_digit_levels( + rows: &[CyclotomicRing], +) -> usize { + let coeff_bits = max_centered_coeff_bits(rows); + coeff_bits.div_ceil(OUTER_NTT_LOG_BASIS as usize) + 1 +} + +fn witness_ntt_digit_levels( + witness: &[Vec>], +) -> usize { + witness + .iter() + .map(|row| outer_ntt_digit_levels(row)) + .max() + .unwrap_or(1) +} + +fn pow2_field(exp: u32) -> F { + let two = F::one() + F::one(); + let mut acc = F::one(); + for _ in 0..exp { + acc = acc * two; + } + acc +} + +pub(crate) fn expand_matrix_for_i8_digits( + matrix: &[Vec>], + num_digits: usize, + log_basis: u32, +) -> Vec>> { + let scale_step = pow2_field::(log_basis); + let mut scales = Vec::with_capacity(num_digits); + let mut scale = F::one(); + for _ in 0..num_digits { + scales.push(scale); + scale = scale * scale_step; + } + + cfg_iter!(matrix) + .map(|row| { + let mut expanded = Vec::with_capacity(row.len() * num_digits); + for entry in row { + for scale in &scales { + expanded.push(entry.scale(scale)); + } + } + expanded + }) + .collect() +} + +pub(crate) fn decompose_rows_ntt_i8_exact( + rows: &[CyclotomicRing], + num_digits: usize, + log_basis: u32, +) -> Vec<[i8; D]> { + let mut out = Vec::with_capacity(rows.len() * num_digits); + for row in rows { + out.extend(row.balanced_decompose_pow2_i8(num_digits, log_basis)); + } + out +} + +#[tracing::instrument(skip_all, name = "labrador::commit_witness_ntt")] +fn commit_witness_ntt( + matrix: &[Vec>], + witness: &[Vec>], +) -> Result>>, HachiError> { + if matrix.is_empty() { + return Ok(vec![vec![]; witness.len()]); + } + + if let Some(witness_i8) = try_centered_i8_rows(witness) { + return mat_vec_mul_crt_ntt_i8_many(matrix, &witness_i8); + } + + // Large decomposed witness rows can exceed the safe reconstruction range of + // the generic ring-element NTT multiply. Re-expand them into balanced i8 + // planes and scale A by powers of two so the shared CRT backend stays exact. + let witness_digit_levels = witness_ntt_digit_levels(witness); + let expanded_matrix = + expand_matrix_for_i8_digits(matrix, witness_digit_levels, OUTER_NTT_LOG_BASIS); + let witness_digits: Vec> = cfg_iter!(witness) + .map(|row| decompose_rows_ntt_i8_exact(row, witness_digit_levels, OUTER_NTT_LOG_BASIS)) + .collect(); + mat_vec_mul_crt_ntt_i8_many(&expanded_matrix, &witness_digits) +} + +#[tracing::instrument(skip_all, name = "labrador::commit_inner_ntt")] +fn commit_inner_ntt( + matrix: &[Vec>], + inner_commitment: &[Vec>], + num_digits: usize, + decompose_modulus: u32, +) -> TwoTierResult { + let inner_opening_digits_per_row: Vec>> = cfg_iter!(inner_commitment) + .map(|t| decompose_rows_with_carry(t, num_digits, decompose_modulus)) + .collect(); + let inner_opening_digits: Vec> = + inner_opening_digits_per_row.into_iter().flatten().collect(); + + if matrix.is_empty() { + return Ok((inner_opening_digits.clone(), inner_opening_digits)); + } + + // The outer B-multiply sees arbitrary Labrador key coefficients times + // decomposed carry digits. Re-expand those digits into small i8 planes and + // expand B by powers of two so the product stays within the conservative + // CRT range used by the shared NTT backend. + let outer_digit_levels = outer_ntt_digit_levels(&inner_opening_digits); + let expanded_matrix = + expand_matrix_for_i8_digits(matrix, outer_digit_levels, OUTER_NTT_LOG_BASIS); + let inner_opening_digits_i8 = decompose_rows_ntt_i8_exact( + &inner_opening_digits, + outer_digit_levels, + OUTER_NTT_LOG_BASIS, + ); + let u0 = mat_vec_mul_crt_ntt_i8_many(&expanded_matrix, &[inner_opening_digits_i8])? + .into_iter() + .next() + .unwrap_or_default(); + Ok((inner_opening_digits, u0)) +} + +/// NTT-accelerated two-tier commitment: `witness → t = A·w → t̂ → u = B·t̂`. +/// +/// Returns `(inner_opening_digits, payload)` where `inner_opening_digits` is +/// the flattened decomposed inner commitment +/// and `u` is the outer commitment. +/// +/// # Errors +/// +/// Propagates NTT or matrix shape errors. +#[tracing::instrument(skip_all, name = "labrador::ntt_two_tier_commit")] +pub fn ntt_two_tier_commit( + a_mat: &[Vec>], + b_mat: &[Vec>], + witness: &[Vec>], + num_digits: usize, + decompose_modulus: u32, +) -> TwoTierResult { + let t = commit_witness_ntt(a_mat, witness)?; + commit_inner_ntt(b_mat, &t, num_digits, decompose_modulus) +} + +fn build_linear_garbage( + witness: &LabradorWitness, +) -> Vec> { + let r = witness.rows().len(); + let pairs: Vec<(usize, usize)> = (0..r).flat_map(|i| (i..r).map(move |j| (i, j))).collect(); + cfg_iter!(pairs) + .map(|&(i, j)| { + let len = witness.rows()[i].len().min(witness.rows()[j].len()); + let mut acc = CyclotomicRing::::zero(); + for k in 0..len { + acc += witness.rows()[i][k] * witness.rows()[j][k]; + } + acc + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::{Fp128, Fp64}; + use crate::protocol::commitment::utils::linear::mat_vec_mul_crt_ntt_many; + use crate::protocol::labrador::setup::LabradorSetup; + use crate::protocol::labrador::types::LabradorReductionConfig; + use crate::protocol::labrador::utils::mat_vec_mul; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + fn sample_witness() -> LabradorWitness { + let row = |len: usize| -> Vec> { + (0..len) + .map(|i| { + CyclotomicRing::from_coefficients(std::array::from_fn(|j| { + F::from_i64(((i + j) as i64 % 9) - 4) + })) + }) + .collect() + }; + LabradorWitness::new(vec![row(4), row(4), row(4)]) + } + + #[test] + fn commit_linear_only_is_deterministic() { + let witness = sample_witness(); + let cfg = LabradorReductionConfig { + witness_digit_parts: 1, + witness_digit_bits: 8, + aux_digit_parts: 2, + aux_digit_bits: 10, + inner_commit_rank: 3, + outer_commit_rank: 2, + tail: false, + }; + let seed = [3u8; 32]; + let a = commit_linear_only(&witness, &cfg, &seed).unwrap(); + let b = commit_linear_only(&witness, &cfg, &seed).unwrap(); + assert_eq!(a, b); + assert!( + !a.linear_garbage_payload.is_empty(), + "linear garbage payload must exist" + ); + } + + #[test] + fn ntt_two_tier_commit_matches_schoolbook_fp128_non_tail() { + type F128 = Fp128<0xfffffffffffffffffffffffffffffeed>; + const D128: usize = 256; + + let row = |seed: i64, len: usize| -> Vec> { + (0..len) + .map(|j| { + CyclotomicRing::from_coefficients(std::array::from_fn(|k| { + let raw = (seed + j as i64 * 7 + k as i64 * 11) % 17; + F128::from_i64(raw - 8) + })) + }) + .collect() + }; + let mut second_row = row(2, 36); + second_row.resize(48, CyclotomicRing::::zero()); + let witness = vec![row(1, 48), second_row]; + + let cfg = LabradorReductionConfig { + witness_digit_parts: 1, + witness_digit_bits: 35, + aux_digit_parts: 4, + aux_digit_bits: 32, + inner_commit_rank: 3, + outer_commit_rank: 3, + tail: false, + }; + let comkey_seed = [17u8; 32]; + let setup = LabradorSetup::::new(&cfg, witness.len(), 48, &comkey_seed); + + let t_schoolbook: Vec>> = witness + .iter() + .map(|row| mat_vec_mul(&setup.matrices.a_mat, row)) + .collect(); + let t_direct_ntt = mat_vec_mul_crt_ntt_many(&setup.matrices.a_mat, &witness).unwrap(); + assert_eq!(t_direct_ntt, t_schoolbook); + + let t_cached_ntt = commit_witness_ntt(&setup.matrices.a_mat, &witness).unwrap(); + assert_eq!(t_cached_ntt, t_schoolbook); + + let inner_opening_digits_schoolbook: Vec> = t_schoolbook + .iter() + .flat_map(|row| { + decompose_rows_with_carry(row, cfg.aux_digit_parts, cfg.aux_digit_bits as u32) + }) + .collect(); + let inner_opening_payload_schoolbook = + mat_vec_mul(&setup.matrices.b_mat, &inner_opening_digits_schoolbook); + + let (inner_opening_digits_ntt, inner_opening_payload_ntt) = ntt_two_tier_commit( + &setup.matrices.a_mat, + &setup.matrices.b_mat, + &witness, + cfg.aux_digit_parts, + cfg.aux_digit_bits as u32, + ) + .unwrap(); + + assert_eq!(inner_opening_digits_ntt, inner_opening_digits_schoolbook); + assert_eq!(inner_opening_payload_ntt, inner_opening_payload_schoolbook); + } + + #[test] + fn commit_witness_ntt_matches_schoolbook_on_large_tail_digits() { + type F128 = Fp128<0xfffffffffffffffffffffffffffffeed>; + const D128: usize = 256; + + let row = |seed: i64, scale_exp: u32| -> Vec> { + let scale = pow2_field::(scale_exp); + (0..28) + .map(|j| { + CyclotomicRing::from_coefficients(std::array::from_fn(|k| { + let raw = (seed + j as i64 * 7 + k as i64 * 11) % 17; + F128::from_i64(raw - 8) * scale + })) + }) + .collect() + }; + let witness = vec![row(1, 35), row(2, 32), row(3, 35)]; + + let cfg = LabradorReductionConfig { + witness_digit_parts: 1, + witness_digit_bits: 39, + aux_digit_parts: 1, + aux_digit_bits: 128, + inner_commit_rank: 4, + outer_commit_rank: 0, + tail: true, + }; + let comkey_seed = [23u8; 32]; + let setup = LabradorSetup::::new(&cfg, witness.len(), 28, &comkey_seed); + + let t_schoolbook: Vec>> = witness + .iter() + .map(|row| mat_vec_mul(&setup.matrices.a_mat, row)) + .collect(); + let t_cached_ntt = commit_witness_ntt(&setup.matrices.a_mat, &witness).unwrap(); + + assert_eq!(t_cached_ntt, t_schoolbook); + } +} diff --git a/src/protocol/labrador/config.rs b/src/protocol/labrador/config.rs new file mode 100644 index 00000000..fcea7e0c --- /dev/null +++ b/src/protocol/labrador/config.rs @@ -0,0 +1,1044 @@ +//! Labrador parameter-selection and security checks. + +use crate::error::HachiError; +use crate::primitives::serialization::Compress; +use crate::protocol::commitment::utils::norm::detect_field_modulus; +use crate::protocol::labrador::guardrails::LABRADOR_MAX_LEVELS; +use crate::protocol::labrador::types::{LabradorReductionConfig, LabradorWitness}; +use crate::{CanonicalField, FieldCore, HachiSerialize}; +use std::f64::consts::{E, PI}; +const LABRADOR_LOGDELTA: f64 = 0.00639138757765197; // log2(1.00444) +const LABRADOR_T: f64 = 14.0; +const LABRADOR_SLACK: f64 = 2.0; +const LABRADOR_TAU1: f64 = 32.0; +const LABRADOR_TAU2: f64 = 8.0; + +/// Full fold-level plan: security parameters plus witness reshaping layout. +#[derive(Debug, Clone)] +pub struct LabradorFoldPlan { + /// Security parameters (formerly `f`, `b`, `fu`, `bu`, `kappa`, + /// `kappa1`, `tail`). + pub config: LabradorReductionConfig, + /// Virtual row length after reshaping (formerly `nn`). + pub virtual_row_len: usize, + /// Per-original-row split count. `0` = continuation (concatenate with next + /// row), `>0` = boundary that terminates a group and splits it into this + /// many virtual rows of length `virtual_row_len` (formerly `nu`). + pub row_split_counts: Vec, +} + +const MAX_WITNESS_DIGIT_PARTS: usize = 8; +const MAX_COMMITMENT_RANK: usize = 32; + +#[derive(Debug, Clone)] +struct LabradorWitnessPlanningProfile { + row_lengths: Vec, + norm_sum: f64, + coeff_bit_bound: Option, +} + +#[derive(Debug, Clone)] +pub(crate) struct LabradorFoldEstimate { + pub plan: LabradorFoldPlan, + pub level_payload_bytes: usize, + pub next_witness_bytes: usize, + pub transition_bytes: usize, + next_row_lengths: Vec, + next_norm_sum: f64, +} + +impl LabradorFoldEstimate { + fn next_profile(&self) -> Result { + LabradorWitnessPlanningProfile::new(self.next_row_lengths.clone(), self.next_norm_sum, None) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct LabradorRecursiveSizeEstimate { + pub initial_plan: LabradorFoldPlan, + pub proof_bytes: usize, + pub final_witness_bytes: usize, + pub level_count: usize, +} + +impl LabradorWitnessPlanningProfile { + fn new( + row_lengths: Vec, + norm_sum: f64, + coeff_bit_bound: Option, + ) -> Result { + if row_lengths.is_empty() { + return Err(HachiError::InvalidInput( + "cannot select config for empty Labrador witness".to_string(), + )); + } + if row_lengths.iter().sum::() == 0 { + return Err(HachiError::InvalidInput( + "cannot select config for zero-length Labrador witness".to_string(), + )); + } + if !norm_sum.is_finite() || norm_sum < 0.0 { + return Err(HachiError::InvalidInput( + "cannot select config for non-finite Labrador witness norm".to_string(), + )); + } + Ok(Self { + row_lengths, + norm_sum, + coeff_bit_bound, + }) + } + + fn from_witness( + witness: &LabradorWitness, + ) -> Result { + let row_lengths = witness.rows().iter().map(|r| r.len()).collect(); + Self::new(row_lengths, witness.norm() as f64, None) + } + + fn from_handoff_witness( + witness: &LabradorWitness, + coeff_bit_bound: usize, + ) -> Result { + let row_lengths = witness.rows().iter().map(|r| r.len()).collect(); + Self::new( + row_lengths, + witness.norm() as f64, + Some(coeff_bit_bound.max(1)), + ) + } + + fn total_len(&self) -> usize { + self.row_lengths.iter().sum() + } +} + +/// Euclidean SIS estimate for a flattened Module-SIS instance. +/// +/// This mirrors the `norm == 2` path in the sibling lattice-estimator: +/// flatten rank-`rank` over ring degree `D` to SIS with `n = rank * D`, +/// flatten `width_ring_elems` ring columns to `m = width_ring_elems * D`, +/// solve for the required root-Hermite factor at the optimizer's preferred +/// attack dimension, and convert that to an approximate BKZ block size using +/// the Chen-style `delta(beta)` relation. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct SisEuclideanLatticeEstimate { + /// Exact field modulus used for the estimate. + pub modulus: u128, + /// `log2(modulus)` from the exact modulus value. + pub logq: f64, + /// Flattened SIS row dimension `n = rank * D`. + pub sis_dimension: usize, + /// Flattened SIS width `m = width_ring_elems * D`. + pub sis_width: usize, + /// Euclidean norm bound supplied to the estimator. + pub norm: f64, + /// Optimized attack sublattice dimension `d_att <= m`. + pub attack_dimension: usize, + /// Required root-Hermite factor `delta_req`. + pub required_delta: f64, + /// BKZ block size implied by `delta_req`, rounded up. + pub bkz_beta: usize, + /// Whether the estimator considers lattice reduction feasible. + pub reduction_possible: bool, + /// `log2(lb)` for the estimator's lower-bound predicate. + pub log2_solution_lower_bound: f64, + /// Whether the supplied norm exceeds the estimator's lower bound. + pub solution_exists: bool, + /// Approximate `log2(rop)` under the BDGL16 asymptotic cost model. + /// + /// This is `+inf` when the Euclidean estimator would reject the instance + /// as not attackable under its feasibility predicate. + pub log2_rop_bdgl16: f64, +} + +/// Module-SIS security check used by the C reference. +/// +/// Returns `true` when `log2(norm) < min(LOGQ, 2*sqrt(LOGQ*LOGDELTA*N)*sqrt(rank))`. +pub fn sis_secure(rank: usize, norm: f64) -> bool { + sis_secure_with_params(rank, norm, logq_bits::() as f64, D as f64) +} + +/// Approximate the sibling lattice-estimator's Euclidean SIS attack model for +/// a flattened Module-SIS instance. +/// +/// The input `width_ring_elems` is the number of ring columns before +/// flattening. Internally this becomes SIS width `m = width_ring_elems * D`. +/// +/// # Errors +/// +/// Returns an error on zero dimensions, non-positive / non-finite norms, +/// modulus detection failure, or when the bound is trivially large compared to +/// the modulus (matching the estimator's Euclidean guardrail). +pub fn estimate_module_sis_euclidean( + rank: usize, + width_ring_elems: usize, + norm: f64, +) -> Result { + if rank == 0 { + return Err(HachiError::InvalidInput( + "SIS estimate requires rank > 0".to_string(), + )); + } + if width_ring_elems == 0 { + return Err(HachiError::InvalidInput( + "SIS estimate requires width_ring_elems > 0".to_string(), + )); + } + if !norm.is_finite() || norm <= 0.0 { + return Err(HachiError::InvalidInput( + "SIS estimate requires a finite positive norm".to_string(), + )); + } + + let modulus = detect_field_modulus::(); + if modulus <= 1 { + return Err(HachiError::InvalidInput( + "SIS estimate requires modulus > 1".to_string(), + )); + } + let modulus_f = modulus as f64; + let logq = modulus_f.log2(); + let sis_dimension = rank + .checked_mul(D) + .ok_or_else(|| HachiError::InvalidInput("SIS estimate dimension overflow".to_string()))?; + let sis_width = width_ring_elems + .checked_mul(D) + .ok_or_else(|| HachiError::InvalidInput("SIS estimate width overflow".to_string()))?; + + if norm >= (modulus_f - 1.0) / 2.0 { + return Err(HachiError::InvalidInput( + "SIS estimate expects norm < (q-1)/2".to_string(), + )); + } + + let log2_norm = norm.log2(); + let log_delta = if log2_norm == 0.0 { + 0.0 + } else { + (log2_norm * log2_norm) / (4.0 * sis_dimension as f64 * logq) + }; + let opt_attack_dimension = if log_delta > 0.0 { + ((sis_dimension as f64 * logq) / log_delta).sqrt().floor() as usize + } else { + sis_width + }; + let attack_dimension = opt_attack_dimension.clamp(2, sis_width.max(2)); + + let root_volume = sis_dimension as f64 * logq / attack_dimension as f64; + let required_delta_log2 = + (log2_norm - root_volume) / (attack_dimension.saturating_sub(1) as f64); + let required_delta = 2f64.powf(required_delta_log2); + let beta = beta_from_root_hermite(required_delta).unwrap_or(usize::MAX); + let reduction_possible = required_delta >= 1.0 && beta <= attack_dimension; + let bkz_beta = if reduction_possible { + beta + } else { + attack_dimension + }; + + // Matches the Euclidean estimator's lower-bound feasibility gate: + // lb = min(sqrt(n * ln(q)), sqrt(d) * q^(n/d)). + let log2_lb_gaussian = 0.5 * ((sis_dimension as f64) * modulus_f.ln()).log2(); + let log2_lb_qary = 0.5 * (attack_dimension as f64).log2() + root_volume; + let log2_solution_lower_bound = log2_lb_gaussian.min(log2_lb_qary); + let solution_exists = log2_norm > log2_solution_lower_bound; + + let log2_rop_bdgl16 = if reduction_possible && solution_exists { + let repeat = if bkz_beta < attack_dimension { + 8.0 * attack_dimension as f64 + } else { + 1.0 + }; + let lll_log2 = 3.0 * (attack_dimension as f64).log2(); + let sieve_log2 = 0.292 * bkz_beta as f64 + 16.4 + repeat.log2(); + log2_add_exp(lll_log2, sieve_log2) + } else { + f64::INFINITY + }; + + Ok(SisEuclideanLatticeEstimate { + modulus, + logq, + sis_dimension, + sis_width, + norm, + attack_dimension, + required_delta, + bkz_beta, + reduction_possible, + log2_solution_lower_bound, + solution_exists, + log2_rop_bdgl16, + }) +} + +/// Select a linear-only Labrador fold plan (non-tail mode). +/// +/// Mirrors the C `init_proof` parameter selection path with `quadratic=0`, +/// including the row-split k-loop that determines the optimal +/// `virtual_row_len`. +/// +/// # Errors +/// +/// Returns an error if witness metadata is empty/invalid or if no secure +/// commitment ranks are found within supported bounds. +pub fn select_config( + witness: &LabradorWitness, +) -> Result { + plan_fold::(witness, false).map(|p| p.config) +} + +/// Select a linear-only Labrador fold plan with explicit tail flag. +/// +/// # Errors +/// +/// Returns an error if witness metadata is empty/invalid or if no secure +/// commitment ranks are found within supported bounds. +pub fn select_config_with_mode( + witness: &LabradorWitness, + tail: bool, +) -> Result { + plan_fold::(witness, tail).map(|p| p.config) +} + +/// Compute a full Labrador fold plan (config + reshaping layout). +/// +/// Mirrors the C `init_proof` algorithm with `quadratic=0`: all input rows +/// are placed in a single group (boundary at the last row only). The k-loop +/// searches from k=15 down to k=1, keeps every secure candidate whose +/// commitment overhead fits within `1.1 × virtual_row_len`, and returns the candidate with +/// the smallest carried witness size for the next transition. +/// +/// # Errors +/// +/// Returns an error if the witness is empty or no secure parameters exist. +#[tracing::instrument(skip_all, name = "labrador::plan_fold")] +pub fn plan_fold( + witness: &LabradorWitness, + tail: bool, +) -> Result { + let profile = LabradorWitnessPlanningProfile::from_witness(witness)?; + plan_fold_with_profile::(&profile, tail) +} + +fn plan_fold_with_profile( + profile: &LabradorWitnessPlanningProfile, + tail: bool, +) -> Result { + search_best_estimate_with_profile::(profile, tail).map(|estimate| estimate.plan) +} + +fn coeff_varz_cap(coeff_bit_bound: usize) -> Option { + let exp = coeff_bit_bound.checked_mul(2)?; + if exp > i32::MAX as usize { + return None; + } + let cap = 2f64.powi(exp as i32) / 12.0 * (LABRADOR_TAU1 + 4.0 * LABRADOR_TAU2); + (cap.is_finite() && cap > 0.0).then_some(cap) +} + +fn search_best_estimate_with_profile< + F: FieldCore + CanonicalField + HachiSerialize, + const D: usize, +>( + profile: &LabradorWitnessPlanningProfile, + tail: bool, +) -> Result { + let row_lengths = &profile.row_lengths; + let r = row_lengths.len(); + let total_len = profile.total_len(); + let logq_bits = logq_bits::(); + let logq = logq_bits as f64; + let d = D as f64; + + let mut last_plan = None; + let mut best_estimate = None; + let mut best_score = usize::MAX; + + for k in (1..=15usize).rev() { + let virtual_row_len = total_len.div_ceil(k); + let virtual_row_count = k; + let virtual_row_count_f = virtual_row_count as f64; + + let mut varz = profile.norm_sum / (virtual_row_len as f64 * d); + varz *= LABRADOR_TAU1 + 4.0 * LABRADOR_TAU2; + if let Some(coeff_bit_bound) = profile.coeff_bit_bound { + if let Some(varz_cap) = coeff_varz_cap(coeff_bit_bound) { + varz = varz.min(varz_cap); + } + } + if !varz.is_finite() || varz <= 0.0 { + varz = 1.0; + } + + let witness_digit_part_range = if tail { + 1..=1 + } else { + 1..=MAX_WITNESS_DIGIT_PARTS + }; + + for witness_digit_parts in witness_digit_part_range { + let mut witness_digit_bits = ((12.0f64.log2() + varz.log2()) + / (2.0 * witness_digit_parts as f64)) + .round() as isize; + witness_digit_bits = witness_digit_bits.clamp(1, logq_bits as isize); + + let (aux_digit_parts, aux_digit_bits) = if tail { + (1usize, logq_bits.max(1)) + } else { + let aux_digit_parts = ((logq_bits + 2 * (witness_digit_bits as usize) / 3) + / (witness_digit_bits as usize)) + .max(1); + let aux_digit_bits = ((logq_bits + aux_digit_parts / 2) / aux_digit_parts).max(1); + (aux_digit_parts, aux_digit_bits) + }; + + let mut found_inner_commit_rank = None; + let mut last_normsq = 0.0f64; + + for inner_commit_rank in 1..=MAX_COMMITMENT_RANK { + let mut normsq = (2f64.powi(2 * witness_digit_bits as i32) / 12.0 + * ((witness_digit_parts - 1) as f64) + + varz + / 2f64.powi( + 2 * (witness_digit_parts - 1) as i32 * witness_digit_bits as i32, + )) + * virtual_row_len as f64; + if !tail { + let hi_exp = logq_bits as isize + - (aux_digit_parts.saturating_sub(1) * aux_digit_bits) as isize; + let hi_exp = hi_exp.max(0) as i32; + normsq += (2f64.powi(2 * aux_digit_bits as i32) + * ((aux_digit_parts - 1) as f64) + + 2f64.powi(2 * hi_exp)) + / 12.0 + * (virtual_row_count_f * inner_commit_rank as f64 + + (virtual_row_count_f * virtual_row_count_f + virtual_row_count_f) + / 2.0); + } + normsq *= d; + last_normsq = normsq; + + if sis_secure_with_params( + inner_commit_rank, + 6.0 * LABRADOR_T + * LABRADOR_SLACK + * 2f64 + .powi(((witness_digit_parts - 1) * witness_digit_bits as usize) as i32) + * normsq.sqrt(), + logq, + d, + ) { + found_inner_commit_rank = Some(inner_commit_rank); + break; + } + } + + let inner_commit_rank = match found_inner_commit_rank { + Some(rank) => rank, + None => { + last_plan = Some(build_plan( + LabradorReductionConfig { + witness_digit_parts, + witness_digit_bits: witness_digit_bits as usize, + aux_digit_parts, + aux_digit_bits, + inner_commit_rank: MAX_COMMITMENT_RANK, + outer_commit_rank: 0, + tail, + }, + virtual_row_len, + r, + virtual_row_count, + )); + continue; + } + }; + + let outer_commit_rank = if tail { + 0 + } else { + match (1..=MAX_COMMITMENT_RANK).find(|&rank| { + sis_secure_with_params(rank, 2.0 * LABRADOR_SLACK * last_normsq.sqrt(), logq, d) + }) { + Some(rank) => rank, + None => { + last_plan = Some(build_plan( + LabradorReductionConfig { + witness_digit_parts, + witness_digit_bits: witness_digit_bits as usize, + aux_digit_parts, + aux_digit_bits, + inner_commit_rank, + outer_commit_rank: MAX_COMMITMENT_RANK, + tail, + }, + virtual_row_len, + r, + virtual_row_count, + )); + continue; + } + } + }; + + let plan = build_plan( + LabradorReductionConfig { + witness_digit_parts, + witness_digit_bits: witness_digit_bits as usize, + aux_digit_parts, + aux_digit_bits, + inner_commit_rank, + outer_commit_rank, + tail, + }, + virtual_row_len, + r, + virtual_row_count, + ); + let estimate = estimate_plan_with_profile::(profile, &plan, last_normsq)?; + let score = estimate.transition_bytes; + last_plan = Some(plan); + maybe_take_better_estimate(&mut best_estimate, &mut best_score, score, estimate); + } + } + + if let Some(estimate) = best_estimate { + return Ok(estimate); + } + + last_plan.map_or_else( + || { + Err(HachiError::InvalidInput( + "failed to find secure Labrador fold parameters".to_string(), + )) + }, + |plan| estimate_plan_with_profile::(profile, &plan, profile.norm_sum.max(1.0)), + ) +} + +fn estimate_plan_with_profile( + profile: &LabradorWitnessPlanningProfile, + plan: &LabradorFoldPlan, + next_norm_sum: f64, +) -> Result { + let virtual_row_count: usize = plan.row_split_counts.iter().sum(); + let next_row_lengths = + estimate_next_row_lengths(virtual_row_count, plan.virtual_row_len, &plan.config); + let level_payload_bytes = estimate_level_payload_bytes::( + profile.row_lengths.len(), + virtual_row_count, + plan.virtual_row_len, + &plan.config, + ); + let next_witness_bytes = estimate_witness_bytes_from_row_lengths::(&next_row_lengths); + Ok(LabradorFoldEstimate { + plan: plan.clone(), + level_payload_bytes, + next_witness_bytes, + transition_bytes: level_payload_bytes + next_witness_bytes, + next_row_lengths, + next_norm_sum: next_norm_sum.max(1.0), + }) +} + +fn estimate_next_norm_sum_for_config( + profile: &LabradorWitnessPlanningProfile, + virtual_row_len: usize, + virtual_row_count: usize, + config: &LabradorReductionConfig, +) -> f64 { + let logq_bits = logq_bits::(); + let d = D as f64; + let mut varz = profile.norm_sum / (virtual_row_len as f64 * d); + varz *= LABRADOR_TAU1 + 4.0 * LABRADOR_TAU2; + if let Some(coeff_bit_bound) = profile.coeff_bit_bound { + if let Some(varz_cap) = coeff_varz_cap(coeff_bit_bound) { + varz = varz.min(varz_cap); + } + } + if !varz.is_finite() || varz <= 0.0 { + varz = 1.0; + } + + let mut normsq = (2f64.powi(2 * config.witness_digit_bits as i32) / 12.0 + * ((config.witness_digit_parts - 1) as f64) + + varz + / 2f64.powi( + 2 * (config.witness_digit_parts - 1) as i32 * config.witness_digit_bits as i32, + )) + * virtual_row_len as f64; + if !config.tail { + let hi_exp = logq_bits as isize + - (config.aux_digit_parts.saturating_sub(1) * config.aux_digit_bits) as isize; + let hi_exp = hi_exp.max(0) as i32; + let virtual_row_count_f = virtual_row_count as f64; + normsq += (2f64.powi(2 * config.aux_digit_bits as i32) + * ((config.aux_digit_parts - 1) as f64) + + 2f64.powi(2 * hi_exp)) + / 12.0 + * (virtual_row_count_f * config.inner_commit_rank as f64 + + (virtual_row_count_f * virtual_row_count_f + virtual_row_count_f) / 2.0); + } + (normsq * d).max(1.0) +} + +pub(crate) fn estimate_fold_step( + witness: &LabradorWitness, + tail: bool, +) -> Result { + let profile = LabradorWitnessPlanningProfile::from_witness(witness)?; + search_best_estimate_with_profile::(&profile, tail) +} + +pub(crate) fn estimate_selected_fold_step< + F: FieldCore + CanonicalField + HachiSerialize, + const D: usize, +>( + witness: &LabradorWitness, + plan: &LabradorFoldPlan, +) -> Result { + let profile = LabradorWitnessPlanningProfile::from_witness(witness)?; + let virtual_row_count: usize = plan.row_split_counts.iter().sum(); + let next_norm_sum = estimate_next_norm_sum_for_config::( + &profile, + plan.virtual_row_len, + virtual_row_count, + &plan.config, + ); + estimate_plan_with_profile::(&profile, plan, next_norm_sum) +} + +#[cfg(test)] +pub(crate) fn estimate_recursive_proof_with_plan< + F: FieldCore + CanonicalField + HachiSerialize, + const D: usize, +>( + witness: &LabradorWitness, + initial_plan: &LabradorFoldPlan, +) -> Result { + let profile = LabradorWitnessPlanningProfile::from_witness(witness)?; + let virtual_row_count: usize = initial_plan.row_split_counts.iter().sum(); + let next_norm_sum = estimate_next_norm_sum_for_config::( + &profile, + initial_plan.virtual_row_len, + virtual_row_count, + &initial_plan.config, + ); + let initial_estimate = + estimate_plan_with_profile::(&profile, initial_plan, next_norm_sum)?; + let (proof_bytes, final_witness_bytes, level_count) = + simulate_recursive_proof_bytes::(profile, Some(initial_estimate))?; + Ok(LabradorRecursiveSizeEstimate { + initial_plan: initial_plan.clone(), + proof_bytes, + final_witness_bytes, + level_count, + }) +} + +pub(crate) fn estimate_handoff_recursive_proof< + F: FieldCore + CanonicalField + HachiSerialize, + const D: usize, +>( + witness: &LabradorWitness, + coeff_bit_bound: usize, +) -> Result { + let profile = LabradorWitnessPlanningProfile::from_handoff_witness(witness, coeff_bit_bound)?; + let initial_estimate = search_best_estimate_with_profile::(&profile, false)?; + let initial_plan = initial_estimate.plan.clone(); + let (proof_bytes, final_witness_bytes, level_count) = + simulate_recursive_proof_bytes::(profile, Some(initial_estimate))?; + Ok(LabradorRecursiveSizeEstimate { + initial_plan, + proof_bytes, + final_witness_bytes, + level_count, + }) +} + +fn simulate_recursive_proof_bytes< + F: FieldCore + CanonicalField + HachiSerialize, + const D: usize, +>( + mut profile: LabradorWitnessPlanningProfile, + mut first_non_tail: Option, +) -> Result<(usize, usize, usize), HachiError> { + let mut level_payload_total = 0usize; + let mut level_count = 0usize; + + while level_count + 1 < LABRADOR_MAX_LEVELS { + let before_bytes = estimate_witness_bytes_from_row_lengths::(&profile.row_lengths); + if before_bytes == 0 || profile.row_lengths.len() <= 1 { + break; + } + let estimate = match first_non_tail.take() { + Some(estimate) => estimate, + None => search_best_estimate_with_profile::(&profile, false)?, + }; + if estimate.transition_bytes >= before_bytes { + break; + } + level_payload_total += estimate.level_payload_bytes; + profile = estimate.next_profile()?; + level_count += 1; + } + + if level_count + 1 < LABRADOR_MAX_LEVELS { + let before_bytes = estimate_witness_bytes_from_row_lengths::(&profile.row_lengths); + if before_bytes > 0 && profile.row_lengths.len() > 1 { + let tail_estimate = search_best_estimate_with_profile::(&profile, true)?; + if tail_estimate.transition_bytes < before_bytes { + level_payload_total += tail_estimate.level_payload_bytes; + profile = tail_estimate.next_profile()?; + level_count += 1; + } + } + } + + let final_witness_bytes = estimate_witness_bytes_from_row_lengths::(&profile.row_lengths); + Ok(( + 4 + level_payload_total + final_witness_bytes, + final_witness_bytes, + level_count, + )) +} + +fn estimate_next_row_lengths( + virtual_row_count: usize, + virtual_row_len: usize, + config: &LabradorReductionConfig, +) -> Vec { + let mut row_lengths = vec![virtual_row_len; config.witness_digit_parts]; + if !config.tail { + row_lengths.push( + virtual_row_count * config.inner_commit_rank * config.aux_digit_parts + + virtual_row_count * (virtual_row_count + 1) / 2 * config.aux_digit_parts, + ); + } + row_lengths +} + +fn estimate_witness_bytes_from_row_lengths( + row_lengths: &[usize], +) -> usize { + 4 + row_lengths + .iter() + .map(|&ring_elems| estimate_flat_ring_vec_bytes::(ring_elems)) + .sum::() +} + +fn estimate_level_payload_bytes( + input_row_count: usize, + virtual_row_count: usize, + virtual_row_len: usize, + config: &LabradorReductionConfig, +) -> usize { + let inner_payload_ring_elems = if config.tail || config.outer_commit_rank == 0 { + virtual_row_count * config.inner_commit_rank * config.aux_digit_parts + } else { + config.outer_commit_rank + }; + let linear_payload_ring_elems = if config.tail || config.outer_commit_rank == 0 { + virtual_row_count * (virtual_row_count + 1) / 2 * config.aux_digit_parts + } else { + config.outer_commit_rank + }; + + 1 + estimate_vec_usize_bytes(input_row_count) + + config.serialized_size(Compress::No) + + virtual_row_len.serialized_size(Compress::No) + + estimate_vec_usize_bytes(input_row_count) + + estimate_flat_ring_vec_bytes::(inner_payload_ring_elems) + + estimate_flat_ring_vec_bytes::(linear_payload_ring_elems) + + jl_projection_bytes() + + 8 + + estimate_flat_ring_vec_bytes::(jl_lifts::()) + + 16 +} + +fn estimate_flat_ring_vec_bytes( + ring_elems: usize, +) -> usize { + 4 + 8 + ring_elems * D * F::zero().serialized_size(Compress::No) +} + +fn estimate_vec_usize_bytes(len: usize) -> usize { + 8 + len * 8 +} + +fn jl_projection_bytes() -> usize { + 256 * std::mem::size_of::() +} + +fn maybe_take_better_estimate( + best_estimate: &mut Option, + best_score: &mut usize, + score: usize, + candidate: LabradorFoldEstimate, +) { + if score < *best_score + || (score == *best_score + && best_estimate.as_ref().is_none_or(|best| { + candidate.plan.row_split_counts.iter().sum::() + < best.plan.row_split_counts.iter().sum::() + })) + { + *best_score = score; + *best_estimate = Some(candidate); + } +} + +fn build_plan( + config: LabradorReductionConfig, + virtual_row_len: usize, + input_row_count: usize, + virtual_row_count: usize, +) -> LabradorFoldPlan { + let mut row_split_counts = vec![0usize; input_row_count]; + if !row_split_counts.is_empty() { + row_split_counts[input_row_count - 1] = virtual_row_count; + } + LabradorFoldPlan { + config, + virtual_row_len, + row_split_counts, + } +} + +/// Build a trivial fold plan (no reshaping) from a config and row lengths. +/// +/// All rows keep their original lengths; `virtual_row_len = max(row_lengths)` +/// and `row_split_counts` +/// marks each row as its own virtual row. +pub fn trivial_plan(config: LabradorReductionConfig, row_lengths: &[usize]) -> LabradorFoldPlan { + let virtual_row_len = row_lengths.iter().copied().max().unwrap_or(0); + let row_split_counts: Vec = row_lengths.iter().map(|_| 1).collect(); + LabradorFoldPlan { + config, + virtual_row_len, + row_split_counts, + } +} + +/// Compute a full Labrador fold plan for the Hachi→Labrador handoff witness. +/// +/// Unlike the generic recursive planner, the handoff planner is seeded from the +/// actual witness rows and their squared norm, rather than collapsing the input +/// to only `(row_count, max_row_len)`. +/// +/// # Errors +/// +/// Returns an error if the witness is empty or no secure parameter +/// combination exists within the supported bounds. +pub fn plan_handoff( + witness: &LabradorWitness, + coeff_bit_bound: usize, +) -> Result { + let profile = LabradorWitnessPlanningProfile::from_handoff_witness(witness, coeff_bit_bound)?; + plan_fold_with_profile::(&profile, false) +} + +/// Select Labrador reduction config for the Hachi→Labrador handoff witness. +/// +/// # Errors +/// +/// Returns an error if the witness is empty or no secure parameter +/// combination exists within the supported bounds. +pub fn select_handoff_config( + witness: &LabradorWitness, + coeff_bit_bound: usize, +) -> Result { + plan_handoff::(witness, coeff_bit_bound).map(|plan| plan.config) +} + +pub(crate) fn logq_bits() -> usize { + let modulus = detect_field_modulus::(); + if modulus <= 1 { + return 1; + } + 128 - (modulus.saturating_sub(1)).leading_zeros() as usize +} + +pub(crate) fn jl_lifts() -> usize { + 128_usize.div_ceil(logq_bits::().max(1)) +} + +fn sis_secure_with_params(rank: usize, norm: f64, logq: f64, ring_degree: f64) -> bool { + if rank == 0 || !norm.is_finite() || norm <= 0.0 { + return false; + } + let mut maxlog = 2.0 * (logq * LABRADOR_LOGDELTA * ring_degree).sqrt() * (rank as f64).sqrt(); + maxlog = maxlog.min(logq); + norm.log2() < maxlog +} + +fn root_hermite_from_beta(beta: f64) -> f64 { + ((beta / (2.0 * PI * E)) * (PI * beta).powf(1.0 / beta)).powf(1.0 / (2.0 * (beta - 1.0))) +} + +fn beta_from_root_hermite(delta: f64) -> Option { + const MIN_BETA: usize = 40; + const MAX_BETA: usize = 1 << 16; + + if !delta.is_finite() || delta <= 1.0 { + return None; + } + if root_hermite_from_beta(MIN_BETA as f64) < delta { + return Some(MIN_BETA); + } + + let mut beta = MIN_BETA; + while beta < MAX_BETA / 2 && root_hermite_from_beta((2 * beta) as f64) > delta { + beta *= 2; + } + while beta + 10 < MAX_BETA && root_hermite_from_beta((beta + 10) as f64) > delta { + beta += 10; + } + while beta < MAX_BETA && root_hermite_from_beta(beta as f64) >= delta { + beta += 1; + } + + (beta < MAX_BETA).then_some(beta) +} + +fn log2_add_exp(a: f64, b: f64) -> f64 { + let hi = a.max(b); + let lo = a.min(b); + hi + (1.0 + 2f64.powf(lo - hi)).log2() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp128; + use crate::algebra::fields::Fp64; + use crate::algebra::ring::CyclotomicRing; + use crate::protocol::commitment::Fp128FullCommitmentConfig; + use crate::protocol::CommitmentConfig; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + fn row(len: usize) -> Vec> { + (0..len) + .map(|i| { + CyclotomicRing::from_coefficients(std::array::from_fn(|j| { + F::from_i64(((i + j) as i64 % 5) - 2) + })) + }) + .collect() + } + + #[test] + fn sis_secure_rejects_non_positive_norm() { + assert!(!sis_secure::(4, 0.0)); + assert!(!sis_secure::(4, -1.0)); + } + + #[test] + fn select_config_returns_valid_ranges() { + let witness = LabradorWitness::new(vec![row(32), row(32), row(32)]); + let cfg = select_config::(&witness).unwrap(); + assert!((1..=MAX_WITNESS_DIGIT_PARTS).contains(&cfg.witness_digit_parts)); + assert!(cfg.witness_digit_bits > 0); + assert!(cfg.aux_digit_parts > 0); + assert!(cfg.aux_digit_bits > 0); + assert!((1..=MAX_COMMITMENT_RANK).contains(&cfg.inner_commit_rank)); + assert!((1..=MAX_COMMITMENT_RANK).contains(&cfg.outer_commit_rank)); + assert!(!cfg.tail); + } + + #[test] + fn handoff_estimate_is_not_worse_than_generic_plan_on_small_coeffs() { + let witness = LabradorWitness::new(vec![row(48), row(48), row(48)]); + + let generic_plan = plan_fold::(&witness, false).unwrap(); + let generic_estimate = + estimate_recursive_proof_with_plan::(&witness, &generic_plan).unwrap(); + let handoff_estimate = estimate_handoff_recursive_proof::(&witness, 3).unwrap(); + + assert!(handoff_estimate.proof_bytes <= generic_estimate.proof_bytes); + } + + #[test] + fn planner_can_search_more_than_two_z_parts() { + let row_lengths = vec![512usize, 512usize, 512usize]; + let found = (20..=80).any(|exp| { + let profile = + LabradorWitnessPlanningProfile::new(row_lengths.clone(), 2f64.powi(exp), None) + .unwrap(); + plan_fold_with_profile::(&profile, false) + .map(|plan| plan.config.witness_digit_parts > 2) + .unwrap_or(false) + }); + assert!( + found, + "expected planner search to reach witness_digit_parts > 2" + ); + } + + #[test] + fn sis_estimate_rejects_invalid_inputs() { + assert!(estimate_module_sis_euclidean::(0, 10, 1.0).is_err()); + assert!(estimate_module_sis_euclidean::(1, 0, 1.0).is_err()); + assert!(estimate_module_sis_euclidean::(1, 10, 0.0).is_err()); + } + + #[test] + fn print_profile_style_handoff_sis_summary() { + type F128 = Fp128<0xfffffffffffffffffffffffffffffeed>; + type Cfg = Fp128FullCommitmentConfig; + const D128: usize = Cfg::D; + const MAX_NUM_VARS: usize = 25; + + let layout = Cfg::commitment_layout(MAX_NUM_VARS).unwrap(); + let rank = Cfg::N_D + Cfg::N_B + 2 + Cfg::N_A; + let width_ring_elems = layout.d_matrix_width + + layout.outer_width + + layout.inner_width * layout.num_digits_fold; + let beta_inf = (1usize << layout.r_vars) + * Cfg::challenge_weight_for_ring_dim(D128) + * (1usize << (layout.log_basis - 1)); + let collision_inf = (2 * beta_inf) as f64; + let width_coords = width_ring_elems * D128; + let l2_bound = (width_coords as f64).sqrt() * collision_inf; + + let heuristic_secure = + sis_secure_with_params(rank, l2_bound, logq_bits::() as f64, D128 as f64); + let heuristic_max_log2 = (2.0 + * (logq_bits::() as f64 * LABRADOR_LOGDELTA * D128 as f64).sqrt() + * (rank as f64).sqrt()) + .min(logq_bits::() as f64); + let estimate = + estimate_module_sis_euclidean::(rank, width_ring_elems, l2_bound).unwrap(); + + eprintln!( + "[labrador::config] profile-style handoff SIS summary: \ + max_num_vars={MAX_NUM_VARS}, D={D128}, r_vars={}, m_vars={}, \ + width_ring_elems={}, width_coords={}, rank={}, \ + log2(bound)={:.2}, heuristic_max_log2={:.2}, heuristic_secure={}, \ + d_att={}, delta_req={:.6}, beta={}, solution_exists={}, log2(lb)={:.2}, log2(rop_bdgl16)={:.2}", + layout.r_vars, + layout.m_vars, + width_ring_elems, + width_coords, + rank, + l2_bound.log2(), + heuristic_max_log2, + heuristic_secure, + estimate.attack_dimension, + estimate.required_delta, + estimate.bkz_beta, + estimate.solution_exists, + estimate.log2_solution_lower_bound, + estimate.log2_rop_bdgl16, + ); + + assert_eq!(estimate.sis_dimension, rank * D128); + assert_eq!(estimate.sis_width, width_coords); + assert!(estimate.required_delta.is_finite()); + } +} diff --git a/src/protocol/labrador/constraints.rs b/src/protocol/labrador/constraints.rs new file mode 100644 index 00000000..2ed55a5f --- /dev/null +++ b/src/protocol/labrador/constraints.rs @@ -0,0 +1,522 @@ +//! Labrador constraint types and shared recursive builders. + +use crate::algebra::ring::CyclotomicRing; +use crate::algebra::SparseChallenge; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::labrador::setup::LabradorSetupMatrices; +use crate::protocol::labrador::types::{LabradorReducedConstraintPlan, LabradorReductionConfig}; +use crate::protocol::labrador::utils::pow2_field; +use crate::{cfg_into_iter, CanonicalField, FieldCore, FromSmallInt}; +use std::ops::Range; +use std::sync::Arc; + +type PreparedNextConstraintInputs = + (NextWitnessLayout, Vec, Vec, Vec>); + +/// One sparse linear term in a Labrador constraint. +/// +/// This encodes the paper-style contribution ``, except that +/// `offset` allows the term to start inside a packed witness row. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorConstraintTerm { + /// Witness row used by this term. + pub row: usize, + /// Starting column within the witness row. + pub offset: usize, + /// Coefficients dotted against the witness row slice. + pub coefficients: Vec>, +} + +impl LabradorConstraintTerm { + /// Build one sparse term ``. + pub fn new(row: usize, offset: usize, coefficients: Vec>) -> Self { + Self { + row, + offset, + coefficients, + } + } +} + +/// One scalar Labrador linear constraint. +/// +/// Ignoring the quadratic paper term `a_ij`, this stores one equation of the form +/// `sum_terms = b`, where `target` is the single ring element `b`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorConstraint { + /// Sparse row terms contributing to the constraint. + pub terms: Vec>, + /// Right-hand side ring element. + pub target: CyclotomicRing, +} + +impl LabradorConstraint { + /// Build a scalar Labrador constraint. + pub fn new(terms: Vec>, target: CyclotomicRing) -> Self { + Self { terms, target } + } +} + +/// Layout of the next-level witness emitted by one standard Labrador fold. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct NextWitnessLayout { + /// Number of rows used by the decomposed `z` witness. + pub z_part_rows: usize, + /// Row index holding `inner_opening_digits || linear_garbage_digits`. + pub aux_row: usize, + /// Number of decomposed inner-commitment entries. + pub inner_opening_digits_len: usize, + /// Number of decomposed linear-garbage entries. + pub linear_garbage_digits_len: usize, +} + +impl NextWitnessLayout { + /// Derive the next-witness layout from input row count and config. + pub(crate) fn new(input_rows: usize, config: &LabradorReductionConfig) -> Self { + let inner_opening_digits_len = + input_rows * config.inner_commit_rank * config.aux_digit_parts; + let linear_garbage_digits_len = input_rows * (input_rows + 1) / 2 * config.aux_digit_parts; + Self { + z_part_rows: config.witness_digit_parts, + aux_row: config.witness_digit_parts, + inner_opening_digits_len, + linear_garbage_digits_len, + } + } + + /// Total number of witness rows at the next recursion level. + pub(crate) fn num_rows(self) -> usize { + self.z_part_rows + 1 + } + + /// Total length of the auxiliary row. + pub(crate) fn aux_row_len(self) -> usize { + self.inner_opening_digits_len + self.linear_garbage_digits_len + } + + /// Slice of the auxiliary row occupied by `inner_opening_digits`. + pub(crate) fn inner_opening_digits_range(self) -> Range { + 0..self.inner_opening_digits_len + } + + /// Slice of the auxiliary row occupied by `linear_garbage_digits`. + pub(crate) fn linear_garbage_digits_range(self) -> Range { + self.inner_opening_digits_len..self.aux_row_len() + } +} + +/// Build the recursive target relation for the next Labrador level. +fn prepare_next_constraint_inputs( + phi_total: &[Vec>], + challenges: &[SparseChallenge], + row_lengths: &[usize], + max_len: usize, + config: &LabradorReductionConfig, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, +{ + let r = row_lengths.len(); + if r == 0 || challenges.len() != r { + return Err(HachiError::InvalidInput( + "challenge row count mismatch".to_string(), + )); + } + if config.witness_digit_parts == 0 { + return Err(HachiError::InvalidInput( + "cannot build next constraints with witness_digit_parts=0".to_string(), + )); + } + + let layout = NextWitnessLayout::new(r, config); + let pow_b: Vec = (0..config.witness_digit_parts) + .map(|idx| pow2_field::(config.witness_digit_bits * idx)) + .collect(); + let pow_bu: Vec = (0..config.aux_digit_parts) + .map(|idx| pow2_field::(config.aux_digit_bits * idx)) + .collect(); + let amortized_phi = combine_phi(phi_total, challenges, max_len); + Ok((layout, pow_b, pow_bu, amortized_phi)) +} + +#[allow(clippy::too_many_arguments)] +fn build_constraints_from_prepared( + layout: NextWitnessLayout, + config: &LabradorReductionConfig, + challenges: &[SparseChallenge], + pow_b: &[F], + pow_bu: &[F], + amortized_phi: &[CyclotomicRing], + aggregated_rhs: &CyclotomicRing, + inner_opening_payload: &[CyclotomicRing], + linear_garbage_payload: &[CyclotomicRing], + setup: &LabradorSetupMatrices, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, +{ + let mut constraints = Vec::new(); + let dense_challenges: Vec> = challenges + .iter() + .map(|challenge| { + challenge + .to_dense::() + .expect("sampler outputs valid challenges") + }) + .collect(); + if config.outer_commit_rank > 0 { + if inner_opening_payload.len() != config.outer_commit_rank + || linear_garbage_payload.len() != config.outer_commit_rank + { + return Err(HachiError::InvalidInput( + "payload length mismatch for next statement".to_string(), + )); + } + constraints.extend(build_outer_commitment_constraints( + layout, + setup, + inner_opening_payload, + )); + constraints.extend(build_linear_garbage_commitment_constraints( + layout, + setup, + linear_garbage_payload, + )); + } + constraints.extend(build_amortized_opening_constraints( + layout, + &dense_challenges, + config, + pow_b, + pow_bu, + setup, + )); + constraints.push(build_linear_garbage_constraint( + layout, + &dense_challenges, + config, + pow_b, + pow_bu, + amortized_phi, + )); + constraints.push(build_diagonal_constraint( + layout, + aggregated_rhs, + challenges.len(), + config, + pow_bu, + )); + Ok(constraints) +} + +/// Build the recursive target relation for the next Labrador level. +#[allow(clippy::too_many_arguments)] +#[allow(dead_code)] +#[tracing::instrument(skip_all, name = "labrador::build_next_constraints")] +pub(crate) fn build_next_constraints( + phi_total: &[Vec>], + aggregated_rhs: &CyclotomicRing, + challenges: &[SparseChallenge], + row_lengths: &[usize], + max_len: usize, + config: &LabradorReductionConfig, + inner_opening_payload: &[CyclotomicRing], + linear_garbage_payload: &[CyclotomicRing], + setup: &LabradorSetupMatrices, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, +{ + let (layout, pow_b, pow_bu, amortized_phi) = + prepare_next_constraint_inputs(phi_total, challenges, row_lengths, max_len, config)?; + build_constraints_from_prepared( + layout, + config, + challenges, + &pow_b, + &pow_bu, + &amortized_phi, + aggregated_rhs, + inner_opening_payload, + linear_garbage_payload, + setup, + ) +} + +#[tracing::instrument(skip_all, name = "labrador::build_next_constraint_plan")] +pub(crate) fn build_next_constraint_plan( + phi_total: &[Vec>], + aggregated_rhs: &CyclotomicRing, + challenges: &[SparseChallenge], + row_lengths: &[usize], + max_len: usize, + config: &LabradorReductionConfig, + setup: Arc>, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, +{ + let (_layout, _pow_b, _pow_bu, amortized_phi) = + prepare_next_constraint_inputs(phi_total, challenges, row_lengths, max_len, config)?; + Ok(LabradorReducedConstraintPlan { + row_count: row_lengths.len(), + max_len, + config: *config, + challenges: challenges.to_vec(), + amortized_phi, + aggregated_rhs: *aggregated_rhs, + setup, + }) +} + +#[tracing::instrument(skip_all, name = "labrador::materialize_reduced_constraints")] +pub(crate) fn materialize_reduced_constraints( + plan: &LabradorReducedConstraintPlan, + inner_opening_payload: &[CyclotomicRing], + linear_garbage_payload: &[CyclotomicRing], +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, +{ + let layout = NextWitnessLayout::new(plan.row_count, &plan.config); + let pow_b: Vec = (0..plan.config.witness_digit_parts) + .map(|idx| pow2_field::(plan.config.witness_digit_bits * idx)) + .collect(); + let pow_bu: Vec = (0..plan.config.aux_digit_parts) + .map(|idx| pow2_field::(plan.config.aux_digit_bits * idx)) + .collect(); + build_constraints_from_prepared( + layout, + &plan.config, + &plan.challenges, + &pow_b, + &pow_bu, + &plan.amortized_phi, + &plan.aggregated_rhs, + inner_opening_payload, + linear_garbage_payload, + plan.setup.as_ref(), + ) +} + +/// Build the paper's outer-commitment check (Fig. 3, line 19) +/// `inner_opening_payload = B * inner_opening_digits`, with the quadratic +/// `g_ij` contribution omitted. +/// +/// The paper writes this as one vector equation. Here it is scalarized into one +/// `LabradorConstraint` per row of `B` / entry of the opening-side payload, all +/// reading the `inner_opening_digits` prefix of the auxiliary witness row. +fn build_outer_commitment_constraints( + layout: NextWitnessLayout, + setup: &LabradorSetupMatrices, + inner_opening_payload: &[CyclotomicRing], +) -> Vec> { + setup + .b_mat + .iter() + .zip(inner_opening_payload.iter()) + .map(|(b_row, target)| { + LabradorConstraint::new( + vec![LabradorConstraintTerm::new( + layout.aux_row, + layout.inner_opening_digits_range().start, + b_row.clone(), + )], + *target, + ) + }) + .collect() +} + +/// Build the linear-garbage commitment check (Fig. 3, line 20) +/// `linear_garbage_payload = D * linear_garbage_digits`. +/// +/// As with the opening-side payload, the paper presents a vector equation; this +/// implementation expands it into one scalar constraint per row of `D` / entry +/// of the linear-garbage-side payload, reading the +/// `linear_garbage_digits` suffix of the auxiliary witness row. +fn build_linear_garbage_commitment_constraints( + layout: NextWitnessLayout, + setup: &LabradorSetupMatrices, + linear_garbage_payload: &[CyclotomicRing], +) -> Vec> { + setup + .d_mat + .iter() + .zip(linear_garbage_payload.iter()) + .map(|(d_row, target)| { + LabradorConstraint::new( + vec![LabradorConstraintTerm::new( + layout.aux_row, + layout.linear_garbage_digits_range().start, + d_row.clone(), + )], + *target, + ) + }) + .collect() +} + +/// Build the amortized opening relation (Fig. 3, line 15) +/// `A * z_tilde = sum_i c_i * t_tilde_i`. +/// +/// The paper's equation is `inner_commit_rank`-dimensional, so this function +/// emits one scalar constraint per row of `A`. The first +/// `witness_digit_parts` witness rows reconstruct the decomposed +/// `z_tilde = sum_k 2^(k * witness_digit_bits) z^(k)`, while the +/// `inner_opening_digits` slice of the +/// auxiliary row reconstructs each decomposed `t_tilde_i`. +fn build_amortized_opening_constraints( + layout: NextWitnessLayout, + challenges: &[CyclotomicRing], + config: &LabradorReductionConfig, + pow_b: &[F], + pow_bu: &[F], + setup: &LabradorSetupMatrices, +) -> Vec> { + (0..config.inner_commit_rank) + .map(|output_idx| { + let mut terms = Vec::with_capacity(config.witness_digit_parts + 1); + for (part_idx, scale) in pow_b.iter().copied().enumerate() { + let coeffs = setup.a_mat[output_idx] + .iter() + .map(|elem| elem.scale(&scale)) + .collect(); + terms.push(LabradorConstraintTerm::new(part_idx, 0, coeffs)); + } + + let mut aux_coeffs = + vec![CyclotomicRing::::zero(); layout.inner_opening_digits_len]; + for (row_idx, challenge) in challenges.iter().enumerate() { + for (part_idx, &scale) in pow_bu.iter().enumerate() { + let idx = row_idx * config.inner_commit_rank * config.aux_digit_parts + + output_idx * config.aux_digit_parts + + part_idx; + aux_coeffs[idx] = -(challenge.scale(&scale)); + } + } + terms.push(LabradorConstraintTerm::new( + layout.aux_row, + layout.inner_opening_digits_range().start, + aux_coeffs, + )); + + LabradorConstraint::new(terms, CyclotomicRing::::zero()) + }) + .collect() +} + +/// Build the linear-only garbage relation (Fig. 3, line 17) +/// `sum_i c_i * = sum_{i <= j} c_i c_j * h_ij`. +/// +/// `amortized_phi` already equals `sum_i c_i * phi_i`, so the left-hand side is +/// reconstructed from the decomposed `z_tilde` rows. The right-hand side is +/// reconstructed from the packed upper-triangular `linear_garbage_digits` +/// entries stored in the +/// auxiliary row. +fn build_linear_garbage_constraint( + layout: NextWitnessLayout, + challenges: &[CyclotomicRing], + config: &LabradorReductionConfig, + pow_b: &[F], + pow_bu: &[F], + amortized_phi: &[CyclotomicRing], +) -> LabradorConstraint { + let mut terms = Vec::with_capacity(config.witness_digit_parts + 1); + for (part_idx, scale) in pow_b.iter().copied().enumerate() { + let coeffs = amortized_phi + .iter() + .map(|elem| elem.scale(&scale)) + .collect(); + terms.push(LabradorConstraintTerm::new(part_idx, 0, coeffs)); + } + + let mut h_coeffs = vec![CyclotomicRing::::zero(); layout.linear_garbage_digits_len]; + for i in 0..challenges.len() { + for j in i..challenges.len() { + let coeff = challenges[i] * challenges[j]; + let pair = pair_index(i, j, challenges.len()); + for (part_idx, &scale) in pow_bu.iter().enumerate() { + let idx = pair * config.aux_digit_parts + part_idx; + h_coeffs[idx] = -(coeff.scale(&scale)); + } + } + } + terms.push(LabradorConstraintTerm::new( + layout.aux_row, + layout.linear_garbage_digits_range().start, + h_coeffs, + )); + + LabradorConstraint::new(terms, CyclotomicRing::::zero()) +} + +/// Build the linear-only diagonal relation (Fig. 3, line 18 with no `a_ij` +/// or `g_ij`) +/// `sum_i h_ii = aggregated_rhs`. +/// +/// Only diagonal packed `linear_garbage_digits` entries contribute here. Their +/// decomposed digits are reweighted by powers of `2^aux_digit_bits` to +/// reconstruct each `h_ii`. +fn build_diagonal_constraint( + layout: NextWitnessLayout, + aggregated_rhs: &CyclotomicRing, + input_rows: usize, + config: &LabradorReductionConfig, + pow_bu: &[F], +) -> LabradorConstraint { + let mut diag_coeffs = vec![CyclotomicRing::::zero(); layout.linear_garbage_digits_len]; + for i in 0..input_rows { + let pair = pair_index(i, i, input_rows); + for (part_idx, &scale) in pow_bu.iter().enumerate() { + let idx = pair * config.aux_digit_parts + part_idx; + diag_coeffs[idx] = constant_poly(scale); + } + } + + LabradorConstraint::new( + vec![LabradorConstraintTerm::new( + layout.aux_row, + layout.linear_garbage_digits_range().start, + diag_coeffs, + )], + *aggregated_rhs, + ) +} + +fn combine_phi( + phi_total: &[Vec>], + challenges: &[SparseChallenge], + max_len: usize, +) -> Vec> { + cfg_into_iter!(0..max_len) + .map(|j| { + let mut acc = CyclotomicRing::::zero(); + for (row_phi, challenge) in phi_total.iter().zip(challenges.iter()) { + if let Some(elem) = row_phi.get(j) { + elem.mul_by_sparse_into(challenge, &mut acc); + } + } + acc + }) + .collect() +} + +fn constant_poly(value: F) -> CyclotomicRing { + CyclotomicRing::from_coefficients(std::array::from_fn( + |i| { + if i == 0 { + value + } else { + F::zero() + } + }, + )) +} + +pub(crate) fn pair_index(i: usize, j: usize, r: usize) -> usize { + debug_assert!(i <= j && j < r); + i * (2 * r - i + 1) / 2 + (j - i) +} diff --git a/src/protocol/labrador/fold.rs b/src/protocol/labrador/fold.rs new file mode 100644 index 00000000..6ce60f21 --- /dev/null +++ b/src/protocol/labrador/fold.rs @@ -0,0 +1,1421 @@ +//! Labrador amortization transitions (standard and tail levels). + +use crate::algebra::ring::CyclotomicRing; +use crate::algebra::SparseChallenge; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::commitment::utils::linear::{ + decompose_rows_with_carry, mat_vec_mul_crt_ntt_i8_labrador_cross, mat_vec_mul_crt_ntt_i8_many, + mat_vec_mul_ntt_single_i8, +}; +use crate::protocol::labrador::aggregation::{ + add_phi_flat_in_place, aggregate_jl_constraints_prover, aggregate_statement, +}; +use crate::protocol::labrador::commit::{ + decompose_rows_ntt_i8_exact, expand_matrix_for_i8_digits, ntt_two_tier_commit, + outer_ntt_digit_levels, OUTER_NTT_LOG_BASIS, +}; +use crate::protocol::labrador::config::LabradorFoldPlan; +use crate::protocol::labrador::constraints::{build_next_constraint_plan, pair_index}; +use crate::protocol::labrador::johnson_lindenstrauss::project; +use crate::protocol::labrador::setup::LabradorSetup; +use crate::protocol::labrador::transcript::{ + absorb_labrador_jl_projection, absorb_labrador_level_context, LabradorLevelTranscriptContext, +}; +use crate::protocol::labrador::types::{ + LabradorLevelProof, LabradorReductionConfig, LabradorStatement, LabradorWitness, +}; +use crate::protocol::labrador::utils::{mat_vec_mul, try_centered_i8_rows}; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::{challenge_sparse_ring_elements_rejection_sampled, Transcript}; +use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt}; +use std::sync::Arc; + +/// Output of one Labrador fold transition. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorFoldResult { + /// Next witness after amortization. + pub next_witness: LabradorWitness, + /// Replay-complete level proof record. + pub level_proof: LabradorLevelProof, + /// Reduced statement consumed by the next verifier step. + pub statement: LabradorStatement, +} + +/// Perform one Labrador fold level (standard or tail, determined by +/// `config.tail`). +/// +/// Follows the C Labrador protocol phases: +/// 1. Reshape witness according to `plan.row_split_counts` into virtual rows +/// of length `plan.virtual_row_len` +/// 2. Commit: inner + outer two-tier commitment +/// 3. Project: JL projection → p\[256\], nonce +/// 4. LIFTS × (collapse + lift): build linear constraints from JL +/// 5. Amortize: absorb into transcript, sample ring-element challenges, +/// fold z = sum_i c_i * s_i, decompose z → output witness +/// +/// # Errors +/// +/// Returns `HachiError::InvalidInput` if the witness is empty or +/// `config.witness_digit_parts` is zero. +/// Propagates errors from commitment, projection, or hashing. +#[tracing::instrument( + skip_all, + name = "labrador::prove_level", + fields(level_index, tail = config.tail) +)] +pub fn prove_level( + witness: &LabradorWitness, + statement: &LabradorStatement, + config: &LabradorReductionConfig, + plan: &LabradorFoldPlan, + setup: &Arc>, + level_index: usize, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + T: Transcript, +{ + if witness.rows().is_empty() { + return Err(HachiError::InvalidInput( + "cannot fold empty Labrador witness".to_string(), + )); + } + if config.witness_digit_parts == 0 { + return Err(HachiError::InvalidInput( + "Labrador fold requires witness_digit_parts > 0".to_string(), + )); + } + + let orig_row_lengths: Vec = witness.rows().iter().map(|row| row.len()).collect(); + + // Phase 0: Reshape witness according to the row-split plan. + let reshaped = reshape_rows(witness.rows(), &plan.row_split_counts, plan.virtual_row_len); + let virtual_witness = LabradorWitness::new_unchecked(reshaped); + let virt_row_lengths: Vec = virtual_witness.rows().iter().map(|r| r.len()).collect(); + let virtual_row_count = virt_row_lengths.len(); + let virtual_row_len = plan.virtual_row_len; + + // Phase 1: Inner commitments and opening-side payload. + let (inner_opening_digits, inner_opening_payload) = ntt_two_tier_commit( + &setup.matrices.a_mat, + &setup.matrices.b_mat, + virtual_witness.rows(), + config.aux_digit_parts, + config.aux_digit_bits as u32, + )?; + + // Absorb level context and the opening-side payload before deriving the JL seed. + absorb_labrador_level_context( + transcript, + &LabradorLevelTranscriptContext { + level_index, + tail: config.tail, + input_row_lengths: orig_row_lengths.clone(), + witness_digit_parts: config.witness_digit_parts, + witness_digit_bits: config.witness_digit_bits, + aux_digit_parts: config.aux_digit_parts, + aux_digit_bits: config.aux_digit_bits, + inner_commit_rank: config.inner_commit_rank, + outer_commit_rank: config.outer_commit_rank, + }, + )?; + + transcript.append_serde( + labels::ABSORB_LABRADOR_INNER_OPENING_PAYLOAD, + &inner_opening_payload, + ); + + // Phase 2: JL Projection — nonce + matrix squeezed from transcript. + let (jl_projection, jl_nonce, jl_matrix) = project(&virtual_witness, transcript)?; + + absorb_labrador_jl_projection(transcript, &jl_projection); + + // Phase 3: JL lift constraints and aggregation (on virtual rows). + let (phi_jl_flat, jl_lift_residuals) = + aggregate_jl_constraints_prover(&virtual_witness, &jl_matrix, transcript)?; + + // Aggregate statement constraints on ORIGINAL rows, then reshape phi. + let (phi_stmt_orig, _statement_aggregated_rhs) = + aggregate_statement(statement, &orig_row_lengths, transcript)?; + let phi_stmt = reshape_phi::(&phi_stmt_orig, &plan.row_split_counts, virtual_row_len); + + let mut phi_total = phi_stmt; + add_phi_flat_in_place(&mut phi_total, &phi_jl_flat)?; + + // Linear garbage h_ij from aggregated phi and virtual witness. + let h = compute_linear_garbage(&phi_total, &virtual_witness)?; + + let aggregated_rhs = { + let r = virtual_witness.rows().len(); + cfg_fold_reduce!( + (0..r), + || CyclotomicRing::::zero(), + |acc, i| acc + h[pair_index(i, i, r)], + |a, b| a + b + ) + }; + + let linear_garbage_digits = + tracing::info_span!("labrador::decompose_linear_garbage").in_scope(|| { + decompose_rows_with_carry(&h, config.aux_digit_parts, config.aux_digit_bits as u32) + }); + + let linear_garbage_payload = build_linear_garbage_payload(setup, &linear_garbage_digits); + + // Absorb the linear-garbage-side payload before amortization challenges. + transcript.append_serde( + labels::ABSORB_LABRADOR_LINEAR_GARBAGE_PAYLOAD, + &linear_garbage_payload, + ); + + // Phase 4: Amortize — sample challenge ring-elements from the transcript, fold. + let challenges = sample_amortize_challenges::(transcript, virtual_row_count)?; + tracing::debug!( + level_index, + tail = config.tail, + ?challenges, + "labrador prover amortize challenges" + ); + let z = amortize_witness(&virtual_witness, &challenges, virtual_row_len); + + let decomposed_z = + tracing::info_span!("labrador::decompose_amortized_witness").in_scope(|| { + decompose_rows_with_carry( + &z, + config.witness_digit_parts, + config.witness_digit_bits as u32, + ) + }); + let z_rows = split_decomposed_rows(&decomposed_z, config.witness_digit_parts, z.len())?; + let next_witness = assemble_output_witness( + z_rows, + &inner_opening_digits, + &linear_garbage_digits, + config.tail, + ); + + let next_witness_norm_sq: u128 = + tracing::info_span!("labrador::next_witness_norm").in_scope(|| next_witness.norm()); + + let reduced_constraints = if config.tail { + None + } else { + Some(Box::new(build_next_constraint_plan( + &phi_total, + &aggregated_rhs, + &challenges, + &virt_row_lengths, + virtual_row_len, + config, + setup.verifier_setup(), + )?)) + }; + + let level_proof = LabradorLevelProof { + tail: config.tail, + input_row_lengths: orig_row_lengths, + config: *config, + virtual_row_len, + row_split_counts: plan.row_split_counts.clone(), + inner_opening_payload: inner_opening_payload.clone(), + linear_garbage_payload: linear_garbage_payload.clone(), + jl_projection, + jl_nonce, + jl_lift_residuals, + next_witness_norm_sq, + }; + + let statement = LabradorStatement { + inner_opening_payload, + linear_garbage_payload, + challenges: challenges.clone(), + constraints: Vec::new(), + reduced_constraints, + witness_norm_bound_sq: next_witness_norm_sq, + }; + + Ok(LabradorFoldResult { + next_witness, + level_proof, + statement, + }) +} + +/// Reshape witness rows according to `row_split_counts` into virtual rows of +/// length `virtual_row_len`. +#[tracing::instrument(skip_all, name = "labrador::reshape_rows")] +fn reshape_rows( + rows: &[Vec>], + row_split_counts: &[usize], + virtual_row_len: usize, +) -> Vec>> { + let mut result = Vec::with_capacity(row_split_counts.iter().copied().sum()); + let mut group: Vec> = Vec::new(); + + for (i, row) in rows.iter().enumerate() { + group.extend_from_slice(row); + let splits = if i < row_split_counts.len() { + row_split_counts[i] + } else { + 0 + }; + if splits > 0 { + for chunk_idx in 0..splits { + let start = chunk_idx * virtual_row_len; + if start + virtual_row_len <= group.len() { + result.push(group[start..start + virtual_row_len].to_vec()); + } else { + let mut virtual_row = vec![CyclotomicRing::::zero(); virtual_row_len]; + let available = group.len().saturating_sub(start).min(virtual_row_len); + if available > 0 { + virtual_row[..available].copy_from_slice(&group[start..start + available]); + } + result.push(virtual_row); + } + } + group.clear(); + } + } + result +} + +/// Reshape phi vectors (same layout as witness reshaping). +fn reshape_phi( + phi: &[Vec>], + row_split_counts: &[usize], + virtual_row_len: usize, +) -> Vec>> { + reshape_rows(phi, row_split_counts, virtual_row_len) +} + +#[tracing::instrument(skip_all, name = "labrador::split_decomposed_rows")] +fn split_decomposed_rows( + flat: &[CyclotomicRing], + parts: usize, + len: usize, +) -> Result>>, HachiError> { + if parts == 0 { + return Err(HachiError::InvalidInput( + "cannot split decomposition with zero parts".to_string(), + )); + } + if flat.len() != len * parts { + return Err(HachiError::InvalidInput(format!( + "decomposition length mismatch: got {}, expected {}", + flat.len(), + len * parts + ))); + } + let rows: Vec>> = cfg_into_iter!(0..parts) + .map(|part| { + let mut row = Vec::with_capacity(len); + for idx in 0..len { + row.push(flat[idx * parts + part]); + } + row + }) + .collect(); + Ok(rows) +} + +#[tracing::instrument(skip_all, name = "labrador::compute_linear_garbage")] +fn compute_linear_garbage( + phi: &[Vec>], + witness: &LabradorWitness, +) -> Result>, HachiError> { + let r = witness.rows().len(); + if phi.len() != r { + return Err(HachiError::InvalidInput( + "phi row count mismatch".to_string(), + )); + } + for (phi_row, witness_row) in phi.iter().zip(witness.rows().iter()) { + if phi_row.len() != witness_row.len() { + return Err(HachiError::InvalidInput( + "phi row length mismatch".to_string(), + )); + } + } + let rows = witness.rows(); + + if let Some(rows_i8) = try_centered_i8_rows(rows) { + match mat_vec_mul_crt_ntt_i8_labrador_cross(phi, &rows_i8) { + Ok(out) => { + tracing::debug!("linear garbage via direct i8 Labrador-cross CRT+NTT path"); + return Ok(out); + } + Err(err) => { + tracing::debug!( + error = %err, + "labrador-cross i8 kernel unavailable; falling back to generic i8 CRT+NTT" + ); + } + } + if let Ok(cross) = mat_vec_mul_crt_ntt_i8_many(phi, &rows_i8) { + tracing::debug!("linear garbage via generic direct i8 CRT+NTT path"); + return Ok(pack_linear_garbage_from_cross::(&cross)); + } + } + + // For large coefficients, avoid generic ring×ring CRT+NTT because its + // reconstruction bound can be exceeded. Instead, decompose witness rows to + // bounded i8 planes and expand phi columns by powers of two, keeping the + // shared i8 backend exact (same strategy as commitment NTT path). + let ntt_digit_levels = rows + .iter() + .map(|row| outer_ntt_digit_levels(row)) + .max() + .unwrap_or(1); + let expanded_phi = expand_matrix_for_i8_digits(phi, ntt_digit_levels, OUTER_NTT_LOG_BASIS); + let rows_digits: Vec> = cfg_iter!(rows) + .map(|row| decompose_rows_ntt_i8_exact(row, ntt_digit_levels, OUTER_NTT_LOG_BASIS)) + .collect(); + match mat_vec_mul_crt_ntt_i8_labrador_cross(&expanded_phi, &rows_digits) { + Ok(out) => { + tracing::debug!( + ntt_digit_levels, + "linear garbage via decomposed i8 Labrador-cross CRT+NTT path" + ); + return Ok(out); + } + Err(err) => { + tracing::debug!( + error = %err, + ntt_digit_levels, + "labrador-cross decomposed i8 kernel unavailable; falling back to generic i8 CRT+NTT" + ); + } + } + if let Ok(cross) = mat_vec_mul_crt_ntt_i8_many(&expanded_phi, &rows_digits) { + tracing::debug!( + ntt_digit_levels, + "linear garbage via generic decomposed i8 CRT+NTT path" + ); + return Ok(pack_linear_garbage_from_cross::(&cross)); + } + + tracing::debug!("linear garbage via pair-parallel schoolbook fallback"); + let pairs: Vec<(usize, usize)> = (0..r).flat_map(|i| (i..r).map(move |j| (i, j))).collect(); + let out: Vec> = cfg_into_iter!(pairs) + .map(|(i, j)| { + let mut acc = CyclotomicRing::::zero(); + let phi_i = &phi[i]; + let row_i = &rows[i]; + if i == j { + for (lhs, rhs) in phi_i.iter().zip(row_i.iter()) { + lhs.mul_accumulate_into(rhs, &mut acc); + } + } else { + let phi_j = &phi[j]; + let row_j = &rows[j]; + for ((lhs_ij, rhs_ij), (lhs_ji, rhs_ji)) in phi_i + .iter() + .zip(row_j.iter()) + .zip(phi_j.iter().zip(row_i.iter())) + { + lhs_ij.mul_accumulate_into(rhs_ij, &mut acc); + lhs_ji.mul_accumulate_into(rhs_ji, &mut acc); + } + } + acc + }) + .collect(); + Ok(out) +} + +#[tracing::instrument(skip_all, name = "labrador::pack_linear_garbage_from_cross")] +fn pack_linear_garbage_from_cross( + cross: &[Vec>], +) -> Vec> { + let r = cross.len(); + let packed_rows: Vec>> = cfg_into_iter!(0..r) + .map(|i| { + debug_assert_eq!(cross[i].len(), r); + let mut packed = Vec::with_capacity(r - i); + packed.push(cross[i][i]); + for j in i + 1..r { + packed.push(cross[i][j] + cross[j][i]); + } + packed + }) + .collect(); + let mut out = Vec::with_capacity(r * (r + 1) / 2); + for packed in packed_rows { + out.extend(packed); + } + out +} + +/// Compute z = sum_i c_i * s_i (all-row linear combination). +#[tracing::instrument(skip_all, name = "labrador::amortize_witness")] +fn amortize_witness( + witness: &LabradorWitness, + challenges: &[SparseChallenge], + max_len: usize, +) -> Vec> { + cfg_into_iter!(0..max_len) + .map(|j| { + let mut acc = CyclotomicRing::::zero(); + for (row, challenge) in witness.rows().iter().zip(challenges.iter()) { + if let Some(elem) = row.get(j) { + elem.mul_by_sparse_into(challenge, &mut acc); + } + } + acc + }) + .collect() +} + +#[tracing::instrument(skip_all, name = "labrador::build_linear_garbage_payload")] +fn build_linear_garbage_payload( + setup: &LabradorSetup, + linear_garbage_digits: &[CyclotomicRing], +) -> Vec> { + if !setup.ntt_d_scaled_levels.is_empty() { + let outer_digit_levels = outer_ntt_digit_levels(linear_garbage_digits); + if outer_digit_levels <= setup.ntt_d_scaled_levels.len() { + let mut digit_planes = + vec![vec![[0i8; D]; linear_garbage_digits.len()]; outer_digit_levels]; + for (idx, ring) in linear_garbage_digits.iter().enumerate() { + let digits = + ring.balanced_decompose_pow2_i8(outer_digit_levels, OUTER_NTT_LOG_BASIS); + for (level, digit) in digits.into_iter().enumerate() { + digit_planes[level][idx] = digit; + } + } + let mut payload = vec![CyclotomicRing::::zero(); setup.matrices.d_mat.len()]; + for (slot, level_digits) in setup + .ntt_d_scaled_levels + .iter() + .zip(digit_planes.iter()) + .take(outer_digit_levels) + { + let partial = mat_vec_mul_ntt_single_i8(slot, level_digits); + for (dst, src) in payload.iter_mut().zip(partial.into_iter()) { + *dst += src; + } + } + return payload; + } + } + if !setup.matrices.d_mat.is_empty() { + mat_vec_mul(&setup.matrices.d_mat, linear_garbage_digits) + } else { + linear_garbage_digits.to_vec() + } +} + +#[tracing::instrument(skip_all, name = "labrador::sample_amortize_challenges")] +fn sample_amortize_challenges( + transcript: &mut T, + rows: usize, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + challenge_sparse_ring_elements_rejection_sampled::( + transcript, + labels::CHALLENGE_LABRADOR_AMORTIZE, + rows, + ) +} + +#[tracing::instrument(skip_all, name = "labrador::assemble_output_witness")] +fn assemble_output_witness( + mut z_rows: Vec>>, + inner_opening_digits: &[CyclotomicRing], + linear_garbage_digits: &[CyclotomicRing], + tail: bool, +) -> LabradorWitness { + if !tail { + let mut aux = Vec::with_capacity(inner_opening_digits.len() + linear_garbage_digits.len()); + aux.extend_from_slice(inner_opening_digits); + aux.extend_from_slice(linear_garbage_digits); + z_rows.push(aux); + } + LabradorWitness::new_unchecked(z_rows) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::protocol::labrador::aggregation::aggregate_jl_constraints_verifier; + use crate::protocol::labrador::commit::ntt_two_tier_commit; + use crate::protocol::labrador::config::trivial_plan; + use crate::protocol::labrador::constraints::{LabradorConstraint, LabradorConstraintTerm}; + use crate::protocol::labrador::johnson_lindenstrauss::LabradorJlMatrix; + use crate::protocol::labrador::types::LabradorReductionConfig; + use crate::protocol::labrador::{verify, LabradorProof}; + use crate::protocol::transcript::labels::DOMAIN_LABRADOR_RECURSION; + use crate::protocol::transcript::Blake2bTranscript; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + fn sample_witness() -> LabradorWitness { + let row = |len: usize| -> Vec> { + (0..len) + .map(|i| { + CyclotomicRing::from_coefficients(std::array::from_fn(|j| { + F::from_i64(((i + j) as i64 % 5) - 2) + })) + }) + .collect() + }; + LabradorWitness::new(vec![row(4), row(4), row(4)]) + } + + fn make_plan( + cfg: &LabradorReductionConfig, + witness: &LabradorWitness, + ) -> LabradorFoldPlan { + let row_lengths: Vec = witness.rows().iter().map(|r| r.len()).collect(); + trivial_plan(*cfg, &row_lengths) + } + + fn replay_amortize_challenges_for_level( + statement: &LabradorStatement, + level: &LabradorLevelProof, + ) -> Vec { + let virtual_row_count = level.row_split_counts.iter().sum::(); + let virt_row_lengths = vec![level.virtual_row_len; virtual_row_count]; + let jl_cols = virtual_row_count * level.virtual_row_len * D; + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + absorb_labrador_level_context( + &mut transcript, + &LabradorLevelTranscriptContext { + level_index: 0, + tail: level.tail, + input_row_lengths: level.input_row_lengths.clone(), + witness_digit_parts: level.config.witness_digit_parts, + witness_digit_bits: level.config.witness_digit_bits, + aux_digit_parts: level.config.aux_digit_parts, + aux_digit_bits: level.config.aux_digit_bits, + inner_commit_rank: level.config.inner_commit_rank, + outer_commit_rank: level.config.outer_commit_rank, + }, + ) + .unwrap(); + transcript.append_serde( + labels::ABSORB_LABRADOR_INNER_OPENING_PAYLOAD, + &level.inner_opening_payload, + ); + let jl_matrix = LabradorJlMatrix::replay_nonce_search::>( + &mut transcript, + level.jl_nonce, + jl_cols, + ) + .unwrap(); + absorb_labrador_jl_projection(&mut transcript, &level.jl_projection); + aggregate_jl_constraints_verifier( + &virt_row_lengths, + &level.jl_projection, + &jl_matrix, + &level.jl_lift_residuals, + &mut transcript, + ) + .unwrap(); + aggregate_statement(statement, &level.input_row_lengths, &mut transcript).unwrap(); + transcript.append_serde( + labels::ABSORB_LABRADOR_LINEAR_GARBAGE_PAYLOAD, + &level.linear_garbage_payload, + ); + sample_amortize_challenges::, D>(&mut transcript, virtual_row_count) + .unwrap() + } + + #[test] + fn standard_fold_produces_decomposed_output() { + let witness = sample_witness(); + let statement = LabradorStatement { + inner_opening_payload: Vec::new(), + linear_garbage_payload: Vec::new(), + challenges: Vec::new(), + constraints: Vec::new(), + reduced_constraints: None, + witness_norm_bound_sq: 1 << 20, + }; + let cfg = LabradorReductionConfig { + witness_digit_parts: 1, + witness_digit_bits: 8, + aux_digit_parts: 2, + aux_digit_bits: 10, + inner_commit_rank: 3, + outer_commit_rank: 2, + tail: false, + }; + let plan = make_plan(&cfg, &witness); + let seed = [1u8; 32]; + let virtual_row_count = plan.row_split_counts.iter().sum::(); + let setup = std::sync::Arc::new(LabradorSetup::new( + &cfg, + virtual_row_count, + plan.virtual_row_len, + &seed, + )); + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let out = prove_level( + &witness, + &statement, + &cfg, + &plan, + &setup, + 0, + &mut transcript, + ) + .unwrap(); + assert!( + !out.next_witness.rows().is_empty(), + "fold must produce output witness" + ); + assert_eq!(out.next_witness.rows().len(), cfg.witness_digit_parts + 1); + assert!(!out.level_proof.linear_garbage_payload.is_empty()); + } + + #[test] + fn tail_fold_produces_decomposed_output() { + let witness = sample_witness(); + let statement = LabradorStatement { + inner_opening_payload: Vec::new(), + linear_garbage_payload: Vec::new(), + challenges: Vec::new(), + constraints: Vec::new(), + reduced_constraints: None, + witness_norm_bound_sq: 1 << 20, + }; + let cfg = LabradorReductionConfig { + witness_digit_parts: 1, + witness_digit_bits: 8, + aux_digit_parts: 1, + aux_digit_bits: 32, + inner_commit_rank: 2, + outer_commit_rank: 0, + tail: true, + }; + let plan = make_plan(&cfg, &witness); + let seed = [3u8; 32]; + let virtual_row_count = plan.row_split_counts.iter().sum::(); + let setup = std::sync::Arc::new(LabradorSetup::new( + &cfg, + virtual_row_count, + plan.virtual_row_len, + &seed, + )); + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let out = prove_level( + &witness, + &statement, + &cfg, + &plan, + &setup, + 1, + &mut transcript, + ) + .unwrap(); + assert!( + !out.next_witness.rows().is_empty(), + "tail fold must produce output" + ); + assert_eq!(out.next_witness.rows().len(), cfg.witness_digit_parts); + assert!(out.level_proof.tail); + } + + #[test] + fn tail_fold_roundtrip_verifies() { + let row = |seed: i64| -> Vec> { + (0..28) + .map(|j| { + CyclotomicRing::from_coefficients(std::array::from_fn(|k| { + let raw = (seed + j as i64 * 3 + k as i64 * 5) % 11; + F::from_i64(raw - 5) + })) + }) + .collect() + }; + let witness = LabradorWitness::new(vec![row(1), row(2), row(3)]); + let statement = LabradorStatement { + inner_opening_payload: Vec::new(), + linear_garbage_payload: Vec::new(), + challenges: Vec::new(), + constraints: Vec::new(), + reduced_constraints: None, + witness_norm_bound_sq: 1 << 80, + }; + let cfg = LabradorReductionConfig { + witness_digit_parts: 1, + witness_digit_bits: 39, + aux_digit_parts: 1, + aux_digit_bits: 128, + inner_commit_rank: 4, + outer_commit_rank: 0, + tail: true, + }; + let plan = make_plan(&cfg, &witness); + let comkey_seed = [13u8; 32]; + let virtual_row_count = plan.row_split_counts.iter().sum::(); + let setup = std::sync::Arc::new(LabradorSetup::new( + &cfg, + virtual_row_count, + plan.virtual_row_len, + &comkey_seed, + )); + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let fold = prove_level( + &witness, + &statement, + &cfg, + &plan, + &setup, + 0, + &mut transcript, + ) + .unwrap(); + + let proof = LabradorProof { + levels: vec![fold.level_proof], + final_opening_witness: fold.next_witness, + }; + let mut verify_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let verify_result = + verify(&statement, &proof, &comkey_seed, &mut verify_transcript).unwrap(); + assert_eq!(verify_result.terminal_statement, statement); + assert_eq!( + verify_result.final_opening_witness, + proof.final_opening_witness + ); + } + + #[test] + fn amortize_challenges_replay_and_bind_transcript_inputs() { + let mk_ring = |c: i64| { + CyclotomicRing::::from_coefficients(std::array::from_fn(|i| { + if i == 0 { + F::from_i64(c) + } else { + F::zero() + } + })) + }; + let witness = LabradorWitness::new(vec![ + vec![mk_ring(1), mk_ring(2)], + vec![mk_ring(3), mk_ring(-1)], + ]); + let target = witness.rows()[0][0] + witness.rows()[1][1]; + let statement = LabradorStatement { + inner_opening_payload: Vec::new(), + linear_garbage_payload: Vec::new(), + challenges: Vec::new(), + constraints: vec![LabradorConstraint::new( + vec![ + LabradorConstraintTerm::new(0, 0, vec![mk_ring(1), mk_ring(0)]), + LabradorConstraintTerm::new(1, 0, vec![mk_ring(0), mk_ring(1)]), + ], + target, + )], + reduced_constraints: None, + witness_norm_bound_sq: 1 << 40, + }; + let cfg = LabradorReductionConfig { + witness_digit_parts: 4, + witness_digit_bits: 8, + aux_digit_parts: 4, + aux_digit_bits: 8, + inner_commit_rank: 2, + outer_commit_rank: 2, + tail: false, + }; + let plan = make_plan(&cfg, &witness); + let comkey_seed = [9u8; 32]; + let virtual_row_count = plan.row_split_counts.iter().sum::(); + let setup = std::sync::Arc::new(LabradorSetup::new( + &cfg, + virtual_row_count, + plan.virtual_row_len, + &comkey_seed, + )); + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let fold = prove_level( + &witness, + &statement, + &cfg, + &plan, + &setup, + 0, + &mut transcript, + ) + .unwrap(); + + let replayed = replay_amortize_challenges_for_level(&statement, &fold.level_proof); + assert_eq!(replayed, fold.statement.challenges); + + let mut mutated_linear_garbage_payload = fold.level_proof.clone(); + mutated_linear_garbage_payload.linear_garbage_payload[0] += mk_ring(1); + let replayed_linear_garbage_payload = + replay_amortize_challenges_for_level(&statement, &mutated_linear_garbage_payload); + assert_ne!(replayed_linear_garbage_payload, fold.statement.challenges); + + let mut mutated_nonce = fold.level_proof.clone(); + mutated_nonce.jl_nonce = if mutated_nonce.jl_nonce == 1 { 2 } else { 1 }; + let replayed_nonce = replay_amortize_challenges_for_level(&statement, &mutated_nonce); + assert_ne!(replayed_nonce, fold.statement.challenges); + } + + #[test] + fn amortize_is_linear_combination() { + let witness = sample_witness(); + let one = SparseChallenge { + positions: vec![0], + coeffs: vec![1], + }; + let challenges = vec![one; witness.rows().len()]; + let max_len = witness.rows().iter().map(|r| r.len()).max().unwrap(); + let z = amortize_witness(&witness, &challenges, max_len); + + for (j, z_elem) in z.iter().enumerate().take(max_len) { + let expected = witness + .rows() + .iter() + .map(|row| { + row.get(j) + .copied() + .unwrap_or_else(CyclotomicRing::::zero) + }) + .fold(CyclotomicRing::::zero(), |a, b| a + b); + assert_eq!(*z_elem, expected); + } + } + + #[test] + fn linear_garbage_decomposed_ntt_matches_reference() { + let r = 3usize; + let len = 6usize; + let mk_ring = |seed: i64| { + CyclotomicRing::::from_coefficients(std::array::from_fn(|k| { + let raw = (seed + 5 * k as i64) % 97; + F::from_i64(raw - 48) + })) + }; + + let phi: Vec>> = (0..r) + .map(|i| { + (0..len) + .map(|j| mk_ring(17 * i as i64 + 11 * j as i64)) + .collect() + }) + .collect(); + let witness_rows: Vec>> = (0..r) + .map(|i| { + (0..len) + .map(|j| mk_ring(23 * i as i64 + 7 * j as i64 + 3)) + .collect() + }) + .collect(); + let witness = LabradorWitness::new_unchecked(witness_rows.clone()); + + // Ensure this test exercises the decomposed-i8 NTT path, not the direct i8 path. + assert!(try_centered_i8_rows(witness.rows()).is_none()); + + let got = compute_linear_garbage(&phi, &witness).unwrap(); + + let mut expected = vec![CyclotomicRing::::zero(); r * (r + 1) / 2]; + for i in 0..r { + for j in i..r { + let pair = pair_index(i, j, r); + let mut acc = CyclotomicRing::::zero(); + for col in 0..len { + phi[i][col].mul_accumulate_into(&witness_rows[j][col], &mut acc); + if i != j { + phi[j][col].mul_accumulate_into(&witness_rows[i][col], &mut acc); + } + } + expected[pair] = acc; + } + } + assert_eq!(got, expected); + } + + #[test] + fn standard_fold_roundtrip_verifies() { + let mk_ring = |c: i64| { + CyclotomicRing::::from_coefficients(std::array::from_fn(|i| { + if i == 0 { + F::from_i64(c) + } else { + F::zero() + } + })) + }; + let witness = LabradorWitness::new(vec![ + vec![mk_ring(1), mk_ring(2)], + vec![mk_ring(3), mk_ring(-1)], + ]); + let target = witness.rows()[0][0] + witness.rows()[1][1]; + let statement = LabradorStatement { + inner_opening_payload: Vec::new(), + linear_garbage_payload: Vec::new(), + challenges: Vec::new(), + constraints: vec![LabradorConstraint::new( + vec![ + LabradorConstraintTerm::new(0, 0, vec![mk_ring(1), mk_ring(0)]), + LabradorConstraintTerm::new(1, 0, vec![mk_ring(0), mk_ring(1)]), + ], + target, + )], + reduced_constraints: None, + witness_norm_bound_sq: 1 << 40, + }; + let cfg = LabradorReductionConfig { + witness_digit_parts: 4, + witness_digit_bits: 8, + aux_digit_parts: 4, + aux_digit_bits: 8, + inner_commit_rank: 2, + outer_commit_rank: 2, + tail: false, + }; + let plan = make_plan(&cfg, &witness); + let comkey_seed = [9u8; 32]; + let virtual_row_count = plan.row_split_counts.iter().sum::(); + let setup = std::sync::Arc::new(LabradorSetup::new( + &cfg, + virtual_row_count, + plan.virtual_row_len, + &comkey_seed, + )); + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let fold = prove_level( + &witness, + &statement, + &cfg, + &plan, + &setup, + 0, + &mut transcript, + ) + .unwrap(); + + let proof = LabradorProof { + levels: vec![fold.level_proof], + final_opening_witness: fold.next_witness, + }; + let mut verify_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + verify(&statement, &proof, &comkey_seed, &mut verify_transcript).unwrap(); + + let base_proof = LabradorProof { + levels: Vec::new(), + final_opening_witness: proof.final_opening_witness.clone(), + }; + let mut base_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + verify( + &fold.statement, + &base_proof, + &comkey_seed, + &mut base_transcript, + ) + .unwrap(); + } + + #[test] + fn non_tail_grouped_rows_statement_roundtrip_verifies() { + let row = |seed: i64, len: usize| -> Vec> { + (0..len) + .map(|j| { + CyclotomicRing::from_coefficients(std::array::from_fn(|k| { + let raw = (seed + j as i64 * 3 + k as i64 * 5) % 11; + F::from_i64(raw - 5) + })) + }) + .collect() + }; + let witness = LabradorWitness::new_unchecked(vec![row(1, 48), row(2, 36)]); + let statement = LabradorStatement { + inner_opening_payload: Vec::new(), + linear_garbage_payload: Vec::new(), + challenges: Vec::new(), + constraints: Vec::new(), + reduced_constraints: None, + witness_norm_bound_sq: 1 << 100, + }; + let cfg = LabradorReductionConfig { + witness_digit_parts: 1, + witness_digit_bits: 35, + aux_digit_parts: 4, + aux_digit_bits: 32, + inner_commit_rank: 3, + outer_commit_rank: 3, + tail: false, + }; + let plan = LabradorFoldPlan { + config: cfg, + virtual_row_len: 48, + row_split_counts: vec![0, 2], + }; + let comkey_seed = [17u8; 32]; + let setup = std::sync::Arc::new(LabradorSetup::new(&cfg, 2, 48, &comkey_seed)); + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let fold = prove_level( + &witness, + &statement, + &cfg, + &plan, + &setup, + 0, + &mut transcript, + ) + .unwrap(); + + let base_proof = LabradorProof { + levels: Vec::new(), + final_opening_witness: fold.next_witness.clone(), + }; + let mut base_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + verify( + &fold.statement, + &base_proof, + &comkey_seed, + &mut base_transcript, + ) + .unwrap(); + } + + #[test] + fn two_level_fold_roundtrip_verifies() { + let mk_ring = |c: i64| { + CyclotomicRing::::from_coefficients(std::array::from_fn(|i| { + if i == 0 { + F::from_i64(c) + } else { + F::zero() + } + })) + }; + let witness = LabradorWitness::new(vec![ + vec![mk_ring(1), mk_ring(2)], + vec![mk_ring(3), mk_ring(-1)], + ]); + let target = witness.rows()[0][0] + witness.rows()[1][1]; + let statement = LabradorStatement { + inner_opening_payload: Vec::new(), + linear_garbage_payload: Vec::new(), + challenges: Vec::new(), + constraints: vec![LabradorConstraint::new( + vec![ + LabradorConstraintTerm::new(0, 0, vec![mk_ring(1), mk_ring(0)]), + LabradorConstraintTerm::new(1, 0, vec![mk_ring(0), mk_ring(1)]), + ], + target, + )], + reduced_constraints: None, + witness_norm_bound_sq: 1 << 40, + }; + let cfg = LabradorReductionConfig { + witness_digit_parts: 4, + witness_digit_bits: 8, + aux_digit_parts: 4, + aux_digit_bits: 8, + inner_commit_rank: 2, + outer_commit_rank: 2, + tail: false, + }; + let comkey_seed = [9u8; 32]; + let plan1 = make_plan(&cfg, &witness); + let virtual_row_count1 = plan1.row_split_counts.iter().sum::(); + let setup1 = std::sync::Arc::new(LabradorSetup::new( + &cfg, + virtual_row_count1, + plan1.virtual_row_len, + &comkey_seed, + )); + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let fold1 = prove_level( + &witness, + &statement, + &cfg, + &plan1, + &setup1, + 0, + &mut transcript, + ) + .unwrap(); + let plan2 = make_plan(&cfg, &fold1.next_witness); + let virtual_row_count2 = plan2.row_split_counts.iter().sum::(); + let setup2 = std::sync::Arc::new(LabradorSetup::new( + &cfg, + virtual_row_count2, + plan2.virtual_row_len, + &comkey_seed, + )); + let fold2 = prove_level( + &fold1.next_witness, + &fold1.statement, + &cfg, + &plan2, + &setup2, + 1, + &mut transcript, + ) + .unwrap(); + + let r = fold2.level_proof.input_row_lengths.len(); + let challenges = &fold2.statement.challenges; + let aux_row = &fold2.next_witness.rows()[cfg.witness_digit_parts]; + let inner_opening_digits_len = r * cfg.inner_commit_rank * cfg.aux_digit_parts; + let inner_opening_digits = &aux_row[..inner_opening_digits_len]; + let mut t_flat = Vec::with_capacity(r * cfg.inner_commit_rank); + for chunk in inner_opening_digits.chunks(cfg.aux_digit_parts) { + t_flat.push(CyclotomicRing::gadget_recompose_pow2( + chunk, + cfg.aux_digit_bits as u32, + )); + } + let z_parts: Vec>> = + fold2.next_witness.rows()[..cfg.witness_digit_parts].to_vec(); + let mut z = Vec::with_capacity(z_parts[0].len()); + for idx in 0..z_parts[0].len() { + let mut slice = Vec::with_capacity(cfg.witness_digit_parts); + for part in &z_parts { + slice.push(part[idx]); + } + z.push(CyclotomicRing::gadget_recompose_pow2( + &slice, + cfg.witness_digit_bits as u32, + )); + } + let az = mat_vec_mul(&setup2.matrices.a_mat, &z); + let mut rhs = vec![CyclotomicRing::::zero(); cfg.inner_commit_rank]; + for (row_idx, t_row) in t_flat.chunks(cfg.inner_commit_rank).enumerate() { + for k in 0..cfg.inner_commit_rank { + t_row[k].mul_by_sparse_into(&challenges[row_idx], &mut rhs[k]); + } + } + assert_eq!(az, rhs); + + let proof = LabradorProof { + levels: vec![fold1.level_proof, fold2.level_proof], + final_opening_witness: fold2.next_witness, + }; + let mut verify_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let verify_result = + verify(&statement, &proof, &comkey_seed, &mut verify_transcript).unwrap(); + assert_eq!(verify_result.terminal_statement, fold2.statement); + assert_eq!( + verify_result.final_opening_witness, + proof.final_opening_witness + ); + } + + #[test] + fn tail_linear_relation_matches_schoolbook_on_virtual_rows() { + let row = |seed: i64| -> Vec> { + (0..28) + .map(|j| { + CyclotomicRing::from_coefficients(std::array::from_fn(|k| { + let raw = (seed + j as i64 * 3 + k as i64 * 5) % 11; + F::from_i64(raw - 5) + })) + }) + .collect() + }; + let virtual_witness = LabradorWitness::new(vec![row(1), row(2), row(3)]); + let cfg = LabradorReductionConfig { + witness_digit_parts: 1, + witness_digit_bits: 39, + aux_digit_parts: 1, + aux_digit_bits: 128, + inner_commit_rank: 4, + outer_commit_rank: 0, + tail: true, + }; + let comkey_seed = [13u8; 32]; + let setup = LabradorSetup::new(&cfg, 3, 28, &comkey_seed); + let (inner_opening_digits, inner_opening_payload) = ntt_two_tier_commit( + &setup.matrices.a_mat, + &setup.matrices.b_mat, + virtual_witness.rows(), + cfg.aux_digit_parts, + cfg.aux_digit_bits as u32, + ) + .unwrap(); + assert_eq!(inner_opening_digits, inner_opening_payload); + + let challenges = vec![ + SparseChallenge { + positions: vec![0, 7, 11], + coeffs: vec![1, -1, 2], + }, + SparseChallenge { + positions: vec![3, 9, 17], + coeffs: vec![-1, 1, -2], + }, + SparseChallenge { + positions: vec![5, 12, 21], + coeffs: vec![2, -1, 1], + }, + ]; + let z = amortize_witness(&virtual_witness, &challenges, 28); + let az = mat_vec_mul(&setup.matrices.a_mat, &z); + + let mut rhs_payload = vec![CyclotomicRing::::zero(); cfg.inner_commit_rank]; + for (row_idx, chunk) in inner_opening_digits + .chunks(cfg.inner_commit_rank * cfg.aux_digit_parts) + .enumerate() + { + let challenge = &challenges[row_idx]; + for (k, rhs_k) in rhs_payload.iter_mut().enumerate() { + let start = k * cfg.aux_digit_parts; + let t = CyclotomicRing::gadget_recompose_pow2( + &chunk[start..start + cfg.aux_digit_parts], + cfg.aux_digit_bits as u32, + ); + t.mul_by_sparse_into(challenge, rhs_k); + } + } + + let mut rhs_schoolbook = vec![CyclotomicRing::::zero(); cfg.inner_commit_rank]; + for (row, challenge) in virtual_witness.rows().iter().zip(challenges.iter()) { + let t_row = mat_vec_mul(&setup.matrices.a_mat, row); + for (rhs_k, t_k) in rhs_schoolbook.iter_mut().zip(t_row.iter()) { + t_k.mul_by_sparse_into(challenge, rhs_k); + } + } + + assert_eq!(az, rhs_schoolbook); + assert_eq!(rhs_payload, rhs_schoolbook); + } +} + +#[cfg(test)] +mod malicious_prover { + use super::*; + use crate::algebra::fields::Fp64; + use crate::protocol::labrador::config::trivial_plan; + use crate::protocol::labrador::constraints::{LabradorConstraint, LabradorConstraintTerm}; + use crate::protocol::labrador::types::LabradorReductionConfig; + use crate::protocol::labrador::{verify, LabradorProof}; + use crate::protocol::transcript::labels::DOMAIN_LABRADOR_RECURSION; + use crate::protocol::transcript::Blake2bTranscript; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + fn mk_ring(c: i64) -> CyclotomicRing { + CyclotomicRing::::from_coefficients(std::array::from_fn(|i| { + if i == 0 { + F::from_i64(c) + } else { + F::zero() + } + })) + } + + fn valid_single_level_proof() -> (LabradorStatement, LabradorProof, [u8; 32]) { + let witness = LabradorWitness::new(vec![ + vec![mk_ring(1), mk_ring(2)], + vec![mk_ring(3), mk_ring(-1)], + ]); + let target = witness.rows()[0][0] + witness.rows()[1][1]; + let statement = LabradorStatement { + inner_opening_payload: Vec::new(), + linear_garbage_payload: Vec::new(), + challenges: Vec::new(), + constraints: vec![LabradorConstraint::new( + vec![ + LabradorConstraintTerm::new(0, 0, vec![mk_ring(1), mk_ring(0)]), + LabradorConstraintTerm::new(1, 0, vec![mk_ring(0), mk_ring(1)]), + ], + target, + )], + reduced_constraints: None, + witness_norm_bound_sq: 1 << 40, + }; + let cfg = LabradorReductionConfig { + witness_digit_parts: 4, + witness_digit_bits: 8, + aux_digit_parts: 4, + aux_digit_bits: 8, + inner_commit_rank: 2, + outer_commit_rank: 2, + tail: false, + }; + let comkey_seed = [9u8; 32]; + let row_lengths: Vec = witness.rows().iter().map(|r| r.len()).collect(); + let plan = trivial_plan(cfg, &row_lengths); + let virtual_row_count = plan.row_split_counts.iter().sum::(); + let setup = std::sync::Arc::new(LabradorSetup::new( + &cfg, + virtual_row_count, + plan.virtual_row_len, + &comkey_seed, + )); + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let fold = prove_level( + &witness, + &statement, + &cfg, + &plan, + &setup, + 0, + &mut transcript, + ) + .unwrap(); + let proof = LabradorProof { + levels: vec![fold.level_proof], + final_opening_witness: fold.next_witness, + }; + (statement, proof, comkey_seed) + } + + fn assert_verification_fails( + statement: &LabradorStatement, + proof: &LabradorProof, + comkey_seed: &[u8; 32], + ) { + let mut verify_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + assert!( + verify(statement, proof, comkey_seed, &mut verify_transcript).is_err(), + "maliciously altered proof should fail verification" + ); + } + + #[test] + fn malicious_inner_opening_payload_fails_verification() { + let (statement, mut proof, comkey_seed) = valid_single_level_proof(); + proof.levels[0].inner_opening_payload[0].coefficients_mut()[0] += F::one(); + assert_verification_fails(&statement, &proof, &comkey_seed); + } + + #[test] + fn malicious_linear_garbage_payload_fails_verification() { + let (statement, mut proof, comkey_seed) = valid_single_level_proof(); + proof.levels[0].linear_garbage_payload[0].coefficients_mut()[0] += F::one(); + assert_verification_fails(&statement, &proof, &comkey_seed); + } + + #[test] + fn malicious_jl_projection_fails_verification() { + let (statement, mut proof, comkey_seed) = valid_single_level_proof(); + proof.levels[0].jl_projection[0] = i64::MAX; + assert_verification_fails(&statement, &proof, &comkey_seed); + } + + #[test] + fn malicious_jl_nonce_fails_verification() { + let (statement, mut proof, comkey_seed) = valid_single_level_proof(); + proof.levels[0].jl_nonce += 1; + assert_verification_fails(&statement, &proof, &comkey_seed); + } + + #[test] + fn malicious_jl_lift_residuals_fail_verification() { + let (statement, mut proof, comkey_seed) = valid_single_level_proof(); + proof.levels[0].jl_lift_residuals[0].coefficients_mut()[0] += F::one(); + assert_verification_fails(&statement, &proof, &comkey_seed); + } +} diff --git a/src/protocol/labrador/guardrails.rs b/src/protocol/labrador/guardrails.rs new file mode 100644 index 00000000..c5d9341d --- /dev/null +++ b/src/protocol/labrador/guardrails.rs @@ -0,0 +1,89 @@ +//! Guardrails for Labrador protocol plumbing. + +use crate::error::HachiError; + +/// Maximum recursion levels accepted by the protocol. +/// +/// Mirrors the fixed upper bound used by the C reference (`proof *pi[16]`). +pub const LABRADOR_MAX_LEVELS: usize = 16; +/// Upper bound for JL nonce search attempts. +pub const LABRADOR_MAX_JL_NONCE_RETRIES: u64 = 1 << 20; +/// Upper bound on challenge polynomials sampled per call. +pub const LABRADOR_MAX_CHALLENGE_POLYS: usize = 1 << 12; +/// Upper bound for temporary byte allocations in Labrador helpers. +pub const LABRADOR_MAX_TEMP_BYTES: usize = 1 << 27; // 128 MiB + +/// Checked conversion from `usize` to `u64`. +/// +/// # Errors +/// +/// Returns an error when `value` does not fit into `u64`. +pub fn checked_usize_to_u64(value: usize, what: &'static str) -> Result { + u64::try_from(value) + .map_err(|_| HachiError::InvalidInput(format!("{what} does not fit in u64: {value}"))) +} + +/// Ensure a value is a power of two. +/// +/// # Errors +/// +/// Returns an error if `value` is not a power of two. +pub fn ensure_power_of_two(value: usize, what: &'static str) -> Result<(), HachiError> { + if !value.is_power_of_two() { + return Err(HachiError::InvalidInput(format!( + "{what} must be a power of two, got {value}" + ))); + } + Ok(()) +} + +/// Checked `a * b` for allocation sizing. +/// +/// # Errors +/// +/// Returns an error if multiplication overflows `usize`. +pub fn checked_mul(a: usize, b: usize, what: &'static str) -> Result { + a.checked_mul(b) + .ok_or_else(|| HachiError::InvalidInput(format!("overflow while computing {what}"))) +} + +/// Checked `a + b` for allocation sizing. +/// +/// # Errors +/// +/// Returns an error if addition overflows `usize`. +pub fn checked_add(a: usize, b: usize, what: &'static str) -> Result { + a.checked_add(b) + .ok_or_else(|| HachiError::InvalidInput(format!("overflow while computing {what}"))) +} + +/// Validate temporary allocation size against guardrail cap. +/// +/// # Errors +/// +/// Returns an error if `bytes > LABRADOR_MAX_TEMP_BYTES`. +pub fn ensure_temp_allocation_limit(bytes: usize, what: &'static str) -> Result<(), HachiError> { + if bytes > LABRADOR_MAX_TEMP_BYTES { + return Err(HachiError::InvalidInput(format!( + "{what} temporary allocation too large: {bytes} bytes (max {LABRADOR_MAX_TEMP_BYTES})" + ))); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn checked_mul_detects_overflow() { + let err = checked_mul(usize::MAX, 2, "overflow-test").unwrap_err(); + assert!(matches!(err, HachiError::InvalidInput(_))); + } + + #[test] + fn temp_limit_enforced() { + let err = ensure_temp_allocation_limit(LABRADOR_MAX_TEMP_BYTES + 1, "tmp").unwrap_err(); + assert!(matches!(err, HachiError::InvalidInput(_))); + } +} diff --git a/src/protocol/labrador/johnson_lindenstrauss.rs b/src/protocol/labrador/johnson_lindenstrauss.rs new file mode 100644 index 00000000..3d8e02de --- /dev/null +++ b/src/protocol/labrador/johnson_lindenstrauss.rs @@ -0,0 +1,699 @@ +//! Johnson-Lindenstrauss helpers for Labrador reduction. + +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::labrador::guardrails::LABRADOR_MAX_JL_NONCE_RETRIES; +use crate::protocol::labrador::types::LabradorWitness; +use crate::protocol::transcript::{labels, Transcript}; +use crate::{CanonicalField, FieldCore}; +use sha3::digest::{ExtendableOutput, Update, XofReader}; +use sha3::Shake128; + +const JL_ROWS: usize = 256; +const JL_ROW_XOF_DOMAIN: &[u8] = b"hachi/labrador-jl-matrix-row-v1"; + +fn expand_jl_seed_row(seed: &[u8], cols: usize, row_idx: usize, row_bytes: usize) -> Vec { + let mut xof = Shake128::default(); + xof.update(JL_ROW_XOF_DOMAIN); + xof.update(seed); + xof.update(&(cols as u64).to_le_bytes()); + xof.update(&(row_idx as u64).to_le_bytes()); + let mut reader = xof.finalize_xof(); + let mut row = vec![0u8; row_bytes]; + reader.read(&mut row); + row +} + +fn expand_jl_seed(seed: &[u8], cols: usize, row_bytes: usize) -> Vec> { + // Each row is derived from (seed, cols, row_idx), so generation is + // deterministic and safe to parallelize without shared XOF state. + cfg_into_iter!(0..JL_ROWS) + .map(|row_idx| expand_jl_seed_row(seed, cols, row_idx, row_bytes)) + .collect() +} + +fn jl_row_bytes(cols: usize) -> Result { + if cols == 0 { + return Err(HachiError::InvalidInput( + "JL matrix requires non-zero column count".to_string(), + )); + } + Ok((cols * 2).div_ceil(8)) +} + +pub(crate) fn replay_nonce_search_seed( + transcript: &mut T, + jl_nonce: u64, + cols: usize, +) -> Result<(usize, [u8; 32]), HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + if !(1..=LABRADOR_MAX_JL_NONCE_RETRIES).contains(&jl_nonce) { + return Err(HachiError::InvalidInput(format!( + "JL nonce out of range: {jl_nonce}" + ))); + } + let row_bytes = jl_row_bytes(cols)?; + transcript.append_bytes(labels::ABSORB_LABRADOR_JL_NONCE, &jl_nonce.to_le_bytes()); + let seed_vec = transcript.challenge_bytes(labels::CHALLENGE_LABRADOR_JL_SEED, 32); + let seed: [u8; 32] = seed_vec + .try_into() + .map_err(|_| HachiError::InvalidInput("JL seed length mismatch".to_string()))?; + Ok((row_bytes, seed)) +} + +fn centered_from_canonical( + canonical: u128, + modulus: u128, + half_modulus: u128, +) -> Result { + let magnitude = centered_magnitude(canonical, modulus, half_modulus); + let magnitude = i128::try_from(magnitude).map_err(|_| { + HachiError::InvalidInput("JL centered coefficient exceeds i128 range".to_string()) + })?; + Ok(if canonical > half_modulus { + -magnitude + } else { + magnitude + }) +} + +fn centered_magnitude(canonical: u128, modulus: u128, half_modulus: u128) -> u128 { + if canonical > half_modulus { + modulus - canonical + } else { + canonical + } +} + +enum CenteredWitness { + I64 { + coeffs: Vec, + }, + I128 { + rings: Vec<[i128; D]>, + sum_abs: u128, + }, +} + +impl CenteredWitness { + fn ring_len(&self) -> usize { + match self { + Self::I64 { coeffs, .. } => coeffs.len() / D, + Self::I128 { rings, .. } => rings.len(), + } + } +} + +#[inline] +fn jl_pair_to_sign(pair: u8) -> i8 { + ((pair == 0b11) as i8) - ((pair == 0b00) as i8) +} + +#[inline] +fn jl_pair_at(row: &[u8], col: usize) -> u8 { + let shift = (col & 0b11) << 1; + (row[col >> 2] >> shift) & 0b11 +} + +#[tracing::instrument(skip_all, name = "labrador::center_witness")] +fn center_witness_by_ring( + witness: &LabradorWitness, +) -> Result, HachiError> { + let q = (-F::one()).to_canonical_u128() + 1; + let half_q = q / 2; + let total_rings: usize = witness.rows().iter().map(Vec::len).sum(); + + let mut requires_i128 = false; + 'detect_width: for row in witness.rows() { + for ring in row { + for coeff in ring.coefficients() { + let canonical = coeff.to_canonical_u128(); + let magnitude = centered_magnitude(canonical, q, half_q); + if magnitude > i64::MAX as u128 { + requires_i128 = true; + break 'detect_width; + } + } + } + } + + if requires_i128 { + let mut centered = Vec::with_capacity(total_rings); + let mut sum_abs = 0u128; + for row in witness.rows() { + for ring in row { + let mut coeffs = [0i128; D]; + for (idx, coeff) in ring.coefficients().iter().enumerate() { + coeffs[idx] = centered_from_canonical(coeff.to_canonical_u128(), q, half_q)?; + sum_abs = sum_abs.saturating_add(coeffs[idx].unsigned_abs()); + } + centered.push(coeffs); + } + } + Ok(CenteredWitness::I128 { + rings: centered, + sum_abs, + }) + } else { + let mut centered = Vec::with_capacity(total_rings * D); + for row in witness.rows() { + for ring in row { + for coeff in ring.coefficients() { + let centered_i128 = + centered_from_canonical(coeff.to_canonical_u128(), q, half_q)?; + centered.push(i64::try_from(centered_i128).map_err(|_| { + HachiError::InvalidInput( + "JL centered coefficient unexpectedly exceeds i64 range".to_string(), + ) + })?); + } + } + } + Ok(CenteredWitness::I64 { coeffs: centered }) + } +} + +/// Packed ternary JL matrix with entries in `{-1, 0, +1}`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorJlMatrix { + cols: usize, + row_bytes: usize, + pub(crate) packed_rows: Vec>, +} + +impl LabradorJlMatrix { + /// Number of columns in each JL row. + pub fn cols(&self) -> usize { + self.cols + } + + pub(crate) fn is_well_formed(&self) -> bool { + self.cols > 0 + && self.packed_rows.len() == JL_ROWS + && self + .packed_rows + .iter() + .all(|row| row.len() == self.row_bytes) + } + + #[cfg(test)] + fn from_sign_rows(signs: Vec>) -> Result { + if signs.len() != JL_ROWS { + return Err(HachiError::InvalidInput(format!( + "JL matrix requires exactly {JL_ROWS} rows" + ))); + } + let cols = signs.first().map_or(0, Vec::len); + if cols == 0 { + return Err(HachiError::InvalidInput( + "JL matrix requires non-zero column count".to_string(), + )); + } + if signs.iter().any(|row| row.len() != cols) { + return Err(HachiError::InvalidInput( + "JL matrix row length mismatch".to_string(), + )); + } + + let row_bytes = (cols * 2).div_ceil(8); + let mut packed_rows = vec![vec![0u8; row_bytes]; JL_ROWS]; + for (row_idx, row) in signs.iter().enumerate() { + for (col_idx, &sign) in row.iter().enumerate() { + let pair = match sign { + -1 => 0b00, + 0 => 0b01, + 1 => 0b11, + _ => { + return Err(HachiError::InvalidInput( + "JL matrix entries must be in {-1, 0, +1}".to_string(), + )) + } + }; + packed_rows[row_idx][col_idx >> 2] |= pair << ((col_idx & 0b11) << 1); + } + } + + Ok(Self { + cols, + row_bytes, + packed_rows, + }) + } + + #[cfg(test)] + fn sign_at(&self, row_idx: usize, col_idx: usize) -> Option { + if row_idx >= JL_ROWS || col_idx >= self.cols { + return None; + } + Some(jl_pair_to_sign(jl_pair_at( + &self.packed_rows[row_idx], + col_idx, + ))) + } + + /// Squeeze a JL matrix directly from the transcript. + /// + /// # Errors + /// + /// Returns an error if `cols` is zero. + #[tracing::instrument(skip_all, name = "labrador::jl_matrix_generate")] + pub fn generate(transcript: &mut T, cols: usize) -> Result + where + F: FieldCore + CanonicalField, + T: Transcript, + { + let row_bytes = jl_row_bytes(cols)?; + let seed = transcript.challenge_bytes(labels::CHALLENGE_LABRADOR_JL_SEED, 32); + let packed_rows = expand_jl_seed(&seed, cols, row_bytes); + Ok(Self { + cols, + row_bytes, + packed_rows, + }) + } + + /// Reconstruct the accepted JL matrix from the prover-chosen nonce. + /// + /// The prover now performs nonce search on cloned transcript states and + /// only commits the accepted nonce back into the real transcript. The + /// verifier therefore absorbs exactly that accepted nonce once. + /// + /// # Errors + /// + /// Returns an error if `cols` is zero or `jl_nonce` is out of range. + #[tracing::instrument(skip_all, name = "labrador::jl_matrix_replay")] + pub fn replay_nonce_search( + transcript: &mut T, + jl_nonce: u64, + cols: usize, + ) -> Result + where + F: FieldCore + CanonicalField, + T: Transcript, + { + let (row_bytes, seed) = replay_nonce_search_seed::(transcript, jl_nonce, cols)?; + let packed_rows = expand_jl_seed(&seed, cols, row_bytes); + Ok(Self { + cols, + row_bytes, + packed_rows, + }) + } +} + +/// Project a witness into 256 JL coordinates and return the nonce used. +/// +/// Each nonce attempt runs on a cloned transcript. Only the accepted nonce +/// is committed to the real transcript, which keeps verifier replay bounded +/// and prevents rejected attempts from perturbing later challenges. +/// +/// # Errors +/// +/// Returns an error if the witness is empty or if no valid projection is found +/// within the nonce search limit. +#[tracing::instrument(skip_all, name = "labrador::project")] +pub fn project( + witness: &LabradorWitness, + transcript: &mut T, +) -> Result<([i64; 256], u64, LabradorJlMatrix), HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + let total_coeffs: usize = witness.rows().iter().map(|row| row.len() * D).sum(); + if total_coeffs == 0 { + return Err(HachiError::InvalidInput( + "cannot JL-project empty witness".to_string(), + )); + } + let centered_witness = center_witness_by_ring(witness)?; + if centered_witness.ring_len() * D != total_coeffs { + return Err(HachiError::InvalidInput( + "centered witness length mismatch".to_string(), + )); + } + + let witness_norm: u128 = witness.norm(); + let norm_bound = 256u128.saturating_mul(witness_norm); + let component_bound = next_power_of_two_u64(4.0 * (witness_norm as f64).sqrt()); + + for nonce in 1..=LABRADOR_MAX_JL_NONCE_RETRIES { + let mut nonce_transcript = transcript.clone(); + nonce_transcript.append_bytes(labels::ABSORB_LABRADOR_JL_NONCE, &nonce.to_le_bytes()); + let matrix = LabradorJlMatrix::generate::(&mut nonce_transcript, total_coeffs)?; + if let Some(proj) = project_streaming::(&matrix, ¢ered_witness, total_coeffs) { + if proj.iter().any(|&p| p.unsigned_abs() >= component_bound) { + continue; + } + let proj_norm: u128 = proj.iter().fold(0u128, |acc, &p| { + acc + p.unsigned_abs() as u128 * p.unsigned_abs() as u128 + }); + if proj_norm > norm_bound { + continue; + } + *transcript = nonce_transcript; + return Ok((proj, nonce, matrix)); + } + } + Err(HachiError::InvalidInput(format!( + "failed JL projection nonce search after {LABRADOR_MAX_JL_NONCE_RETRIES} attempts" + ))) +} + +fn next_power_of_two_u64(x: f64) -> u64 { + if x <= 1.0 { + return 1; + } + let bits = x.log2().ceil() as u32; + if bits >= 64 { + return u64::MAX; + } + 1u64 << bits +} + +/// Collapse a JL projection with challenge coefficients. +/// +/// Returns the linear target value `sum_i alpha[i] * projection[i]`. +pub fn collapse(projection: &[i64; 256], alpha: &[i64; 256]) -> i64 { + projection + .iter() + .zip(alpha.iter()) + .fold(0i128, |acc, (&p, &a)| acc + (p as i128) * (a as i128)) + .clamp(i64::MIN as i128, i64::MAX as i128) as i64 +} + +/// Zero out a polynomial constant term for proof transmission. +/// +/// Returns the modified polynomial and the removed constant term. +pub fn zero_constant_term_for_proof( + mut poly: CyclotomicRing, +) -> (CyclotomicRing, F) { + let coeffs = poly.coefficients_mut(); + let c0 = coeffs[0]; + coeffs[0] = F::zero(); + (poly, c0) +} + +/// Restore a polynomial constant term during verifier-side reduction. +pub fn restore_constant_term( + mut transmitted: CyclotomicRing, + constant: F, +) -> CyclotomicRing { + transmitted.coefficients_mut()[0] = constant; + transmitted +} + +/// Compute the JL projection by streaming over witness coefficients without +/// materializing the full flattened vector. +#[inline] +fn project_row_i64(row: &[u8], coeffs: &[i64], cols: usize) -> Option { + let full_bytes = cols >> 2; + let remainder = cols & 0b11; + let mut coeff_idx = 0usize; + let mut acc = 0i128; + + for &byte in row.iter().take(full_bytes) { + let pair0 = byte & 0b11; + let pair1 = (byte >> 2) & 0b11; + let pair2 = (byte >> 4) & 0b11; + let pair3 = (byte >> 6) & 0b11; + + acc += (jl_pair_to_sign(pair0) as i128) * (coeffs[coeff_idx] as i128); + acc += (jl_pair_to_sign(pair1) as i128) * (coeffs[coeff_idx + 1] as i128); + acc += (jl_pair_to_sign(pair2) as i128) * (coeffs[coeff_idx + 2] as i128); + acc += (jl_pair_to_sign(pair3) as i128) * (coeffs[coeff_idx + 3] as i128); + coeff_idx += 4; + } + + if remainder > 0 { + let byte = row[full_bytes]; + for lane in 0..remainder { + let pair = (byte >> (lane << 1)) & 0b11; + acc += (jl_pair_to_sign(pair) as i128) * (coeffs[coeff_idx] as i128); + coeff_idx += 1; + } + } + + i64::try_from(acc).ok() +} + +fn project_row_i128( + row: &[u8], + rings: &[[i128; D]], + cols: usize, + use_checked: bool, +) -> Option { + let mut acc = 0i128; + let mut col_idx = 0usize; + + for coeff_chunk in rings { + for &value in coeff_chunk { + let pair = jl_pair_at(row, col_idx); + if use_checked { + match jl_pair_to_sign(pair) { + -1 => acc = acc.checked_sub(value)?, + 0 => {} + 1 => acc = acc.checked_add(value)?, + _ => unreachable!(), + } + } else { + acc += (jl_pair_to_sign(pair) as i128) * value; + } + col_idx += 1; + } + } + debug_assert_eq!(col_idx, cols); + i64::try_from(acc).ok() +} + +#[tracing::instrument(skip_all, name = "labrador::project_streaming")] +fn project_streaming( + matrix: &LabradorJlMatrix, + centered_witness: &CenteredWitness, + total_coeffs: usize, +) -> Option<[i64; 256]> { + if !matrix.is_well_formed() + || matrix.cols() != total_coeffs + || centered_witness.ring_len() * D != total_coeffs + { + return None; + } + let results: Vec> = match centered_witness { + CenteredWitness::I64 { coeffs } => cfg_into_iter!(0..JL_ROWS) + .map(|row_idx| project_row_i64(&matrix.packed_rows[row_idx], coeffs, total_coeffs)) + .collect(), + CenteredWitness::I128 { rings, sum_abs } => { + let use_checked = *sum_abs > i128::MAX as u128; + cfg_into_iter!(0..JL_ROWS) + .map(|row_idx| { + project_row_i128( + &matrix.packed_rows[row_idx], + rings, + total_coeffs, + use_checked, + ) + }) + .collect() + } + }; + let mut out = [0i64; 256]; + for (i, val) in results.into_iter().enumerate() { + out[i] = val?; + } + Some(out) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::{Fp64, Prime128M13M4P0}; + use crate::protocol::transcript::labels::DOMAIN_LABRADOR_RECURSION; + use crate::protocol::transcript::Blake2bTranscript; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + type F128 = Prime128M13M4P0; + const D: usize = 64; + + fn sample_witness_from_seed_generic(seed: u64) -> LabradorWitness + where + G: FieldCore + CanonicalField + FromSmallInt, + { + let num_rows = 2 + (seed % 3) as usize; + let row_len = 3 + (seed % 5) as usize; + let rows: Vec>> = (0..num_rows) + .map(|r| { + (0..row_len) + .map(|i| { + CyclotomicRing::from_coefficients(std::array::from_fn(|j| { + let mix = seed + .wrapping_mul(6364136223846793005) + .wrapping_add(r as u64 * 997 + i as u64 * 31 + j as u64); + G::from_i64(((mix % 11) as i64) - 5) + })) + }) + .collect() + }) + .collect(); + LabradorWitness::new(rows) + } + + fn sample_witness_from_seed(seed: u64) -> LabradorWitness { + sample_witness_from_seed_generic::(seed) + } + + fn witness_squared_norm( + witness: &LabradorWitness, + ) -> i128 { + let q = (-G::one()).to_canonical_u128() + 1; + let half_q = q / 2; + let mut norm_sq = 0i128; + for row in witness.rows() { + for ring in row { + for coeff in ring.coefficients() { + let c = coeff.to_canonical_u128(); + let v = if c > half_q { + -((q - c) as i128) + } else { + c as i128 + }; + norm_sq += v * v; + } + } + } + norm_sq + } + + #[test] + fn project_is_deterministic_and_replayable() { + let witness = sample_witness_from_seed(42); + let mut t1 = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let mut t2 = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let (p1, n1, _) = project(&witness, &mut t1).unwrap(); + let (p2, n2, _) = project(&witness, &mut t2).unwrap(); + assert_eq!(p1, p2); + assert_eq!(n1, n2); + } + + #[test] + fn project_fp128_is_deterministic_and_replayable() { + let witness = sample_witness_from_seed_generic::(42); + let mut t1 = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let mut t2 = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let (p1, n1, _) = project(&witness, &mut t1).unwrap(); + let (p2, n2, _) = project(&witness, &mut t2).unwrap(); + assert_eq!(p1, p2); + assert_eq!(n1, n2); + } + + #[test] + fn project_streaming_handles_fp128_centered_values_beyond_i64() { + let q = (-F128::one()).to_canonical_u128() + 1; + let large = q / 2 + 17; + let ring = CyclotomicRing::::from_coefficients(std::array::from_fn(|idx| { + if idx == 0 || idx == 1 { + F128::from_canonical_u128_reduced(large) + } else { + F128::zero() + } + })); + let witness = LabradorWitness::new(vec![vec![ring]]); + let centered = center_witness_by_ring(&witness).unwrap(); + let centered_abs = match ¢ered { + CenteredWitness::I64 { coeffs, .. } => coeffs[0].unsigned_abs() as u128, + CenteredWitness::I128 { rings, .. } => rings[0][0].unsigned_abs(), + }; + assert!(centered_abs > i64::MAX as u128); + + let signs: Vec> = (0..JL_ROWS) + .map(|row_idx| { + let mut row = vec![0i8; D]; + if row_idx == 0 { + row[0] = 1; + row[1] = -1; + } + row + }) + .collect(); + let matrix = LabradorJlMatrix::from_sign_rows(signs).unwrap(); + let projection = project_streaming::(&matrix, ¢ered, D).unwrap(); + assert_eq!(projection[0], 0); + assert!(projection.iter().skip(1).all(|&v| v == 0)); + } + + #[test] + fn packed_matrix_roundtrips_manual_signs() { + let signs: Vec> = (0..JL_ROWS) + .map(|row_idx| { + (0..7) + .map(|col_idx| match (row_idx + col_idx) % 3 { + 0 => -1, + 1 => 0, + _ => 1, + }) + .collect() + }) + .collect(); + let matrix = LabradorJlMatrix::from_sign_rows(signs.clone()).unwrap(); + for (row_idx, row) in signs.iter().enumerate() { + for (col_idx, &sign) in row.iter().enumerate() { + assert_eq!(matrix.sign_at(row_idx, col_idx), Some(sign)); + } + } + } + + #[test] + fn project_norm_bound_over_multiple_witnesses() { + for seed in 1..=10u64 { + let witness = sample_witness_from_seed(seed); + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let (projection, nonce, _) = project(&witness, &mut transcript).unwrap(); + + let beta = witness_squared_norm(&witness); + let p_norm_sq: i128 = projection.iter().map(|&v| (v as i128) * (v as i128)).sum(); + let p_inf: i128 = projection.iter().map(|&v| (v as i128).abs()).max().unwrap(); + let entry_bound = ((128.0 * beta as f64).sqrt()) as i128; + + tracing::debug!( + seed, + nonce, + p_norm_sq, + p_inf, + entry_bound, + beta, + "JL projection check" + ); + assert!( + p_inf <= entry_bound, + "seed={seed}: ||p||_inf={p_inf} exceeds sqrt(128β)={entry_bound}" + ); + } + } + + #[test] + fn collapse_matches_dot_product() { + let projection = std::array::from_fn(|i| i as i64 - 10); + let alpha = std::array::from_fn(|i| (2 * i as i64) - 7); + let got = collapse(&projection, &alpha); + let expected = projection + .iter() + .zip(alpha.iter()) + .fold(0i64, |acc, (&p, &a)| acc + p * a); + assert_eq!(got, expected); + } + + #[test] + fn lift_zero_and_restore_constant_term() { + let poly: CyclotomicRing = + CyclotomicRing::from_coefficients(std::array::from_fn(|i| F::from_i64(i as i64 - 5))); + let (tx, c0) = zero_constant_term_for_proof(poly); + assert!(tx.coefficients()[0].is_zero()); + let restored = restore_constant_term(tx, c0); + assert_eq!(restored, poly); + } +} diff --git a/src/protocol/labrador/mod.rs b/src/protocol/labrador/mod.rs new file mode 100644 index 00000000..bd22c16f --- /dev/null +++ b/src/protocol/labrador/mod.rs @@ -0,0 +1,38 @@ +//! Labrador recursive proof sub-protocol. +//! +//! This module hosts the Labrador recursive proof sub-protocol used by Hachi's +//! handoff path. + +pub mod aggregation; +pub mod challenge; +pub mod comkey; +pub mod commit; +pub mod config; +mod constraints; +pub mod fold; +pub mod guardrails; +pub mod johnson_lindenstrauss; +pub mod prover; +pub mod setup; +pub mod transcript; +pub mod types; +pub mod utils; +pub mod verifier; + +pub use comkey::{derive_labrador_comkey_seed, LabradorComKeySeed}; +pub use commit::{commit_linear_only, LabradorCommitmentArtifacts}; +pub use config::{ + estimate_module_sis_euclidean, plan_fold, plan_handoff, select_config, select_config_with_mode, + sis_secure, LabradorFoldPlan, SisEuclideanLatticeEstimate, +}; +pub use constraints::{LabradorConstraint, LabradorConstraintTerm}; +pub use fold::{prove_level, LabradorFoldResult}; +pub use johnson_lindenstrauss::{ + collapse, project, restore_constant_term, zero_constant_term_for_proof, LabradorJlMatrix, +}; +pub use prover::{prove, prove_with_plan}; +pub use setup::LabradorSetup; +pub use types::{ + LabradorLevelProof, LabradorProof, LabradorReductionConfig, LabradorStatement, LabradorWitness, +}; +pub use verifier::{verify, LabradorVerifyResult}; diff --git a/src/protocol/labrador/prover.rs b/src/protocol/labrador/prover.rs new file mode 100644 index 00000000..0afa1b9a --- /dev/null +++ b/src/protocol/labrador/prover.rs @@ -0,0 +1,387 @@ +//! Labrador prover loop. + +use crate::error::HachiError; +use crate::primitives::serialization::Compress; +use crate::protocol::labrador::comkey::LabradorComKeySeed; +use crate::protocol::labrador::config::{ + estimate_fold_step, estimate_selected_fold_step, logq_bits, plan_fold, trivial_plan, + LabradorFoldPlan, +}; +use crate::protocol::labrador::fold::prove_level; +use crate::protocol::labrador::guardrails::LABRADOR_MAX_LEVELS; +use crate::protocol::labrador::setup::LabradorSetup; +use crate::protocol::labrador::types::{LabradorProof, LabradorStatement, LabradorWitness}; +use crate::protocol::labrador::LabradorReductionConfig; +use crate::protocol::proof::FlatLabradorWitness; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt, HachiSerialize}; +use std::sync::Arc; + +/// Build a recursive Labrador proof with optional tail acceptance. +/// +/// Standard levels are applied while witness size decreases. Tail mode is then +/// attempted once and accepted only if total `(proof + witness)` size improves. +/// +/// # Errors +/// +/// Returns an error if folding fails or if recursion limits are exceeded. +#[tracing::instrument(skip_all, name = "labrador::prove")] +pub fn prove( + initial_witness: LabradorWitness, + initial_statement: &LabradorStatement, + comkey_seed: &LabradorComKeySeed, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt + HachiSerialize, + T: Transcript, +{ + if initial_witness.rows().is_empty() { + return Err(HachiError::InvalidInput( + "cannot prove with empty Labrador witness".to_string(), + )); + } + + let mut levels = Vec::new(); + let mut witness = initial_witness; + let mut _statement = initial_statement.clone(); + let mut level_idx = 0usize; + + while level_idx + 1 < LABRADOR_MAX_LEVELS { + let before_bytes = witness_size_bytes::(&witness); + if before_bytes == 0 || witness.rows().len() <= 1 { + break; + } + + let estimate = estimate_fold_step::(&witness, false)?; + if estimate.transition_bytes >= before_bytes { + break; + } + let plan = estimate.plan; + let cfg = plan.config; + let virtual_row_count: usize = plan.row_split_counts.iter().sum(); + let setup = Arc::new(LabradorSetup::new( + &cfg, + virtual_row_count, + plan.virtual_row_len, + comkey_seed, + )); + let fold = prove_level( + &witness, + &_statement, + &cfg, + &plan, + &setup, + level_idx, + transcript, + )?; + levels.push(fold.level_proof); + _statement = fold.statement; + witness = fold.next_witness; + level_idx += 1; + } + + if level_idx + 1 < LABRADOR_MAX_LEVELS { + let baseline_bytes = witness_size_bytes::(&witness); + let tail_estimate = estimate_fold_step::(&witness, true)?; + if tail_estimate.transition_bytes >= baseline_bytes { + return Ok(LabradorProof { + levels, + final_opening_witness: witness, + }); + } + let tail_plan = tail_estimate.plan; + let tail_cfg = tail_plan.config; + + let virtual_row_count: usize = tail_plan.row_split_counts.iter().sum(); + let tail_setup = Arc::new(LabradorSetup::new( + &tail_cfg, + virtual_row_count, + tail_plan.virtual_row_len, + comkey_seed, + )); + let mut tail_transcript = transcript.clone(); + if let Ok(tail) = prove_level( + &witness, + &_statement, + &tail_cfg, + &tail_plan, + &tail_setup, + level_idx, + &mut tail_transcript, + ) { + levels.push(tail.level_proof); + _statement = tail.statement; + witness = tail.next_witness; + *transcript = tail_transcript; + } + } + + Ok(LabradorProof { + levels, + final_opening_witness: witness, + }) +} + +/// Build a recursive Labrador proof using a caller-supplied initial plan. +/// +/// The initial plan is used for the first fold level. Later levels fall back to +/// the last accepted config if `plan_fold` fails. +/// +/// # Errors +/// +/// Returns [`HachiError`] if any fold level fails (e.g. empty witness, +/// invalid config, or transcript errors). +/// +/// # Panics +/// +/// Panics if estimating a trivial follow-on fold unexpectedly fails while +/// proving a previously accepted recursive step. +#[tracing::instrument(skip_all, name = "labrador::prove_with_plan")] +pub fn prove_with_plan( + initial_witness: LabradorWitness, + initial_statement: &LabradorStatement, + initial_plan: &LabradorFoldPlan, + comkey_seed: &LabradorComKeySeed, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt + HachiSerialize, + T: Transcript, +{ + if initial_witness.rows().is_empty() { + return Err(HachiError::InvalidInput( + "cannot prove with empty Labrador witness".to_string(), + )); + } + + let mut levels = Vec::new(); + let mut witness = initial_witness; + let mut statement = initial_statement.clone(); + let mut level_idx = 0usize; + let mut fallback_cfg = initial_plan.config; + let initial_row_lengths: Vec = witness.rows().iter().map(|row| row.len()).collect(); + let initial_ring_elems: usize = initial_row_lengths.iter().sum(); + let initial_witness_bytes = witness_size_bytes::(&witness); + tracing::debug!( + ?initial_row_lengths, + total_ring_elems = initial_ring_elems, + witness_bytes = initial_witness_bytes, + serialized_bytes = initial_witness_bytes, + virtual_row_len = initial_plan.virtual_row_len, + row_split_counts = ?initial_plan.row_split_counts, + witness_digit_parts = initial_plan.config.witness_digit_parts, + witness_digit_bits = initial_plan.config.witness_digit_bits, + aux_digit_parts = initial_plan.config.aux_digit_parts, + aux_digit_bits = initial_plan.config.aux_digit_bits, + inner_commit_rank = initial_plan.config.inner_commit_rank, + outer_commit_rank = initial_plan.config.outer_commit_rank, + tail = initial_plan.config.tail, + "labrador initial witness" + ); + + while level_idx + 1 < LABRADOR_MAX_LEVELS { + let before_bytes = witness_size_bytes::(&witness); + if before_bytes == 0 || witness.rows().len() <= 1 { + break; + } + + let estimate = if level_idx == 0 { + estimate_selected_fold_step::(&witness, initial_plan)? + } else { + estimate_fold_step::(&witness, false).unwrap_or_else(|_| { + let row_lengths: Vec = witness.rows().iter().map(|r| r.len()).collect(); + let plan = trivial_plan(fallback_cfg, &row_lengths); + estimate_selected_fold_step::(&witness, &plan) + .expect("trivial fold estimate must succeed") + }) + }; + if estimate.transition_bytes >= before_bytes { + break; + } + let plan = estimate.plan; + let cfg = plan.config; + let virtual_row_count: usize = plan.row_split_counts.iter().sum(); + let setup = Arc::new(LabradorSetup::new( + &cfg, + virtual_row_count, + plan.virtual_row_len, + comkey_seed, + )); + + let mut attempt_transcript = transcript.clone(); + let fold = prove_level( + &witness, + &statement, + &cfg, + &plan, + &setup, + level_idx, + &mut attempt_transcript, + )?; + tracing::debug!( + current_bytes = before_bytes, + estimated_level_bytes = estimate.level_payload_bytes, + estimated_next_witness_bytes = estimate.next_witness_bytes, + estimated_candidate_bytes = estimate.transition_bytes, + accept = estimate.transition_bytes < before_bytes, + virtual_row_len = plan.virtual_row_len, + virtual_row_count, + row_split_counts = ?plan.row_split_counts, + witness_digit_parts = cfg.witness_digit_parts, + witness_digit_bits = cfg.witness_digit_bits, + aux_digit_parts = cfg.aux_digit_parts, + aux_digit_bits = cfg.aux_digit_bits, + inner_commit_rank = cfg.inner_commit_rank, + outer_commit_rank = cfg.outer_commit_rank, + tail = cfg.tail, + "labrador non-tail candidate" + ); + + *transcript = attempt_transcript; + levels.push(fold.level_proof); + statement = fold.statement; + witness = fold.next_witness; + fallback_cfg = cfg; + level_idx += 1; + } + + if level_idx + 1 < LABRADOR_MAX_LEVELS { + let tail_plan = plan_fold::(&witness, true).unwrap_or_else(|_| { + let row_lengths: Vec = witness.rows().iter().map(|r| r.len()).collect(); + trivial_plan( + LabradorReductionConfig { + tail: true, + outer_commit_rank: 0, + aux_digit_parts: 1, + aux_digit_bits: logq_bits::(), + ..fallback_cfg + }, + &row_lengths, + ) + }); + let baseline_bytes = witness_size_bytes::(&witness); + let tail_estimate = estimate_fold_step::(&witness, true).unwrap_or_else(|_| { + estimate_selected_fold_step::(&witness, &tail_plan) + .expect("tail trivial estimate must succeed") + }); + if tail_estimate.transition_bytes >= baseline_bytes { + return Ok(LabradorProof { + levels, + final_opening_witness: witness, + }); + } + let tail_cfg = tail_plan.config; + + let virtual_row_count: usize = tail_plan.row_split_counts.iter().sum(); + let tail_setup = Arc::new(LabradorSetup::new( + &tail_cfg, + virtual_row_count, + tail_plan.virtual_row_len, + comkey_seed, + )); + let mut tail_transcript = transcript.clone(); + if let Ok(tail) = prove_level( + &witness, + &statement, + &tail_cfg, + &tail_plan, + &tail_setup, + level_idx, + &mut tail_transcript, + ) { + tracing::debug!( + baseline_bytes, + estimated_level_bytes = tail_estimate.level_payload_bytes, + estimated_next_witness_bytes = tail_estimate.next_witness_bytes, + estimated_candidate_bytes = tail_estimate.transition_bytes, + accept = tail_estimate.transition_bytes < baseline_bytes, + virtual_row_len = tail_plan.virtual_row_len, + virtual_row_count, + row_split_counts = ?tail_plan.row_split_counts, + witness_digit_parts = tail_cfg.witness_digit_parts, + witness_digit_bits = tail_cfg.witness_digit_bits, + aux_digit_parts = tail_cfg.aux_digit_parts, + aux_digit_bits = tail_cfg.aux_digit_bits, + inner_commit_rank = tail_cfg.inner_commit_rank, + outer_commit_rank = tail_cfg.outer_commit_rank, + tail = tail_cfg.tail, + "labrador final tail compare" + ); + levels.push(tail.level_proof); + witness = tail.next_witness; + *transcript = tail_transcript; + } + } + + Ok(LabradorProof { + levels, + final_opening_witness: witness, + }) +} + +fn witness_size_bytes( + witness: &LabradorWitness, +) -> usize { + FlatLabradorWitness::from_typed(witness).serialized_size(Compress::No) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::algebra::ring::CyclotomicRing; + use crate::protocol::labrador::{verify, LabradorStatement}; + use crate::protocol::transcript::labels::DOMAIN_LABRADOR_RECURSION; + use crate::protocol::transcript::Blake2bTranscript; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + fn sample_witness() -> LabradorWitness { + let row = |len: usize| -> Vec> { + (0..len) + .map(|i| { + CyclotomicRing::from_coefficients(std::array::from_fn(|j| { + F::from_i64(((i + j) as i64 % 7) - 3) + })) + }) + .collect() + }; + LabradorWitness::new(vec![row(6), row(6), row(6)]) + } + + #[test] + fn prover_loop_returns_final_opening_witness() { + let statement = LabradorStatement { + inner_opening_payload: Vec::new(), + linear_garbage_payload: Vec::new(), + challenges: Vec::new(), + constraints: Vec::new(), + reduced_constraints: None, + witness_norm_bound_sq: 1024, + }; + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let proof = prove(sample_witness(), &statement, &[1u8; 32], &mut transcript).unwrap(); + assert!(!proof.final_opening_witness.rows().is_empty()); + assert!(proof.levels.len() <= LABRADOR_MAX_LEVELS); + } + + #[test] + fn prover_proof_verifies() { + let statement = LabradorStatement { + inner_opening_payload: Vec::new(), + linear_garbage_payload: Vec::new(), + challenges: Vec::new(), + constraints: Vec::new(), + reduced_constraints: None, + witness_norm_bound_sq: 1 << 30, + }; + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let proof = prove(sample_witness(), &statement, &[1u8; 32], &mut transcript).unwrap(); + + let mut verify_transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + verify(&statement, &proof, &[1u8; 32], &mut verify_transcript).unwrap(); + } +} diff --git a/src/protocol/labrador/setup.rs b/src/protocol/labrador/setup.rs new file mode 100644 index 00000000..cda020a6 --- /dev/null +++ b/src/protocol/labrador/setup.rs @@ -0,0 +1,233 @@ +//! Labrador commitment key setup. + +use crate::algebra::ring::CyclotomicRing; +use crate::protocol::commitment::utils::crt_ntt::{build_ntt_slot, NttSlotCache}; +use crate::protocol::commitment::utils::flat_matrix::FlatMatrix; +use crate::protocol::labrador::comkey::{derive_extendable_comkey_matrix, LabradorComKeySeed}; +use crate::protocol::labrador::commit::OUTER_NTT_LOG_BASIS; +use crate::protocol::labrador::types::LabradorReductionConfig; +use crate::{CanonicalField, FieldCore, FieldSampling}; +use std::sync::Arc; + +/// Matrix-only Labrador setup shared by prover and verifier recursion. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorSetupMatrices { + /// Inner commitment matrix A. + pub a_mat: Vec>>, + /// Outer commitment matrix B. Not needed for the last fold proof. + pub b_mat: Vec>>, + /// Linear-garbage commitment matrix D. Not needed for the last fold proof. + pub d_mat: Vec>>, +} + +impl LabradorSetupMatrices { + /// Derive the commitment-key matrices for a single Labrador level. + #[tracing::instrument(skip_all, name = "labrador::setup_matrices")] + pub fn new( + config: &LabradorReductionConfig, + num_witness_rows: usize, + max_witness_len: usize, + comkey_seed: &LabradorComKeySeed, + ) -> Self { + let a_mat = derive_extendable_comkey_matrix::( + config.inner_commit_rank, + max_witness_len, + comkey_seed, + b"labrador/comkey/A", + ); + + let (b_mat, d_mat) = if config.outer_commit_rank > 0 && !config.tail { + let inner_opening_digits_len = + num_witness_rows * config.inner_commit_rank * config.aux_digit_parts; + let linear_garbage_digits_len = + num_witness_rows * (num_witness_rows + 1) / 2 * config.aux_digit_parts; + + let b = derive_extendable_comkey_matrix::( + config.outer_commit_rank, + inner_opening_digits_len, + comkey_seed, + b"labrador/comkey/B", + ); + let d = derive_extendable_comkey_matrix::( + config.outer_commit_rank, + linear_garbage_digits_len, + comkey_seed, + b"labrador/comkey/U2", + ); + (b, d) + } else { + (Vec::new(), Vec::new()) + }; + + Self { + a_mat, + b_mat, + d_mat, + } + } +} + +#[inline] +fn pow2_field(exp: u32) -> F { + let two = F::one() + F::one(); + let mut acc = F::one(); + for _ in 0..exp { + acc = acc * two; + } + acc +} + +#[inline] +fn max_linear_garbage_ntt_levels(config: &LabradorReductionConfig) -> usize { + if config.aux_digit_parts == 0 || config.aux_digit_bits == 0 { + return 0; + } + let modulus = (-F::one()).to_canonical_u128() + 1; + let field_bits = (u128::BITS - modulus.leading_zeros()) as usize; + let aux_bits = config.aux_digit_bits; + let carry_shift = aux_bits.saturating_mul(config.aux_digit_parts.saturating_sub(1)); + let carry_bits = field_bits.saturating_sub(carry_shift).max(1); + let max_digit_bits = aux_bits.max(carry_bits); + max_digit_bits.div_ceil(OUTER_NTT_LOG_BASIS as usize) + 1 +} + +/// Pre-derived commitment-key matrices for one Labrador level. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorSetup { + /// Shared matrix payload for prover and verifier-side recursion. + pub matrices: Arc>, + /// Precomputed NTT caches for D scaled by `2^{k*OUTER_NTT_LOG_BASIS}`. + /// + /// Index `k` corresponds to level `k` in i8-basis decomposition of + /// linear-garbage digits. + pub ntt_d_scaled_levels: Vec>, +} + +impl LabradorSetup { + /// Derive all commitment-key matrices for a single Labrador level. + #[tracing::instrument(skip_all, name = "labrador::setup")] + pub fn new( + config: &LabradorReductionConfig, + num_witness_rows: usize, + max_witness_len: usize, + comkey_seed: &LabradorComKeySeed, + ) -> Self { + let matrices = Arc::new(LabradorSetupMatrices::new( + config, + num_witness_rows, + max_witness_len, + comkey_seed, + )); + let ntt_d_scaled_levels = if matrices.d_mat.is_empty() { + Vec::new() + } else { + let max_levels = max_linear_garbage_ntt_levels::(config); + let mut slots = Vec::with_capacity(max_levels); + let scale_step = pow2_field::(OUTER_NTT_LOG_BASIS); + let mut scale = F::one(); + for _ in 0..max_levels { + let scaled_d: Vec>> = matrices + .d_mat + .iter() + .map(|row| row.iter().map(|entry| entry.scale(&scale)).collect()) + .collect(); + let scaled_d_flat = FlatMatrix::from_ring_matrix(&scaled_d); + match build_ntt_slot(scaled_d_flat.view::()) { + Ok(slot) => slots.push(slot), + Err(err) => { + tracing::debug!( + error = %err, + "failed to precompute Labrador D-matrix scaled NTT caches; using runtime fallback" + ); + slots.clear(); + break; + } + } + scale = scale * scale_step; + } + slots + }; + Self { + matrices, + ntt_d_scaled_levels, + } + } + + /// Return the matrix-only setup used by verifier-side recursion. + pub fn verifier_setup(&self) -> Arc> { + Arc::clone(&self.matrices) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::protocol::labrador::types::LabradorReductionConfig; + + type F = Fp64<4294967197>; + const D: usize = 64; + const SEED: [u8; 32] = [7u8; 32]; + + const NUM_ROWS: usize = 5; + const MAX_LEN: usize = 12; + + fn standard_config() -> LabradorReductionConfig { + LabradorReductionConfig { + witness_digit_parts: 2, + witness_digit_bits: 8, + aux_digit_parts: 3, + aux_digit_bits: 10, + inner_commit_rank: 4, + outer_commit_rank: 3, + tail: false, + } + } + + fn tail_config() -> LabradorReductionConfig { + LabradorReductionConfig { + tail: true, + outer_commit_rank: 0, + ..standard_config() + } + } + + #[test] + fn standard_setup_matrix_dimensions() { + let cfg = standard_config(); + let setup = LabradorSetup::::new(&cfg, NUM_ROWS, MAX_LEN, &SEED); + + assert_eq!(setup.matrices.a_mat.len(), cfg.inner_commit_rank); + assert!(setup.matrices.a_mat.iter().all(|row| row.len() == MAX_LEN)); + + let inner_opening_digits_len = NUM_ROWS * cfg.inner_commit_rank * cfg.aux_digit_parts; + assert_eq!(setup.matrices.b_mat.len(), cfg.outer_commit_rank); + assert!(setup + .matrices + .b_mat + .iter() + .all(|row| row.len() == inner_opening_digits_len)); + + let linear_garbage_digits_len = NUM_ROWS * (NUM_ROWS + 1) / 2 * cfg.aux_digit_parts; + assert_eq!(setup.matrices.d_mat.len(), cfg.outer_commit_rank); + assert!(setup + .matrices + .d_mat + .iter() + .all(|row| row.len() == linear_garbage_digits_len)); + assert!(!setup.ntt_d_scaled_levels.is_empty()); + } + + #[test] + fn tail_setup_has_empty_outer_matrices() { + let cfg = tail_config(); + let setup = LabradorSetup::::new(&cfg, NUM_ROWS, MAX_LEN, &SEED); + + assert_eq!(setup.matrices.a_mat.len(), cfg.inner_commit_rank); + assert!(setup.matrices.a_mat.iter().all(|row| row.len() == MAX_LEN)); + + assert!(setup.matrices.b_mat.is_empty()); + assert!(setup.matrices.d_mat.is_empty()); + assert!(setup.ntt_d_scaled_levels.is_empty()); + } +} diff --git a/src/protocol/labrador/transcript.rs b/src/protocol/labrador/transcript.rs new file mode 100644 index 00000000..ca612269 --- /dev/null +++ b/src/protocol/labrador/transcript.rs @@ -0,0 +1,377 @@ +//! Canonical transcript schedule helpers for Greyhound/Labrador. +//! +//! These helpers centralize byte-level encoding for prover/verifier replay: +//! dimension binding and nonce encoding. + +use crate::algebra::ring::CyclotomicRing; +use crate::error::HachiError; +use crate::protocol::labrador::guardrails::checked_usize_to_u64; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, HachiSerialize}; + +/// Greyhound evaluation transcript context. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct GreyhoundEvalTranscriptContext { + /// Matrix rows for reshaped witness. + pub m_rows: usize, + /// Matrix columns for reshaped witness. + pub n_cols: usize, + /// Number of "inner" multilinear variables. + pub inner_vars: usize, + /// Length of the evaluation point vector. + pub eval_point_len: usize, +} + +/// Labrador level transcript context. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorLevelTranscriptContext { + /// Zero-based recursion level index. + pub level_index: usize, + /// Whether this level is in tail mode. + pub tail: bool, + /// Input witness row lengths (`n[i]` in the C reference). + pub input_row_lengths: Vec, + /// Witness decomposition parts (formerly `f`). + pub witness_digit_parts: usize, + /// Witness decomposition basis log2 (formerly `b`). + pub witness_digit_bits: usize, + /// Auxiliary decomposition parts (formerly `fu`). + pub aux_digit_parts: usize, + /// Auxiliary decomposition basis log2 (formerly `bu`). + pub aux_digit_bits: usize, + /// Inner commitment rank (formerly `kappa`). + pub inner_commit_rank: usize, + /// Outer commitment rank (formerly `kappa1`). + pub outer_commit_rank: usize, +} + +fn append_u64_le(buf: &mut Vec, value: u64) { + buf.extend_from_slice(&value.to_le_bytes()); +} + +fn encode_usize_slice(buf: &mut Vec, values: &[usize]) -> Result<(), HachiError> { + append_u64_le(buf, checked_usize_to_u64(values.len(), "slice length")?); + for &v in values { + append_u64_le(buf, checked_usize_to_u64(v, "slice element")?); + } + Ok(()) +} + +fn encode_greyhound_eval_context( + ctx: &GreyhoundEvalTranscriptContext, +) -> Result, HachiError> { + let mut bytes = Vec::with_capacity(2 + 8 * 4); + // Versioned payload for deterministic replay stability. + bytes.push(1u8); + bytes.push(0u8); // backend id removed + append_u64_le(&mut bytes, checked_usize_to_u64(ctx.m_rows, "m_rows")?); + append_u64_le(&mut bytes, checked_usize_to_u64(ctx.n_cols, "n_cols")?); + append_u64_le( + &mut bytes, + checked_usize_to_u64(ctx.inner_vars, "inner_vars")?, + ); + append_u64_le( + &mut bytes, + checked_usize_to_u64(ctx.eval_point_len, "eval_point_len")?, + ); + Ok(bytes) +} + +fn encode_labrador_level_context( + ctx: &LabradorLevelTranscriptContext, +) -> Result, HachiError> { + let mut bytes = Vec::with_capacity(4 + 8 * (7 + ctx.input_row_lengths.len())); + // Versioned payload for deterministic replay stability. + bytes.push(1u8); + bytes.push(u8::from(ctx.tail)); + bytes.push(0u8); // backend id removed + bytes.push(0u8); // reserved + append_u64_le( + &mut bytes, + checked_usize_to_u64(ctx.level_index, "level_index")?, + ); + append_u64_le( + &mut bytes, + checked_usize_to_u64(ctx.witness_digit_parts, "witness_digit_parts")?, + ); + append_u64_le( + &mut bytes, + checked_usize_to_u64(ctx.witness_digit_bits, "witness_digit_bits")?, + ); + append_u64_le( + &mut bytes, + checked_usize_to_u64(ctx.aux_digit_parts, "aux_digit_parts")?, + ); + append_u64_le( + &mut bytes, + checked_usize_to_u64(ctx.aux_digit_bits, "aux_digit_bits")?, + ); + append_u64_le( + &mut bytes, + checked_usize_to_u64(ctx.inner_commit_rank, "inner_commit_rank")?, + ); + append_u64_le( + &mut bytes, + checked_usize_to_u64(ctx.outer_commit_rank, "outer_commit_rank")?, + ); + encode_usize_slice(&mut bytes, &ctx.input_row_lengths)?; + Ok(bytes) +} + +/// Absorb canonical Greyhound evaluation context bytes. +/// +/// # Errors +/// +/// Returns an error if any dimension does not fit in `u64`. +pub fn absorb_greyhound_eval_context( + transcript: &mut T, + ctx: &GreyhoundEvalTranscriptContext, +) -> Result<(), HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + let bytes = encode_greyhound_eval_context(ctx)?; + transcript.append_bytes(labels::ABSORB_GREYHOUND_EVAL_CONTEXT, &bytes); + Ok(()) +} + +/// Absorb canonical Greyhound evaluation claim bytes (`r` and ring-valued `v`). +/// +/// Absorbs each coordinate of the evaluation point, then all D coefficients of +/// the ring-valued evaluation target. +pub fn absorb_greyhound_eval_claim( + transcript: &mut T, + eval_point: &[F], + eval_target: &CyclotomicRing, +) where + F: FieldCore + CanonicalField, + T: Transcript, +{ + for coord in eval_point { + transcript.append_field(labels::ABSORB_GREYHOUND_EVAL_POINT, coord); + } + for coeff in eval_target.coefficients() { + transcript.append_field(labels::ABSORB_GREYHOUND_EVAL_VALUE, coeff); + } +} + +/// Absorb Greyhound commitment payload `u2`. +pub fn absorb_greyhound_u2(transcript: &mut T, u2: &S) +where + F: FieldCore + CanonicalField, + T: Transcript, + S: HachiSerialize, +{ + transcript.append_serde(labels::ABSORB_GREYHOUND_U2, u2); +} + +/// Sample a Greyhound fold challenge. +pub fn sample_greyhound_fold_challenge(transcript: &mut T) -> F +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + transcript.challenge_scalar(labels::CHALLENGE_GREYHOUND_FOLD) +} + +/// Absorb canonical Labrador level context bytes. +/// +/// # Errors +/// +/// Returns an error if any dimension does not fit in `u64`. +#[tracing::instrument(skip_all, name = "labrador::absorb_level_context")] +pub fn absorb_labrador_level_context( + transcript: &mut T, + ctx: &LabradorLevelTranscriptContext, +) -> Result<(), HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + let bytes = encode_labrador_level_context(ctx)?; + transcript.append_bytes(labels::ABSORB_LABRADOR_RECURSION_CONTEXT, &bytes); + Ok(()) +} + +/// Absorb Labrador JL projection vector bytes (`i64` little-endian). +#[tracing::instrument(skip_all, name = "labrador::absorb_jl_projection")] +pub fn absorb_labrador_jl_projection(transcript: &mut T, projection: &[i64; 256]) +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + let mut bytes = Vec::with_capacity(256 * std::mem::size_of::()); + for coeff in projection { + bytes.extend_from_slice(&coeff.to_le_bytes()); + } + transcript.append_bytes(labels::ABSORB_LABRADOR_JL_PROJECTION, &bytes); +} + +/// Absorb Labrador JL nonce (`u64` little-endian). +pub fn absorb_labrador_jl_nonce(transcript: &mut T, nonce: u64) +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + transcript.append_bytes(labels::ABSORB_LABRADOR_JL_NONCE, &nonce.to_le_bytes()); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::protocol::transcript::Blake2bTranscript; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + fn scalar_ring(s: F) -> CyclotomicRing { + CyclotomicRing::from_coefficients(std::array::from_fn( + |i| { + if i == 0 { + s + } else { + F::zero() + } + }, + )) + } + + // Fixed test nonces for deterministic replay. + const TEST_NONCE_LOW: u64 = 1; + const TEST_NONCE_HIGH: u64 = 2; + const TEST_NONCE_REPLAY: u64 = 42; + + #[test] + fn greyhound_context_replay_is_deterministic() { + let ctx = GreyhoundEvalTranscriptContext { + m_rows: 64, + n_cols: 128, + inner_vars: 6, + eval_point_len: 13, + }; + let eval_point: Vec = (0..13).map(|i| F::from_u64((i + 3) as u64)).collect(); + let eval_target = scalar_ring(F::from_u64(77)); + let u2 = vec![F::from_u64(9), F::from_u64(11), F::from_u64(13)]; + + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_GREYHOUND_EVAL); + absorb_greyhound_eval_context::(&mut t1, &ctx).unwrap(); + absorb_greyhound_eval_claim::(&mut t1, &eval_point, &eval_target); + absorb_greyhound_u2::(&mut t1, &u2); + let c1 = sample_greyhound_fold_challenge::(&mut t1); + + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_GREYHOUND_EVAL); + absorb_greyhound_eval_context::(&mut t2, &ctx).unwrap(); + absorb_greyhound_eval_claim::(&mut t2, &eval_point, &eval_target); + absorb_greyhound_u2::(&mut t2, &u2); + let c2 = sample_greyhound_fold_challenge::(&mut t2); + + assert_eq!(c1, c2, "same transcript schedule must replay identically"); + } + + #[test] + fn greyhound_context_binds_dimensions() { + let eval_point: Vec = (0..10).map(|i| F::from_u64((i + 5) as u64)).collect(); + let eval_target = scalar_ring(F::from_u64(17)); + let u2 = vec![F::from_u64(1), F::from_u64(2)]; + + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_GREYHOUND_EVAL); + absorb_greyhound_eval_context::( + &mut t1, + &GreyhoundEvalTranscriptContext { + m_rows: 32, + n_cols: 32, + inner_vars: 5, + eval_point_len: 10, + }, + ) + .unwrap(); + absorb_greyhound_eval_claim::(&mut t1, &eval_point, &eval_target); + absorb_greyhound_u2::(&mut t1, &u2); + let c1 = sample_greyhound_fold_challenge::(&mut t1); + + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_GREYHOUND_EVAL); + absorb_greyhound_eval_context::( + &mut t2, + &GreyhoundEvalTranscriptContext { + m_rows: 32, + n_cols: 64, // dimension changed + inner_vars: 5, + eval_point_len: 10, + }, + ) + .unwrap(); + absorb_greyhound_eval_claim::(&mut t2, &eval_point, &eval_target); + absorb_greyhound_u2::(&mut t2, &u2); + let c2 = sample_greyhound_fold_challenge::(&mut t2); + + assert_ne!( + c1, c2, + "dimension changes must affect transcript challenges" + ); + } + + #[test] + fn labrador_context_and_nonce_replay_is_deterministic() { + let ctx = LabradorLevelTranscriptContext { + level_index: 2, + tail: false, + input_row_lengths: vec![1024, 2048, 128, 64], + witness_digit_parts: 2, + witness_digit_bits: 8, + aux_digit_parts: 3, + aux_digit_bits: 10, + inner_commit_rank: 12, + outer_commit_rank: 6, + }; + let projection = std::array::from_fn(|i| i as i64 - 127); + let nonce = TEST_NONCE_REPLAY; + + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_LABRADOR_RECURSION); + absorb_labrador_level_context::(&mut t1, &ctx).unwrap(); + absorb_labrador_jl_projection::(&mut t1, &projection); + absorb_labrador_jl_nonce::(&mut t1, nonce); + let c1 = t1.challenge_scalar(labels::CHALLENGE_LABRADOR_AGGREGATION); + + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_LABRADOR_RECURSION); + absorb_labrador_level_context::(&mut t2, &ctx).unwrap(); + absorb_labrador_jl_projection::(&mut t2, &projection); + absorb_labrador_jl_nonce::(&mut t2, nonce); + let c2 = t2.challenge_scalar(labels::CHALLENGE_LABRADOR_AGGREGATION); + + assert_eq!(c1, c2, "identical schedule must be replay deterministic"); + } + + #[test] + fn labrador_nonce_binding_changes_challenge() { + let ctx = LabradorLevelTranscriptContext { + level_index: 0, + tail: true, + input_row_lengths: vec![64, 32], + witness_digit_parts: 1, + witness_digit_bits: 8, + aux_digit_parts: 2, + aux_digit_bits: 10, + inner_commit_rank: 4, + outer_commit_rank: 0, + }; + let projection = [0i64; 256]; + + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_LABRADOR_RECURSION); + absorb_labrador_level_context::(&mut t1, &ctx).unwrap(); + absorb_labrador_jl_projection::(&mut t1, &projection); + absorb_labrador_jl_nonce::(&mut t1, TEST_NONCE_LOW); + let c1 = t1.challenge_scalar(labels::CHALLENGE_LABRADOR_AGGREGATION); + + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_LABRADOR_RECURSION); + absorb_labrador_level_context::(&mut t2, &ctx).unwrap(); + absorb_labrador_jl_projection::(&mut t2, &projection); + absorb_labrador_jl_nonce::(&mut t2, TEST_NONCE_HIGH); + let c2 = t2.challenge_scalar(labels::CHALLENGE_LABRADOR_AGGREGATION); + + assert_ne!(c1, c2, "nonce must be transcript-binding"); + } +} diff --git a/src/protocol/labrador/types.rs b/src/protocol/labrador/types.rs new file mode 100644 index 00000000..cf4fb065 --- /dev/null +++ b/src/protocol/labrador/types.rs @@ -0,0 +1,206 @@ +//! Core Labrador witness/statement/proof types. + +use crate::algebra::ring::CyclotomicRing; +use crate::algebra::SparseChallenge; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::labrador::constraints::LabradorConstraint; +use crate::protocol::labrador::setup::LabradorSetupMatrices; +use crate::{cfg_fold_reduce, CanonicalField, FieldCore}; +use std::sync::Arc; + +/// Witness object for a Labrador statement, holding the `s_i` row vectors. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct LabradorWitness { + rows: Vec>>, +} + +impl LabradorWitness { + /// Build a witness from row vectors, all of which must share the same length. + /// + /// # Panics + /// + /// Panics if any two rows differ in length. + pub fn new(rows: Vec>>) -> Self { + if let Some(first_len) = rows.first().map(|r| r.len()) { + assert!( + rows.iter().all(|r| r.len() == first_len), + "all witness rows must have the same length" + ); + } + Self { rows } + } + + /// Build a witness without asserting uniform row length. + /// + /// Use only where the protocol produces rows of mixed length + /// (e.g. z-decomposition rows plus an auxiliary row). + pub(crate) fn new_unchecked(rows: Vec>>) -> Self { + Self { rows } + } + + /// Borrow the underlying row slices. + pub fn rows(&self) -> &[Vec>] { + &self.rows + } +} + +impl LabradorWitness { + /// Squared coefficient norm summed over every ring element in the witness. + pub fn norm(&self) -> u128 { + cfg_fold_reduce!( + (0..self.rows.len()), + || 0u128, + |acc, i| { + let row_sum = self.rows[i] + .iter() + .map(|ring| ring.coeff_norm_sq()) + .fold(0u128, |a, v| a.saturating_add(v)); + acc.saturating_add(row_sum) + }, + |a, b| a.saturating_add(b) + ) + } +} + +/// Compact recipe for the next-level Labrador statement. +/// +/// This keeps the dominant recursive structure factored so the next level can +/// aggregate it directly without first materializing a full sparse constraint +/// vector. Explicit constraints are only reconstructed when they are actually +/// needed (for example, at terminal verification). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorReducedConstraintPlan { + /// Number of virtual input rows reduced at the previous level. + pub row_count: usize, + /// Length of each decomposed z-row in the next witness. + pub max_len: usize, + /// Reduction parameters that define the next witness layout. + pub config: LabradorReductionConfig, + /// Amortization challenges from the previous level. + pub challenges: Vec, + /// Amortized `sum_i c_i * phi_i` relation carried into the next level + /// (formerly `combined_phi`). + pub amortized_phi: Vec>, + /// Aggregated right-hand side for the diagonal relation + /// (formerly `b_total`). + pub aggregated_rhs: CyclotomicRing, + /// Commitment matrices needed to replay the reduced statement. + pub setup: Arc>, +} + +/// Public statement reduced to Labrador recursion. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorStatement { + /// Opening-side payload for the current round (formerly `u1`). + /// + /// This is an outer commitment in standard rounds and the raw opening-side + /// digits in tail mode. + pub inner_opening_payload: Vec>, + /// Linear-garbage-side payload for the current round (formerly `u2`). + /// + /// This is an outer commitment in standard rounds and the raw + /// linear-garbage digits in tail mode. + pub linear_garbage_payload: Vec>, + /// Amortization challenges (per input witness row). + pub challenges: Vec, + /// Sparse constraints checked by reducer/verifier. + pub constraints: Vec>, + /// Compact recursive statement representation used between Labrador levels. + pub reduced_constraints: Option>>, + /// Squared witness norm bound (formerly `beta_sq`). + pub witness_norm_bound_sq: u128, +} + +/// Per-level reduction parameters. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct LabradorReductionConfig { + /// Number of witness-side digit parts (formerly `f`). + pub witness_digit_parts: usize, + /// Bit width of each witness-side digit (formerly `b`). + pub witness_digit_bits: usize, + /// Number of auxiliary digit parts (formerly `fu`). + pub aux_digit_parts: usize, + /// Bit width of each auxiliary digit (formerly `bu`). + pub aux_digit_bits: usize, + /// Inner commitment rank (formerly `kappa`). + pub inner_commit_rank: usize, + /// Outer commitment rank (formerly `kappa1`, `0` in tail mode). + pub outer_commit_rank: usize, + /// Tail-mode marker. + pub tail: bool, +} + +/// One recursive Labrador level proof payload. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorLevelProof { + /// Whether this level uses tail semantics. + pub tail: bool, + /// Input row lengths (`n[i]` in C). + pub input_row_lengths: Vec, + /// Configuration selected for this level. + pub config: LabradorReductionConfig, + /// Virtual row length after reshaping (formerly `nn`). + pub virtual_row_len: usize, + /// Per-original-row split counts from the fold plan (formerly `nu`). + pub row_split_counts: Vec, + /// Opening-side payload for this level (formerly `u1`). + pub inner_opening_payload: Vec>, + /// Linear-garbage-side payload for this level (formerly `u2`). + pub linear_garbage_payload: Vec>, + /// JL projection vector. + pub jl_projection: [i64; 256], + /// JL nonce used to regenerate projection matrix. + pub jl_nonce: u64, + /// JL lift residuals with constant term zeroed in the proof + /// (formerly `bb`). + pub jl_lift_residuals: Vec>, + /// Output witness norm bound after reduction (formerly `norm_sq`). + pub next_witness_norm_sq: u128, +} + +/// Full recursive Labrador proof plus final clear opening witness. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorProof { + /// Recursive level payloads. + pub levels: Vec>, + /// Final clear witness opened at recursion termination. + pub final_opening_witness: LabradorWitness, +} + +impl LabradorLevelProof { + /// Serialized size of this level in bytes. + pub fn size(&self) -> usize { + let ring_bytes = std::mem::size_of::>(); + let ring_count = self.inner_opening_payload.len() + + self.linear_garbage_payload.len() + + self.jl_lift_residuals.len(); + ring_count * ring_bytes + + self.jl_projection.len() * std::mem::size_of::() + + std::mem::size_of::() // jl_nonce + + std::mem::size_of::() // next_witness_norm_sq + } +} + +impl LabradorProof { + /// Construct an empty proof (used when Labrador is disabled). + pub fn empty() -> Self { + Self { + levels: Vec::new(), + final_opening_witness: LabradorWitness { rows: Vec::new() }, + } + } + + /// Total serialized size of the proof in bytes. + pub fn size(&self) -> usize { + let ring_bytes = std::mem::size_of::>(); + let levels_size: usize = self.levels.iter().map(|l| l.size()).sum(); + let witness_rings: usize = self + .final_opening_witness + .rows + .iter() + .map(|r| r.len()) + .sum(); + levels_size + witness_rings * ring_bytes + } +} diff --git a/src/protocol/labrador/utils.rs b/src/protocol/labrador/utils.rs new file mode 100644 index 00000000..5a5230f7 --- /dev/null +++ b/src/protocol/labrador/utils.rs @@ -0,0 +1,40 @@ +//! Shared utility helpers for the Labrador sub-protocol. + +use crate::algebra::ring::CyclotomicRing; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::commitment::utils::linear::try_centered_i8_cache_from_ring_coeffs; +use crate::{CanonicalField, FieldCore, FromSmallInt}; + +pub(crate) fn mat_vec_mul( + mat: &[Vec>], + vec: &[CyclotomicRing], +) -> Vec> { + cfg_iter!(mat) + .map(|row| { + debug_assert_eq!(row.len(), vec.len()); + let mut acc = CyclotomicRing::::zero(); + for (a, x) in row.iter().zip(vec.iter()) { + acc += *a * *x; + } + acc + }) + .collect() +} + +pub(crate) fn try_centered_i8_rows( + rows: &[Vec>], +) -> Option>> { + rows.iter() + .map(|row| try_centered_i8_cache_from_ring_coeffs(row)) + .collect() +} + +pub(crate) fn pow2_field(exp: usize) -> F { + let two = F::from_u64(2); + let mut acc = F::one(); + for _ in 0..exp { + acc = acc * two; + } + acc +} diff --git a/src/protocol/labrador/verifier.rs b/src/protocol/labrador/verifier.rs new file mode 100644 index 00000000..0d4a389e --- /dev/null +++ b/src/protocol/labrador/verifier.rs @@ -0,0 +1,1559 @@ +//! Labrador verifier/reducer loop. + +use crate::algebra::ring::CyclotomicRing; +use crate::algebra::SparseChallenge; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::commitment::utils::linear::mat_vec_mul_crt_ntt_i8_many; +use crate::protocol::labrador::aggregation::{ + aggregate_jl_constraints_verifier, aggregate_statement, safe_to_use_scalar_randomness, +}; +use crate::protocol::labrador::comkey::LabradorComKeySeed; +use crate::protocol::labrador::constraints::{ + materialize_reduced_constraints, pair_index, LabradorConstraint, NextWitnessLayout, +}; +use crate::protocol::labrador::guardrails::LABRADOR_MAX_LEVELS; +use crate::protocol::labrador::johnson_lindenstrauss::LabradorJlMatrix; +use crate::protocol::labrador::setup::LabradorSetupMatrices; +use crate::protocol::labrador::transcript::{ + absorb_labrador_jl_projection, absorb_labrador_level_context, LabradorLevelTranscriptContext, +}; +use crate::protocol::labrador::types::{ + LabradorLevelProof, LabradorProof, LabradorReducedConstraintPlan, LabradorStatement, + LabradorWitness, +}; +use crate::protocol::labrador::utils::{mat_vec_mul, pow2_field, try_centered_i8_rows}; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::{ + challenge_ring_element, challenge_sparse_ring_elements_rejection_sampled, Transcript, +}; +use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt}; +use std::sync::Arc; + +/// Output of verifier-side Labrador reduction. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorVerifyResult { + /// Statement after replaying all reduction levels. + pub terminal_statement: LabradorStatement, + /// Final clear opening witness from the proof payload. + pub final_opening_witness: LabradorWitness, +} + +/// Verify Labrador proof and return terminal reduction state. +/// +/// Currently supports a single Labrador level; recursive reduction is +/// intentionally deferred until the folding statement update is implemented. +/// +/// # Errors +/// +/// Returns [`HachiError::InvalidProof`] on structural inconsistencies, +/// norm bound violations, or constraint failures. +#[tracing::instrument(skip_all, name = "labrador::verify")] +pub fn verify( + initial_statement: &LabradorStatement, + proof: &LabradorProof, + comkey_seed: &LabradorComKeySeed, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + T: Transcript, +{ + if proof.levels.len() > LABRADOR_MAX_LEVELS || proof.final_opening_witness.rows().is_empty() { + return Err(HachiError::InvalidProof); + } + + if proof.levels.is_empty() { + let final_norm = proof.final_opening_witness.norm(); + if final_norm > initial_statement.witness_norm_bound_sq { + return Err(HachiError::InvalidProof); + } + let constraints = explicit_constraints(initial_statement)?; + verify_constraints(&constraints, &proof.final_opening_witness)?; + return Ok(LabradorVerifyResult { + terminal_statement: initial_statement.clone(), + final_opening_witness: proof.final_opening_witness.clone(), + }); + } + + let mut statement = initial_statement.clone(); + let last_idx = proof.levels.len() - 1; + for (idx, level) in proof.levels.iter().enumerate() { + if level.tail { + if idx != last_idx { + return Err(HachiError::InvalidProof); + } + verify_tail_level( + &statement, + level, + &proof.final_opening_witness, + comkey_seed, + transcript, + idx, + )?; + return Ok(LabradorVerifyResult { + terminal_statement: statement, + final_opening_witness: proof.final_opening_witness.clone(), + }); + } + statement = reduce_statement(&statement, level, comkey_seed, transcript, idx)?; + } + + let final_norm = proof.final_opening_witness.norm(); + if final_norm > statement.witness_norm_bound_sq { + return Err(HachiError::InvalidProof); + } + let constraints = explicit_constraints(&statement)?; + verify_constraints(&constraints, &proof.final_opening_witness)?; + + Ok(LabradorVerifyResult { + terminal_statement: statement, + final_opening_witness: proof.final_opening_witness.clone(), + }) +} + +#[tracing::instrument(skip_all, name = "labrador::explicit_constraints")] +fn explicit_constraints( + statement: &LabradorStatement, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, +{ + if let Some(plan) = statement.reduced_constraints.as_deref() { + materialize_reduced_constraints( + plan, + &statement.inner_opening_payload, + &statement.linear_garbage_payload, + ) + } else { + Ok(statement.constraints.clone()) + } +} + +#[tracing::instrument( + skip_all, + name = "labrador::reduce_statement", + fields(level_index, tail = level.tail) +)] +fn reduce_statement( + statement: &LabradorStatement, + level: &LabradorLevelProof, + comkey_seed: &LabradorComKeySeed, + transcript: &mut T, + level_index: usize, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + T: Transcript, +{ + let virtual_row_count = validate_level_shape(level, false)?; + let virtual_row_len = level.virtual_row_len; + let virt_row_lengths = vec![virtual_row_len; virtual_row_count]; + + absorb_labrador_level_context( + transcript, + &LabradorLevelTranscriptContext { + level_index, + tail: level.tail, + input_row_lengths: level.input_row_lengths.clone(), + witness_digit_parts: level.config.witness_digit_parts, + witness_digit_bits: level.config.witness_digit_bits, + aux_digit_parts: level.config.aux_digit_parts, + aux_digit_bits: level.config.aux_digit_bits, + inner_commit_rank: level.config.inner_commit_rank, + outer_commit_rank: level.config.outer_commit_rank, + }, + )?; + transcript.append_serde( + labels::ABSORB_LABRADOR_INNER_OPENING_PAYLOAD, + &level.inner_opening_payload, + ); + + let total_len: usize = virt_row_lengths.iter().sum(); + let jl_cols = total_len * D; + let jl_matrix = + LabradorJlMatrix::replay_nonce_search::(transcript, level.jl_nonce, jl_cols)?; + absorb_labrador_jl_projection(transcript, &level.jl_projection); + + let (phi_jl_flat, b_jl) = aggregate_jl_constraints_verifier( + &virt_row_lengths, + &level.jl_projection, + &jl_matrix, + &level.jl_lift_residuals, + transcript, + )?; + let explicit_aggregation = if statement.reduced_constraints.is_none() { + Some(aggregate_statement( + statement, + &level.input_row_lengths, + transcript, + )?) + } else { + None + }; + let reduced_aggregation = statement + .reduced_constraints + .as_deref() + .map(|plan| prepare_reduced_statement_aggregation(statement, plan, transcript)) + .transpose()?; + + transcript.append_serde( + labels::ABSORB_LABRADOR_LINEAR_GARBAGE_PAYLOAD, + &level.linear_garbage_payload, + ); + let challenges = replay_amortize_challenges::(transcript, virtual_row_count)?; + tracing::debug!( + level_index, + tail = false, + ?challenges, + "labrador verifier amortize challenges" + ); + let mut amortized_phi = + if let Some((phi_stmt_orig, _statement_rhs)) = explicit_aggregation.as_ref() { + let phi_stmt = reshape_phi_verifier::( + phi_stmt_orig, + &level.input_row_lengths, + &level.row_split_counts, + virtual_row_len, + )?; + let mut phi_total = phi_stmt; + add_phi_flat_in_place(&mut phi_total, &phi_jl_flat)?; + combine_virtual_rows(&phi_total, &challenges, virtual_row_len)? + } else { + let plan = statement + .reduced_constraints + .as_deref() + .ok_or(HachiError::InvalidProof)?; + let aggregation = reduced_aggregation + .as_ref() + .ok_or(HachiError::InvalidProof)?; + let mut amortized_phi = finalize_reduced_statement_aggregation( + plan, + aggregation, + &level.input_row_lengths, + &level.row_split_counts, + virtual_row_len, + &challenges, + )?; + let amortized_phi_jl = combine_flat_rows(&phi_jl_flat, &challenges, virtual_row_len)?; + add_amortized_phi_in_place(&mut amortized_phi, &amortized_phi_jl)?; + amortized_phi + }; + let statement_rhs = if let Some((_, statement_rhs)) = explicit_aggregation { + statement_rhs + } else { + reduced_aggregation + .as_ref() + .ok_or(HachiError::InvalidProof)? + .aggregated_rhs + }; + let aggregated_rhs = statement_rhs + b_jl; + + let setup = Arc::new(LabradorSetupMatrices::new( + &level.config, + virtual_row_count, + virtual_row_len, + comkey_seed, + )); + let reduced_constraints = LabradorReducedConstraintPlan { + row_count: virt_row_lengths.len(), + max_len: virtual_row_len, + config: level.config, + challenges: challenges.clone(), + amortized_phi: std::mem::take(&mut amortized_phi), + aggregated_rhs, + setup, + }; + + Ok(LabradorStatement { + inner_opening_payload: level.inner_opening_payload.clone(), + linear_garbage_payload: level.linear_garbage_payload.clone(), + challenges, + constraints: Vec::new(), + reduced_constraints: Some(Box::new(reduced_constraints)), + witness_norm_bound_sq: level.next_witness_norm_sq, + }) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct ReshapeCombineSegment { + src_start: usize, + dst_start: usize, + len: usize, + challenge_idx: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct ReducedStatementAggregationReplay { + b_alphas: Vec>, + d_alphas: Vec>, + a_alphas: Vec>, + alpha_lg: CyclotomicRing, + alpha_diag: CyclotomicRing, + aggregated_rhs: CyclotomicRing, +} + +#[inline] +fn scalar_to_ring(scalar: F) -> CyclotomicRing { + let mut coeffs = [F::zero(); D]; + coeffs[0] = scalar; + CyclotomicRing::from_coefficients(coeffs) +} + +#[tracing::instrument(skip_all, name = "labrador::prepare_reduced_statement_aggregation")] +fn prepare_reduced_statement_aggregation( + statement: &LabradorStatement, + plan: &LabradorReducedConstraintPlan, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + if plan.setup.a_mat.len() != plan.config.inner_commit_rank + || plan.setup.b_mat.len() != statement.inner_opening_payload.len() + || plan.setup.d_mat.len() != statement.linear_garbage_payload.len() + { + return Err(HachiError::InvalidProof); + } + + let mut aggregated_rhs = CyclotomicRing::::zero(); + + let b_alphas: Vec> = statement + .inner_opening_payload + .iter() + .map(|target| { + if safe_to_use_scalar_randomness::() { + let alpha_scalar = + transcript.challenge_scalar(labels::CHALLENGE_LABRADOR_AGGREGATION); + aggregated_rhs += target.scale(&alpha_scalar); + scalar_to_ring::(alpha_scalar) + } else { + let alpha = + challenge_ring_element(transcript, labels::CHALLENGE_LABRADOR_AGGREGATION); + aggregated_rhs += alpha * *target; + alpha + } + }) + .collect(); + let d_alphas: Vec> = statement + .linear_garbage_payload + .iter() + .map(|target| { + if safe_to_use_scalar_randomness::() { + let alpha_scalar = + transcript.challenge_scalar(labels::CHALLENGE_LABRADOR_AGGREGATION); + aggregated_rhs += target.scale(&alpha_scalar); + scalar_to_ring::(alpha_scalar) + } else { + let alpha = + challenge_ring_element(transcript, labels::CHALLENGE_LABRADOR_AGGREGATION); + aggregated_rhs += alpha * *target; + alpha + } + }) + .collect(); + let a_alphas = (0..plan.config.inner_commit_rank) + .map(|_| { + if safe_to_use_scalar_randomness::() { + let alpha_scalar = + transcript.challenge_scalar(labels::CHALLENGE_LABRADOR_AGGREGATION); + scalar_to_ring::(alpha_scalar) + } else { + challenge_ring_element(transcript, labels::CHALLENGE_LABRADOR_AGGREGATION) + } + }) + .collect(); + let alpha_lg = if safe_to_use_scalar_randomness::() { + let alpha_scalar = transcript.challenge_scalar(labels::CHALLENGE_LABRADOR_AGGREGATION); + scalar_to_ring::(alpha_scalar) + } else { + challenge_ring_element(transcript, labels::CHALLENGE_LABRADOR_AGGREGATION) + }; + let alpha_diag = if safe_to_use_scalar_randomness::() { + let alpha_scalar = transcript.challenge_scalar(labels::CHALLENGE_LABRADOR_AGGREGATION); + aggregated_rhs += plan.aggregated_rhs.scale(&alpha_scalar); + scalar_to_ring::(alpha_scalar) + } else { + let alpha_diag = challenge_ring_element(transcript, labels::CHALLENGE_LABRADOR_AGGREGATION); + aggregated_rhs += alpha_diag * plan.aggregated_rhs; + alpha_diag + }; + + Ok(ReducedStatementAggregationReplay { + b_alphas, + d_alphas, + a_alphas, + alpha_lg, + alpha_diag, + aggregated_rhs, + }) +} + +fn build_reshape_combine_plan( + row_lengths: &[usize], + row_split_counts: &[usize], + virtual_row_len: usize, + challenges: &[SparseChallenge], +) -> Result>, HachiError> { + let virtual_row_count = + validate_reshape_metadata(row_lengths, row_split_counts, virtual_row_len)?; + if challenges.len() != virtual_row_count { + return Err(HachiError::InvalidProof); + } + + let mut row_segments = vec![Vec::new(); row_lengths.len()]; + let mut group_rows = Vec::new(); + let mut challenge_cursor = 0usize; + + for (row_idx, &row_len) in row_lengths.iter().enumerate() { + group_rows.push((row_idx, row_len)); + let splits = row_split_counts[row_idx]; + if splits == 0 { + continue; + } + + let group_len: usize = group_rows.iter().map(|(_, len)| *len).sum(); + if group_len > splits * virtual_row_len { + return Err(HachiError::InvalidProof); + } + + let mut group_pos = 0usize; + for &(group_row_idx, len) in &group_rows { + let mut row_offset = 0usize; + while row_offset < len { + let challenge_idx = challenge_cursor + group_pos / virtual_row_len; + let dst_start = group_pos % virtual_row_len; + let take = (virtual_row_len - dst_start).min(len - row_offset); + row_segments[group_row_idx].push(ReshapeCombineSegment { + src_start: row_offset, + dst_start, + len: take, + challenge_idx, + }); + row_offset += take; + group_pos += take; + } + } + + challenge_cursor += splits; + group_rows.clear(); + } + + if !group_rows.is_empty() || challenge_cursor != virtual_row_count { + return Err(HachiError::InvalidProof); + } + + Ok(row_segments) +} + +fn accumulate_row_slice_into_amortized_phi( + amortized_phi: &mut [CyclotomicRing], + segments: &[ReshapeCombineSegment], + challenges: &[SparseChallenge], + row_offset: usize, + coeffs: &[CyclotomicRing], + alpha: &CyclotomicRing, +) -> Result<(), HachiError> { + let row_end = row_offset + .checked_add(coeffs.len()) + .ok_or(HachiError::InvalidProof)?; + let mut covered = 0usize; + + for segment in segments { + let seg_start = segment.src_start; + let seg_end = seg_start + segment.len; + let start = row_offset.max(seg_start); + let end = row_end.min(seg_end); + if start >= end { + continue; + } + + let coeff_start = start - row_offset; + let dst_start = segment.dst_start + (start - seg_start); + let weight = alpha.mul_by_sparse(&challenges[segment.challenge_idx]); + cfg_iter_mut!(amortized_phi[dst_start..dst_start + (end - start)]) + .zip(cfg_iter!(coeffs[coeff_start..coeff_start + (end - start)])) + .for_each(|(dst, src)| weight.mul_accumulate_into(src, dst)); + covered += end - start; + } + + if covered != coeffs.len() { + return Err(HachiError::InvalidProof); + } + Ok(()) +} + +fn accumulate_point_into_amortized_phi( + amortized_phi: &mut [CyclotomicRing], + segments: &[ReshapeCombineSegment], + challenges: &[SparseChallenge], + position: usize, + value: &CyclotomicRing, +) -> Result<(), HachiError> { + for segment in segments { + let seg_end = segment.src_start + segment.len; + if !(segment.src_start..seg_end).contains(&position) { + continue; + } + + let dst_idx = segment.dst_start + (position - segment.src_start); + value.mul_by_sparse_into( + &challenges[segment.challenge_idx], + &mut amortized_phi[dst_idx], + ); + return Ok(()); + } + Err(HachiError::InvalidProof) +} + +fn combine_virtual_rows( + rows: &[Vec>], + challenges: &[SparseChallenge], + virtual_row_len: usize, +) -> Result>, HachiError> { + if rows.len() != challenges.len() { + return Err(HachiError::InvalidProof); + } + + let mut combined = vec![CyclotomicRing::::zero(); virtual_row_len]; + for (row, challenge) in rows.iter().zip(challenges.iter()) { + if row.len() != virtual_row_len { + return Err(HachiError::InvalidProof); + } + for (dst, src) in combined.iter_mut().zip(row.iter()) { + src.mul_by_sparse_into(challenge, dst); + } + } + Ok(combined) +} + +fn combine_flat_rows( + rows_flat: &[CyclotomicRing], + challenges: &[SparseChallenge], + virtual_row_len: usize, +) -> Result>, HachiError> { + if rows_flat.len() != challenges.len() * virtual_row_len { + return Err(HachiError::InvalidProof); + } + + let mut combined = vec![CyclotomicRing::::zero(); virtual_row_len]; + for (row, challenge) in rows_flat.chunks(virtual_row_len).zip(challenges.iter()) { + for (dst, src) in combined.iter_mut().zip(row.iter()) { + src.mul_by_sparse_into(challenge, dst); + } + } + Ok(combined) +} + +fn add_amortized_phi_in_place( + dst: &mut [CyclotomicRing], + src: &[CyclotomicRing], +) -> Result<(), HachiError> { + if dst.len() != src.len() { + return Err(HachiError::InvalidProof); + } + cfg_iter_mut!(dst) + .zip(cfg_iter!(src)) + .for_each(|(dst_elem, src_elem)| *dst_elem += *src_elem); + Ok(()) +} + +#[tracing::instrument(skip_all, name = "labrador::finalize_reduced_statement_aggregation")] +fn finalize_reduced_statement_aggregation( + plan: &LabradorReducedConstraintPlan, + aggregation: &ReducedStatementAggregationReplay, + row_lengths: &[usize], + row_split_counts: &[usize], + virtual_row_len: usize, + challenges: &[SparseChallenge], +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, +{ + let layout = NextWitnessLayout::new(plan.row_count, &plan.config); + if row_lengths.len() != layout.num_rows() { + return Err(HachiError::InvalidProof); + } + if row_lengths + .iter() + .take(plan.config.witness_digit_parts) + .any(|&len| len != plan.max_len) + || row_lengths[layout.aux_row] != layout.aux_row_len() + { + return Err(HachiError::InvalidProof); + } + + let row_segments = + build_reshape_combine_plan(row_lengths, row_split_counts, virtual_row_len, challenges)?; + let aux_segments = row_segments + .get(layout.aux_row) + .ok_or(HachiError::InvalidProof)?; + let mut amortized_phi = vec![CyclotomicRing::::zero(); virtual_row_len]; + let pow_witness_bits: Vec = (0..plan.config.witness_digit_parts) + .map(|idx| pow2_field::(plan.config.witness_digit_bits * idx)) + .collect(); + let pow_aux_bits: Vec = (0..plan.config.aux_digit_parts) + .map(|idx| pow2_field::(plan.config.aux_digit_bits * idx)) + .collect(); + let inner_opening_start = layout.inner_opening_digits_range().start; + let linear_garbage_start = layout.linear_garbage_digits_range().start; + + for (alpha, b_row) in aggregation.b_alphas.iter().zip(plan.setup.b_mat.iter()) { + accumulate_row_slice_into_amortized_phi( + &mut amortized_phi, + aux_segments, + challenges, + inner_opening_start, + b_row, + alpha, + )?; + } + for (alpha, d_row) in aggregation.d_alphas.iter().zip(plan.setup.d_mat.iter()) { + accumulate_row_slice_into_amortized_phi( + &mut amortized_phi, + aux_segments, + challenges, + linear_garbage_start, + d_row, + alpha, + )?; + } + + for (output_idx, alpha) in aggregation.a_alphas.iter().enumerate() { + let a_row = &plan.setup.a_mat[output_idx]; + for (part_idx, &scale) in pow_witness_bits.iter().enumerate() { + let scaled_alpha = alpha.scale(&scale); + accumulate_row_slice_into_amortized_phi( + &mut amortized_phi, + &row_segments[part_idx], + challenges, + 0, + a_row, + &scaled_alpha, + )?; + } + + for (row_idx, challenge) in plan.challenges.iter().enumerate() { + let base = alpha.mul_by_sparse(challenge); + for (part_idx, &scale) in pow_aux_bits.iter().enumerate() { + let idx = inner_opening_start + + row_idx * plan.config.inner_commit_rank * plan.config.aux_digit_parts + + output_idx * plan.config.aux_digit_parts + + part_idx; + let value = -(base.scale(&scale)); + accumulate_point_into_amortized_phi( + &mut amortized_phi, + aux_segments, + challenges, + idx, + &value, + )?; + } + } + } + + for (part_idx, &scale) in pow_witness_bits.iter().enumerate() { + let scaled_alpha = aggregation.alpha_lg.scale(&scale); + accumulate_row_slice_into_amortized_phi( + &mut amortized_phi, + &row_segments[part_idx], + challenges, + 0, + &plan.amortized_phi, + &scaled_alpha, + )?; + } + for i in 0..plan.challenges.len() { + for j in i..plan.challenges.len() { + let base = aggregation + .alpha_lg + .mul_by_sparse(&plan.challenges[i]) + .mul_by_sparse(&plan.challenges[j]); + let pair = pair_index(i, j, plan.challenges.len()); + for (part_idx, &scale) in pow_aux_bits.iter().enumerate() { + let idx = linear_garbage_start + pair * plan.config.aux_digit_parts + part_idx; + let value = -(base.scale(&scale)); + accumulate_point_into_amortized_phi( + &mut amortized_phi, + aux_segments, + challenges, + idx, + &value, + )?; + } + } + } + + for i in 0..plan.row_count { + let pair = pair_index(i, i, plan.row_count); + for (part_idx, &scale) in pow_aux_bits.iter().enumerate() { + let idx = linear_garbage_start + pair * plan.config.aux_digit_parts + part_idx; + let value = aggregation.alpha_diag.scale(&scale); + accumulate_point_into_amortized_phi( + &mut amortized_phi, + aux_segments, + challenges, + idx, + &value, + )?; + } + } + + Ok(amortized_phi) +} + +#[tracing::instrument(skip_all, name = "labrador::reshape_phi_verifier")] +fn reshape_phi_verifier( + phi: &[Vec>], + row_lengths: &[usize], + row_split_counts: &[usize], + virtual_row_len: usize, +) -> Result>>, HachiError> { + let virtual_row_count = + validate_reshape_metadata(row_lengths, row_split_counts, virtual_row_len)?; + let mut result = Vec::new(); + let mut group: Vec> = Vec::new(); + + for (i, row) in phi.iter().enumerate() { + if i >= row_lengths.len() || row.len() != row_lengths[i] { + return Err(HachiError::InvalidProof); + } + group.extend(row.iter().copied()); + let splits = if i < row_split_counts.len() { + row_split_counts[i] + } else { + 0 + }; + if splits > 0 { + if group.len() > splits * virtual_row_len { + return Err(HachiError::InvalidProof); + } + for chunk_idx in 0..splits { + let start = chunk_idx * virtual_row_len; + let mut virtual_row = vec![CyclotomicRing::::zero(); virtual_row_len]; + for (j, val) in group.iter().enumerate().skip(start).take(virtual_row_len) { + virtual_row[j - start] = *val; + } + result.push(virtual_row); + } + group.clear(); + } + } + if !group.is_empty() || result.len() != virtual_row_count { + return Err(HachiError::InvalidProof); + } + Ok(result) +} + +#[tracing::instrument(skip_all, name = "labrador::replay_amortize_challenges")] +fn replay_amortize_challenges( + transcript: &mut T, + rows: usize, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + challenge_sparse_ring_elements_rejection_sampled::( + transcript, + labels::CHALLENGE_LABRADOR_AMORTIZE, + rows, + ) +} + +#[allow(clippy::too_many_lines, clippy::too_many_arguments)] +#[tracing::instrument( + skip_all, + name = "labrador::verify_tail_level", + fields(level_index, tail = level.tail) +)] +fn verify_tail_level( + statement: &LabradorStatement, + level: &LabradorLevelProof, + witness: &LabradorWitness, + comkey_seed: &LabradorComKeySeed, + transcript: &mut T, + level_index: usize, +) -> Result<(), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + T: Transcript, +{ + let virtual_row_len = level.virtual_row_len; + let virtual_row_count = validate_level_shape(level, true)?; + + if witness.rows().len() != level.config.witness_digit_parts { + return Err(HachiError::InvalidProof); + } + for row in witness.rows().iter() { + if row.len() != virtual_row_len { + return Err(HachiError::InvalidProof); + } + } + + let inner_opening_digits_len = + virtual_row_count * level.config.inner_commit_rank * level.config.aux_digit_parts; + let linear_garbage_digits_len = + virtual_row_count * (virtual_row_count + 1) / 2 * level.config.aux_digit_parts; + if level.inner_opening_payload.len() != inner_opening_digits_len + || level.linear_garbage_payload.len() != linear_garbage_digits_len + { + return Err(HachiError::InvalidProof); + } + let inner_opening_digits = &level.inner_opening_payload; + let linear_garbage_digits = &level.linear_garbage_payload; + + absorb_labrador_level_context( + transcript, + &LabradorLevelTranscriptContext { + level_index, + tail: level.tail, + input_row_lengths: level.input_row_lengths.clone(), + witness_digit_parts: level.config.witness_digit_parts, + witness_digit_bits: level.config.witness_digit_bits, + aux_digit_parts: level.config.aux_digit_parts, + aux_digit_bits: level.config.aux_digit_bits, + inner_commit_rank: level.config.inner_commit_rank, + outer_commit_rank: level.config.outer_commit_rank, + }, + )?; + + transcript.append_serde( + labels::ABSORB_LABRADOR_INNER_OPENING_PAYLOAD, + &level.inner_opening_payload, + ); + + let virt_total_len = virtual_row_count * virtual_row_len; + let jl_cols = virt_total_len * D; + let jl_matrix = + LabradorJlMatrix::replay_nonce_search::(transcript, level.jl_nonce, jl_cols)?; + + absorb_labrador_jl_projection(transcript, &level.jl_projection); + + let virt_row_lengths = vec![virtual_row_len; virtual_row_count]; + let (phi_jl_flat, b_jl) = aggregate_jl_constraints_verifier( + &virt_row_lengths, + &level.jl_projection, + &jl_matrix, + &level.jl_lift_residuals, + transcript, + )?; + + let explicit_aggregation = if statement.reduced_constraints.is_none() { + Some(aggregate_statement( + statement, + &level.input_row_lengths, + transcript, + )?) + } else { + None + }; + let reduced_aggregation = statement + .reduced_constraints + .as_deref() + .map(|plan| prepare_reduced_statement_aggregation(statement, plan, transcript)) + .transpose()?; + + transcript.append_serde( + labels::ABSORB_LABRADOR_LINEAR_GARBAGE_PAYLOAD, + &level.linear_garbage_payload, + ); + let challenges = replay_amortize_challenges::(transcript, virtual_row_count)?; + tracing::debug!( + level_index, + tail = true, + ?challenges, + "labrador verifier amortize challenges" + ); + let amortized_phi = if let Some((phi_stmt_orig, _)) = explicit_aggregation.as_ref() { + let phi_stmt = reshape_phi_verifier::( + phi_stmt_orig, + &level.input_row_lengths, + &level.row_split_counts, + virtual_row_len, + )?; + let mut phi_total = phi_stmt; + add_phi_flat_in_place(&mut phi_total, &phi_jl_flat)?; + combine_virtual_rows(&phi_total, &challenges, virtual_row_len)? + } else { + let plan = statement + .reduced_constraints + .as_deref() + .ok_or(HachiError::InvalidProof)?; + let aggregation = reduced_aggregation + .as_ref() + .ok_or(HachiError::InvalidProof)?; + let mut amortized_phi = finalize_reduced_statement_aggregation( + plan, + aggregation, + &level.input_row_lengths, + &level.row_split_counts, + virtual_row_len, + &challenges, + )?; + let amortized_phi_jl = combine_flat_rows(&phi_jl_flat, &challenges, virtual_row_len)?; + add_amortized_phi_in_place(&mut amortized_phi, &amortized_phi_jl)?; + amortized_phi + }; + let b_stmt = if let Some((_, b_stmt)) = explicit_aggregation { + b_stmt + } else { + reduced_aggregation + .as_ref() + .ok_or(HachiError::InvalidProof)? + .aggregated_rhs + }; + let aggregated_rhs = b_stmt + b_jl; + + let (computed_norm, proj_norm) = tracing::info_span!("labrador::verify_tail_norms") + .in_scope(|| (witness.norm(), projection_norm_sq(&level.jl_projection))); + if computed_norm > level.next_witness_norm_sq { + return Err(HachiError::InvalidProof); + } + let proj_bound = 256u128.saturating_mul(statement.witness_norm_bound_sq); + if proj_norm > proj_bound { + return Err(HachiError::InvalidProof); + } + + let setup = tracing::info_span!("labrador::verify_tail_setup").in_scope(|| { + LabradorSetupMatrices::new( + &level.config, + virtual_row_count, + virtual_row_len, + comkey_seed, + ) + }); + let witness_i8 = tracing::info_span!("labrador::verify_tail_digit_cache") + .in_scope(|| try_centered_i8_rows(witness.rows())); + let (az, rhs) = tracing::info_span!("labrador::verify_tail_linear_check").in_scope( + || -> Result<_, HachiError> { + let az = mat_vec_mul_decomposed::( + &setup.a_mat, + witness.rows(), + witness_i8.as_deref(), + level.config.witness_digit_bits, + )?; + let rhs = accumulate_decomposed_t_rhs::( + inner_opening_digits, + virtual_row_count, + level.config.inner_commit_rank, + level.config.aux_digit_parts, + level.config.aux_digit_bits as u32, + &challenges, + )?; + Ok((az, rhs)) + }, + )?; + if az != rhs { + return Err(HachiError::InvalidProof); + } + + let (lhs, rhs, diag_sum) = tracing::info_span!("labrador::verify_tail_quadratic_check") + .in_scope(|| -> Result<_, HachiError> { + let lhs = decomposed_dot_product::( + &amortized_phi, + witness.rows(), + witness_i8.as_deref(), + level.config.witness_digit_bits, + )?; + let (rhs, diag_sum) = accumulate_decomposed_h_rhs::( + linear_garbage_digits, + virtual_row_count, + level.config.aux_digit_parts, + level.config.aux_digit_bits as u32, + &challenges, + )?; + Ok((lhs, rhs, diag_sum)) + })?; + if lhs != rhs { + return Err(HachiError::InvalidProof); + } + + if diag_sum - aggregated_rhs != CyclotomicRing::::zero() { + return Err(HachiError::InvalidProof); + } + + Ok(()) +} + +#[allow(clippy::too_many_lines)] +#[allow(dead_code)] +fn verify_single_level( + statement: &LabradorStatement, + level: &LabradorLevelProof, + witness: &LabradorWitness, + comkey_seed: &LabradorComKeySeed, + transcript: &mut T, +) -> Result<(), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt, + T: Transcript, +{ + let virtual_row_len = level.virtual_row_len; + let virtual_row_count = validate_level_shape(level, false)?; + let layout = NextWitnessLayout::new(virtual_row_count, &level.config); + let expected_rows = layout.num_rows(); + if witness.rows().len() != expected_rows { + return Err(HachiError::InvalidProof); + } + for row in witness.rows().iter().take(level.config.witness_digit_parts) { + if row.len() != virtual_row_len { + return Err(HachiError::InvalidProof); + } + } + + let aux = &witness.rows()[layout.aux_row]; + if aux.len() != layout.aux_row_len() { + return Err(HachiError::InvalidProof); + } + let (inner_opening_digits, linear_garbage_digits) = + aux.split_at(layout.inner_opening_digits_len); + + absorb_labrador_level_context( + transcript, + &LabradorLevelTranscriptContext { + level_index: 0, + tail: level.tail, + input_row_lengths: level.input_row_lengths.clone(), + witness_digit_parts: level.config.witness_digit_parts, + witness_digit_bits: level.config.witness_digit_bits, + aux_digit_parts: level.config.aux_digit_parts, + aux_digit_bits: level.config.aux_digit_bits, + inner_commit_rank: level.config.inner_commit_rank, + outer_commit_rank: level.config.outer_commit_rank, + }, + )?; + transcript.append_serde( + labels::ABSORB_LABRADOR_INNER_OPENING_PAYLOAD, + &level.inner_opening_payload, + ); + + let virt_total_len = virtual_row_count * virtual_row_len; + let jl_cols = virt_total_len * D; + let jl_matrix = + LabradorJlMatrix::replay_nonce_search::(transcript, level.jl_nonce, jl_cols)?; + absorb_labrador_jl_projection(transcript, &level.jl_projection); + + let virt_row_lengths = vec![virtual_row_len; virtual_row_count]; + let (phi_jl_flat, b_jl) = aggregate_jl_constraints_verifier( + &virt_row_lengths, + &level.jl_projection, + &jl_matrix, + &level.jl_lift_residuals, + transcript, + )?; + let (phi_stmt_orig, b_stmt) = + aggregate_statement(statement, &level.input_row_lengths, transcript)?; + let phi_stmt = reshape_phi_verifier::( + &phi_stmt_orig, + &level.input_row_lengths, + &level.row_split_counts, + virtual_row_len, + )?; + + let mut phi_total = phi_stmt; + add_phi_flat_in_place(&mut phi_total, &phi_jl_flat)?; + let aggregated_rhs = b_stmt + b_jl; + + transcript.append_serde( + labels::ABSORB_LABRADOR_LINEAR_GARBAGE_PAYLOAD, + &level.linear_garbage_payload, + ); + let challenges = replay_amortize_challenges::(transcript, virtual_row_count)?; + + let z_parts: Vec>> = witness + .rows() + .iter() + .take(level.config.witness_digit_parts) + .cloned() + .collect(); + let z = recompose_from_parts(&z_parts, level.config.witness_digit_bits as u32)?; + + let t_flat = recompose_flat( + inner_opening_digits, + level.config.aux_digit_parts, + level.config.aux_digit_bits as u32, + )?; + let h_flat = recompose_flat( + linear_garbage_digits, + level.config.aux_digit_parts, + level.config.aux_digit_bits as u32, + )?; + if t_flat.len() != virtual_row_count * level.config.inner_commit_rank { + return Err(HachiError::InvalidProof); + } + if h_flat.len() != virtual_row_count * (virtual_row_count + 1) / 2 { + return Err(HachiError::InvalidProof); + } + let mut t_by_row = Vec::with_capacity(virtual_row_count); + for chunk in t_flat.chunks(level.config.inner_commit_rank) { + t_by_row.push(chunk.to_vec()); + } + + if !statement.inner_opening_payload.is_empty() + && statement.inner_opening_payload != level.inner_opening_payload + { + return Err(HachiError::InvalidProof); + } + if !statement.linear_garbage_payload.is_empty() + && statement.linear_garbage_payload != level.linear_garbage_payload + { + return Err(HachiError::InvalidProof); + } + + let setup = LabradorSetupMatrices::new( + &level.config, + virtual_row_count, + virtual_row_len, + comkey_seed, + ); + + if level.config.outer_commit_rank > 0 { + let inner_opening_payload_check = mat_vec_mul(&setup.b_mat, inner_opening_digits); + if inner_opening_payload_check != level.inner_opening_payload { + return Err(HachiError::InvalidProof); + } + let linear_garbage_payload_check = mat_vec_mul(&setup.d_mat, linear_garbage_digits); + if linear_garbage_payload_check != level.linear_garbage_payload { + return Err(HachiError::InvalidProof); + } + } else { + if level.inner_opening_payload != inner_opening_digits { + return Err(HachiError::InvalidProof); + } + if level.linear_garbage_payload != linear_garbage_digits { + return Err(HachiError::InvalidProof); + } + } + + let computed_norm = witness.norm(); + if computed_norm > level.next_witness_norm_sq { + return Err(HachiError::InvalidProof); + } + + if projection_norm_sq(&level.jl_projection) + > 256u128.saturating_mul(statement.witness_norm_bound_sq) + { + return Err(HachiError::InvalidProof); + } + + let az = mat_vec_mul(&setup.a_mat, &z); + let mut rhs = vec![CyclotomicRing::::zero(); level.config.inner_commit_rank]; + for (i, t_row) in t_by_row.iter().enumerate() { + for k in 0..level.config.inner_commit_rank { + t_row[k].mul_by_sparse_into(&challenges[i], &mut rhs[k]); + } + } + if az != rhs { + return Err(HachiError::InvalidProof); + } + + let mut amortized_phi = vec![CyclotomicRing::::zero(); virtual_row_len]; + for (i, phi_row) in phi_total.iter().enumerate() { + for (j, elem) in phi_row.iter().enumerate() { + elem.mul_by_sparse_into(&challenges[i], &mut amortized_phi[j]); + } + } + let lhs = dot_product(&amortized_phi, &z); + let mut rhs = CyclotomicRing::::zero(); + let mut idx = 0usize; + for i in 0..virtual_row_count { + for j in i..virtual_row_count { + rhs += h_flat[idx] + .mul_by_sparse(&challenges[i]) + .mul_by_sparse(&challenges[j]); + idx += 1; + } + } + if lhs != rhs { + return Err(HachiError::InvalidProof); + } + + let mut diag_sum = CyclotomicRing::::zero(); + for i in 0..virtual_row_count { + let idx = pair_index(i, i, virtual_row_count); + diag_sum += h_flat[idx]; + } + if diag_sum - aggregated_rhs != CyclotomicRing::::zero() { + return Err(HachiError::InvalidProof); + } + + Ok(()) +} +fn projection_norm_sq(projection: &[i64; 256]) -> u128 { + projection.iter().fold(0u128, |acc, &v| { + let x = v as i128; + let sq = x * x; + acc.saturating_add(sq as u128) + }) +} + +#[tracing::instrument(skip_all, name = "labrador::validate_level_shape")] +fn validate_level_shape( + level: &LabradorLevelProof, + expect_tail: bool, +) -> Result { + if level.tail != expect_tail || level.config.tail != expect_tail { + return Err(HachiError::InvalidProof); + } + if level.config.witness_digit_parts == 0 || level.config.aux_digit_parts == 0 { + return Err(HachiError::InvalidProof); + } + if expect_tail { + if level.config.outer_commit_rank != 0 { + return Err(HachiError::InvalidProof); + } + } else if level.config.outer_commit_rank == 0 { + return Err(HachiError::InvalidProof); + } + validate_reshape_metadata( + &level.input_row_lengths, + &level.row_split_counts, + level.virtual_row_len, + ) +} + +fn validate_reshape_metadata( + row_lengths: &[usize], + row_split_counts: &[usize], + virtual_row_len: usize, +) -> Result { + if row_lengths.is_empty() || row_split_counts.len() != row_lengths.len() || virtual_row_len == 0 + { + return Err(HachiError::InvalidProof); + } + + let mut virtual_row_count = 0usize; + let mut grouped_len = 0usize; + for (&row_len, &splits) in row_lengths.iter().zip(row_split_counts.iter()) { + grouped_len = grouped_len + .checked_add(row_len) + .ok_or(HachiError::InvalidProof)?; + if splits > 0 { + let capacity = splits + .checked_mul(virtual_row_len) + .ok_or(HachiError::InvalidProof)?; + if grouped_len > capacity { + return Err(HachiError::InvalidProof); + } + virtual_row_count = virtual_row_count + .checked_add(splits) + .ok_or(HachiError::InvalidProof)?; + grouped_len = 0; + } + } + + if grouped_len != 0 || virtual_row_count == 0 { + return Err(HachiError::InvalidProof); + } + + Ok(virtual_row_count) +} + +#[tracing::instrument(skip_all, name = "labrador::mat_vec_mul_decomposed")] +fn mat_vec_mul_decomposed( + matrix: &[Vec>], + parts: &[Vec>], + parts_i8: Option<&[Vec<[i8; D]>]>, + log_basis: usize, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, +{ + if parts.is_empty() { + return Err(HachiError::InvalidProof); + } + + if let Some(parts_i8) = parts_i8 { + if let Ok(images) = mat_vec_mul_crt_ntt_i8_many(matrix, parts_i8) { + let mut acc = vec![CyclotomicRing::::zero(); matrix.len()]; + for (part_idx, image) in images.into_iter().enumerate() { + let scale = pow2_field::(part_idx * log_basis); + for (dst, src) in acc.iter_mut().zip(image.iter()) { + *dst += src.scale(&scale); + } + } + return Ok(acc); + } + } + + let mut acc = vec![CyclotomicRing::::zero(); matrix.len()]; + for (part_idx, part) in parts.iter().enumerate() { + let image = mat_vec_mul(matrix, part); + let scale = pow2_field::(part_idx * log_basis); + for (dst, src) in acc.iter_mut().zip(image.iter()) { + *dst += src.scale(&scale); + } + } + Ok(acc) +} + +#[tracing::instrument(skip_all, name = "labrador::decomposed_dot_product")] +fn decomposed_dot_product( + lhs: &[CyclotomicRing], + parts: &[Vec>], + parts_i8: Option<&[Vec<[i8; D]>]>, + log_basis: usize, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, +{ + if parts.is_empty() { + return Err(HachiError::InvalidProof); + } + + if let Some(parts_i8) = parts_i8 { + if let Ok(images) = mat_vec_mul_crt_ntt_i8_many(&[lhs.to_vec()], parts_i8) { + let mut acc = CyclotomicRing::::zero(); + for (part_idx, image) in images.into_iter().enumerate() { + let scale = pow2_field::(part_idx * log_basis); + let value = image.into_iter().next().ok_or(HachiError::InvalidProof)?; + acc += value.scale(&scale); + } + return Ok(acc); + } + } + + let mut acc = CyclotomicRing::::zero(); + for (part_idx, part) in parts.iter().enumerate() { + if part.len() != lhs.len() { + return Err(HachiError::InvalidProof); + } + let scale = pow2_field::(part_idx * log_basis); + acc += dot_product(lhs, part).scale(&scale); + } + Ok(acc) +} + +fn recompose_digit_chunk( + flat: &[CyclotomicRing], + index: usize, + parts: usize, + log_basis: u32, +) -> Result, HachiError> { + let start = index.checked_mul(parts).ok_or(HachiError::InvalidProof)?; + let end = start.checked_add(parts).ok_or(HachiError::InvalidProof)?; + if end > flat.len() { + return Err(HachiError::InvalidProof); + } + Ok(CyclotomicRing::gadget_recompose_pow2( + &flat[start..end], + log_basis, + )) +} + +#[tracing::instrument(skip_all, name = "labrador::accumulate_decomposed_t_rhs")] +fn accumulate_decomposed_t_rhs( + inner_opening_digits: &[CyclotomicRing], + virtual_row_count: usize, + inner_commit_rank: usize, + parts: usize, + log_basis: u32, + challenges: &[SparseChallenge], +) -> Result>, HachiError> { + if challenges.len() != virtual_row_count + || inner_opening_digits.len() != virtual_row_count * inner_commit_rank * parts + { + return Err(HachiError::InvalidProof); + } + let mut rhs = vec![CyclotomicRing::::zero(); inner_commit_rank]; + for (row_idx, challenge) in challenges.iter().enumerate() { + for (k, rhs_k) in rhs.iter_mut().enumerate() { + let inner_opening = recompose_digit_chunk( + inner_opening_digits, + row_idx * inner_commit_rank + k, + parts, + log_basis, + )?; + inner_opening.mul_by_sparse_into(challenge, rhs_k); + } + } + Ok(rhs) +} + +#[tracing::instrument(skip_all, name = "labrador::accumulate_decomposed_h_rhs")] +fn accumulate_decomposed_h_rhs( + linear_garbage_digits: &[CyclotomicRing], + virtual_row_count: usize, + parts: usize, + log_basis: u32, + challenges: &[SparseChallenge], +) -> Result<(CyclotomicRing, CyclotomicRing), HachiError> { + let pair_count = virtual_row_count + .checked_mul(virtual_row_count + 1) + .and_then(|v| v.checked_div(2)) + .ok_or(HachiError::InvalidProof)?; + if challenges.len() != virtual_row_count || linear_garbage_digits.len() != pair_count * parts { + return Err(HachiError::InvalidProof); + } + let mut rhs = CyclotomicRing::::zero(); + let mut diag_sum = CyclotomicRing::::zero(); + for i in 0..virtual_row_count { + for j in i..virtual_row_count { + let idx = pair_index(i, j, virtual_row_count); + let linear_garbage = + recompose_digit_chunk(linear_garbage_digits, idx, parts, log_basis)?; + rhs += linear_garbage + .mul_by_sparse(&challenges[i]) + .mul_by_sparse(&challenges[j]); + if i == j { + diag_sum += linear_garbage; + } + } + } + Ok((rhs, diag_sum)) +} + +#[tracing::instrument(skip_all, name = "labrador::recompose_from_parts")] +fn recompose_from_parts( + parts: &[Vec>], + log_basis: u32, +) -> Result>, HachiError> { + if parts.is_empty() { + return Err(HachiError::InvalidProof); + } + let len = parts[0].len(); + for row in parts.iter().skip(1) { + if row.len() != len { + return Err(HachiError::InvalidProof); + } + } + let mut out = Vec::with_capacity(len); + for idx in 0..len { + let mut slice = Vec::with_capacity(parts.len()); + for part in parts { + slice.push(part[idx]); + } + out.push(CyclotomicRing::gadget_recompose_pow2(&slice, log_basis)); + } + Ok(out) +} + +#[tracing::instrument(skip_all, name = "labrador::recompose_flat")] +fn recompose_flat( + flat: &[CyclotomicRing], + parts: usize, + log_basis: u32, +) -> Result>, HachiError> { + if parts == 0 || flat.len() % parts != 0 { + return Err(HachiError::InvalidProof); + } + let mut out = Vec::with_capacity(flat.len() / parts); + for chunk in flat.chunks(parts) { + out.push(CyclotomicRing::gadget_recompose_pow2(chunk, log_basis)); + } + Ok(out) +} + +#[tracing::instrument(skip_all, name = "labrador::add_phi_flat_in_place_verifier")] +fn add_phi_flat_in_place( + acc: &mut [Vec>], + other_flat: &[CyclotomicRing], +) -> Result<(), HachiError> { + let mut cursor = 0usize; + for row_acc in acc.iter_mut() { + let end = cursor + row_acc.len(); + if end > other_flat.len() { + return Err(HachiError::InvalidProof); + } + for (a, b) in row_acc.iter_mut().zip(other_flat[cursor..end].iter()) { + *a += *b; + } + cursor = end; + } + if cursor != other_flat.len() { + return Err(HachiError::InvalidProof); + } + Ok(()) +} + +fn dot_product( + lhs: &[CyclotomicRing], + rhs: &[CyclotomicRing], +) -> CyclotomicRing { + let mut acc = CyclotomicRing::::zero(); + let len = lhs.len().min(rhs.len()); + for i in 0..len { + acc += lhs[i] * rhs[i]; + } + acc +} + +#[tracing::instrument(skip_all, name = "labrador::verify_constraints")] +fn verify_constraints( + constraints: &[LabradorConstraint], + witness: &LabradorWitness, +) -> Result<(), HachiError> { + for (idx, cnst) in constraints.iter().enumerate() { + let mut lhs = CyclotomicRing::::zero(); + + for term in &cnst.terms { + if term.row >= witness.rows().len() { + return Err(HachiError::InvalidProof); + } + let row = &witness.rows()[term.row]; + if term.offset + term.coefficients.len() > row.len() { + return Err(HachiError::InvalidProof); + } + for (j, coeff) in term.coefficients.iter().enumerate() { + lhs += *coeff * row[term.offset + j]; + } + } + + if lhs != cnst.target { + return Err(HachiError::InvalidInput(format!( + "Labrador constraint {idx} not satisfied" + ))); + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::fields::Fp64; + use crate::algebra::ring::CyclotomicRing; + use crate::protocol::labrador::LabradorConstraintTerm; + use crate::protocol::transcript::labels::DOMAIN_LABRADOR_RECURSION; + use crate::protocol::transcript::Blake2bTranscript; + use crate::FromSmallInt; + + type F = Fp64<4294967197>; + const D: usize = 64; + + #[test] + fn verify_accepts_basic_linear_constraint() { + let row = vec![CyclotomicRing::::from_coefficients( + std::array::from_fn(|i| if i == 0 { F::from_i64(3) } else { F::zero() }), + )]; + let witness = LabradorWitness::new(vec![row.clone()]); + let coeff = vec![CyclotomicRing::one()]; + let target = CyclotomicRing::::from_coefficients(std::array::from_fn(|i| { + if i == 0 { + F::from_i64(3) + } else { + F::zero() + } + })); + let constraint = + LabradorConstraint::new(vec![LabradorConstraintTerm::new(0, 0, coeff)], target); + let statement = LabradorStatement { + inner_opening_payload: Vec::new(), + linear_garbage_payload: Vec::new(), + challenges: Vec::new(), + constraints: vec![constraint], + reduced_constraints: None, + witness_norm_bound_sq: 1000, + }; + let proof = LabradorProof { + levels: Vec::new(), + final_opening_witness: witness.clone(), + }; + let mut transcript = Blake2bTranscript::::new(DOMAIN_LABRADOR_RECURSION); + let out = verify(&statement, &proof, &[1u8; 32], &mut transcript).unwrap(); + assert_eq!(out.final_opening_witness, witness); + } +} diff --git a/src/protocol/labrador_handoff.rs b/src/protocol/labrador_handoff.rs new file mode 100644 index 00000000..96a102f6 --- /dev/null +++ b/src/protocol/labrador_handoff.rs @@ -0,0 +1,761 @@ +//! Direct Labrador handoff from Hachi's quadratic equation. +//! +//! Instead of computing the quotient `r`, evaluating at a random `alpha`, and +//! running sumcheck, this module converts the ring-level relation `Mz = y` +//! directly into Labrador constraints. The witness `w` is +//! `[w_hat | inner_opening_digits | z_pre]` +//! with no quotient portion. + +use crate::algebra::ring::CyclotomicRing; +use crate::algebra::SparseChallenge; +use crate::error::HachiError; +use crate::primitives::poly::multilinear_lagrange_basis; +use crate::primitives::serialization::{Compress, Valid}; +use crate::protocol::commitment::transcript_append::AppendToTranscript; +use crate::protocol::commitment::utils::crt_ntt::NttSlotCache; +use crate::protocol::commitment::utils::flat_matrix::FlatMatrix; +use crate::protocol::commitment::utils::linear::flatten_i8_blocks; +use crate::protocol::commitment::{ + CommitmentConfig, HachiCommitmentLayout, HachiExpandedSetup, HachiScheduleInputs, + RingCommitment, +}; +use crate::protocol::commitment_scheme::next_level_opening_point; +use crate::protocol::hachi_poly_ops::{BalancedDigitPoly, HachiPolyOps}; +use crate::protocol::labrador::config::{ + estimate_handoff_recursive_proof, logq_bits, LabradorRecursiveSizeEstimate, +}; +use crate::protocol::labrador::types::{LabradorStatement, LabradorWitness}; +use crate::protocol::labrador::{ + prove_with_plan, verify as verify_labrador, LabradorConstraint, LabradorConstraintTerm, +}; +use crate::protocol::opening_point::{ring_opening_point_from_field, BasisMode, RingOpeningPoint}; +use crate::protocol::proof::{ + FlatLabradorProof, FlatLabradorWitness, FlatRingVec, HachiCommitmentHint, HachiProofTail, + LabradorTail, PackedDigits, +}; +use crate::protocol::quadratic_equation::QuadraticEquation; +use crate::protocol::ring_switch::WCommitmentConfig; +use crate::protocol::transcript::labels::{ABSORB_COMMITMENT, ABSORB_EVALUATION_CLAIMS}; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, FieldSampling, FromSmallInt, HachiSerialize}; +use std::time::Instant; + +/// Build Labrador constraints that encode the ring-level relation `Mz = y` from +/// the Hachi quadratic equation. +/// +/// Witness layout (3 rows): +/// row 0: `w_hat_flat` — `depth_open * num_blocks` ring elements +/// row 1: `t_hat_flat` — `depth_open * N_A * num_blocks` ring elements +/// row 2: `z_pre_decomp` — `depth_fold * inner_width` ring elements +/// +/// Constraint rows (all ring-level, no alpha evaluation): +/// - N_D constraints: `D_mat * w_hat_flat = v` +/// - N_B constraints: `B_mat * t_hat_flat = u` +/// - 1 constraint: `b^T * G_open * w_hat = y_eval` +/// - 1 constraint: `c^T * G_open * w_hat - a^T * G_commit * J * z_pre = 0` +/// - N_A constraints: `c^T * G_open * t_hat_slice - A * J * z_pre = 0` +#[allow(clippy::too_many_arguments)] +#[tracing::instrument(skip_all, name = "labrador::handoff_build_constraints")] +pub(crate) fn build_hachi_labrador_constraints( + a_mat: &FlatMatrix, + b_mat: &FlatMatrix, + d_mat: &FlatMatrix, + opening_point: &RingOpeningPoint, + challenges: &[SparseChallenge], + v: &[CyclotomicRing], + u: &[CyclotomicRing], + y_eval: &CyclotomicRing, + layout: HachiCommitmentLayout, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, +{ + let depth_open = layout.num_digits_open; + let depth_commit = layout.num_digits_commit; + let depth_fold = layout.num_digits_fold; + let log_basis = layout.log_basis; + let num_blocks = opening_point.b.len(); + let block_len = layout.block_len; + let inner_width = block_len * depth_commit; + + let w_len = depth_open * num_blocks; + let t_len = depth_open * Cfg::N_A * num_blocks; + let z_len = depth_fold * inner_width; + + let g_open = gadget_scalars::(depth_open, log_basis); + let g_commit = gadget_scalars::(depth_commit, log_basis); + let j_fold = gadget_scalars::(depth_fold, log_basis); + + let scalar_ring = + |s: F| -> CyclotomicRing { + CyclotomicRing::from_coefficients(std::array::from_fn(|k| { + if k == 0 { + s + } else { + F::zero() + } + })) + }; + + let dense_challenges: Vec> = challenges + .iter() + .map(|c| c.to_dense::().expect("valid challenge")) + .collect(); + + let mut constraints = Vec::with_capacity(Cfg::N_D + Cfg::N_B + 2 + Cfg::N_A); + + // D rows enforce `D_mat * w_hat_flat = v`. + let d_view = d_mat.view::(); + for (i, &v_i) in v.iter().enumerate().take(Cfg::N_D) { + let d_row = d_view.row(i); + let coeffs: Vec> = d_row.iter().take(w_len).copied().collect(); + constraints.push(LabradorConstraint::new( + vec![LabradorConstraintTerm::new(0, 0, coeffs)], + v_i, + )); + } + + // B rows enforce `B_mat * t_hat_flat = u`. + let b_view = b_mat.view::(); + for (i, &u_i) in u.iter().enumerate().take(Cfg::N_B) { + let b_row = b_view.row(i); + let coeffs: Vec> = b_row.iter().take(t_len).copied().collect(); + constraints.push(LabradorConstraint::new( + vec![LabradorConstraintTerm::new(1, 0, coeffs)], + u_i, + )); + } + + // This row enforces the opening evaluation claim. + { + let mut phi_w = vec![CyclotomicRing::::zero(); w_len]; + for (i, &b_i) in opening_point.b.iter().enumerate() { + for (d, &g) in g_open.iter().enumerate() { + phi_w[i * depth_open + d] = scalar_ring(b_i * g); + } + } + constraints.push(LabradorConstraint::new( + vec![LabradorConstraintTerm::new(0, 0, phi_w)], + *y_eval, + )); + } + + // This row ties the folded witness to the pre-handoff `z` decomposition. + { + let mut phi_w = vec![CyclotomicRing::::zero(); w_len]; + for (i, c_i) in dense_challenges.iter().enumerate() { + for (d, &g) in g_open.iter().enumerate() { + phi_w[i * depth_open + d] = c_i.scale(&g); + } + } + + let mut phi_z = vec![CyclotomicRing::::zero(); z_len]; + for (i, &a_i) in opening_point.a.iter().enumerate() { + for (d, &g) in g_commit.iter().enumerate() { + let ag = a_i * g; + for (t, &j) in j_fold.iter().enumerate() { + let idx = (i * depth_commit + d) * depth_fold + t; + phi_z[idx] = scalar_ring(-(ag * j)); + } + } + } + constraints.push(LabradorConstraint::new( + vec![ + LabradorConstraintTerm::new(0, 0, phi_w), + LabradorConstraintTerm::new(2, 0, phi_z), + ], + CyclotomicRing::::zero(), + )); + } + + // A rows link the folded inner openings back to the same `z` decomposition. + let a_view = a_mat.view::(); + for a_idx in 0..Cfg::N_A { + let mut phi_t = vec![CyclotomicRing::::zero(); t_len]; + for (i, c_i) in dense_challenges.iter().enumerate() { + for (d, &g) in g_open.iter().enumerate() { + let t_idx = i * (Cfg::N_A * depth_open) + a_idx * depth_open + d; + phi_t[t_idx] = c_i.scale(&g); + } + } + + let mut phi_z = vec![CyclotomicRing::::zero(); z_len]; + let a_row = a_view.row(a_idx); + for (k, &a_ring) in a_row.iter().take(inner_width).enumerate() { + for (t, &j) in j_fold.iter().enumerate() { + phi_z[k * depth_fold + t] = -(a_ring.scale(&j)); + } + } + + constraints.push(LabradorConstraint::new( + vec![ + LabradorConstraintTerm::new(1, 0, phi_t), + LabradorConstraintTerm::new(2, 0, phi_z), + ], + CyclotomicRing::::zero(), + )); + } + + Ok(constraints) +} + +/// Assemble the Labrador witness from the quad-eq prover state. +/// +/// Converts i8-digit planes to ring elements and decomposes `z_pre`. +#[tracing::instrument(skip_all, name = "labrador::handoff_build_witness")] +pub(crate) fn build_labrador_witness( + w_hat_flat: &[[i8; D]], + t_hat_flat: &[[i8; D]], + z_pre: &[CyclotomicRing], + layout: HachiCommitmentLayout, +) -> LabradorWitness +where + F: FieldCore + CanonicalField, +{ + let depth_fold = layout.num_digits_fold; + let log_basis = layout.log_basis; + + let to_ring = |digits: &[i8; D]| -> CyclotomicRing { + CyclotomicRing::from_coefficients(std::array::from_fn(|k| F::from_i64(digits[k] as i64))) + }; + + let row0: Vec> = w_hat_flat.iter().map(to_ring).collect(); + let row1: Vec> = t_hat_flat.iter().map(to_ring).collect(); + + let mut row2 = Vec::with_capacity(z_pre.len() * depth_fold); + for z_j in z_pre { + for plane in z_j.balanced_decompose_pow2_i8(depth_fold, log_basis) { + row2.push(to_ring(&plane)); + } + } + + LabradorWitness::new_unchecked(vec![row0, row1, row2]) +} + +/// Estimate the full Labrador recursive proof for the Hachi handoff witness. +#[tracing::instrument(skip_all, name = "labrador::handoff_estimate")] +pub(crate) fn hachi_labrador_estimate< + F: FieldCore + CanonicalField + HachiSerialize, + const D: usize, +>( + witness: &LabradorWitness, + coeff_bit_bound: usize, +) -> Result { + estimate_handoff_recursive_proof::(witness, coeff_bit_bound) +} + +/// Execute the Labrador direct handoff from the Hachi folding loop. +/// +/// Instead of computing the quotient `r`, evaluating at alpha, and running +/// sumcheck, this function runs the quadratic equation at D' and hands the +/// ring-level `Mz = y` directly to Labrador. +/// +/// # Errors +/// +/// Propagates errors from the quad eq, Labrador config selection, or Labrador proving. +#[allow(clippy::too_many_arguments)] +#[tracing::instrument(skip_all, name = "labrador::handoff_prove")] +pub(crate) fn labrador_handoff_prove( + current_w: &[i8], + current_hint: &HachiCommitmentHint, + current_commitment: &RingCommitment, + current_challenges: &[F], + current_num_u: usize, + current_num_l: usize, + expanded_setup: &HachiExpandedSetup, + ntt_d: &NttSlotCache, + transcript: &mut T, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt + Valid, + T: Transcript, + Cfg: CommitmentConfig, +{ + let t0 = Instant::now(); + let mut handoff_transcript = transcript.clone(); + + let opening_point = tracing::info_span!("labrador::handoff_prepare_opening_point") + .in_scope(|| next_level_opening_point(current_challenges, current_num_u, current_num_l)); + + let alpha = D_HANDOFF.trailing_zeros() as usize; + if opening_point.len() < alpha { + return Err(HachiError::InvalidPointDimension { + expected: alpha, + actual: opening_point.len(), + }); + } + + let w_layout = >::commitment_layout(opening_point.len())?; + let direct_tail = PackedDigits::from_i8_digits(current_w, w_layout.log_basis); + let direct_hachi_tail_bytes = direct_tail.serialized_size(Compress::No); + let target_num_vars = w_layout.m_vars + w_layout.r_vars + alpha; + let mut padded_point = opening_point.clone(); + padded_point.resize(target_num_vars, F::zero()); + let outer_point = &padded_point[alpha..]; + + let ring_opening_point = + tracing::info_span!("labrador::handoff_ring_opening_point").in_scope(|| { + ring_opening_point_from_field::( + outer_point, + w_layout.r_vars, + w_layout.m_vars, + BasisMode::Lagrange, + ) + })?; + + let a_flat = &expanded_setup.A; + let b_flat = &expanded_setup.B; + let d_flat = &expanded_setup.D_mat; + + let (w_poly, y_ring, w_folded) = tracing::info_span!("labrador::handoff_fold_witness") + .in_scope(|| { + let w_poly = BalancedDigitPoly::::from_i8_digits(current_w)?; + let (y_ring, w_folded) = w_poly.evaluate_and_fold( + &ring_opening_point.b, + &ring_opening_point.a, + w_layout.block_len, + ); + Ok::<_, HachiError>((w_poly, y_ring, w_folded)) + })?; + + tracing::info_span!("labrador::handoff_absorb_claims").in_scope(|| { + current_commitment.append_to_transcript(ABSORB_COMMITMENT, &mut handoff_transcript); + for pt in &padded_point { + handoff_transcript.append_field(ABSORB_EVALUATION_CLAIMS, pt); + } + handoff_transcript.append_serde(ABSORB_EVALUATION_CLAIMS, &y_ring); + }); + + let quad_eq = tracing::info_span!("labrador::handoff_quad_eq").in_scope(|| { + let level_params = WCommitmentConfig::::level_params(HachiScheduleInputs { + max_num_vars: padded_point.len(), + level: 1, + current_w_len: current_w.len(), + }); + Ok::<_, HachiError>(Box::new(QuadraticEquation::< + F, + D_HANDOFF, + WCommitmentConfig, + >::new_prover( + ntt_d, + ring_opening_point.clone(), + &w_poly, + w_folded, + level_params, + current_hint.clone(), + &mut handoff_transcript, + current_commitment, + &y_ring, + w_layout, + )?)) + })?; + + tracing::debug!( + elapsed_s = t0.elapsed().as_secs_f64(), + "labrador_handoff quad_eq" + ); + + let t1 = Instant::now(); + + let w_hat_flat = quad_eq + .w_hat_flat() + .ok_or_else(|| HachiError::InvalidInput("missing w_hat_flat".into()))?; + let inner_opening_digits = &quad_eq + .hint() + .ok_or_else(|| HachiError::InvalidInput("missing hint".into()))? + .inner_opening_digits; + let inner_opening_digits_flat = flatten_i8_blocks(inner_opening_digits); + let z_pre = quad_eq + .z_pre() + .ok_or_else(|| HachiError::InvalidInput("missing z_pre".into()))?; + + let constraints = + build_hachi_labrador_constraints::>( + a_flat, + b_flat, + d_flat, + quad_eq.opening_point(), + &quad_eq.challenges, + &quad_eq.v, + ¤t_commitment.u, + &y_ring, + w_layout, + )?; + + let witness = build_labrador_witness(w_hat_flat, &inner_opening_digits_flat, z_pre, w_layout); + let witness_norm_bound_sq = witness.norm(); + + let estimate = hachi_labrador_estimate::(&witness, w_layout.log_basis as usize)?; + let plan = estimate.initial_plan.clone(); + let cfg = plan.config; + let handoff_row_lengths: Vec = witness.rows().iter().map(|row| row.len()).collect(); + let handoff_ring_elems: usize = handoff_row_lengths.iter().sum(); + let handoff_witness_bits = handoff_ring_elems * D_HANDOFF * logq_bits::(); + let handoff_witness_bytes = + FlatLabradorWitness::from_typed(&witness).serialized_size(Compress::No); + + let statement = LabradorStatement { + inner_opening_payload: Vec::new(), + linear_garbage_payload: Vec::new(), + challenges: Vec::new(), + constraints, + reduced_constraints: None, + witness_norm_bound_sq, + }; + + let comkey_seed = expanded_setup.labrador_comkey_seed(); + + tracing::debug!( + digits = current_w.len(), + log_basis = w_layout.log_basis, + raw_i8_bytes = current_w.len(), + packed_direct_bytes = direct_hachi_tail_bytes, + row_count = witness.rows().len(), + ?handoff_row_lengths, + total_ring_elems = handoff_ring_elems, + witness_bits = handoff_witness_bits, + serialized_bytes = handoff_witness_bytes, + witness_norm_bound_sq = %witness_norm_bound_sq, + max_row_len = handoff_row_lengths.iter().copied().max().unwrap_or(0), + virtual_row_len = plan.virtual_row_len, + row_split_counts = ?plan.row_split_counts, + witness_digit_parts = cfg.witness_digit_parts, + witness_digit_bits = cfg.witness_digit_bits, + aux_digit_parts = cfg.aux_digit_parts, + aux_digit_bits = cfg.aux_digit_bits, + inner_commit_rank = cfg.inner_commit_rank, + outer_commit_rank = cfg.outer_commit_rank, + tail = cfg.tail, + estimated_labrador_levels = estimate.level_count, + estimated_labrador_proof_bytes = estimate.proof_bytes, + estimated_labrador_final_witness_bytes = estimate.final_witness_bytes, + elapsed_s = t1.elapsed().as_secs_f64(), + rows = witness.rows().len(), + constraint_count = statement.constraints.len(), + "labrador_handoff witness/constraints" + ); + + let v_bytes = FlatRingVec::from_ring_elems(&quad_eq.v).serialized_size(Compress::No); + let y_ring_bytes = FlatRingVec::from_single(&y_ring).serialized_size(Compress::No); + let estimated_labrador_tail_bytes = estimate.proof_bytes + + v_bytes + + y_ring_bytes + + witness_norm_bound_sq.serialized_size(Compress::No); + tracing::info!( + packed_direct_bytes = direct_hachi_tail_bytes, + estimated_labrador_tail_bytes, + selected_tail = if estimated_labrador_tail_bytes < direct_hachi_tail_bytes { + "labrador" + } else { + "direct" + }, + estimated_labrador_proof_bytes = estimate.proof_bytes, + v_bytes, + y_ring_bytes, + witness_norm_bound_sq_bytes = witness_norm_bound_sq.serialized_size(Compress::No), + "labrador_handoff estimated tail comparison" + ); + if estimated_labrador_tail_bytes >= direct_hachi_tail_bytes { + return Ok(HachiProofTail::Direct(direct_tail)); + } + + let t2 = Instant::now(); + let labrador_proof = prove_with_plan::( + witness, + &statement, + &plan, + &comkey_seed, + &mut handoff_transcript, + )?; + #[cfg(debug_assertions)] + { + let roundtrip = FlatLabradorProof::from_typed(&labrador_proof).to_typed::(); + assert!( + roundtrip == labrador_proof, + "labrador handoff proof roundtrip must preserve the proof" + ); + + let mut self_verify_transcript = handoff_transcript.clone(); + verify_labrador::( + &statement, + &labrador_proof, + &comkey_seed, + &mut self_verify_transcript, + ) + .expect("freshly generated Labrador handoff proof must verify"); + } + *transcript = handoff_transcript; + + tracing::info!( + elapsed_s = t2.elapsed().as_secs_f64(), + levels = labrador_proof.levels.len(), + "labrador prove complete" + ); + + Ok(HachiProofTail::Labrador(Box::new(LabradorTail { + labrador_proof: FlatLabradorProof::from_typed(&labrador_proof), + v: FlatRingVec::from_ring_elems(&quad_eq.v), + y_ring: FlatRingVec::from_single(&y_ring), + witness_norm_bound_sq, + }))) +} + +/// Verify the direct Labrador tail of a Hachi proof. +/// +/// Replays the quadratic equation transcript operations (absorb commitment, +/// evaluation claims, v; derive challenges), rebuilds the ring-level Labrador +/// constraints, and verifies the Labrador recursive proof. +/// +/// # Errors +/// +/// Propagates errors from constraint reconstruction or Labrador verification. +#[allow(clippy::too_many_arguments)] +#[tracing::instrument(skip_all, name = "labrador::handoff_verify")] +pub(crate) fn labrador_handoff_verify( + tail: &LabradorTail, + opening_point: &[F], + opening_value: &F, + current_commitment: &RingCommitment, + expanded_setup: &HachiExpandedSetup, + transcript: &mut T, +) -> Result<(), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling + FromSmallInt + Valid, + T: Transcript, + Cfg: CommitmentConfig, +{ + let t0 = Instant::now(); + let alpha_prime = D_HANDOFF.trailing_zeros() as usize; + if opening_point.len() < alpha_prime { + return Err(HachiError::InvalidPointDimension { + expected: alpha_prime, + actual: opening_point.len(), + }); + } + + let v: Vec> = tail.v.to_vec(); + let y_ring: CyclotomicRing = tail.y_ring.to_single(); + let labrador_proof = tail.labrador_proof.to_typed::(); + + if !tracing::info_span!("labrador::handoff_match_opening_claim") + .in_scope(|| matches_opening_claim::(&y_ring, opening_point, opening_value)) + { + return Err(HachiError::InvalidProof); + } + + let w_layout = >::commitment_layout(opening_point.len())?; + let target_num_vars = w_layout.m_vars + w_layout.r_vars + alpha_prime; + let mut padded_point = opening_point.to_vec(); + padded_point.resize(target_num_vars, F::zero()); + let outer_point = &padded_point[alpha_prime..]; + + let ring_opening_point = + tracing::info_span!("labrador::handoff_ring_opening_point").in_scope(|| { + ring_opening_point_from_field::( + outer_point, + w_layout.r_vars, + w_layout.m_vars, + BasisMode::Lagrange, + ) + })?; + + // Replay transcript against the carried Hachi commitment. + tracing::info_span!("labrador::handoff_absorb_claims").in_scope(|| { + current_commitment.append_to_transcript(ABSORB_COMMITMENT, transcript); + for pt in &padded_point { + transcript.append_field(ABSORB_EVALUATION_CLAIMS, pt); + } + transcript.append_serde(ABSORB_EVALUATION_CLAIMS, &y_ring); + }); + + // Derive challenges via verifier-side quad eq (absorbs v, samples challenges). + let quad_eq = tracing::info_span!("labrador::handoff_quad_eq").in_scope(|| { + let level_params = WCommitmentConfig::::level_params(HachiScheduleInputs { + max_num_vars: padded_point.len(), + level: 1, + current_w_len: 0, + }); + QuadraticEquation::>::new_verifier( + ring_opening_point.clone(), + v.clone(), + level_params, + transcript, + current_commitment, + &y_ring, + w_layout, + ) + })?; + + let a_flat = &expanded_setup.A; + let b_flat = &expanded_setup.B; + let d_flat = &expanded_setup.D_mat; + + // Rebuild constraints from public data. + let constraints = + build_hachi_labrador_constraints::>( + a_flat, + b_flat, + d_flat, + quad_eq.opening_point(), + &quad_eq.challenges, + &v, + ¤t_commitment.u, + &y_ring, + w_layout, + )?; + + let statement = LabradorStatement { + inner_opening_payload: Vec::new(), + linear_garbage_payload: Vec::new(), + challenges: Vec::new(), + constraints, + reduced_constraints: None, + witness_norm_bound_sq: tail.witness_norm_bound_sq, + }; + + let comkey_seed = expanded_setup.labrador_comkey_seed(); + + let result = + verify_labrador::(&statement, &labrador_proof, &comkey_seed, transcript); + if result.is_ok() { + tracing::info!( + elapsed_s = t0.elapsed().as_secs_f64(), + levels = labrador_proof.levels.len(), + "labrador verify complete" + ); + } else { + tracing::error!( + elapsed_s = t0.elapsed().as_secs_f64(), + levels = labrador_proof.levels.len(), + "labrador verify FAIL" + ); + } + result?; + + Ok(()) +} + +fn gadget_scalars(levels: usize, log_basis: u32) -> Vec { + let base = F::from_canonical_u128_reduced(1u128 << log_basis); + let mut out = Vec::with_capacity(levels); + let mut power = F::one(); + for _ in 0..levels { + out.push(power); + power = power * base; + } + out +} + +fn matches_opening_claim( + y_ring: &CyclotomicRing, + opening_point: &[F], + opening_value: &F, +) -> bool { + let alpha = D.trailing_zeros() as usize; + let coeff_point = &opening_point[..alpha]; + let mut coeff_basis = vec![F::zero(); D]; + multilinear_lagrange_basis(&mut coeff_basis, coeff_point); + let inner_ring = CyclotomicRing::from_slice(&coeff_basis); + let d = F::from_u64(D as u64); + let trace_lhs = (*y_ring * inner_ring.sigma_m1()).coefficients()[0] * d; + trace_lhs == d * *opening_value +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::commitment::utils::linear::flatten_i8_blocks; + use crate::protocol::commitment::{ + HachiCommitmentCore, HachiScheduleInputs, RingCommitmentScheme, + }; + use crate::protocol::hachi_poly_ops::{DensePoly, HachiPolyOps}; + use crate::protocol::proof::HachiCommitmentHint; + use crate::protocol::quadratic_equation::QuadraticEquation; + use crate::protocol::transcript::Blake2bTranscript; + use crate::test_utils::*; + use crate::Transcript; + + const TRANSCRIPT_SEED: &[u8] = b"test/labrador_handoff"; + + /// Verify that the Labrador constraints built from the quad eq are + /// satisfied by the corresponding witness. + #[test] + fn constraints_satisfied_by_witness() { + let (setup, _) = + >::setup(16).unwrap(); + + let blocks = sample_blocks(); + let w = + >::commit_ring_blocks( + &blocks, &setup, + ) + .unwrap(); + + let point = RingOpeningPoint { + a: sample_a(), + b: sample_b(), + }; + + let ring_coeffs: Vec> = + blocks.iter().flat_map(|b| b.iter().copied()).collect(); + let poly = DensePoly::from_ring_coeffs(ring_coeffs); + let hint = HachiCommitmentHint::new(w.t_hat); + let mut transcript = Blake2bTranscript::::new(TRANSCRIPT_SEED); + let layout = setup.layout(); + let (y_ring, w_folded) = poly.evaluate_and_fold(&point.b, &point.a, layout.block_len); + let level_params = TinyConfig::level_params(HachiScheduleInputs { + max_num_vars: setup.expanded.seed.max_num_vars, + level: 0, + current_w_len: layout.num_blocks * layout.block_len * D, + }); + + let quad_eq = QuadraticEquation::::new_prover( + &setup.ntt_D, + point.clone(), + &poly, + w_folded, + level_params, + hint, + &mut transcript, + &w.commitment, + &y_ring, + layout, + ) + .unwrap(); + + let w_hat_flat = quad_eq.w_hat_flat().unwrap(); + let inner_opening_digits = &quad_eq.hint().unwrap().inner_opening_digits; + let inner_opening_digits_flat = flatten_i8_blocks(inner_opening_digits); + let z_pre = quad_eq.z_pre().unwrap(); + + let constraints = build_hachi_labrador_constraints::( + &setup.expanded.A, + &setup.expanded.B, + &setup.expanded.D_mat, + quad_eq.opening_point(), + &quad_eq.challenges, + &quad_eq.v, + &w.commitment.u, + &y_ring, + layout, + ) + .unwrap(); + + let witness = build_labrador_witness(w_hat_flat, &inner_opening_digits_flat, z_pre, layout); + + let rows = witness.rows(); + for (ci, constraint) in constraints.iter().enumerate() { + let mut lhs = CyclotomicRing::::zero(); + for term in &constraint.terms { + for (j, coeff) in term.coefficients.iter().enumerate() { + let idx = term.offset + j; + if idx < rows[term.row].len() { + lhs += *coeff * rows[term.row][idx]; + } + } + } + assert_eq!(lhs, constraint.target, "constraint {ci} not satisfied"); + } + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs new file mode 100644 index 00000000..40e20eba --- /dev/null +++ b/src/protocol/mod.rs @@ -0,0 +1,43 @@ +//! Protocol-layer transcript and commitment abstractions. +//! +//! This module defines the Hachi-native protocol interfaces used by higher-level +//! proof logic. It intentionally stays independent from external integration +//! details (for example, Jolt wiring). + +pub mod challenges; +pub mod commitment; +pub mod commitment_scheme; +pub mod dispatch; +pub mod hachi_poly_ops; +pub mod labrador; +pub mod labrador_handoff; +pub mod opening_point; +pub mod prg; +pub mod proof; +pub mod quadratic_equation; +pub mod ring_switch; +pub mod sumcheck; +pub mod transcript; + +pub use commitment::{ + optimal_m_r_split, AppendToTranscript, CommitmentConfig, CommitmentScheme, DummyProof, + DynamicSmallTestCommitmentConfig, Fp128BoundedCommitmentConfig, Fp128CommitmentConfig, + Fp128FullCommitmentConfig, Fp128HalvingDCommitmentConfig, Fp128LogBasisCommitmentConfig, + Fp128OneHotCommitmentConfig, HachiCommitment, HachiCommitmentCore, HachiCommitmentLayout, + HachiExpandedSetup, HachiOpeningClaim, HachiOpeningPoint, HachiProverSetup, HachiSetupSeed, + HachiVerifierSetup, RingCommitment, RingCommitmentScheme, SmallTestCommitmentConfig, +}; +pub use commitment_scheme::HachiCommitmentScheme; +pub use hachi_poly_ops::{DensePoly, HachiPolyOps, OneHotIndex, OneHotPoly}; +pub use opening_point::{BasisMode, RingOpeningPoint}; +pub use proof::{ + FlatCommitmentHint, FlatLabradorLevelProof, FlatLabradorProof, FlatLabradorWitness, + FlatRingVec, HachiLevelProof, HachiProof, HachiProofTail, PackedDigits, +}; +pub use quadratic_equation::QuadraticEquation; +pub use sumcheck::batched_sumcheck::{prove_batched_sumcheck, verify_batched_sumcheck}; +pub use sumcheck::{ + prove_sumcheck, verify_sumcheck, CompressedUniPoly, SumcheckInstanceProver, + SumcheckInstanceVerifier, SumcheckProof, UniPoly, +}; +pub use transcript::{sample_ext_challenge, Blake2bTranscript, KeccakTranscript, Transcript}; diff --git a/src/protocol/opening_point.rs b/src/protocol/opening_point.rs new file mode 100644 index 00000000..812ad6b6 --- /dev/null +++ b/src/protocol/opening_point.rs @@ -0,0 +1,126 @@ +//! Ring-native opening point for the Hachi protocol. + +use crate::algebra::CyclotomicRing; +use crate::error::HachiError; +use crate::primitives::poly::multilinear_lagrange_basis; +use crate::FieldCore; + +/// Polynomial basis mode for the evaluation relation. +/// +/// Determines how the polynomial's values are interpreted during an opening +/// proof. The commitment itself is basis-agnostic; the basis only affects +/// the tensor-product weights used in `prove` and `verify`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BasisMode { + /// Evaluations over the boolean hypercube. + /// + /// The weight vector is `⊗ᵢ (1 − xᵢ, xᵢ)` (multilinear Lagrange basis). + /// Use when the committed values are `f(b)` for `b ∈ {0,1}^n`. + Lagrange, + + /// Coefficients of multilinear monomials. + /// + /// The weight vector is `⊗ᵢ (1, xᵢ)`. + /// Use when the committed values are the coefficients `c_S` such that + /// `f(x) = Σ_S c_S · ∏_{i ∈ S} x_i`. + Monomial, +} + +/// Ring-native opening point storing field scalars. +/// +/// Contains the two vectors used by the §4.2 prover: +/// - `a`: evaluation vector of length `2^m` (inner-block coordinates). +/// - `b`: block-select vector of length `2^r` (outer coordinates). +/// +/// These are raw field scalars, not ring elements — they originate from +/// basis weight evaluations (Lagrange or monomial) and are always constant +/// (scalar) ring elements when embedded into the ring. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RingOpeningPoint { + /// Evaluation vector of length `2^m` (field scalars). + pub a: Vec, + /// Block-select vector of length `2^r` (field scalars). + pub b: Vec, +} + +/// Multilinear Lagrange weights: `⊗ᵢ (1 − xᵢ, xᵢ)`. +pub fn lagrange_weights(point: &[F]) -> Vec { + let len = 1usize << point.len(); + let mut weights = vec![F::zero(); len]; + multilinear_lagrange_basis(&mut weights, point); + weights +} + +/// Multilinear monomial weights: `⊗ᵢ (1, xᵢ)`. +/// +/// The j-th entry is `∏_{i ∈ bits(j)} point[i]`. +pub fn monomial_weights(point: &[F]) -> Vec { + let len = 1usize << point.len(); + let mut weights = vec![F::zero(); len]; + weights[0] = F::one(); + for (level, &p) in point.iter().enumerate() { + let k = 1usize << level; + for i in (0..k).rev() { + weights[i + k] = weights[i] * p; + } + } + weights +} + +/// Return tensor-product weights for one opening point under the chosen basis. +pub fn basis_weights(point: &[F], basis: BasisMode) -> Vec { + match basis { + BasisMode::Lagrange => lagrange_weights(point), + BasisMode::Monomial => monomial_weights(point), + } +} + +/// Convert the outer portion of a field opening point into ring-native vectors. +/// +/// The first `m_vars` coordinates select the position within each block; the +/// remaining `r_vars` coordinates select which block is opened. +/// +/// # Errors +/// +/// Returns an error if `m_vars + r_vars` overflows or if `opening_point` has +/// the wrong length. +pub fn ring_opening_point_from_field( + opening_point: &[F], + r_vars: usize, + m_vars: usize, + basis: BasisMode, +) -> Result, HachiError> { + let expected_len = r_vars + .checked_add(m_vars) + .ok_or_else(|| HachiError::InvalidSetup("opening point length overflow".to_string()))?; + if opening_point.len() != expected_len { + return Err(HachiError::InvalidPointDimension { + expected: expected_len, + actual: opening_point.len(), + }); + } + + let a = basis_weights(&opening_point[..m_vars], basis); + let b = basis_weights(&opening_point[m_vars..], basis); + Ok(RingOpeningPoint { a, b }) +} + +/// Reduce the inner `alpha = log2(D)` opening coordinates to one ring element. +/// +/// # Errors +/// +/// Returns an error if the number of basis weights implied by `inner_point` +/// does not match `D`. +pub fn reduce_inner_opening_to_ring_element( + inner_point: &[F], + basis: BasisMode, +) -> Result, HachiError> { + let weights = basis_weights(inner_point, basis); + if weights.len() != D { + return Err(HachiError::InvalidInput(format!( + "inner basis length {} does not match D={D}", + weights.len() + ))); + } + Ok(CyclotomicRing::from_slice(&weights)) +} diff --git a/src/protocol/prg.rs b/src/protocol/prg.rs new file mode 100644 index 00000000..f1eb6581 --- /dev/null +++ b/src/protocol/prg.rs @@ -0,0 +1,361 @@ +//! Matrix PRG backends shared by commitment/JL derivation. +//! +//! The PRG is keyed per matrix entry using domain-separated context bytes. + +use aes::Aes128; +use ctr::cipher::{KeyIvInit, StreamCipher}; +use rand_core::{CryptoRng, RngCore}; +use sha3::digest::{ExtendableOutput, Update, XofReader}; +use sha3::Shake256; + +const MATRIX_PRG_DOMAIN: &[u8] = b"hachi/matrix-prg"; +const MATRIX_PRG_SHAKE_DOMAIN: &[u8] = b"hachi/matrix-prg/shake256"; +const MATRIX_PRG_AES_DOMAIN: &[u8] = b"hachi/matrix-prg/aes128-ctr"; + +/// Stable backend identifiers for transcript/context binding. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum MatrixPrgBackendId { + /// SHAKE256 XOF backend. + Shake256 = 0, + /// AES-128-CTR backend. + Aes128Ctr = 1, +} + +impl TryFrom for MatrixPrgBackendId { + type Error = crate::error::HachiError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(Self::Shake256), + 1 => Ok(Self::Aes128Ctr), + _ => Err(crate::error::HachiError::InvalidInput(format!( + "unknown matrix PRG backend id: {value}" + ))), + } + } +} + +impl From for u8 { + fn from(value: MatrixPrgBackendId) -> Self { + value as u8 + } +} + +/// Input context used for deterministic matrix-entry sampling. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MatrixPrgContext<'a> { + /// Public seed. + pub seed: &'a [u8; 32], + /// Matrix label (`A`, `B`, `D`, etc.). + pub matrix_label: &'a [u8], + /// Matrix row count. + pub rows: usize, + /// Matrix column count. + pub cols: usize, + /// Matrix-entry row index. + pub row: usize, + /// Matrix-entry column index. + pub col: usize, +} + +/// Backend trait for matrix-entry PRG streams. +pub trait MatrixPrgBackend: Clone + Send + Sync + 'static { + /// Stable backend identifier. + fn backend_id(&self) -> MatrixPrgBackendId; + /// Construct a stream RNG for one matrix entry. + fn entry_rng(&self, context: &MatrixPrgContext<'_>) -> MatrixPrgRng; +} + +/// Runtime backend selector. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MatrixPrgBackendChoice { + /// SHAKE256 XOF stream. + Shake256, + /// AES-128-CTR stream. + Aes128Ctr, +} + +impl MatrixPrgBackendChoice { + /// Return the stable backend id. + pub fn backend_id(self) -> MatrixPrgBackendId { + match self { + Self::Shake256 => MatrixPrgBackendId::Shake256, + Self::Aes128Ctr => MatrixPrgBackendId::Aes128Ctr, + } + } + + /// Construct a stream RNG for one matrix entry. + pub fn entry_rng(self, context: &MatrixPrgContext<'_>) -> MatrixPrgRng { + match self { + Self::Shake256 => Shake256Backend.entry_rng(context), + Self::Aes128Ctr => Aes128CtrBackend.entry_rng(context), + } + } +} + +impl Default for MatrixPrgBackendChoice { + fn default() -> Self { + Self::Shake256 + } +} + +/// SHAKE256 backend implementation. +#[derive(Debug, Clone, Copy, Default)] +pub struct Shake256Backend; + +impl MatrixPrgBackend for Shake256Backend { + fn backend_id(&self) -> MatrixPrgBackendId { + MatrixPrgBackendId::Shake256 + } + + fn entry_rng(&self, context: &MatrixPrgContext<'_>) -> MatrixPrgRng { + MatrixPrgRng::Shake(ShakeEntryRng::new(context)) + } +} + +/// AES-128-CTR backend implementation. +#[derive(Debug, Clone, Copy, Default)] +pub struct Aes128CtrBackend; + +impl MatrixPrgBackend for Aes128CtrBackend { + fn backend_id(&self) -> MatrixPrgBackendId { + MatrixPrgBackendId::Aes128Ctr + } + + fn entry_rng(&self, context: &MatrixPrgContext<'_>) -> MatrixPrgRng { + let (key, iv) = derive_aes_key_iv(context); + // On aarch64, the `aes` crate uses target-feature intrinsics when + // available; we still gate this branch for explicit architecture intent. + #[cfg(target_arch = "aarch64")] + { + if std::arch::is_aarch64_feature_detected!("aes") { + return MatrixPrgRng::AesCtr(Aes128CtrEntryRng::new(&key, &iv)); + } + } + // TODO(x86_64): add explicit AES-NI runtime path selection once CI has + // dedicated hardware coverage. Today we use the `aes` crate default. + #[cfg(target_arch = "x86_64")] + { + let _ = std::arch::is_x86_feature_detected!("aes"); + } + MatrixPrgRng::AesCtr(Aes128CtrEntryRng::new(&key, &iv)) + } +} + +/// Matrix-entry RNG wrapper over supported PRG backends. +#[allow(clippy::large_enum_variant)] +pub enum MatrixPrgRng { + /// SHAKE256 XOF-backed RNG. + Shake(ShakeEntryRng), + /// AES-128-CTR-backed RNG. + AesCtr(Aes128CtrEntryRng), +} + +impl RngCore for MatrixPrgRng { + fn next_u32(&mut self) -> u32 { + let mut buf = [0u8; 4]; + self.fill_bytes(&mut buf); + u32::from_le_bytes(buf) + } + + fn next_u64(&mut self) -> u64 { + let mut buf = [0u8; 8]; + self.fill_bytes(&mut buf); + u64::from_le_bytes(buf) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + match self { + Self::Shake(rng) => rng.fill_bytes(dest), + Self::AesCtr(rng) => rng.fill_bytes(dest), + } + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +impl CryptoRng for MatrixPrgRng {} + +/// SHAKE256-backed matrix-entry RNG. +pub struct ShakeEntryRng { + reader: Box, +} + +impl ShakeEntryRng { + fn new(context: &MatrixPrgContext<'_>) -> Self { + let mut xof = Shake256::default(); + absorb_matrix_context(&mut xof, MATRIX_PRG_SHAKE_DOMAIN, context); + Self { + reader: Box::new(xof.finalize_xof()), + } + } +} + +impl RngCore for ShakeEntryRng { + fn next_u32(&mut self) -> u32 { + let mut buf = [0u8; 4]; + self.fill_bytes(&mut buf); + u32::from_le_bytes(buf) + } + + fn next_u64(&mut self) -> u64 { + let mut buf = [0u8; 8]; + self.fill_bytes(&mut buf); + u64::from_le_bytes(buf) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.reader.read(dest); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +impl CryptoRng for ShakeEntryRng {} + +type AesCtrCipher = ctr::Ctr128BE; + +/// AES-128-CTR-backed matrix-entry RNG. +pub struct Aes128CtrEntryRng { + cipher: AesCtrCipher, +} + +impl Aes128CtrEntryRng { + fn new(key: &[u8; 16], iv: &[u8; 16]) -> Self { + Self { + cipher: AesCtrCipher::new(key.into(), iv.into()), + } + } +} + +impl RngCore for Aes128CtrEntryRng { + fn next_u32(&mut self) -> u32 { + let mut buf = [0u8; 4]; + self.fill_bytes(&mut buf); + u32::from_le_bytes(buf) + } + + fn next_u64(&mut self) -> u64 { + let mut buf = [0u8; 8]; + self.fill_bytes(&mut buf); + u64::from_le_bytes(buf) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + dest.fill(0u8); + self.cipher.apply_keystream(dest); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +impl CryptoRng for Aes128CtrEntryRng {} + +fn derive_aes_key_iv(context: &MatrixPrgContext<'_>) -> ([u8; 16], [u8; 16]) { + let mut xof = Shake256::default(); + absorb_matrix_context(&mut xof, MATRIX_PRG_AES_DOMAIN, context); + let mut out = [0u8; 32]; + xof.finalize_xof().read(&mut out); + let key: [u8; 16] = out[..16].try_into().expect("XOF produced 32 bytes"); + let iv: [u8; 16] = out[16..].try_into().expect("XOF produced 32 bytes"); + (key, iv) +} + +fn absorb_matrix_context( + xof: &mut Shake256, + backend_domain: &[u8], + context: &MatrixPrgContext<'_>, +) { + absorb_len_prefixed(xof, b"domain", MATRIX_PRG_DOMAIN); + absorb_len_prefixed(xof, b"backend", backend_domain); + absorb_len_prefixed(xof, b"seed", context.seed); + absorb_len_prefixed(xof, b"matrix", context.matrix_label); + absorb_len_prefixed(xof, b"rows", &(context.rows as u64).to_le_bytes()); + absorb_len_prefixed(xof, b"cols", &(context.cols as u64).to_le_bytes()); + absorb_len_prefixed(xof, b"row", &(context.row as u64).to_le_bytes()); + absorb_len_prefixed(xof, b"col", &(context.col as u64).to_le_bytes()); +} + +fn absorb_len_prefixed(xof: &mut Shake256, label: &[u8], data: &[u8]) { + xof.update(&(label.len() as u64).to_le_bytes()); + xof.update(label); + xof.update(&(data.len() as u64).to_le_bytes()); + xof.update(data); +} + +#[cfg(test)] +mod tests { + use super::*; + + fn context<'a>(seed: &'a [u8; 32], row: usize, col: usize) -> MatrixPrgContext<'a> { + MatrixPrgContext { + seed, + matrix_label: b"A", + rows: 4, + cols: 5, + row, + col, + } + } + + #[test] + fn shake_backend_is_deterministic() { + let seed = [42u8; 32]; + let ctx = context(&seed, 1, 3); + let mut rng1 = Shake256Backend.entry_rng(&ctx); + let mut rng2 = Shake256Backend.entry_rng(&ctx); + let mut a = [0u8; 96]; + let mut b = [0u8; 96]; + rng1.fill_bytes(&mut a); + rng2.fill_bytes(&mut b); + assert_eq!(a, b); + } + + #[test] + fn aes_backend_is_deterministic() { + let seed = [7u8; 32]; + let ctx = context(&seed, 0, 2); + let mut rng1 = Aes128CtrBackend.entry_rng(&ctx); + let mut rng2 = Aes128CtrBackend.entry_rng(&ctx); + let mut a = [0u8; 96]; + let mut b = [0u8; 96]; + rng1.fill_bytes(&mut a); + rng2.fill_bytes(&mut b); + assert_eq!(a, b); + } + + #[test] + fn row_col_changes_separate_streams() { + let seed = [9u8; 32]; + let mut rng_a = Shake256Backend.entry_rng(&context(&seed, 0, 0)); + let mut rng_b = Shake256Backend.entry_rng(&context(&seed, 0, 1)); + let mut a = [0u8; 64]; + let mut b = [0u8; 64]; + rng_a.fill_bytes(&mut a); + rng_b.fill_bytes(&mut b); + assert_ne!(a, b); + } + + #[test] + fn backend_choice_changes_stream() { + let seed = [5u8; 32]; + let ctx = context(&seed, 2, 4); + let mut shake = MatrixPrgBackendChoice::Shake256.entry_rng(&ctx); + let mut aes = MatrixPrgBackendChoice::Aes128Ctr.entry_rng(&ctx); + let mut a = [0u8; 64]; + let mut b = [0u8; 64]; + shake.fill_bytes(&mut a); + aes.fill_bytes(&mut b); + assert_ne!(a, b); + } +} diff --git a/src/protocol/proof.rs b/src/protocol/proof.rs new file mode 100644 index 00000000..2cf01a6c --- /dev/null +++ b/src/protocol/proof.rs @@ -0,0 +1,1528 @@ +//! Proof structures for the Hachi protocol. + +use crate::algebra::CyclotomicRing; +use crate::error::HachiError; +use crate::primitives::serialization::{Compress, SerializationError}; +use crate::primitives::serialization::{Valid, Validate}; +use crate::protocol::commitment::RingCommitment; +use crate::protocol::labrador::types::{ + LabradorLevelProof, LabradorProof, LabradorReductionConfig, LabradorWitness, +}; +use crate::protocol::sumcheck::SumcheckProof; +use crate::{CanonicalField, FieldCore, FromSmallInt, HachiDeserialize, HachiSerialize}; +use std::io::{Read, Write}; +use std::marker::PhantomData; + +/// Bit-packed balanced digits for the final-level witness vector. +/// +/// Each element is a signed value in `[-b/2, b/2)` where `b = 2^bits_per_elem`, +/// stored in two's-complement using exactly `bits_per_elem` bits per value. +/// This reduces proof size by ~32x compared to storing full field elements. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PackedDigits { + /// Number of logical elements. + pub num_elems: usize, + /// Bits per element (= `log_basis` from the commitment config). + pub bits_per_elem: u32, + /// Bit-packed two's-complement data. + pub data: Vec, +} + +/// Precomputed lookup table mapping balanced digit index → field element. +/// +/// Wraps `FromSmallInt::digit_lut` with convenient signed-digit indexing. +/// Index a digit `d ∈ [-b/2, b/2)` via [`get`](DigitLut::get). +pub(crate) struct DigitLut { + table: [F; 16], + half_b: i8, +} + +impl DigitLut { + #[inline] + pub(crate) fn new(log_basis: u32) -> Self { + let half_b = 1i8 << (log_basis - 1); + Self { + table: F::digit_lut(log_basis), + half_b, + } + } + + #[inline(always)] + pub(crate) fn get(&self, d: i8) -> F { + self.table[(d + self.half_b) as usize] + } +} + +impl PackedDigits { + /// Pack balanced i8 digits into bit-packed form. + /// + /// Each element must be in `[-b/2, b/2)` where `b = 2^log_basis`. + /// + /// # Panics + /// + /// Panics (in debug) if any element does not fit in `log_basis` bits. + pub fn from_i8_digits(w: &[i8], log_basis: u32) -> Self { + assert!(log_basis > 0 && log_basis <= 7, "log_basis out of range"); + let half_b = 1i8 << (log_basis - 1); + + let bits = log_basis as usize; + let total_bits = w.len() * bits; + let num_bytes = total_bits.div_ceil(8); + let mut data = vec![0u8; num_bytes]; + + for (i, &signed) in w.iter().enumerate() { + debug_assert!( + signed >= -half_b && signed < half_b, + "digit {signed} out of range for log_basis={log_basis}" + ); + let unsigned = (signed as u8) & ((1u8 << bits) - 1); + let bit_offset = i * bits; + let byte_idx = bit_offset / 8; + let bit_idx = bit_offset % 8; + data[byte_idx] |= unsigned << bit_idx; + if bit_idx + bits > 8 { + data[byte_idx + 1] |= unsigned >> (8 - bit_idx); + } + } + + Self { + num_elems: w.len(), + bits_per_elem: log_basis, + data, + } + } + + /// Unpack to field elements using a precomputed lookup table. + pub fn to_field_elems(&self) -> Vec { + let bits = self.bits_per_elem as usize; + let mask = (1u8 << bits) - 1; + let sign_bit = 1u8 << (bits - 1); + let lut = DigitLut::::new(self.bits_per_elem); + + let mut out = Vec::with_capacity(self.num_elems); + for i in 0..self.num_elems { + let bit_offset = i * bits; + let byte_idx = bit_offset / 8; + let bit_idx = bit_offset % 8; + let mut raw = (self.data[byte_idx] >> bit_idx) & mask; + if bit_idx + bits > 8 { + raw |= (self.data[byte_idx + 1] << (8 - bit_idx)) & mask; + } + let signed = if raw & sign_bit != 0 { + raw as i8 | !(mask as i8) + } else { + raw as i8 + }; + out.push(lut.get(signed)); + } + out + } + + /// Number of packed data bytes. + pub fn packed_byte_len(&self) -> usize { + self.data.len() + } +} + +impl HachiSerialize for PackedDigits { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (self.num_elems as u64).serialize_with_mode(&mut writer, compress)?; + (self.bits_per_elem as u8).serialize_with_mode(&mut writer, compress)?; + writer.write_all(&self.data)?; + Ok(()) + } + + fn serialized_size(&self, _compress: Compress) -> usize { + 8 + 1 + self.data.len() + } +} + +impl Valid for PackedDigits { + fn check(&self) -> Result<(), SerializationError> { + if self.bits_per_elem == 0 || self.bits_per_elem > 7 { + return Err(SerializationError::InvalidData( + "bits_per_elem out of range".to_string(), + )); + } + let expected_bytes = (self.num_elems * self.bits_per_elem as usize).div_ceil(8); + if self.data.len() != expected_bytes { + return Err(SerializationError::InvalidData( + "packed data length mismatch".to_string(), + )); + } + Ok(()) + } +} + +impl HachiDeserialize for PackedDigits { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let num_elems = u64::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let bits_per_elem = u8::deserialize_with_mode(&mut reader, compress, validate)? as u32; + let num_bytes = (num_elems * bits_per_elem as usize).div_ceil(8); + let mut data = vec![0u8; num_bytes]; + reader.read_exact(&mut data)?; + let out = Self { + num_elems, + bits_per_elem, + data, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// D-erased storage for a sequence of ring elements as raw field-element +/// coefficients. +/// +/// Each ring element of dimension `ring_dim` is stored as `ring_dim` +/// contiguous field elements in `coeffs`. The total number of ring elements +/// is `coeffs.len() / ring_dim`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FlatRingVec { + coeffs: Vec, + ring_dim: usize, +} + +impl FlatRingVec { + /// Wrap a single ring element. + pub fn from_single(r: &CyclotomicRing) -> Self { + Self { + coeffs: r.coefficients().to_vec(), + ring_dim: D, + } + } + + /// Wrap a slice of ring elements. + pub fn from_ring_elems(elems: &[CyclotomicRing]) -> Self { + let mut coeffs = Vec::with_capacity(elems.len() * D); + for e in elems { + coeffs.extend_from_slice(e.coefficients()); + } + Self { + coeffs, + ring_dim: D, + } + } + + /// Wrap a `RingCommitment`. + pub fn from_commitment(c: &RingCommitment) -> Self { + Self::from_ring_elems(&c.u) + } + + /// Ring dimension (number of field-element coefficients per ring element). + pub fn ring_dim(&self) -> usize { + self.ring_dim + } + + /// Number of ring elements stored. + pub fn count(&self) -> usize { + if self.ring_dim == 0 { + 0 + } else { + self.coeffs.len() / self.ring_dim + } + } + + /// Raw coefficient slice. + pub fn coeffs(&self) -> &[F] { + &self.coeffs + } + + /// Reconstruct a single ring element. + /// + /// # Panics + /// + /// Panics if `D != ring_dim` or `count() != 1`. + pub fn to_single(&self) -> CyclotomicRing { + assert_eq!(D, self.ring_dim, "D mismatch in to_single"); + assert_eq!(self.count(), 1, "expected exactly one ring element"); + CyclotomicRing::from_slice(&self.coeffs) + } + + /// Reconstruct a single ring element, returning `InvalidProof` on shape mismatch. + /// + /// # Errors + /// + /// Returns [`HachiError::InvalidProof`] if the stored ring dimension or + /// element count does not match `D`. + pub fn try_to_single(&self) -> Result, HachiError> { + if self.ring_dim != D || self.coeffs.len() != D { + return Err(HachiError::InvalidProof); + } + Ok(CyclotomicRing::from_slice(&self.coeffs)) + } + + /// Reconstruct a vector of ring elements. + /// + /// # Panics + /// + /// Panics if `D != ring_dim`. + pub fn to_vec(&self) -> Vec> { + assert_eq!(D, self.ring_dim, "D mismatch in to_vec"); + self.coeffs + .chunks_exact(D) + .map(CyclotomicRing::from_slice) + .collect() + } + + /// Reconstruct a vector of ring elements, returning `InvalidProof` on shape mismatch. + /// + /// # Errors + /// + /// Returns [`HachiError::InvalidProof`] if the stored ring dimension does + /// not match `D` or the coefficient buffer is not an exact multiple of `D`. + pub fn try_to_vec(&self) -> Result>, HachiError> { + if self.ring_dim != D || self.coeffs.len() % D != 0 { + return Err(HachiError::InvalidProof); + } + Ok(self + .coeffs + .chunks_exact(D) + .map(CyclotomicRing::from_slice) + .collect()) + } + + /// Reconstruct a `RingCommitment`. + /// + /// # Panics + /// + /// Panics if `D != ring_dim`. + pub fn to_ring_commitment(&self) -> RingCommitment { + RingCommitment { u: self.to_vec() } + } + + /// Reconstruct a `RingCommitment`, returning `InvalidProof` on shape mismatch. + /// + /// # Errors + /// + /// Returns [`HachiError::InvalidProof`] if the stored ring data is not + /// well-formed for ring dimension `D`. + pub fn try_to_ring_commitment( + &self, + ) -> Result, HachiError> { + Ok(RingCommitment { + u: self.try_to_vec()?, + }) + } +} + +impl HachiSerialize for FlatRingVec { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (self.ring_dim as u32).serialize_with_mode(&mut writer, compress)?; + self.coeffs.serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + 4 + self.coeffs.serialized_size(compress) + } +} + +impl Valid for FlatRingVec { + fn check(&self) -> Result<(), SerializationError> { + if self.ring_dim == 0 { + return Err(SerializationError::InvalidData( + "ring_dim must be > 0".to_string(), + )); + } + if self.coeffs.len() % self.ring_dim != 0 { + return Err(SerializationError::InvalidData( + "coeffs length not a multiple of ring_dim".to_string(), + )); + } + Ok(()) + } +} + +impl HachiDeserialize for FlatRingVec { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let ring_dim = u32::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let coeffs = Vec::::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { coeffs, ring_dim }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// D-erased commitment hint for cross-level storage. +/// +/// Stores the decomposed inner-opening digit blocks (formerly `t_hat`) as a +/// flat `Vec` with metadata about block sizes and ring dimension. Convert +/// to/from the typed +/// [`HachiCommitmentHint`] via [`from_typed`](Self::from_typed) and +/// [`to_typed`](Self::to_typed). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FlatCommitmentHint { + data: Vec, + block_sizes: Vec, + ring_dim: usize, +} + +impl FlatCommitmentHint { + /// Convert from a typed hint, consuming it. + pub fn from_typed(hint: HachiCommitmentHint) -> Self { + let block_sizes: Vec = hint.inner_opening_digits.iter().map(|b| b.len()).collect(); + let total_planes: usize = block_sizes.iter().sum(); + let mut data = Vec::with_capacity(total_planes * D); + for block in &hint.inner_opening_digits { + for plane in block { + data.extend_from_slice(plane); + } + } + Self { + data, + block_sizes, + ring_dim: D, + } + } + + /// Reconstruct a typed hint. + /// + /// # Panics + /// + /// Panics if `D != ring_dim`. + pub fn to_typed(&self) -> HachiCommitmentHint { + assert_eq!(D, self.ring_dim, "D mismatch in to_typed"); + let mut inner_opening_digits = Vec::with_capacity(self.block_sizes.len()); + let mut offset = 0; + for &block_size in &self.block_sizes { + let mut block = Vec::with_capacity(block_size); + for _ in 0..block_size { + let mut plane = [0i8; D]; + plane.copy_from_slice(&self.data[offset..offset + D]); + offset += D; + block.push(plane); + } + inner_opening_digits.push(block); + } + HachiCommitmentHint::new(inner_opening_digits) + } + + /// Ring dimension stored in this hint. + pub fn ring_dim(&self) -> usize { + self.ring_dim + } + + /// Empty hint (verifier side, where hint data is not available). + pub fn empty() -> Self { + Self { + data: Vec::new(), + block_sizes: Vec::new(), + ring_dim: 0, + } + } +} + +/// Prover-side hint produced at commitment time. +/// +/// Contains the decomposed inner-opening digits (formerly `t_hat`) needed by +/// the ring-switch step of the prover. The polynomial itself (ring +/// coefficients) is passed separately to `prove` via `HachiPolyOps`. +#[derive(Debug, Clone)] +pub struct HachiCommitmentHint { + /// Decomposed inner-opening digit blocks from the commitment phase as i8 + /// digit planes (formerly `t_hat`). + pub inner_opening_digits: Vec>, + /// Optional recomposed `t_i` rows cached for prover-side A-row work. + t: Option>>>, + _marker: PhantomData, +} + +impl HachiCommitmentHint { + /// Construct a new hint from i8 digit plane blocks. + pub fn new(inner_opening_digits: Vec>) -> Self { + Self { + inner_opening_digits, + t: None, + _marker: PhantomData, + } + } + + /// Construct a hint that also preserves the undecomposed `t_i` rows. + pub fn with_t( + inner_opening_digits: Vec>, + t: Vec>>, + ) -> Self { + Self { + inner_opening_digits, + t: Some(t), + _marker: PhantomData, + } + } + + /// Get the optional recomposed `t_i` rows. + pub fn t(&self) -> Option<&[Vec>]> { + self.t.as_deref() + } + + /// Populate the recomposed `t_i` rows from the inner-opening digits when + /// they are absent. + /// + /// # Errors + /// + /// Returns an error if `num_digits_open` is zero or if any inner-opening + /// digit block length is not a multiple of `num_digits_open`. + pub fn ensure_t_recomposed( + &mut self, + num_digits_open: usize, + log_basis: u32, + ) -> Result<(), HachiError> + where + F: CanonicalField, + { + if self.t.is_some() { + return Ok(()); + } + if num_digits_open == 0 { + return Err(HachiError::InvalidSetup( + "num_digits_open must be nonzero when recomposing inner-opening digits".to_string(), + )); + } + + let t = self + .inner_opening_digits + .iter() + .map(|block| { + if block.len() % num_digits_open != 0 { + return Err(HachiError::InvalidSetup(format!( + "inner-opening digit block has {} planes, expected a multiple of num_digits_open={num_digits_open}", + block.len() + ))); + } + Ok(block + .chunks(num_digits_open) + .map(|digits| CyclotomicRing::gadget_recompose_pow2_i8(digits, log_basis)) + .collect()) + }) + .collect::>>, HachiError>>()?; + self.t = Some(t); + Ok(()) + } +} + +impl PartialEq for HachiCommitmentHint { + fn eq(&self, other: &Self) -> bool { + self.inner_opening_digits == other.inner_opening_digits + } +} + +impl Eq for HachiCommitmentHint {} + +/// Proof payload for stage 1 of a single Hachi level. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiStage1Proof { + /// Stage-1 sumcheck proof over the virtual `S = w(w+1)` table. + pub sumcheck: SumcheckProof, + /// Claimed evaluation of `S` at the stage-1 output point. + pub s_claim: F, +} + +/// Proof payload for stage 2 of a single Hachi level. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiStage2Proof { + /// Stage-2 fused sumcheck proof. + pub sumcheck: SumcheckProof, + /// Commitment to the next witness `w` + /// (ring dim = next level's D, may differ from y_ring/v). + pub next_w_commitment: FlatRingVec, + /// Claimed evaluation of the next witness `w` at the stage-2 challenge point. + pub next_w_eval: F, +} + +/// Proof for a single fold level (quad_eq + ring_switch + sumcheck). +/// +/// D-agnostic: ring elements are stored as [`FlatRingVec`] with their +/// ring dimension recorded. Use [`Self::y_ring_typed`], [`Self::v_typed`], and +/// [`Self::w_commitment_typed`] to reconstruct typed ring elements. +/// +/// One recursive Hachi level proof, split into `stage1` and `stage2` payloads. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiLevelProof { + /// `y_ring` from the §3.1 reduction (ring dim = current level's D). + pub y_ring: FlatRingVec, + /// `v = D · ŵ` (ring dim = current level's D). + pub v: FlatRingVec, + /// Stage-1 proof payload. + pub stage1: HachiStage1Proof, + /// Stage-2 proof payload. + pub stage2: HachiStage2Proof, +} + +impl HachiLevelProof { + /// Construct from typed ring elements for the current level and a + /// pre-erased `FlatRingVec` for the w-commitment (which may be at a + /// different D). + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + y_ring: CyclotomicRing, + v: Vec>, + stage1_sumcheck: SumcheckProof, + stage1_s_claim: F, + stage2_sumcheck: SumcheckProof, + next_w_commitment: FlatRingVec, + next_w_eval: F, + ) -> Self { + Self { + y_ring: FlatRingVec::from_single(&y_ring), + v: FlatRingVec::from_ring_elems(&v), + stage1: HachiStage1Proof { + sumcheck: stage1_sumcheck, + s_claim: stage1_s_claim, + }, + stage2: HachiStage2Proof { + sumcheck: stage2_sumcheck, + next_w_commitment, + next_w_eval, + }, + } + } + + /// Ring dimension of y_ring and v (current level). + pub fn level_d(&self) -> usize { + self.y_ring.ring_dim() + } + + /// Ring dimension of the w_commitment (next level). + pub fn w_commit_d(&self) -> usize { + self.stage2.next_w_commitment.ring_dim() + } + + /// Reconstruct typed `y_ring`. + /// + /// # Panics + /// + /// Panics if `D` does not match the stored ring dimension. + pub fn y_ring_typed(&self) -> CyclotomicRing { + self.y_ring.to_single() + } + + /// Reconstruct typed `y_ring`, returning `InvalidProof` on shape mismatch. + /// + /// # Errors + /// + /// Returns [`HachiError::InvalidProof`] if the stored `y_ring` does not + /// encode exactly one ring element at dimension `D`. + pub fn try_y_ring_typed(&self) -> Result, HachiError> { + self.y_ring.try_to_single() + } + + /// Reconstruct typed `v`. + /// + /// # Panics + /// + /// Panics if `D` does not match the stored ring dimension. + pub fn v_typed(&self) -> Vec> { + self.v.to_vec() + } + + /// Reconstruct typed `v`, returning `InvalidProof` on shape mismatch. + /// + /// # Errors + /// + /// Returns [`HachiError::InvalidProof`] if the stored `v` payload is not + /// well-formed for ring dimension `D`. + pub fn try_v_typed(&self) -> Result>, HachiError> { + self.v.try_to_vec() + } + + /// Reconstruct typed `w_commitment`. + /// + /// # Panics + /// + /// Panics if `D` does not match the stored ring dimension. + pub fn w_commitment_typed(&self) -> RingCommitment { + self.stage2.next_w_commitment.to_ring_commitment() + } + + /// Reconstruct typed `w_commitment`, returning `InvalidProof` on shape mismatch. + /// + /// # Errors + /// + /// Returns [`HachiError::InvalidProof`] if the stored next-level commitment + /// is not well-formed for ring dimension `D`. + pub fn try_w_commitment_typed( + &self, + ) -> Result, HachiError> { + self.stage2.next_w_commitment.try_to_ring_commitment() + } +} + +// --------------------------------------------------------------------------- +// D-erased Labrador proof types for HachiProofTail +// --------------------------------------------------------------------------- + +/// D-erased Labrador level proof. +/// +/// Mirrors [`LabradorLevelProof`] with ring elements stored as [`FlatRingVec`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FlatLabradorLevelProof { + /// Whether this level uses tail semantics. + pub tail: bool, + /// Input row lengths per witness row. + pub input_row_lengths: Vec, + /// Configuration selected for this level. + pub config: LabradorReductionConfig, + /// Virtual row length after reshaping (formerly `nn`). + pub virtual_row_len: usize, + /// Per-original-row split counts from the fold plan (formerly `nu`). + pub row_split_counts: Vec, + /// Opening-side payload (formerly `u1`). + pub inner_opening_payload: FlatRingVec, + /// Linear-garbage-side payload (formerly `u2`). + pub linear_garbage_payload: FlatRingVec, + /// JL projection vector. + pub jl_projection: [i64; 256], + /// JL nonce used to regenerate projection matrix. + pub jl_nonce: u64, + /// JL lift residuals (formerly `bb`). + pub jl_lift_residuals: FlatRingVec, + /// Output witness norm bound after reduction (formerly `norm_sq`). + pub next_witness_norm_sq: u128, +} + +impl FlatLabradorLevelProof { + /// Convert from the typed `LabradorLevelProof`. + pub fn from_typed(p: &LabradorLevelProof) -> Self { + Self { + tail: p.tail, + input_row_lengths: p.input_row_lengths.clone(), + config: p.config, + virtual_row_len: p.virtual_row_len, + row_split_counts: p.row_split_counts.clone(), + inner_opening_payload: FlatRingVec::from_ring_elems(&p.inner_opening_payload), + linear_garbage_payload: FlatRingVec::from_ring_elems(&p.linear_garbage_payload), + jl_projection: p.jl_projection, + jl_nonce: p.jl_nonce, + jl_lift_residuals: FlatRingVec::from_ring_elems(&p.jl_lift_residuals), + next_witness_norm_sq: p.next_witness_norm_sq, + } + } + + /// Reconstruct the typed `LabradorLevelProof`. + /// + /// # Panics + /// + /// Panics if `D` does not match the stored ring dimension. + pub fn to_typed(&self) -> LabradorLevelProof { + LabradorLevelProof { + tail: self.tail, + input_row_lengths: self.input_row_lengths.clone(), + config: self.config, + virtual_row_len: self.virtual_row_len, + row_split_counts: self.row_split_counts.clone(), + inner_opening_payload: self.inner_opening_payload.to_vec(), + linear_garbage_payload: self.linear_garbage_payload.to_vec(), + jl_projection: self.jl_projection, + jl_nonce: self.jl_nonce, + jl_lift_residuals: self.jl_lift_residuals.to_vec(), + next_witness_norm_sq: self.next_witness_norm_sq, + } + } +} + +/// D-erased Labrador witness (rows of ring elements). +/// +/// Mirrors [`LabradorWitness`] with rows stored as [`FlatRingVec`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FlatLabradorWitness { + /// Per-row ring element vectors. + pub rows: Vec>, +} + +impl FlatLabradorWitness { + /// Convert from the typed `LabradorWitness`. + pub fn from_typed(w: &LabradorWitness) -> Self { + Self { + rows: w + .rows() + .iter() + .map(|r| FlatRingVec::from_ring_elems(r)) + .collect(), + } + } + + /// Reconstruct the typed `LabradorWitness`. + /// + /// # Panics + /// + /// Panics if `D` does not match the stored ring dimension. + pub fn to_typed(&self) -> LabradorWitness { + let rows: Vec>> = self.rows.iter().map(|r| r.to_vec()).collect(); + LabradorWitness::new_unchecked(rows) + } +} + +/// D-erased Labrador proof (levels + final witness). +/// +/// Mirrors [`LabradorProof`] with all ring data stored as [`FlatRingVec`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FlatLabradorProof { + /// Recursive level payloads. + pub levels: Vec>, + /// Final clear witness opened at recursion termination. + pub final_opening_witness: FlatLabradorWitness, +} + +impl FlatLabradorProof { + /// Convert from the typed `LabradorProof`. + pub fn from_typed(p: &LabradorProof) -> Self { + Self { + levels: p + .levels + .iter() + .map(FlatLabradorLevelProof::from_typed) + .collect(), + final_opening_witness: FlatLabradorWitness::from_typed(&p.final_opening_witness), + } + } + + /// Reconstruct the typed `LabradorProof`. + /// + /// # Panics + /// + /// Panics if `D` does not match the stored ring dimension. + pub fn to_typed(&self) -> LabradorProof { + LabradorProof { + levels: self.levels.iter().map(|l| l.to_typed()).collect(), + final_opening_witness: self.final_opening_witness.to_typed(), + } + } +} + +/// Labrador tail proof data. +/// +/// Produced when Hachi's folding loop stops and the ring-level `Mz = y` +/// relation from the quadratic equation is handed directly to Labrador +/// without computing quotient `r`, evaluating at alpha, or running sumcheck. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LabradorTail { + /// D-erased full Labrador recursive proof. + pub labrador_proof: FlatLabradorProof, + /// Ring-valued prover message `v = D * w_hat` (public, used to rebuild constraints). + pub v: FlatRingVec, + /// Ring-valued evaluation `y_ring` (public, used to rebuild constraints). + pub y_ring: FlatRingVec, + /// Squared L2 norm bound of the Labrador witness (formerly `beta_sq`). + pub witness_norm_bound_sq: u128, +} + +/// Proof tail: either a direct witness or a Labrador handoff. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum HachiProofTail { + /// Final witness sent in clear as packed balanced digits. + Direct(PackedDigits), + /// Direct Labrador handoff from the quadratic equation. + Labrador(Box>), +} + +/// Hachi PCS proof with multi-level folding. +/// +/// Each level runs the full protocol (quadratic equation, ring switch, +/// sumcheck) on the previous level's witness `w`. The tail is either +/// a direct witness (packed digits) or a Labrador handoff. +/// +/// D-agnostic: per-level ring dimensions are recorded in each +/// [`HachiLevelProof`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HachiProof { + /// Per-level proofs, from the original polynomial (level 0) through + /// recursive w-openings. + pub levels: Vec>, + /// Proof tail: direct witness or Labrador handoff. + pub tail: HachiProofTail, +} + +impl HachiProof { + /// Access the direct final witness when the proof ends with a clear tail. + pub fn final_w(&self) -> Option<&PackedDigits> { + match &self.tail { + HachiProofTail::Direct(pw) => Some(pw), + HachiProofTail::Labrador(_) => None, + } + } + + /// Whether this proof uses a Labrador tail (not a direct witness). + pub fn has_handoff_tail(&self) -> bool { + matches!(&self.tail, HachiProofTail::Labrador(_)) + } + + /// Whether this proof uses the direct Labrador tail. + pub fn has_labrador_tail(&self) -> bool { + matches!(&self.tail, HachiProofTail::Labrador(_)) + } +} + +impl HachiProof { + /// Returns the proof size in bytes (uncompressed). + pub fn size(&self) -> usize { + self.serialized_size(Compress::No) + } +} + +impl HachiSerialize for LabradorReductionConfig { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.witness_digit_parts + .serialize_with_mode(&mut writer, compress)?; + self.witness_digit_bits + .serialize_with_mode(&mut writer, compress)?; + self.aux_digit_parts + .serialize_with_mode(&mut writer, compress)?; + self.aux_digit_bits + .serialize_with_mode(&mut writer, compress)?; + self.inner_commit_rank + .serialize_with_mode(&mut writer, compress)?; + self.outer_commit_rank + .serialize_with_mode(&mut writer, compress)?; + (self.tail as u8).serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.witness_digit_parts.serialized_size(compress) + + self.witness_digit_bits.serialized_size(compress) + + self.aux_digit_parts.serialized_size(compress) + + self.aux_digit_bits.serialized_size(compress) + + self.inner_commit_rank.serialized_size(compress) + + self.outer_commit_rank.serialized_size(compress) + + 1 + } +} + +impl Valid for LabradorReductionConfig { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl HachiDeserialize for LabradorReductionConfig { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let witness_digit_parts = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let witness_digit_bits = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let aux_digit_parts = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let aux_digit_bits = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let inner_commit_rank = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let outer_commit_rank = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let tail = u8::deserialize_with_mode(&mut reader, compress, validate)?; + if tail > 1 { + return Err(SerializationError::InvalidData( + "invalid LabradorReductionConfig tail flag".to_string(), + )); + } + Ok(Self { + witness_digit_parts, + witness_digit_bits, + aux_digit_parts, + aux_digit_bits, + inner_commit_rank, + outer_commit_rank, + tail: tail != 0, + }) + } +} + +impl HachiSerialize for FlatLabradorLevelProof { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (self.tail as u8).serialize_with_mode(&mut writer, compress)?; + self.input_row_lengths + .serialize_with_mode(&mut writer, compress)?; + self.config.serialize_with_mode(&mut writer, compress)?; + self.virtual_row_len + .serialize_with_mode(&mut writer, compress)?; + self.row_split_counts + .serialize_with_mode(&mut writer, compress)?; + self.inner_opening_payload + .serialize_with_mode(&mut writer, compress)?; + self.linear_garbage_payload + .serialize_with_mode(&mut writer, compress)?; + for coeff in &self.jl_projection { + coeff.serialize_with_mode(&mut writer, compress)?; + } + self.jl_nonce.serialize_with_mode(&mut writer, compress)?; + self.jl_lift_residuals + .serialize_with_mode(&mut writer, compress)?; + self.next_witness_norm_sq + .serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + 1 + self.input_row_lengths.serialized_size(compress) + + self.config.serialized_size(compress) + + self.virtual_row_len.serialized_size(compress) + + self.row_split_counts.serialized_size(compress) + + self.inner_opening_payload.serialized_size(compress) + + self.linear_garbage_payload.serialized_size(compress) + + self.jl_projection.len() * std::mem::size_of::() + + self.jl_nonce.serialized_size(compress) + + self.jl_lift_residuals.serialized_size(compress) + + self.next_witness_norm_sq.serialized_size(compress) + } +} + +impl Valid for FlatLabradorLevelProof { + fn check(&self) -> Result<(), SerializationError> { + if self.tail != self.config.tail { + return Err(SerializationError::InvalidData( + "FlatLabradorLevelProof tail/config mismatch".to_string(), + )); + } + if self.tail && self.config.outer_commit_rank != 0 { + return Err(SerializationError::InvalidData( + "FlatLabradorLevelProof tail level must have outer_commit_rank = 0".to_string(), + )); + } + if !self.tail && self.config.outer_commit_rank == 0 { + return Err(SerializationError::InvalidData( + "FlatLabradorLevelProof non-tail level must have outer_commit_rank > 0".to_string(), + )); + } + self.config.check()?; + self.inner_opening_payload.check()?; + self.linear_garbage_payload.check()?; + if self.inner_opening_payload.ring_dim() != self.linear_garbage_payload.ring_dim() + || self.inner_opening_payload.ring_dim() != self.jl_lift_residuals.ring_dim() + { + return Err(SerializationError::InvalidData( + "FlatLabradorLevelProof ring-dimension mismatch".to_string(), + )); + } + self.jl_lift_residuals.check() + } +} + +impl HachiDeserialize for FlatLabradorLevelProof { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let tail = u8::deserialize_with_mode(&mut reader, compress, validate)?; + if tail > 1 { + return Err(SerializationError::InvalidData( + "invalid FlatLabradorLevelProof tail flag".to_string(), + )); + } + + let mut jl_projection = [0i64; 256]; + let input_row_lengths = + Vec::::deserialize_with_mode(&mut reader, compress, validate)?; + let config = + LabradorReductionConfig::deserialize_with_mode(&mut reader, compress, validate)?; + let virtual_row_len = usize::deserialize_with_mode(&mut reader, compress, validate)?; + let row_split_counts = + Vec::::deserialize_with_mode(&mut reader, compress, validate)?; + let inner_opening_payload = + FlatRingVec::deserialize_with_mode(&mut reader, compress, validate)?; + let linear_garbage_payload = + FlatRingVec::deserialize_with_mode(&mut reader, compress, validate)?; + for coeff in &mut jl_projection { + *coeff = i64::deserialize_with_mode(&mut reader, compress, validate)?; + } + let jl_nonce = u64::deserialize_with_mode(&mut reader, compress, validate)?; + let jl_lift_residuals = + FlatRingVec::deserialize_with_mode(&mut reader, compress, validate)?; + let next_witness_norm_sq = u128::deserialize_with_mode(&mut reader, compress, validate)?; + + let out = Self { + tail: tail != 0, + input_row_lengths, + config, + virtual_row_len, + row_split_counts, + inner_opening_payload, + linear_garbage_payload, + jl_projection, + jl_nonce, + jl_lift_residuals, + next_witness_norm_sq, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl HachiSerialize for FlatLabradorWitness { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (self.rows.len() as u32).serialize_with_mode(&mut writer, compress)?; + for row in &self.rows { + row.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + 4 + self + .rows + .iter() + .map(|row| row.serialized_size(compress)) + .sum::() + } +} + +impl Valid for FlatLabradorWitness { + fn check(&self) -> Result<(), SerializationError> { + let expected_ring_dim = self.rows.first().map(FlatRingVec::ring_dim); + for row in &self.rows { + row.check()?; + if expected_ring_dim.is_some_and(|d| row.ring_dim() != d) { + return Err(SerializationError::InvalidData( + "FlatLabradorWitness ring-dimension mismatch".to_string(), + )); + } + } + Ok(()) + } +} + +impl HachiDeserialize for FlatLabradorWitness { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let num_rows = u32::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let mut rows = Vec::with_capacity(num_rows); + for _ in 0..num_rows { + rows.push(FlatRingVec::deserialize_with_mode( + &mut reader, + compress, + validate, + )?); + } + let out = Self { rows }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl HachiSerialize for FlatLabradorProof { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (self.levels.len() as u32).serialize_with_mode(&mut writer, compress)?; + for level in &self.levels { + level.serialize_with_mode(&mut writer, compress)?; + } + self.final_opening_witness + .serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + 4 + self + .levels + .iter() + .map(|level| level.serialized_size(compress)) + .sum::() + + self.final_opening_witness.serialized_size(compress) + } +} + +impl Valid for FlatLabradorProof { + fn check(&self) -> Result<(), SerializationError> { + let mut expected_ring_dim = self + .final_opening_witness + .rows + .first() + .map(FlatRingVec::ring_dim); + for level in &self.levels { + level.check()?; + if let Some(d) = expected_ring_dim { + if level.inner_opening_payload.ring_dim() != d { + return Err(SerializationError::InvalidData( + "FlatLabradorProof ring-dimension mismatch".to_string(), + )); + } + } else { + expected_ring_dim = Some(level.inner_opening_payload.ring_dim()); + } + } + self.final_opening_witness.check() + } +} + +impl HachiDeserialize for FlatLabradorProof { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let num_levels = u32::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let mut levels = Vec::with_capacity(num_levels); + for _ in 0..num_levels { + levels.push(FlatLabradorLevelProof::deserialize_with_mode( + &mut reader, + compress, + validate, + )?); + } + let final_opening_witness = + FlatLabradorWitness::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { + levels, + final_opening_witness, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl HachiSerialize for LabradorTail { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.labrador_proof + .serialize_with_mode(&mut writer, compress)?; + self.v.serialize_with_mode(&mut writer, compress)?; + self.y_ring.serialize_with_mode(&mut writer, compress)?; + self.witness_norm_bound_sq + .serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.labrador_proof.serialized_size(compress) + + self.v.serialized_size(compress) + + self.y_ring.serialized_size(compress) + + self.witness_norm_bound_sq.serialized_size(compress) + } +} + +impl Valid for LabradorTail { + fn check(&self) -> Result<(), SerializationError> { + self.labrador_proof.check()?; + self.v.check()?; + self.y_ring.check()?; + if self.v.ring_dim() != self.y_ring.ring_dim() { + return Err(SerializationError::InvalidData( + "LabradorTail ring-dimension mismatch".to_string(), + )); + } + if self + .labrador_proof + .final_opening_witness + .rows + .first() + .is_some_and(|row| row.ring_dim() != self.v.ring_dim()) + { + return Err(SerializationError::InvalidData( + "LabradorTail witness ring-dimension mismatch".to_string(), + )); + } + for level in &self.labrador_proof.levels { + if level.inner_opening_payload.ring_dim() != self.v.ring_dim() { + return Err(SerializationError::InvalidData( + "LabradorTail level ring-dimension mismatch".to_string(), + )); + } + } + Ok(()) + } +} + +impl HachiDeserialize for LabradorTail { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let out = Self { + labrador_proof: FlatLabradorProof::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + v: FlatRingVec::deserialize_with_mode(&mut reader, compress, validate)?, + y_ring: FlatRingVec::deserialize_with_mode(&mut reader, compress, validate)?, + witness_norm_bound_sq: u128::deserialize_with_mode(&mut reader, compress, validate)?, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl HachiSerialize for HachiCommitmentHint { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (self.inner_opening_digits.len() as u64).serialize_with_mode(&mut writer, compress)?; + for block in &self.inner_opening_digits { + (block.len() as u64).serialize_with_mode(&mut writer, compress)?; + for plane in block { + let bytes: &[u8] = + unsafe { std::slice::from_raw_parts(plane.as_ptr().cast::(), D) }; + writer.write_all(bytes)?; + } + } + Ok(()) + } + fn serialized_size(&self, _compress: Compress) -> usize { + 8 + self + .inner_opening_digits + .iter() + .map(|block| 8 + block.len() * D) + .sum::() + } +} + +impl Valid for HachiCommitmentHint { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl HachiDeserialize for HachiCommitmentHint { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let num_blocks = u64::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let mut inner_opening_digits = Vec::with_capacity(num_blocks); + for _ in 0..num_blocks { + let block_len = u64::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let mut block = Vec::with_capacity(block_len); + for _ in 0..block_len { + let mut plane = [0i8; D]; + let bytes: &mut [u8] = + unsafe { std::slice::from_raw_parts_mut(plane.as_mut_ptr().cast::(), D) }; + reader.read_exact(bytes)?; + block.push(plane); + } + inner_opening_digits.push(block); + } + Ok(Self::new(inner_opening_digits)) + } +} + +impl HachiSerialize for HachiLevelProof { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.y_ring.serialize_with_mode(&mut writer, compress)?; + self.v.serialize_with_mode(&mut writer, compress)?; + self.stage1 + .sumcheck + .serialize_with_mode(&mut writer, compress)?; + self.stage1 + .s_claim + .serialize_with_mode(&mut writer, compress)?; + self.stage2 + .sumcheck + .serialize_with_mode(&mut writer, compress)?; + self.stage2 + .next_w_commitment + .serialize_with_mode(&mut writer, compress)?; + self.stage2 + .next_w_eval + .serialize_with_mode(&mut writer, compress) + } + fn serialized_size(&self, compress: Compress) -> usize { + self.y_ring.serialized_size(compress) + + self.v.serialized_size(compress) + + self.stage1.sumcheck.serialized_size(compress) + + self.stage1.s_claim.serialized_size(compress) + + self.stage2.sumcheck.serialized_size(compress) + + self.stage2.next_w_commitment.serialized_size(compress) + + self.stage2.next_w_eval.serialized_size(compress) + } +} + +impl Valid for HachiLevelProof { + fn check(&self) -> Result<(), SerializationError> { + self.y_ring.check()?; + if self.y_ring.count() != 1 { + return Err(SerializationError::InvalidData( + "hachi level y_ring must contain exactly one ring element".to_string(), + )); + } + self.v.check()?; + if self.v.ring_dim() != self.y_ring.ring_dim() { + return Err(SerializationError::InvalidData( + "hachi level v ring dimension must match y_ring".to_string(), + )); + } + self.stage1.sumcheck.check()?; + self.stage1.s_claim.check()?; + self.stage2.sumcheck.check()?; + self.stage2.next_w_commitment.check()?; + self.stage2.next_w_eval.check() + } +} + +impl HachiDeserialize for HachiLevelProof { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let y_ring = FlatRingVec::deserialize_with_mode(&mut reader, compress, validate)?; + let v = FlatRingVec::deserialize_with_mode(&mut reader, compress, validate)?; + let stage1_sumcheck = + SumcheckProof::deserialize_with_mode(&mut reader, compress, validate)?; + let stage1_s_claim = F::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { + y_ring, + v, + stage1: HachiStage1Proof { + sumcheck: stage1_sumcheck, + s_claim: stage1_s_claim, + }, + stage2: HachiStage2Proof { + sumcheck: SumcheckProof::deserialize_with_mode(&mut reader, compress, validate)?, + next_w_commitment: FlatRingVec::deserialize_with_mode( + &mut reader, + compress, + validate, + )?, + next_w_eval: F::deserialize_with_mode(&mut reader, compress, validate)?, + }, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl HachiSerialize for HachiProof { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (self.levels.len() as u32).serialize_with_mode(&mut writer, compress)?; + for level in &self.levels { + level.serialize_with_mode(&mut writer, compress)?; + } + match &self.tail { + HachiProofTail::Direct(pw) => { + 0u8.serialize_with_mode(&mut writer, compress)?; + pw.serialize_with_mode(&mut writer, compress) + } + HachiProofTail::Labrador(tail) => { + 1u8.serialize_with_mode(&mut writer, compress)?; + tail.serialize_with_mode(&mut writer, compress) + } + } + } + fn serialized_size(&self, compress: Compress) -> usize { + let base = 4 + + self + .levels + .iter() + .map(|l| l.serialized_size(compress)) + .sum::() + + 1; // tag byte + match &self.tail { + HachiProofTail::Direct(pw) => base + pw.serialized_size(compress), + HachiProofTail::Labrador(tail) => base + tail.serialized_size(compress), + } + } +} + +impl Valid for HachiProof { + fn check(&self) -> Result<(), SerializationError> { + for lp in &self.levels { + lp.check()?; + } + for levels in self.levels.windows(2) { + if levels[0].w_commit_d() != levels[1].level_d() { + return Err(SerializationError::InvalidData( + "adjacent hachi levels have mismatched commitment dimensions".to_string(), + )); + } + } + match &self.tail { + HachiProofTail::Direct(pw) => pw.check(), + HachiProofTail::Labrador(tail) => tail.check(), + } + } +} + +impl HachiDeserialize for HachiProof { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let num_levels = u32::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let mut levels = Vec::with_capacity(num_levels); + for _ in 0..num_levels { + levels.push(HachiLevelProof::deserialize_with_mode( + &mut reader, + compress, + validate, + )?); + } + let tag = u8::deserialize_with_mode(&mut reader, compress, validate)?; + let tail = match tag { + 0 => { + let pw = PackedDigits::deserialize_with_mode(&mut reader, compress, validate)?; + HachiProofTail::Direct(pw) + } + 1 => HachiProofTail::Labrador(Box::new(LabradorTail::deserialize_with_mode( + &mut reader, + compress, + validate, + )?)), + _ => { + return Err(SerializationError::InvalidData(format!( + "unknown proof tail tag: {tag}" + ))); + } + }; + let out = Self { levels, tail }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} diff --git a/src/protocol/quadratic_equation.rs b/src/protocol/quadratic_equation.rs new file mode 100644 index 00000000..7fa892a7 --- /dev/null +++ b/src/protocol/quadratic_equation.rs @@ -0,0 +1,931 @@ +//! Quadratic equation builder for the Hachi PCS (§4.2). +//! +//! This module encapsulates the stage-1 prover logic and the generation of +//! the quadratic equation components M, y, z, and v. + +use crate::algebra::{CyclotomicRing, SparseChallenge}; +#[cfg(any(test, debug_assertions))] +use crate::cfg_into_iter; +use crate::error::HachiError; +#[cfg(all(feature = "parallel", any(test, debug_assertions)))] +use crate::parallel::*; +use crate::protocol::challenges::sparse::sample_sparse_challenges; +use crate::protocol::commitment::utils::crt_ntt::NttSlotCache; +use crate::protocol::commitment::utils::linear::{ + flatten_i8_blocks, mat_vec_mul_ntt_single_i8, mat_vec_mul_ntt_single_i8_cyclic, + unreduced_quotient_rows_ntt_cached_centered_i32, +}; +use crate::protocol::commitment::{ + CommitmentConfig, HachiCommitmentLayout, HachiExpandedSetup, HachiLevelParams, RingCommitment, +}; +use crate::protocol::hachi_poly_ops::{DecomposeFoldWitness, HachiPolyOps}; +use crate::protocol::opening_point::RingOpeningPoint; +use crate::protocol::proof::HachiCommitmentHint; +#[cfg(any(test, debug_assertions))] +use crate::protocol::ring_switch::eval_ring_at; +use crate::protocol::transcript::labels::{ABSORB_PROVER_V, CHALLENGE_STAGE1_FOLD}; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore}; +use std::iter::repeat_n; +use std::marker::PhantomData; +use std::time::Instant; + +/// **Step 4.** Compute `v = D · ŵ` (first prover message). +fn compute_v( + ntt_d: &NttSlotCache, + w_hat_flat: &[[i8; D]], +) -> Vec> { + mat_vec_mul_ntt_single_i8(ntt_d, w_hat_flat) +} + +fn flatten_w_hat(w_hat: &[Vec<[i8; D]>]) -> Vec<[i8; D]> { + w_hat.iter().flat_map(|v| v.iter().copied()).collect() +} + +fn compute_z_pre( + poly: &P, + challenges: &[SparseChallenge], + level_params: HachiLevelParams, + layout: HachiCommitmentLayout, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + P: HachiPolyOps, +{ + let z = poly.decompose_fold( + challenges, + layout.block_len, + layout.num_digits_commit, + layout.log_basis, + ); + + let norm = u128::from(z.centered_inf_norm); + let beta = crate::protocol::commitment::beta_linf_fold_bound( + layout.r_vars, + level_params.challenge_weight, + layout.log_basis, + )?; + if norm > beta { + return Err(HachiError::InvalidInput(format!( + "prover abort: ||z||_inf = {norm} > beta = {beta}" + ))); + } + + Ok(z) +} + +/// Stage-1 quadratic equation state for the Hachi protocol. +/// +/// Encapsulates the relation $M(x) \cdot z = y(x) + (X^D + 1) \cdot r(x)$ +/// along with intermediate prover witness data (`w_hat`, `z_pre`, `hint`). +/// +/// M and z are never materialized on the hot path — split-eq factoring computes +/// their products on-the-fly via `compute_r_split_eq`, while debug/test code +/// can reconstruct reference `M_a` rows when needed. +pub struct QuadraticEquation { + /// Stage-1 proof vector `v = D · ŵ`. + pub v: Vec>, + /// Stage-1 folding challenges (sparse representation). + pub challenges: Vec, + /// Vector `y`. + y: Vec>, + /// Opening point (a, b) Lagrange weights. + opening_point: RingOpeningPoint, + /// Pre-decomposition folded witness `z_pre = Σ c_i · s_i` (prover only). + /// Replaces both `z_hat` and `z`: `z_hat = J^{-1}(z_pre)`. + z_pre: Option>, + /// Decomposed `ŵ_i = G_1^{-1}(w_i)` as i8 digit planes (prover only). + w_hat: Option>>, + /// Flattened `w_hat` as i8 digit planes (prover only, computed once and reused). + w_hat_flat: Option>, + /// Pre-decomposition folded ring elements (prover only, avoids recompose roundtrip). + w_folded: Option>>, + /// Commitment hint (prover only). + hint: Option>, + + _marker: PhantomData, +} + +impl QuadraticEquation +where + F: FieldCore + CanonicalField, + Cfg: CommitmentConfig, +{ + /// Prover constructor: runs §4.2 stage 1 and builds all equation components. + /// + /// `poly` provides the ring-level polynomial data for fold/decompose ops. + /// `hint` carries `t_hat` from the commitment phase. + /// + /// # Errors + /// + /// Returns an error if the norm check, challenge sampling, or matrix + /// generation fails. + #[allow(clippy::too_many_arguments)] + #[tracing::instrument(skip_all, name = "QuadraticEquation::new_prover")] + #[inline(never)] + pub fn new_prover, P: HachiPolyOps>( + ntt_d: &NttSlotCache, + ring_opening_point: RingOpeningPoint, + poly: &P, + pre_folded: Vec>, + level_params: HachiLevelParams, + mut hint: HachiCommitmentHint, + transcript: &mut T, + commitment: &RingCommitment, + y_ring: &CyclotomicRing, + layout: HachiCommitmentLayout, + ) -> Result { + { + let x: u8 = 0; + tracing::trace!( + stack_ptr = format_args!("{:#x}", &x as *const u8 as usize), + "QuadraticEquation::new_prover" + ); + } + let (w_hat, w_hat_flat) = { + let _span = tracing::info_span!("decompose_w_hat").entered(); + let depth_open = layout.num_digits_open; + let log_basis = layout.log_basis; + let w_hat: Vec> = pre_folded + .iter() + .map(|w_i| w_i.balanced_decompose_pow2_i8(depth_open, log_basis)) + .collect(); + let w_hat_flat = flatten_w_hat(&w_hat); + (w_hat, w_hat_flat) + }; + hint.ensure_t_recomposed(layout.num_digits_open, layout.log_basis)?; + + let v = { + let _span = + tracing::info_span!("compute_v", w_hat_flat_len = w_hat_flat.len()).entered(); + let mut v = compute_v(ntt_d, &w_hat_flat); + v.truncate(level_params.n_d); + v + }; + + transcript.append_serde(ABSORB_PROVER_V, &v); + + let challenge_cfg = Cfg::stage1_challenge_config(level_params); + let challenges = sample_sparse_challenges::( + transcript, + CHALLENGE_STAGE1_FOLD, + layout.num_blocks, + &challenge_cfg, + )?; + + let z_pre = { + let _span = tracing::info_span!("compute_z_pre").entered(); + compute_z_pre::(poly, &challenges, level_params, layout)? + }; + + let y = generate_y::( + &v, + &commitment.u, + y_ring, + level_params.n_d, + level_params.n_b, + level_params.n_a, + )?; + + Ok(Self { + v, + challenges, + y, + opening_point: ring_opening_point, + z_pre: Some(z_pre), + w_hat: Some(w_hat), + w_hat_flat: Some(w_hat_flat), + w_folded: Some(pre_folded), + hint: Some(hint), + _marker: PhantomData, + }) + } + + /// Verifier constructor: Derives challenges and computes M and y. + /// + /// # Errors + /// + /// Returns an error if challenge derivation fails. + #[tracing::instrument(skip_all, name = "QuadraticEquation::new_verifier")] + #[inline(never)] + pub fn new_verifier>( + ring_opening_point: RingOpeningPoint, + v: Vec>, + level_params: HachiLevelParams, + transcript: &mut T, + commitment: &RingCommitment, + y_ring: &CyclotomicRing, + layout: HachiCommitmentLayout, + ) -> Result { + let challenges = derive_stage1_challenges::( + transcript, + &v, + layout.num_blocks, + level_params, + )?; + let y = generate_y::( + &v, + &commitment.u, + y_ring, + level_params.n_d, + level_params.n_b, + level_params.n_a, + )?; + + Ok(Self { + v, + challenges, + y, + opening_point: ring_opening_point, + z_pre: None, + w_hat: None, + w_hat_flat: None, + w_folded: None, + hint: None, + _marker: PhantomData, + }) + } + + /// Get the vector y. + pub fn y(&self) -> &[CyclotomicRing] { + &self.y + } + + /// Get the vector v. + pub fn v(&self) -> &[CyclotomicRing] { + &self.v + } + + /// Get the opening point (a, b) Lagrange weights. + pub fn opening_point(&self) -> &RingOpeningPoint { + &self.opening_point + } + + /// Get the pre-decomposition folded witness `z_pre` (prover only). + pub fn z_pre(&self) -> Option<&[CyclotomicRing]> { + self.z_pre.as_ref().map(|witness| witness.z_pre.as_slice()) + } + + /// Get centered coefficients for each `z_pre` row (prover only). + pub fn z_pre_centered(&self) -> Option<&[[i32; D]]> { + self.z_pre + .as_ref() + .map(|witness| witness.centered_coeffs.as_slice()) + } + + /// Get `||z_pre||_inf` from the centered witness representation. + pub fn z_pre_centered_inf_norm(&self) -> Option { + self.z_pre.as_ref().map(|witness| witness.centered_inf_norm) + } + + /// Take ownership of the `z_pre` witness, leaving `None` in its place. + pub fn take_z_pre(&mut self) -> Option> { + self.z_pre.take() + } + + /// Get the decomposed witness `ŵ` as i8 digit planes (prover only). + pub fn w_hat(&self) -> Option<&[Vec<[i8; D]>]> { + self.w_hat.as_deref() + } + + /// Get the pre-flattened `w_hat` as i8 digit planes (prover only). + pub fn w_hat_flat(&self) -> Option<&[[i8; D]]> { + self.w_hat_flat.as_deref() + } + + /// Take ownership of `w_hat`, leaving `None` in its place. + pub fn take_w_hat(&mut self) -> Option>> { + self.w_hat.take() + } + + /// Get the pre-decomposition folded ring elements (prover only). + pub fn w_folded(&self) -> Option<&[CyclotomicRing]> { + self.w_folded.as_deref() + } + + /// Get the commitment hint (prover only). + pub fn hint(&self) -> Option<&HachiCommitmentHint> { + self.hint.as_ref() + } + + /// Take ownership of the hint, leaving `None` in its place. + pub fn take_hint(&mut self) -> Option> { + self.hint.take() + } +} + +pub(crate) fn derive_stage1_challenges( + transcript: &mut T, + v: &Vec>, + num_blocks: usize, + level_params: HachiLevelParams, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + let challenge_cfg = Cfg::stage1_challenge_config(level_params); + transcript.append_serde(ABSORB_PROVER_V, v); + sample_sparse_challenges::( + transcript, + CHALLENGE_STAGE1_FOLD, + num_blocks, + &challenge_cfg, + ) +} + +#[cfg(any(test, debug_assertions))] +fn gadget_row_scalars(levels: usize, log_basis: u32) -> Vec { + let base = F::from_canonical_u128_reduced(1u128 << log_basis); + let mut out = Vec::with_capacity(levels); + let mut power = F::one(); + for _ in 0..levels { + out.push(power); + power = power * base; + } + out +} + +/// Add only the high-half quotient contribution of `challenge * ring`. +fn add_sparse_ring_product_high_half( + quotient: &mut [F], + challenge: &SparseChallenge, + ring: &CyclotomicRing, +) { + let rc = ring.coefficients(); + for (&pos, &coeff) in challenge.positions.iter().zip(challenge.coeffs.iter()) { + let c = F::from_i64(coeff as i64); + let p = pos as usize; + let start = D.saturating_sub(p); + for (s, &r_s) in rc.iter().enumerate().skip(start) { + quotient[p + s - D] += c * r_s; + } + } +} + +fn quotient_from_cyclic_and_reduced( + cyclic: &CyclotomicRing, + reduced: &CyclotomicRing, +) -> CyclotomicRing { + let cyc_c = cyclic.coefficients(); + let red_c = reduced.coefficients(); + let quotient = std::array::from_fn(|k| (cyc_c[k] - red_c[k]) * F::TWO_INV); + CyclotomicRing::from_coefficients(quotient) +} + +/// Split-eq replacement for `generate_m` + `compute_r_via_poly_division`. +/// +/// Computes `r` such that `M·z = y + (X^D+1)·r` without materializing M or z. +/// Uses split-eq factoring: `kron(left, gadget) · decomposed = left · pre_decomp`. +#[allow(clippy::too_many_arguments, clippy::needless_borrow)] +#[tracing::instrument(skip_all, name = "compute_r_split_eq")] +#[allow(clippy::too_many_arguments)] +pub(crate) fn compute_r_split_eq( + level_params: HachiLevelParams, + _setup: &HachiExpandedSetup, + challenges: &[SparseChallenge], + w_hat_flat: &[[i8; D]], + t_hat: &[Vec<[i8; D]>], + t: &[Vec>], + w_folded: &[CyclotomicRing], + z_pre_centered: &[[i32; D]], + z_pre_centered_inf_norm: u32, + y: &[CyclotomicRing], + ntt_a: &NttSlotCache, + ntt_b: &NttSlotCache, + ntt_d: &NttSlotCache, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField, +{ + let num_rows = level_params.m_row_count(); + + let t_hat_flat = flatten_i8_blocks(t_hat); + + // D/B rows already know their reduced outputs `y`, so only the cyclic side + // must be computed here; quotient = (cyc - reduced) / 2. + let t_d = Instant::now(); + let d_cyclic = { + let _span = tracing::info_span!("D_rows_ntt").entered(); + mat_vec_mul_ntt_single_i8_cyclic(ntt_d, w_hat_flat) + }; + let d_time = t_d.elapsed().as_secs_f64(); + + let t_b = Instant::now(); + let b_cyclic = { + let _span = tracing::info_span!("B_rows_ntt").entered(); + mat_vec_mul_ntt_single_i8_cyclic(ntt_b, &t_hat_flat) + }; + let b_time = t_b.elapsed().as_secs_f64(); + + let t_a = Instant::now(); + let a_quotients = { + let _span = tracing::info_span!("A_rows_ntt").entered(); + unreduced_quotient_rows_ntt_cached_centered_i32( + ntt_a, + z_pre_centered, + z_pre_centered_inf_norm, + ) + }; + let a_time = t_a.elapsed().as_secs_f64(); + + let mut result = Vec::with_capacity(num_rows); + let mut other_time = 0.0f64; + let mut quotient_buf = vec![F::zero(); D]; + + for row_idx in 0..num_rows { + if row_idx < level_params.n_d { + result.push(quotient_from_cyclic_and_reduced( + &d_cyclic[row_idx], + &y[row_idx], + )); + } else if row_idx < level_params.n_d + level_params.n_b { + result.push(quotient_from_cyclic_and_reduced( + &b_cyclic[row_idx - level_params.n_d], + &y[row_idx], + )); + } else if row_idx >= level_params.n_d + level_params.n_b + 2 { + // A-rows: NTT-accelerated A*z_pre + sparse challenge terms + let t_row = Instant::now(); + let _span = tracing::info_span!("A_row").entered(); + let a_idx = row_idx - (level_params.n_d + level_params.n_b + 2); + + quotient_buf.fill(F::zero()); + for (i, t_rows_i) in t.iter().enumerate() { + if let Some(t_row_i) = t_rows_i.get(a_idx) { + add_sparse_ring_product_high_half(&mut quotient_buf, &challenges[i], t_row_i); + } + } + + let a_q = a_quotients[a_idx].coefficients(); + for k in 0..D { + quotient_buf[k] -= a_q[k]; + } + result.push(CyclotomicRing::from_slice("ient_buf)); + other_time += t_row.elapsed().as_secs_f64(); + } else { + let t_row = Instant::now(); + + if row_idx == level_params.n_d + level_params.n_b { + let _span = tracing::info_span!("bTw_row").entered(); + // `b^T · G · ŵ - y_ring` is degree < D, so its quotient is zero. + result.push(CyclotomicRing::::zero()); + } else { + let _span = tracing::info_span!("challenge_fold_row").entered(); + quotient_buf.fill(F::zero()); + for (i, w_f) in w_folded.iter().enumerate() { + add_sparse_ring_product_high_half(&mut quotient_buf, &challenges[i], w_f); + } + // `a^T · G · J · z_hat` contributes only low-degree terms, so it + // cannot affect the high-half quotient we need here. + result.push(CyclotomicRing::from_slice("ient_buf)); + } + other_time += t_row.elapsed().as_secs_f64(); + } + } + + tracing::debug!( + d_ntt_s = d_time, + b_ntt_s = b_time, + a_ntt_s = a_time, + other_s = other_time, + "compute_r breakdown" + ); + + Ok(result) +} + +/// Reference helper for tests/debug diagnostics: split-eq replacement for +/// `generate_m` + `eval_ring_matrix_at`. +/// +/// Computes the field-element evaluations of each M entry at `alpha`, +/// organized as rows of field elements, without materializing ring-valued `M`. +#[cfg(any(test, debug_assertions))] +#[tracing::instrument(skip_all, name = "compute_m_a_reference")] +pub(crate) fn compute_m_a_reference( + setup: &HachiExpandedSetup, + opening_point: &RingOpeningPoint, + challenges: &[SparseChallenge], + alpha: &F, + level_params: HachiLevelParams, + layout: HachiCommitmentLayout, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField, +{ + let depth_commit = layout.num_digits_commit; + let depth_open = layout.num_digits_open; + let depth_fold = layout.num_digits_fold; + let log_basis = layout.log_basis; + let num_blocks = opening_point.b.len(); + let block_len = layout.block_len; + let w_len = depth_open * num_blocks; + let t_len = depth_open * level_params.n_a * num_blocks; + let z_len = depth_fold * depth_commit * block_len; + let total_cols = w_len + t_len + z_len; + + let g1_open = gadget_row_scalars::(depth_open, log_basis); + let g1_commit = gadget_row_scalars::(depth_commit, log_basis); + let j1 = gadget_row_scalars::(depth_fold, log_basis); + + let c_alphas: Vec = challenges + .iter() + .map(|c| eval_ring_at(&c.to_dense::().expect("valid challenge"), alpha)) + .collect(); + + let d_view = setup.D_mat.view::(); + let b_view = setup.B.view::(); + + let d_rows: Vec> = cfg_into_iter!(0..d_view.num_rows()) + .map(|i| { + let d_row = d_view.row(i); + let mut full = vec![F::zero(); total_cols]; + for (j, ring) in d_row.iter().take(w_len).enumerate() { + full[j] = eval_ring_at(ring, alpha); + } + full + }) + .collect(); + + let b_rows: Vec> = cfg_into_iter!(0..b_view.num_rows()) + .map(|i| { + let b_row = b_view.row(i); + let mut full = vec![F::zero(); total_cols]; + for (j, ring) in b_row.iter().take(t_len).enumerate() { + full[w_len + j] = eval_ring_at(ring, alpha); + } + full + }) + .collect(); + + let mut rows = Vec::with_capacity(level_params.m_row_count()); + rows.extend(d_rows.into_iter().take(level_params.n_d)); + rows.extend(b_rows.into_iter().take(level_params.n_b)); + + // Row 3: b^T · G · ŵ = y_ring (ŵ uses delta_open) + { + let mut full = vec![F::zero(); total_cols]; + for (i, &b_i) in opening_point.b.iter().enumerate() { + for (d, &g) in g1_open.iter().enumerate() { + full[i * depth_open + d] = b_i * g; + } + } + rows.push(full); + } + + // Row 4: (c^T ⊗ G) · ŵ = a^T · G · J · ẑ + { + let mut full = vec![F::zero(); total_cols]; + for (i, &c_alpha) in c_alphas.iter().enumerate() { + for (d, &g) in g1_open.iter().enumerate() { + full[i * depth_open + d] = c_alpha * g; + } + } + let z_offset = w_len + t_len; + for (i, &a_i) in opening_point.a.iter().enumerate() { + for (d, &g) in g1_commit.iter().enumerate() { + let ag = a_i * g; + for (t, &j) in j1.iter().enumerate() { + let idx = (i * depth_commit + d) * depth_fold + t; + full[z_offset + idx] = -(ag * j); + } + } + } + rows.push(full); + } + + // Row 5: (c^T ⊗ G_open) · t̂ = A · J · ẑ + // t̂ uses delta_open (t = A*s has full-field coefficients); ẑ uses delta_commit + for a_idx in 0..level_params.n_a { + let mut full = vec![F::zero(); total_cols]; + for (i, &c_alpha) in c_alphas.iter().enumerate() { + for (d, &g) in g1_open.iter().enumerate() { + let t_idx = i * (level_params.n_a * depth_open) + a_idx * depth_open + d; + full[w_len + t_idx] = c_alpha * g; + } + } + let z_offset = w_len + t_len; + let a_view = setup.A.view::(); + let a_row = a_view.row(a_idx); + let inner_width = block_len * depth_commit; + for (k, ring) in a_row.iter().take(inner_width).enumerate() { + let ring_alpha = eval_ring_at(ring, alpha); + for (t, &j) in j1.iter().enumerate() { + full[z_offset + k * depth_fold + t] = -(ring_alpha * j); + } + } + rows.push(full); + } + + Ok(rows) +} + +pub(crate) fn generate_y( + v: &[CyclotomicRing], + u: &[CyclotomicRing], + u_eval: &CyclotomicRing, + n_d: usize, + n_b: usize, + n_a: usize, +) -> Result>, HachiError> +where + F: FieldCore, +{ + if v.len() != n_d { + return Err(HachiError::InvalidSize { + expected: n_d, + actual: v.len(), + }); + } + if u.len() != n_b { + return Err(HachiError::InvalidSize { + expected: n_b, + actual: u.len(), + }); + } + let mut out = Vec::with_capacity(n_d + n_b + 1 + 1 + n_a); + out.extend_from_slice(v); + out.extend_from_slice(u); + out.push(*u_eval); + out.push(CyclotomicRing::::zero()); + out.extend(repeat_n(CyclotomicRing::::zero(), n_a)); + Ok(out) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::array::from_fn; + + use crate::algebra::{CyclotomicRing, SparseChallengeConfig}; + use crate::protocol::challenges::sparse::sample_sparse_challenges; + use crate::protocol::commitment::HachiProverSetup; + use crate::protocol::commitment::{ + HachiCommitmentCore, HachiScheduleInputs, RingCommitmentScheme, + }; + use crate::protocol::hachi_poly_ops::DensePoly; + use crate::protocol::proof::HachiCommitmentHint; + use crate::protocol::transcript::Blake2bTranscript; + use crate::test_utils::*; + use crate::FromSmallInt; + use crate::Transcript; + + const TRANSCRIPT_SEED: &[u8] = b"test/prover-relation"; + + fn replay_challenges(v: &Vec>) -> Vec> { + let mut transcript = Blake2bTranscript::::new(TRANSCRIPT_SEED); + transcript.append_serde(ABSORB_PROVER_V, v); + + let challenge_cfg = SparseChallengeConfig { + weight: TinyConfig::CHALLENGE_WEIGHT, + nonzero_coeffs: vec![-1, 1], + }; + let sparse = sample_sparse_challenges::, D>( + &mut transcript, + CHALLENGE_STAGE1_FOLD, + NUM_BLOCKS, + &challenge_cfg, + ) + .unwrap(); + sparse + .iter() + .map(|c| c.to_dense::().unwrap()) + .collect() + } + + struct Fixture { + setup: HachiProverSetup, + commitment_u: Vec>, + point: RingOpeningPoint, + blocks: Vec>>, + quad_eq: QuadraticEquation, + /// Challenges re-derived via transcript replay (cross-check). + challenges: Vec>, + } + + fn build_fixture() -> Fixture { + let (setup, _) = + >::setup(16).unwrap(); + + let blocks = sample_blocks(); + let w = + >::commit_ring_blocks( + &blocks, &setup, + ) + .unwrap(); + + let point = RingOpeningPoint { + a: sample_a(), + b: sample_b(), + }; + + let ring_coeffs: Vec> = + blocks.iter().flat_map(|b| b.iter().copied()).collect(); + let poly = DensePoly::from_ring_coeffs(ring_coeffs); + let hint = HachiCommitmentHint::new(w.t_hat); + let mut transcript = Blake2bTranscript::::new(TRANSCRIPT_SEED); + let y_ring = CyclotomicRing::::zero(); + let layout = setup.layout(); + let w_folded = poly.fold_blocks(&point.a, layout.block_len); + let level_params = TinyConfig::level_params(HachiScheduleInputs { + max_num_vars: setup.expanded.seed.max_num_vars, + level: 0, + current_w_len: layout.num_blocks * layout.block_len * D, + }); + let quad_eq = QuadraticEquation::::new_prover( + &setup.ntt_D, + point.clone(), + &poly, + w_folded, + level_params, + hint, + &mut transcript, + &w.commitment, + &y_ring, + layout, + ) + .unwrap(); + + let challenges = replay_challenges(&quad_eq.v); + + Fixture { + setup, + commitment_u: w.commitment.u.clone(), + point, + blocks, + quad_eq, + challenges, + } + } + + fn i8_to_ring(digits: &[[i8; D]]) -> Vec> { + digits + .iter() + .map(|d| { + let coeffs: [F; D] = from_fn(|i| F::from_i64(d[i] as i64)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + } + + /// Row 1: D · ŵ = v + #[test] + fn row1_d_times_w_hat_equals_v() { + let f = build_fixture(); + + let w_hat = f.quad_eq.w_hat().unwrap(); + let w_hat_flat: Vec> = i8_to_ring( + &w_hat + .iter() + .flat_map(|v| v.iter().copied()) + .collect::>(), + ); + let lhs = mat_vec_mul(&f.setup.expanded.D_mat, &w_hat_flat); + + assert_eq!(lhs, f.quad_eq.v(), "Row 1 failed: D · ŵ ≠ v"); + } + + /// Row 2: B · inner opening digits = u (commitment vector) + #[test] + fn row2_b_times_inner_opening_digits_equals_u_commitment() { + let f = build_fixture(); + + let hint = f.quad_eq.hint().unwrap(); + let inner_opening_digits_flat_ring: Vec> = hint + .inner_opening_digits + .iter() + .flat_map(|v| v.iter()) + .map(|plane| { + let coeffs: [F; D] = from_fn(|k| F::from_i64(plane[k] as i64)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let lhs = mat_vec_mul(&f.setup.expanded.B, &inner_opening_digits_flat_ring); + + assert_eq!( + lhs, f.commitment_u, + "Row 2 failed: B · inner opening digits ≠ u" + ); + } + + /// Row 3: b^T · G_{2^r} · ŵ = u_eval + #[test] + fn row3_bt_gadget_w_hat_equals_u_eval() { + let f = build_fixture(); + + let w_hat = f.quad_eq.w_hat().unwrap(); + let w_recomposed: Vec> = w_hat + .iter() + .map(|w_hat_i| CyclotomicRing::gadget_recompose_pow2_i8(w_hat_i, log_basis())) + .collect(); + + let u_eval = w_recomposed + .iter() + .zip(f.point.b.iter()) + .fold(CyclotomicRing::::zero(), |acc, (w_i, b_i)| { + acc + w_i.scale(b_i) + }); + + let u_eval_direct = f.blocks.iter().zip(f.point.b.iter()).fold( + CyclotomicRing::::zero(), + |acc, (block_i, b_i)| { + let inner: CyclotomicRing = block_i + .iter() + .zip(f.point.a.iter()) + .fold(CyclotomicRing::::zero(), |acc2, (f_ij, a_j)| { + acc2 + f_ij.scale(a_j) + }); + acc + inner.scale(b_i) + }, + ); + + assert_eq!( + u_eval, u_eval_direct, + "Row 3 failed: b^T G ŵ ≠ Σ b_i (a^T f_i)" + ); + } + + /// Derive z_hat from z_pre for test assertions. + fn derive_z_hat(z_pre: &[CyclotomicRing]) -> Vec> { + z_pre + .iter() + .flat_map(|z_j| z_j.balanced_decompose_pow2(num_digits_fold(), log_basis())) + .collect() + } + + /// Row 4: (c^T ⊗ G_1) · ŵ = a^T · G_{2^m} · J · ẑ + #[test] + fn row4_challenge_fold_w_equals_a_gadget_j_z_hat() { + let f = build_fixture(); + + let w_hat = f.quad_eq.w_hat().unwrap(); + let w: Vec> = w_hat + .iter() + .map(|w_hat_i| CyclotomicRing::gadget_recompose_pow2_i8(w_hat_i, log_basis())) + .collect(); + + let lhs = f + .challenges + .iter() + .zip(w.iter()) + .fold(CyclotomicRing::::zero(), |acc, (c_i, w_i)| { + acc + (*c_i * *w_i) + }); + + let z_hat = derive_z_hat(f.quad_eq.z_pre().unwrap()); + let z_recovered = recompose_z_hat(&z_hat); + let rhs = a_transpose_gadget_times_vec(&f.point.a, &z_recovered); + + assert_eq!(lhs, rhs, "Row 4 failed: (c^T ⊗ G_1)ŵ ≠ a^T G J ẑ"); + } + + /// Row 5: (c^T ⊗ G_{n_A}) · inner opening digits = A · J · ẑ + #[test] + fn row5_challenge_fold_inner_opening_digits_equals_a_j_z_hat() { + let f = build_fixture(); + + let hint = f.quad_eq.hint().unwrap(); + let mut lhs = vec![CyclotomicRing::::zero(); N_A]; + for (c_i, inner_opening_digits_i) in + f.challenges.iter().zip(hint.inner_opening_digits.iter()) + { + let t_i = gadget_recompose_vec_i8(inner_opening_digits_i); + assert_eq!(t_i.len(), N_A); + for (lhs_j, t_ij) in lhs.iter_mut().zip(t_i.iter()) { + *lhs_j += *c_i * *t_ij; + } + } + + let z_hat = derive_z_hat(f.quad_eq.z_pre().unwrap()); + let z_recovered = recompose_z_hat(&z_hat); + let rhs = mat_vec_mul(&f.setup.expanded.A, &z_recovered); + + assert_eq!( + lhs, rhs, + "Row 5 failed: (c^T ⊗ G_nA)inner opening digits ≠ A · J · ẑ" + ); + } + + #[test] + fn prove_output_shapes_are_correct() { + let f = build_fixture(); + + assert_eq!(f.quad_eq.v().len(), TinyConfig::N_D); + + let w_hat = f.quad_eq.w_hat().unwrap(); + assert_eq!(w_hat.len(), NUM_BLOCKS); + assert!(w_hat.iter().all(|v| v.len() == num_digits_open())); + + let hint = f.quad_eq.hint().unwrap(); + assert_eq!(hint.inner_opening_digits.len(), NUM_BLOCKS); + assert!(hint + .inner_opening_digits + .iter() + .all(|v| v.len() == N_A * num_digits_open())); + + assert_eq!( + f.quad_eq.z_pre().unwrap().len(), + BLOCK_LEN * num_digits_commit() + ); + } +} diff --git a/src/protocol/ring_switch.rs b/src/protocol/ring_switch.rs new file mode 100644 index 00000000..de16e3ae --- /dev/null +++ b/src/protocol/ring_switch.rs @@ -0,0 +1,1176 @@ +//! Ring switching logic for the Hachi PCS (Section 4.3). +//! +//! Handles the transition from the ring-based quadratic equation to field-based +//! sumcheck instances by expanding the ring elements into their coefficient +//! vectors and setting up the evaluation tables. + +use crate::algebra::{CyclotomicRing, SparseChallenge}; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::commitment::utils::crt_ntt::NttSlotCache; +use crate::protocol::commitment::utils::linear::{ + decompose_rows_i8, flatten_i8_blocks, mat_vec_mul_ntt_digits_i8, mat_vec_mul_ntt_i8, + mat_vec_mul_ntt_single_i8, +}; +use crate::protocol::commitment::utils::norm::detect_field_modulus; +use crate::protocol::commitment::{ + hachi_level_layout, CommitmentConfig, DecompositionParams, HachiCommitmentLayout, + HachiExpandedSetup, HachiLevelParams, HachiScheduleInputs, RingCommitment, +}; +use crate::protocol::opening_point::RingOpeningPoint; +use crate::protocol::proof::{DigitLut, FlatCommitmentHint, FlatRingVec, HachiCommitmentHint}; +use crate::protocol::quadratic_equation::{compute_r_split_eq, QuadraticEquation}; +use crate::protocol::sumcheck::eq_poly::EqPolynomial; +use crate::protocol::transcript::labels::{ + ABSORB_SUMCHECK_W, CHALLENGE_RING_SWITCH, CHALLENGE_TAU0, CHALLENGE_TAU1, +}; +use crate::protocol::transcript::Transcript; +use crate::{cfg_into_iter, cfg_iter}; +use crate::{CanonicalField, FieldCore, FieldSampling}; +#[cfg(test)] +use std::array::from_fn; +use std::marker::PhantomData; + +/// D-agnostic output of the ring switch protocol, containing everything +/// needed for sumchecks and level chaining. +pub struct RingSwitchOutput { + /// The witness vector w as balanced digits in `[-b/2, b/2)`. + pub w: Vec, + /// D-erased commitment to w. + pub w_commitment: FlatRingVec, + /// D-erased prover hint for the w-commitment. + pub w_hint: FlatCommitmentHint, + /// Compact evaluation table of w, stored as y-major slices of the live x prefix. + /// Populated by the prover; empty on the verifier side. + pub w_evals_compact: Vec, + /// Physical x width before zero-extension to the next power of two. + pub live_x_cols: usize, + /// Evaluation table of M_alpha(x) (tau1-weighted). + pub m_evals_x: Vec, + /// Evaluation table of alpha powers (y dimension). + pub alpha_evals_y: Vec, + /// Number of upper variable bits. + pub num_u: usize, + /// Number of lower variable bits. + pub num_l: usize, + /// Challenge tau0 for F_0 sumcheck. + pub tau0: Vec, + /// Challenge tau1 for F_alpha sumcheck. + pub tau1: Vec, + /// Basis size b = 2^LOG_BASIS. + pub b: usize, + /// Ring-switch challenge alpha. + pub alpha: F, +} + +/// Build the witness vector `w` from the quadratic equation state. +/// +/// This is the first half of the ring switch: it computes `r` and assembles +/// `w` as a flat `Vec`. The resulting `w` is D-agnostic and can be +/// committed at any ring dimension via [`commit_w`]. +/// +/// # Errors +/// +/// Returns an error if the quadratic equation is missing prover-side data. +#[tracing::instrument(skip_all, name = "ring_switch_build_w")] +#[allow(clippy::too_many_arguments)] +#[inline(never)] +pub fn ring_switch_build_w( + quad_eq: &mut QuadraticEquation, + setup: &HachiExpandedSetup, + ntt_a: &NttSlotCache, + ntt_b: &NttSlotCache, + ntt_d: &NttSlotCache, + level_params: HachiLevelParams, + layout: HachiCommitmentLayout, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, +{ + { + let x: u8 = 0; + tracing::trace!( + stack_ptr = format_args!("{:#x}", &x as *const u8 as usize), + "ring_switch_build_w" + ); + } + let w_hat = quad_eq + .w_hat() + .ok_or_else(|| HachiError::InvalidInput("missing w_hat in prover".to_string()))?; + let w_hat_flat = quad_eq + .w_hat_flat() + .ok_or_else(|| HachiError::InvalidInput("missing w_hat_flat in prover".to_string()))?; + let z_pre_centered = quad_eq + .z_pre_centered() + .ok_or_else(|| HachiError::InvalidInput("missing centered z_pre in prover".to_string()))?; + let z_pre_centered_inf_norm = quad_eq.z_pre_centered_inf_norm().ok_or_else(|| { + HachiError::InvalidInput("missing centered z_pre norm in prover".to_string()) + })?; + let hint = quad_eq + .hint() + .ok_or_else(|| HachiError::InvalidInput("missing hint in prover".to_string()))?; + let inner_opening_digits = &hint.inner_opening_digits; + let t = hint.t().ok_or_else(|| { + HachiError::InvalidInput("missing recomposed t in prover hint".to_string()) + })?; + let w_folded = quad_eq + .w_folded() + .ok_or_else(|| HachiError::InvalidInput("missing w_folded in prover".to_string()))?; + + let r = compute_r_split_eq::( + level_params, + setup, + &quad_eq.challenges, + w_hat_flat, + inner_opening_digits, + t, + w_folded, + z_pre_centered, + z_pre_centered_inf_norm, + quad_eq.y(), + ntt_a, + ntt_b, + ntt_d, + )?; + let w = { + let _span = tracing::info_span!("build_w_coeffs").entered(); + build_w_coeffs::(w_hat, inner_opening_digits, z_pre_centered, &r, layout) + }; + Ok(w) +} + +/// Complete the ring switch after `w` has been committed. +/// +/// Takes the already-committed `w` (with its D-erased commitment and hint) +/// and finishes the protocol: absorbs the commitment into the transcript, +/// samples challenges, and builds the evaluation tables for the fused sumcheck. +/// +/// Only the current level's `D` is needed (for M_alpha expansion and +/// alpha_evals_y). The commitment's ring dimension is encoded in the +/// `FlatRingVec` and does not require a separate const generic. +/// +/// # Errors +/// +/// Returns an error if matrix expansion or evaluation-table construction fails. +#[tracing::instrument(skip_all, name = "ring_switch_finalize")] +#[allow(clippy::too_many_arguments)] +#[inline(never)] +pub fn ring_switch_finalize( + quad_eq: &QuadraticEquation, + setup: &HachiExpandedSetup, + transcript: &mut T, + w: Vec, + w_commitment: FlatRingVec, + w_hint: FlatCommitmentHint, + level_params: HachiLevelParams, + layout: HachiCommitmentLayout, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, + Cfg: CommitmentConfig, +{ + transcript.append_serde(ABSORB_SUMCHECK_W, &w_commitment); + + let alpha: F = transcript.challenge_scalar(CHALLENGE_RING_SWITCH); + + let num_l = D.trailing_zeros() as usize; + let num_ring_elems = w.len() / D; + let live_x_cols = num_ring_elems; + let num_u = num_ring_elems.next_power_of_two().trailing_zeros() as usize; + let m_rows = m_row_count(level_params); + let num_sc_vars = num_u + num_l; + let num_i = m_rows.next_power_of_two().trailing_zeros() as usize; + + let tau0 = sample_tau::(transcript, CHALLENGE_TAU0, num_sc_vars); + let tau1 = sample_tau::(transcript, CHALLENGE_TAU1, num_i); + let alpha_evals_y = build_alpha_evals_y(alpha, D); + + let opening_point = quad_eq.opening_point(); + let challenges = &quad_eq.challenges; + + #[cfg(feature = "parallel")] + let (m_evals_x_result, w_result) = rayon::join( + || { + compute_m_evals_x::( + setup, + opening_point, + challenges, + alpha, + &alpha_evals_y, + level_params, + layout, + &tau1, + ) + }, + || build_w_evals_compact(&w, D), + ); + #[cfg(not(feature = "parallel"))] + let (m_evals_x_result, w_result) = { + let m_evals_x = compute_m_evals_x::( + setup, + opening_point, + challenges, + alpha, + &alpha_evals_y, + level_params, + layout, + &tau1, + )?; + let w_compact = build_w_evals_compact(&w, D); + (Ok(m_evals_x), w_compact) + }; + + let m_evals_x = m_evals_x_result?; + let (w_evals_compact, _, _) = w_result?; + + Ok(RingSwitchOutput { + w, + w_commitment, + w_hint, + w_evals_compact, + live_x_cols, + m_evals_x, + alpha_evals_y, + num_u, + num_l, + tau0, + tau1, + b: 1usize << layout.log_basis, + alpha, + }) +} + +/// Execute the prover side of the ring switching protocol (Section 4.3). +/// +/// Convenience wrapper that calls [`ring_switch_build_w`], [`commit_w`], and +/// [`ring_switch_finalize`] in sequence, all at the same ring dimension `D`. +/// +/// # Errors +/// +/// Returns an error if z_pre/w_hat is missing, commitment fails, or matrix expansion fails. +#[tracing::instrument(skip_all, name = "ring_switch_prover")] +#[allow(clippy::too_many_arguments)] +#[inline(never)] +pub fn ring_switch_prover( + quad_eq: &mut QuadraticEquation, + setup: &HachiExpandedSetup, + transcript: &mut T, + ntt_a: &NttSlotCache, + ntt_b: &NttSlotCache, + ntt_d: &NttSlotCache, + level_params: HachiLevelParams, + layout: HachiCommitmentLayout, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, + Cfg: CommitmentConfig, +{ + let w = ring_switch_build_w::( + quad_eq, + setup, + ntt_a, + ntt_b, + ntt_d, + level_params, + layout, + )?; + + let (w_commitment, w_hint) = commit_w::(&w, ntt_a, ntt_b, level_params)?; + + let w_commitment_flat = FlatRingVec::from_commitment(&w_commitment); + let w_hint_flat = FlatCommitmentHint::from_typed(w_hint); + + ring_switch_finalize::( + quad_eq, + setup, + transcript, + w, + w_commitment_flat, + w_hint_flat, + level_params, + layout, + ) +} + +/// Replay the verifier side of ring switching to reconstruct evaluation tables. +/// +/// Takes the w-commitment as a [`FlatRingVec`] so the verifier does not need +/// to know D_COMMIT (the commitment's ring dimension). +/// +/// # Errors +/// +/// Returns an error if matrix expansion fails. +#[tracing::instrument(skip_all, name = "ring_switch_verifier")] +#[inline(never)] +pub fn ring_switch_verifier( + quad_eq: &QuadraticEquation, + setup: &HachiExpandedSetup, + w_len: usize, + w_commitment: &FlatRingVec, + transcript: &mut T, + level_params: HachiLevelParams, + layout: HachiCommitmentLayout, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + T: Transcript, + Cfg: CommitmentConfig, +{ + transcript.append_serde(ABSORB_SUMCHECK_W, w_commitment); + + let alpha: F = transcript.challenge_scalar(CHALLENGE_RING_SWITCH); + + let num_ring_elems = w_len / D; + let num_u = num_ring_elems.next_power_of_two().trailing_zeros() as usize; + let num_l = D.trailing_zeros() as usize; + let m_rows = m_row_count(level_params); + let num_sc_vars = num_u + num_l; + let num_i = m_rows.next_power_of_two().trailing_zeros() as usize; + + let tau0 = sample_tau::(transcript, CHALLENGE_TAU0, num_sc_vars); + let tau1 = sample_tau::(transcript, CHALLENGE_TAU1, num_i); + let alpha_evals_y = build_alpha_evals_y(alpha, D); + + let m_evals_x = compute_m_evals_x::( + setup, + quad_eq.opening_point(), + &quad_eq.challenges, + alpha, + &alpha_evals_y, + level_params, + layout, + &tau1, + )?; + + Ok(RingSwitchOutput { + w: Vec::new(), + w_commitment: w_commitment.clone(), + w_hint: FlatCommitmentHint::empty(), + w_evals_compact: Vec::new(), + live_x_cols: w_len / D, + m_evals_x, + alpha_evals_y, + num_u, + num_l, + tau0, + tau1, + b: 1usize << layout.log_basis, + alpha, + }) +} + +#[cfg(test)] +pub(crate) fn compute_r_via_poly_division( + m: &[Vec>], + z: &[CyclotomicRing], + y: &[CyclotomicRing], +) -> Result>, HachiError> { + let poly_len = 2 * D - 1; + let out = m + .iter() + .zip(y.iter()) + .map(|(row, y_i)| { + let column_contribution = + |m_ij: &CyclotomicRing, z_j: &CyclotomicRing| -> Vec { + let mut local = vec![F::zero(); poly_len]; + if m_ij.is_zero() { + return local; + } + let a = m_ij.coefficients(); + let b = z_j.coefficients(); + let is_scalar = a[1..].iter().all(|c| c.is_zero()); + if is_scalar { + let scalar = a[0]; + for s in 0..D { + local[s] = scalar * b[s]; + } + } else { + for t in 0..D { + for s in 0..D { + local[t + s] += a[t] * b[s]; + } + } + } + local + }; + + let pointwise_add = |mut a: Vec, b: Vec| -> Vec { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + *ai += *bi; + } + a + }; + + #[cfg(feature = "parallel")] + let mut poly = row + .par_iter() + .zip(z.par_iter()) + .fold( + || vec![F::zero(); poly_len], + |acc, (m_ij, z_j)| pointwise_add(acc, column_contribution(m_ij, z_j)), + ) + .reduce(|| vec![F::zero(); poly_len], pointwise_add); + + #[cfg(not(feature = "parallel"))] + let mut poly = row + .iter() + .zip(z.iter()) + .fold(vec![F::zero(); poly_len], |acc, (m_ij, z_j)| { + pointwise_add(acc, column_contribution(m_ij, z_j)) + }); + let y_coeffs = y_i.coefficients(); + for k in 0..D { + poly[k] -= y_coeffs[k]; + } + let mut quotient = vec![F::zero(); D]; + for k in (D..poly_len).rev() { + let q = poly[k]; + quotient[k - D] = q; + poly[k - D] -= q; + } + let coeffs: [F; D] = from_fn(|k| quotient[k]); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + Ok(out) +} + +/// Derived commitment config for recursive w-openings. +/// +/// Sets `log_commit_bound = log_basis` (w's entries are balanced digits) and +/// `log_open_bound = parent's open bound` (opening folds produce full-field +/// coefficients). +/// +/// For `D=512, Cfg=Fp128FullCommitmentConfig`, this is equivalent to +/// [`Fp128LogBasisCommitmentConfig`](super::commitment::Fp128LogBasisCommitmentConfig). +#[derive(Clone, Copy, Debug)] +pub(crate) struct WCommitmentConfig { + _cfg: PhantomData, +} + +impl CommitmentConfig for WCommitmentConfig { + const D: usize = D; + const N_A: usize = Cfg::N_A; + const N_B: usize = Cfg::N_B; + const N_D: usize = Cfg::N_D; + const CHALLENGE_WEIGHT: usize = Cfg::CHALLENGE_WEIGHT; + + fn challenge_weight_for_ring_dim(d: usize) -> usize { + Cfg::challenge_weight_for_ring_dim(d) + } + + fn w_log_basis() -> u32 { + Cfg::w_log_basis() + } + + fn decomposition() -> DecompositionParams { + let parent = Cfg::decomposition(); + let w_basis = Cfg::w_log_basis(); + let parent_open = parent.log_open_bound.unwrap_or(parent.log_commit_bound); + DecompositionParams { + log_basis: w_basis, + // w entries come from a balanced decomposition; use w_basis for + // the commit bound since that's the widest digit range at any + // recursive level (level-0 entries fit in parent.log_basis <= w_basis). + log_commit_bound: w_basis, + // Opening folds w with arbitrary field-element weights, producing + // full-field-size coefficients that need the same decomposition + // depth as the parent's opening bound. + log_open_bound: Some(parent_open), + } + } + + fn commitment_layout(max_num_vars: usize) -> Result { + let current_w_len = 1usize << max_num_vars; + let (_, layout) = hachi_level_layout::(HachiScheduleInputs { + max_num_vars, + level: 1, + current_w_len, + })?; + Ok(layout) + } +} + +/// Total ring elements in the w polynomial, computed from the main layout. +/// +/// Components: w_hat + t_hat + decomposed z_pre + decomposed r. +pub(crate) fn w_ring_element_count( + level_params: HachiLevelParams, + layout: HachiCommitmentLayout, +) -> usize { + let w_hat_count = layout.num_blocks * layout.num_digits_open; + let t_hat_count = layout.num_blocks * level_params.n_a * layout.num_digits_open; + let z_pre_count = layout.inner_width * layout.num_digits_fold; + let r_count = m_row_count(level_params) * r_decomp_levels::(layout.log_basis); + w_hat_count + t_hat_count + z_pre_count + r_count +} + +/// Compute the w-commitment layout from the main layout. +pub(crate) fn w_commitment_layout( + level_params: HachiLevelParams, + main_layout: HachiCommitmentLayout, +) -> Result { + let total = w_ring_element_count::(level_params, main_layout) + .next_power_of_two() + .max(1); + let alpha = D.trailing_zeros() as usize; + let m_vars = total.trailing_zeros() as usize; + let max_num_vars = m_vars + alpha; + WCommitmentConfig::::commitment_layout(max_num_vars) +} + +/// Commit the witness vector `w` (D-agnostic `Vec`) into `D`-sized ring +/// elements and compute the ring commitment. +/// +/// This is the **D-boundary** in the protocol: the ring switch at level k +/// produces `w` using D_k operations, but `commit_w` re-chunks `w` into +/// D_{k+1}-sized ring elements and commits using D_{k+1} NTT caches. +/// +/// For constant-D configs, D_k = D_{k+1} = D and the distinction is moot. +/// +/// # Errors +/// +/// Returns an error if the commitment layout derivation or NTT mat-vec fails. +#[tracing::instrument(skip_all, name = "commit_w")] +#[inline(never)] +pub fn commit_w( + w: &[i8], + ntt_a: &NttSlotCache, + ntt_b: &NttSlotCache, + level_params: HachiLevelParams, +) -> Result<(RingCommitment, HachiCommitmentHint), HachiError> +where + F: FieldCore + CanonicalField + FieldSampling, + Cfg: CommitmentConfig, +{ + let (w_digits, remainder) = w.as_chunks::(); + if !remainder.is_empty() { + return Err(HachiError::InvalidSize { + expected: D, + actual: w.len(), + }); + } + + let total = w_digits.len().next_power_of_two().max(1); + let alpha = D.trailing_zeros() as usize; + let m_vars_total = total.trailing_zeros() as usize; + let max_num_vars = m_vars_total + alpha; + let w_layout = WCommitmentConfig::::commitment_layout(max_num_vars)?; + + let num_blocks = w_layout.num_blocks; + let block_len = w_layout.block_len; + let depth_commit = w_layout.num_digits_commit; + let depth_open = w_layout.num_digits_open; + let log_basis = w_layout.log_basis; + let coeff_len = w_digits.len(); + + let t_all = if depth_commit == 1 { + // `build_w_coeffs` already emits balanced base-`2^log_basis` digits, so + // the recursive w-commitment can skip the field conversion and feed those + // planes directly into the tiled NTT mat-vec. + let block_slices: Vec<&[[i8; D]]> = (0..num_blocks) + .map(|i| { + let start = i * block_len; + if start >= coeff_len { + &[] as &[[i8; D]] + } else { + &w_digits[start..(start + block_len).min(coeff_len)] + } + }) + .collect(); + mat_vec_mul_ntt_digits_i8(ntt_a, &block_slices) + } else { + let lut = DigitLut::::new(log_basis); + let ring_elems: Vec> = w_digits + .iter() + .map(|digit| { + let coeffs = std::array::from_fn(|k| lut.get(digit[k])); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let block_slices: Vec<&[CyclotomicRing]> = (0..num_blocks) + .map(|i| { + let start = i * block_len; + if start >= coeff_len { + &[] as &[CyclotomicRing] + } else { + &ring_elems[start..(start + block_len).min(coeff_len)] + } + }) + .collect(); + mat_vec_mul_ntt_i8(ntt_a, &block_slices, depth_commit, log_basis) + }; + let t_hat_per_block: Vec> = cfg_iter!(t_all) + .map(|t_i| decompose_rows_i8(t_i, depth_open, log_basis)) + .collect(); + + let t_hat_flat = flatten_i8_blocks(&t_hat_per_block); + let mut u: Vec> = mat_vec_mul_ntt_single_i8(ntt_b, &t_hat_flat); + u.truncate(level_params.n_b); + let hint = HachiCommitmentHint::with_t(t_hat_per_block, t_all); + Ok((RingCommitment { u }, hint)) +} + +pub(crate) fn eval_ring_at(r: &CyclotomicRing, alpha: &F) -> F { + let mut acc = F::zero(); + let mut power = F::one(); + for coeff in r.coefficients() { + acc += *coeff * power; + power = power * *alpha; + } + acc +} + +#[inline] +fn eval_ring_at_pows( + r: &CyclotomicRing, + alpha_pows: &[F], +) -> F { + debug_assert_eq!(alpha_pows.len(), D); + r.coefficients() + .iter() + .zip(alpha_pows.iter()) + .fold(F::zero(), |acc, (coeff, alpha_pow)| { + acc + *coeff * *alpha_pow + }) +} + +#[inline] +fn eval_sparse_challenge_at_pows( + challenge: &SparseChallenge, + alpha_pows: &[F], +) -> Result { + if alpha_pows.len() != D { + return Err(HachiError::InvalidSize { + expected: D, + actual: alpha_pows.len(), + }); + } + + debug_assert_eq!(challenge.positions.len(), challenge.coeffs.len()); + + let mut acc = F::zero(); + for (&pos, &coeff) in challenge.positions.iter().zip(challenge.coeffs.iter()) { + let idx = pos as usize; + debug_assert!(idx < D); + debug_assert_ne!(coeff, 0); + acc += F::from_i64(coeff as i64) * alpha_pows[idx]; + } + Ok(acc) +} + +#[inline] +fn gadget_row_scalars(levels: usize, log_basis: u32) -> Vec { + let base = F::from_canonical_u128_reduced(1u128 << log_basis); + let mut out = Vec::with_capacity(levels); + let mut power = F::one(); + for _ in 0..levels { + out.push(power); + power = power * base; + } + out +} + +pub(crate) fn r_decomp_levels(log_basis: u32) -> usize { + let modulus = detect_field_modulus::(); + let bits = 128 - (modulus.saturating_sub(1)).leading_zeros() as usize; + let lb = log_basis as usize; + let mut levels = (bits + lb.saturating_sub(1)) / lb.max(1); + if levels == 0 { + levels = 1; + } + + let total_bits = levels * lb; + if total_bits <= bits { + let b = 1u128 << log_basis; + let half_q = modulus / 2; + let half_b_minus_1 = b / 2 - 1; + let b_minus_1 = b - 1; + let mut b_pow = 1u128; + for _ in 0..levels { + b_pow = b_pow.saturating_mul(b); + } + let max_positive = half_b_minus_1.saturating_mul((b_pow - 1) / b_minus_1); + if max_positive < half_q { + levels += 1; + } + } + + levels +} + +#[cfg(test)] +#[allow(dead_code)] +pub(crate) fn expand_m_a( + m_a: &[Vec], + alpha: F, + log_basis: u32, +) -> Result, HachiError> { + if m_a.is_empty() { + return Ok(Vec::new()); + } + let rows = m_a.len(); + let cols = m_a[0].len(); + if cols == 0 { + return Ok(vec![F::zero(); rows]); + } + for row in m_a.iter() { + if row.len() != cols { + return Err(HachiError::InvalidSize { + expected: cols, + actual: row.len(), + }); + } + } + + let levels = r_decomp_levels::(log_basis); + let total_cols = cols + .checked_add( + rows.checked_mul(levels) + .ok_or_else(|| HachiError::InvalidSetup("expanded M width overflow".to_string()))?, + ) + .ok_or_else(|| HachiError::InvalidSetup("expanded M width overflow".to_string()))?; + + let base = F::from_canonical_u128_reduced(1u128 << log_basis); + let mut gadget_row = Vec::with_capacity(levels); + let mut power = F::one(); + for _ in 0..levels { + gadget_row.push(power); + power = power * base; + } + + let mut alpha_pow = F::one(); + for _ in 0..D { + alpha_pow = alpha_pow * alpha; + } + let denom = alpha_pow + F::one(); + + let mut out = vec![F::zero(); rows * total_cols]; + for (i, m_a_row) in m_a.iter().enumerate() { + let row_start = i * total_cols; + out[row_start..row_start + cols].copy_from_slice(m_a_row); + let r_start = row_start + cols + i * levels; + for (j, g) in gadget_row.iter().enumerate() { + out[r_start + j] = -denom * *g; + } + } + Ok(out) +} + +/// # Errors +/// +/// Returns an error if `w.len()` is not a multiple of `d`. +pub(crate) fn build_w_evals( + w: &[F], + d: usize, +) -> Result<(Vec, usize, usize), HachiError> { + if d == 0 || w.len() % d != 0 { + return Err(HachiError::InvalidSize { + expected: d, + actual: w.len(), + }); + } + let num_l = d.trailing_zeros() as usize; + let num_ring_elems = w.len() / d; + let num_u = num_ring_elems.next_power_of_two().trailing_zeros() as usize; + let x_len = 1usize << num_u; + let n = x_len << num_l; + + let evals: Vec = cfg_into_iter!(0..n) + .map(|dst| { + let x = dst & (x_len - 1); + let y = dst >> num_u; + let src = y + (x << num_l); + if src < w.len() { + w[src] + } else { + F::zero() + } + }) + .collect(); + Ok((evals, num_u, num_l)) +} + +/// Produce the compact `Vec` eval table of `w` for the fused prover, +/// storing only the physical x prefix for each y slice. +pub(crate) fn build_w_evals_compact( + w: &[i8], + d: usize, +) -> Result<(Vec, usize, usize), HachiError> { + if d == 0 || w.len() % d != 0 { + return Err(HachiError::InvalidSize { + expected: d, + actual: w.len(), + }); + } + let num_l = d.trailing_zeros() as usize; + let live_x_cols = w.len() / d; + let num_u = live_x_cols.next_power_of_two().trailing_zeros() as usize; + + let mut compact = vec![0i8; w.len()]; + + #[cfg(feature = "parallel")] + compact + .par_chunks_mut(live_x_cols) + .enumerate() + .for_each(|(y, row)| { + for (x, dst) in row.iter_mut().enumerate() { + *dst = w[y + (x << num_l)]; + } + }); + + #[cfg(not(feature = "parallel"))] + for (y, row) in compact.chunks_mut(live_x_cols).enumerate() { + for (x, dst) in row.iter_mut().enumerate() { + *dst = w[y + (x << num_l)]; + } + } + Ok((compact, num_u, num_l)) +} + +pub(crate) fn m_row_count(level_params: HachiLevelParams) -> usize { + level_params.m_row_count() +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn compute_m_evals_x( + setup: &HachiExpandedSetup, + opening_point: &RingOpeningPoint, + challenges: &[SparseChallenge], + alpha: F, + alpha_pows: &[F], + level_params: HachiLevelParams, + layout: HachiCommitmentLayout, + tau1: &[F], +) -> Result, HachiError> { + if alpha_pows.len() != D { + return Err(HachiError::InvalidSize { + expected: D, + actual: alpha_pows.len(), + }); + } + + let depth_commit = layout.num_digits_commit; + let depth_open = layout.num_digits_open; + let depth_fold = layout.num_digits_fold; + let log_basis = layout.log_basis; + let num_blocks = opening_point.b.len(); + let block_len = layout.block_len; + let w_len = depth_open * num_blocks; + let t_len = depth_open * level_params.n_a * num_blocks; + let inner_width = block_len * depth_commit; + let z_len = depth_fold * inner_width; + let rows = m_row_count(level_params); + let levels = r_decomp_levels::(log_basis); + let total_cols = w_len + .checked_add(t_len) + .and_then(|cols| cols.checked_add(z_len)) + .and_then(|cols| cols.checked_add(rows.checked_mul(levels)?)) + .ok_or_else(|| HachiError::InvalidSetup("expanded M width overflow".to_string()))?; + + let eq_tau1 = EqPolynomial::evals(tau1); + if eq_tau1.len() < rows { + return Err(HachiError::InvalidSize { + expected: rows, + actual: eq_tau1.len(), + }); + } + + let g1_open = gadget_row_scalars::(depth_open, log_basis); + let g1_commit = gadget_row_scalars::(depth_commit, log_basis); + let fold_gadget = gadget_row_scalars::(depth_fold, log_basis); + let r_gadget = gadget_row_scalars::(levels, log_basis); + let x_len = total_cols.next_power_of_two(); + let mut out = Vec::with_capacity(x_len); + + let c_alphas: Vec = challenges + .iter() + .map(|challenge| eval_sparse_challenge_at_pows::(challenge, alpha_pows)) + .collect::>()?; + + let d_view = setup.D_mat.view::(); + let b_view = setup.B.view::(); + let a_view = setup.A.view::(); + + let row3_weight = eq_tau1[level_params.n_d + level_params.n_b]; + let row4_weight = eq_tau1[level_params.n_d + level_params.n_b + 1]; + let a_weights = &eq_tau1[(level_params.n_d + level_params.n_b + 2)..rows]; + + let w_segment: Vec = cfg_into_iter!(0..w_len) + .map(|x| { + let block_idx = x / depth_open; + let digit_idx = x % depth_open; + let mut acc = (row3_weight * opening_point.b[block_idx] + + row4_weight * c_alphas[block_idx]) + * g1_open[digit_idx]; + for (row_idx, eq_i) in eq_tau1.iter().enumerate().take(level_params.n_d) { + if !eq_i.is_zero() { + acc += *eq_i * eval_ring_at_pows(&d_view.row(row_idx)[x], alpha_pows); + } + } + acc + }) + .collect(); + out.extend(w_segment); + + let t_segment: Vec = cfg_into_iter!(0..t_len) + .map(|x| { + let block_idx = x / (level_params.n_a * depth_open); + let rem = x % (level_params.n_a * depth_open); + let a_idx = rem / depth_open; + let digit_idx = rem % depth_open; + let mut acc = a_weights[a_idx] * c_alphas[block_idx] * g1_open[digit_idx]; + for (row_idx, eq_i) in eq_tau1[level_params.n_d..(level_params.n_d + level_params.n_b)] + .iter() + .enumerate() + { + if !eq_i.is_zero() { + acc += *eq_i * eval_ring_at_pows(&b_view.row(row_idx)[x], alpha_pows); + } + } + acc + }) + .collect(); + out.extend(t_segment); + + let z_base: Vec = cfg_into_iter!(0..inner_width) + .map(|k| { + let block_idx = k / depth_commit; + let digit_idx = k % depth_commit; + let mut acc = row4_weight * opening_point.a[block_idx] * g1_commit[digit_idx]; + for (a_idx, eq_i) in a_weights.iter().enumerate() { + if !eq_i.is_zero() { + acc += *eq_i * eval_ring_at_pows(&a_view.row(a_idx)[k], alpha_pows); + } + } + acc + }) + .collect(); + + let z_segment: Vec = cfg_into_iter!(0..z_len) + .map(|idx| { + let k = idx / depth_fold; + let fold_idx = idx % depth_fold; + -(z_base[k] * fold_gadget[fold_idx]) + }) + .collect(); + out.extend(z_segment); + + let alpha_pow_d = alpha_pows[D - 1] * alpha; + let denom = alpha_pow_d + F::one(); + let r_tail_len = rows * levels; + let r_tail: Vec = cfg_into_iter!(0..r_tail_len) + .map(|idx| { + let row_idx = idx / levels; + let level_idx = idx % levels; + -(eq_tau1[row_idx] * denom * r_gadget[level_idx]) + }) + .collect(); + out.extend(r_tail); + out.resize(x_len, F::zero()); + Ok(out) +} + +pub(crate) fn build_alpha_evals_y(alpha: F, d: usize) -> Vec { + let mut out = vec![F::zero(); d]; + let mut power = F::one(); + for val in out.iter_mut() { + *val = power; + power = power * alpha; + } + out +} + +pub(crate) fn sample_tau>( + transcript: &mut T, + label: &[u8], + n: usize, +) -> Vec { + (0..n).map(|_| transcript.challenge_scalar(label)).collect() +} + +fn balanced_decompose_centered_i32_i8_into( + centered: &[i32; D], + out: &mut [[i8; D]], + log_basis: u32, +) { + let levels = out.len(); + assert!( + log_basis > 0 && log_basis <= 7, + "log_basis must be in 1..=7 for i8 output" + ); + assert!( + (levels as u32).saturating_mul(log_basis) <= 128 + log_basis, + "levels * log_basis must be <= 128 + log_basis" + ); + + let half_b = 1i128 << (log_basis - 1); + let b = half_b << 1; + let mask = b - 1; + + for coeff_idx in 0..D { + let mut c = centered[coeff_idx] as i128; + for plane in out.iter_mut() { + let d = c & mask; + let balanced = if d >= half_b { d - b } else { d }; + c = (c - balanced) >> log_basis; + plane[coeff_idx] = balanced as i8; + } + } +} + +pub(crate) fn build_w_coeffs( + w_hat: &[Vec<[i8; D]>], + t_hat: &[Vec<[i8; D]>], + z_pre_centered: &[[i32; D]], + r: &[CyclotomicRing], + layout: HachiCommitmentLayout, +) -> Vec { + let log_basis = layout.log_basis; + let num_digits_fold = layout.num_digits_fold; + let levels = r_decomp_levels::(log_basis); + + let t_hat_flat = t_hat.iter().flat_map(|v| v.iter()); + + let w_hat_planes: usize = w_hat.iter().map(|v| v.len()).sum(); + let t_hat_planes: usize = t_hat.iter().map(|v| v.len()).sum(); + let z_count = w_hat_planes + t_hat_planes + z_pre_centered.len() * num_digits_fold; + let r_hat_count = r.len() * levels; + tracing::debug!( + w_hat_planes, + t_hat_planes, + z_pre_elems = z_pre_centered.len(), + z_pre_planes = z_pre_centered.len() * num_digits_fold, + r_elems = r.len(), + r_planes = r_hat_count, + total_ring = z_count + r_hat_count, + total_field = (z_count + r_hat_count) * D, + "build_w_coeffs" + ); + let mut out = Vec::with_capacity((z_count + r_hat_count) * D); + let mut digit_scratch = vec![[0i8; D]; num_digits_fold.max(levels)]; + for block in w_hat { + for digits in block { + out.extend_from_slice(digits); + } + } + for digits in t_hat_flat { + out.extend_from_slice(digits); + } + for z_j in z_pre_centered { + let z_planes = &mut digit_scratch[..num_digits_fold]; + balanced_decompose_centered_i32_i8_into(z_j, z_planes, log_basis); + for plane in z_planes.iter() { + out.extend_from_slice(plane); + } + } + for ri in r { + let r_planes = &mut digit_scratch[..levels]; + ri.balanced_decompose_pow2_i8_into(r_planes, log_basis); + for plane in r_planes.iter() { + out.extend_from_slice(plane); + } + } + out +} + +#[cfg(test)] +mod tests { + use super::compute_r_via_poly_division; + use crate::algebra::{CyclotomicRing, Prime128M8M4M1M0}; + use std::array::from_fn; + + use crate::{FieldCore, FromSmallInt}; + + fn compute_r_schoolbook( + m: &[Vec>], + z: &[CyclotomicRing], + y: &[CyclotomicRing], + ) -> Vec> { + let poly_len = 2 * D - 1; + m.iter() + .zip(y.iter()) + .map(|(row, y_i)| { + let mut poly = vec![F::zero(); poly_len]; + for (m_ij, z_j) in row.iter().zip(z.iter()) { + if m_ij.is_zero() { + continue; + } + let a = m_ij.coefficients(); + let b = z_j.coefficients(); + let is_scalar = a[1..].iter().all(|c| c.is_zero()); + if is_scalar { + let scalar = a[0]; + for s in 0..D { + poly[s] += scalar * b[s]; + } + } else { + for t in 0..D { + for s in 0..D { + poly[t + s] += a[t] * b[s]; + } + } + } + } + let y_coeffs = y_i.coefficients(); + for k in 0..D { + poly[k] -= y_coeffs[k]; + } + let mut quotient = vec![F::zero(); D]; + for k in (D..poly_len).rev() { + let q = poly[k]; + quotient[k - D] = q; + poly[k - D] -= q; + } + let coeffs: [F; D] = from_fn(|k| quotient[k]); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + } + + #[test] + fn compute_r_matches_schoolbook_reference() { + type F = Prime128M8M4M1M0; + const D: usize = 64; + + let m: Vec>> = (0..3) + .map(|i| { + (0..4) + .map(|j| { + if (i + j) % 3 == 0 { + let mut coeffs = [F::zero(); D]; + coeffs[0] = F::from_u64((i * 5 + j + 1) as u64); + CyclotomicRing::from_coefficients(coeffs) + } else { + let coeffs = from_fn(|k| { + F::from_u64((i as u64 * 1000 + j as u64 * 100 + k as u64 + 1) % 97) + }); + CyclotomicRing::from_coefficients(coeffs) + } + }) + .collect() + }) + .collect(); + let z: Vec> = (0..4) + .map(|j| { + let coeffs = from_fn(|k| F::from_u64((j as u64 * 37 + k as u64 + 5) % 89)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let y: Vec> = (0..3) + .map(|i| { + let coeffs = from_fn(|k| F::from_u64((i as u64 * 29 + k as u64 + 7) % 83)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + + let expected = compute_r_schoolbook(&m, &z, &y); + let got = compute_r_via_poly_division::(&m, &z, &y) + .expect("ring-switch CRT+NTT path should dispatch for D=64"); + assert_eq!(got, expected); + } +} diff --git a/src/protocol/sumcheck/batched_sumcheck.rs b/src/protocol/sumcheck/batched_sumcheck.rs new file mode 100644 index 00000000..e3ee7fc0 --- /dev/null +++ b/src/protocol/sumcheck/batched_sumcheck.rs @@ -0,0 +1,347 @@ +//! Batched sumcheck protocol. +//! +//! Implements the standard technique for batching parallel sumchecks to reduce +//! verifier cost and proof size. +//! +//! For details, refer to Jim Posen's ["Perspectives on Sumcheck Batching"](https://hackmd.io/s/HyxaupAAA). +//! We do what they describe as "front-loaded" batch sumcheck. +//! +//! Adapted from Jolt's `BatchedSumcheck` implementation. + +use super::{SumcheckInstanceProver, SumcheckInstanceVerifier, SumcheckProof, UniPoly}; +use crate::error::HachiError; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, FromSmallInt}; + +fn mul_pow_2(x: E, k: usize) -> E { + let mut result = x; + for _ in 0..k { + result = result + result; + } + result +} + +fn linear_combination(polys: &[UniPoly], coeffs: &[E]) -> UniPoly { + let max_len = polys.iter().map(|p| p.coeffs.len()).max().unwrap_or(0); + let mut result = vec![E::zero(); max_len]; + for (poly, coeff) in polys.iter().zip(coeffs.iter()) { + for (i, c) in poly.coeffs.iter().enumerate() { + result[i] += *c * *coeff; + } + } + UniPoly::from_coeffs(result) +} + +/// Verifier-side output of the batched sumcheck round replay. +/// +/// This carries all transcript-derived values needed for the final oracle check, +/// which is intentionally split out so callers can compute the expected output +/// claim through an external reduction (e.g. Greyhound) before enforcing +/// equality. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BatchedSumcheckRoundResult { + /// Final claim produced by replaying all sumcheck rounds. + pub output_claim: E, + /// Challenge vector sampled during replay. + pub r_sumcheck: Vec, + /// Front-loaded batching coefficient per verifier instance. + pub batching_coeffs: Vec, + /// Maximum number of rounds among batched instances. + pub max_num_rounds: usize, +} + +/// Produce a batched sumcheck proof for multiple instances sharing the same +/// variable space, driving the Fiat–Shamir transcript. +/// +/// This function: +/// - absorbs each instance's initial claim, +/// - samples batching coefficients (one per instance), +/// - computes a single batched round polynomial per round as a linear +/// combination of the individual round polynomials, +/// - returns a single [`SumcheckProof`] and the derived challenge vector. +/// +/// Instances with fewer rounds than the maximum are padded with constant +/// "dummy" round polynomials (the Jolt "front-loaded" approach). +/// +/// # Panics +/// +/// Panics if `instances` is empty or if 2 is not invertible in the field. +/// +/// # Errors +/// +/// Returns an error if the field inverse of 2 does not exist. +#[tracing::instrument(skip_all, name = "prove_batched_sumcheck")] +pub fn prove_batched_sumcheck( + mut instances: Vec<&mut dyn SumcheckInstanceProver>, + transcript: &mut T, + mut sample_challenge: S, +) -> Result<(SumcheckProof, Vec), HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, + E: FieldCore + FromSmallInt, + S: FnMut(&mut T) -> E, +{ + if instances.is_empty() { + return Err(HachiError::InvalidInput( + "no sumcheck instances provided".into(), + )); + } + + let max_num_rounds = instances + .iter() + .map(|inst| inst.num_rounds()) + .max() + .unwrap(); // safe: non-empty checked above + + // Absorb individual input claims. + for inst in instances.iter() { + let claim = inst.input_claim(); + transcript.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &claim); + } + + // Sample one batching coefficient per instance. + let batching_coeffs: Vec = (0..instances.len()) + .map(|_| sample_challenge(transcript)) + .collect(); + + // To see why we may need to scale by a power of two, consider a batch of + // two sumchecks: + // claim_a = \sum_x P(x) where x \in {0, 1}^M + // claim_b = \sum_{x, y} Q(x, y) where x \in {0, 1}^M, y \in {0, 1}^N + // Then the batched sumcheck is: + // \sum_{x, y} A * P(x) + B * Q(x, y) where A and B are batching coefficients + // = A * \sum_y \sum_x P(x) + B * \sum_{x, y} Q(x, y) + // = A * \sum_y claim_a + B * claim_b + // = A * 2^N * claim_a + B * claim_b + let mut individual_claims: Vec = instances + .iter() + .map(|inst| { + let n = inst.num_rounds(); + let claim = inst.input_claim(); + mul_pow_2(claim, max_num_rounds - n) + }) + .collect(); + + let mut round_polys = Vec::with_capacity(max_num_rounds); + let mut challenges = Vec::with_capacity(max_num_rounds); + + for round in 0..max_num_rounds { + let univariate_polys: Vec> = instances + .iter_mut() + .zip(individual_claims.iter()) + .map(|(inst, previous_claim)| { + let n = inst.num_rounds(); + let offset = max_num_rounds - n; + let active = round >= offset && round < offset + n; + if active { + inst.compute_round_univariate(round - offset, *previous_claim) + } else { + UniPoly::from_coeffs(vec![*previous_claim * E::TWO_INV]) + } + }) + .collect(); + + let batched_poly = linear_combination(&univariate_polys, &batching_coeffs); + + #[cfg(debug_assertions)] + { + let g0 = batched_poly.evaluate(&E::zero()); + let g1 = batched_poly.evaluate(&E::one()); + let batched_claim: E = individual_claims + .iter() + .zip(batching_coeffs.iter()) + .map(|(c, b)| *c * *b) + .fold(E::zero(), |a, v| a + v); + debug_assert!( + g0 + g1 == batched_claim, + "round {round}: H(0) + H(1) != batched claim" + ); + } + + let compressed = batched_poly.compress(); + transcript.append_serde(labels::ABSORB_SUMCHECK_ROUND, &compressed); + let r_j = sample_challenge(transcript); + challenges.push(r_j); + + // Update individual claims from each instance's own univariate. + for (claim, poly) in individual_claims.iter_mut().zip(univariate_polys.iter()) { + *claim = poly.evaluate(&r_j); + } + + // Ingest challenge into each active instance. + for inst in instances.iter_mut() { + let n = inst.num_rounds(); + let offset = max_num_rounds - n; + let active = round >= offset && round < offset + n; + if active { + inst.ingest_challenge(round - offset, r_j); + } + } + + round_polys.push(compressed); + } + + for inst in instances.iter_mut() { + inst.finalize(); + } + + Ok((SumcheckProof { round_polys }, challenges)) +} + +/// Verify a batched sumcheck proof. +/// +/// This function: +/// - absorbs each verifier instance's initial claim, +/// - re-derives the batching coefficients, +/// - computes the batched initial claim, +/// - verifies the proof against the batched claim. +/// +/// Returns transcript-derived verifier data for the caller to perform the final +/// expected-output equality check. +/// +/// # Panics +/// +/// Panics if `verifiers` is empty. +/// +/// # Errors +/// +/// Propagates per-round verification errors. +pub fn verify_batched_sumcheck_rounds( + proof: &SumcheckProof, + verifiers: Vec<&dyn SumcheckInstanceVerifier>, + transcript: &mut T, + mut sample_challenge: S, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, +{ + if verifiers.is_empty() { + return Err(HachiError::InvalidInput( + "no sumcheck instances provided".into(), + )); + } + + let max_degree = verifiers.iter().map(|v| v.degree_bound()).max().unwrap(); // safe: non-empty + let max_num_rounds = verifiers.iter().map(|v| v.num_rounds()).max().unwrap(); // safe: non-empty + + // Absorb individual input claims. + for v in verifiers.iter() { + let claim = v.input_claim(); + transcript.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &claim); + } + + // Re-derive batching coefficients. + let batching_coeffs: Vec = (0..verifiers.len()) + .map(|_| sample_challenge(transcript)) + .collect(); + + // Compute the combined initial claim with power-of-two scaling. + let batched_claim: E = verifiers + .iter() + .zip(batching_coeffs.iter()) + .map(|(v, coeff)| { + let n = v.num_rounds(); + let claim = v.input_claim(); + mul_pow_2(claim, max_num_rounds - n) * *coeff + }) + .fold(E::zero(), |a, v| a + v); + + let (output_claim, r_sumcheck) = proof.verify::( + batched_claim, + max_num_rounds, + max_degree, + transcript, + &mut sample_challenge, + )?; + + Ok(BatchedSumcheckRoundResult { + output_claim, + r_sumcheck, + batching_coeffs, + max_num_rounds, + }) +} + +/// Compute the expected batched output claim from verifier instances and +/// transcript-derived batching data. +/// +/// # Errors +/// +/// Propagates errors from verifier `expected_output_claim` calls. +pub fn compute_batched_expected_output_claim( + verifiers: Vec<&dyn SumcheckInstanceVerifier>, + batching_coeffs: &[E], + max_num_rounds: usize, + r_sumcheck: &[E], +) -> Result { + let expected_output_claim: E = verifiers + .iter() + .zip(batching_coeffs.iter()) + .map(|(v, coeff)| { + let offset = max_num_rounds - v.num_rounds(); + let r_slice = &r_sumcheck[offset..offset + v.num_rounds()]; + v.expected_output_claim(r_slice).map(|val| val * *coeff) + }) + .try_fold(E::zero(), |a, v| v.map(|val| a + val))?; + + Ok(expected_output_claim) +} + +/// Enforce final batched output-claim equality. +/// +/// # Errors +/// +/// Returns an error if `output_claim != expected_output_claim`. +pub fn check_batched_output_claim( + output_claim: E, + expected_output_claim: E, +) -> Result<(), HachiError> { + if output_claim != expected_output_claim { + return Err(HachiError::InvalidProof); + } + + Ok(()) +} + +/// Verify a batched sumcheck proof, including final expected-output equality. +/// +/// This convenience wrapper preserves the previous behavior. Callers that need +/// to inject an external reduction should use [`verify_batched_sumcheck_rounds`] +/// and [`check_batched_output_claim`] directly. +/// +/// # Errors +/// +/// Propagates errors from round verification and output-claim equality check. +#[tracing::instrument(skip_all, name = "verify_batched_sumcheck")] +pub fn verify_batched_sumcheck( + proof: &SumcheckProof, + verifiers: Vec<&dyn SumcheckInstanceVerifier>, + transcript: &mut T, + mut sample_challenge: S, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, +{ + let round_result = verify_batched_sumcheck_rounds::( + proof, + verifiers.clone(), + transcript, + &mut sample_challenge, + )?; + let expected_output_claim = compute_batched_expected_output_claim( + verifiers, + &round_result.batching_coeffs, + round_result.max_num_rounds, + &round_result.r_sumcheck, + )?; + check_batched_output_claim(round_result.output_claim, expected_output_claim)?; + Ok(round_result.r_sumcheck) +} diff --git a/src/protocol/sumcheck/eq_poly.rs b/src/protocol/sumcheck/eq_poly.rs new file mode 100644 index 00000000..8b14b248 --- /dev/null +++ b/src/protocol/sumcheck/eq_poly.rs @@ -0,0 +1,229 @@ +//! Utilities for the equality polynomial `eq(x, y) = Πᵢ (xᵢ yᵢ + (1 − xᵢ)(1 − yᵢ))`. +//! +//! The equality polynomial evaluates to 1 when `x = y` (over the boolean hypercube) +//! and 0 otherwise. Its multilinear extension (MLE) is used throughout sumcheck +//! protocols. +//! +//! Adapted from Jolt's `EqPolynomial` implementation. +//! +//! ## Bit / index order: Little-endian +//! +//! The evaluation tables produced by this module use **little-endian** bit order: +//! entry `b` (as an integer index) corresponds to the boolean vector where +//! bit `k` of `b` equals `x[k]`. In other words, `r[0]` corresponds to the +//! **least-significant bit** (bit 0) and `r[n-1]` to the MSB. + +use crate::FieldCore; +use std::marker::PhantomData; + +/// Utilities for the equality polynomial `eq(x, y) = Πᵢ (xᵢ yᵢ + (1 − xᵢ)(1 − yᵢ))`. +pub struct EqPolynomial(PhantomData); + +impl EqPolynomial { + /// Compute the MLE of the equality polynomial at two points: + /// `eq(x, y) = Πᵢ (xᵢ yᵢ + (1 − xᵢ)(1 − yᵢ))`. + /// + /// # Panics + /// + /// Panics if `x.len() != y.len()`. + pub fn mle(x: &[E], y: &[E]) -> E { + assert_eq!(x.len(), y.len()); + x.iter() + .zip(y.iter()) + .map(|(&x_i, &y_i)| x_i * y_i + (E::one() - x_i) * (E::one() - y_i)) + .fold(E::one(), |acc, v| acc * v) + } + + /// Compute the zero selector: `eq(r, 0) = Πᵢ (1 − rᵢ)`. + pub fn zero_selector(r: &[E]) -> E { + r.iter().fold(E::one(), |acc, &r_i| acc * (E::one() - r_i)) + } + + /// Compute the full evaluation table `{ eq(r, x) : x ∈ {0,1}^n }`. + /// + /// Uses **little-endian** bit order: entry `b` has bit `k` of `b` + /// corresponding to `r[k]`. + /// + /// For a scaled table, use [`Self::evals_with_scaling`]. + pub fn evals(r: &[E]) -> Vec { + Self::evals_with_scaling(r, None) + } + + /// Compute the full evaluation table with optional scaling: + /// `scaling_factor · eq(r, x)` for all `x ∈ {0,1}^n`. + /// + /// Uses the same **little-endian** index order as [`Self::evals`]. + /// If `scaling_factor` is `None`, defaults to 1 (no scaling). + pub fn evals_with_scaling(r: &[E], scaling_factor: Option) -> Vec { + #[cfg(feature = "parallel")] + { + const PARALLEL_THRESHOLD: usize = 16; + if r.len() > PARALLEL_THRESHOLD { + return Self::evals_parallel(r, scaling_factor); + } + } + Self::evals_serial(r, scaling_factor) + } + + /// Serial (single-threaded) version of [`Self::evals_with_scaling`]. + /// + /// Uses **little-endian** index order. + pub fn evals_serial(r: &[E], scaling_factor: Option) -> Vec { + let size = 1usize << r.len(); + let mut evals = vec![E::zero(); size]; + evals[0] = scaling_factor.unwrap_or(E::one()); + let mut len = 1usize; + for &t in r.iter().rev() { + let one_minus_t = E::one() - t; + for j in (0..len).rev() { + evals[2 * j + 1] = evals[j] * t; + evals[2 * j] = evals[j] * one_minus_t; + } + len *= 2; + } + evals + } + + /// Compute eq evaluations and cache intermediate tables. + /// + /// Returns `result` where `result[j]` contains evaluations for the prefix + /// `r[..j]`: `result[j][x] = eq(r[..j], x)` for `x ∈ {0,1}^j`. + /// + /// So `result[0] = [1]`, `result[1]` has 2 entries, ..., and `result[n]` + /// equals [`Self::evals(r)`]. + pub fn evals_cached(r: &[E]) -> Vec> { + Self::evals_cached_with_scaling(r, None) + } + + /// Like [`Self::evals_cached`], but with optional scaling. + pub fn evals_cached_with_scaling(r: &[E], scaling_factor: Option) -> Vec> { + let mut result: Vec> = (0..r.len() + 1).map(|i| vec![E::zero(); 1 << i]).collect(); + result[0][0] = scaling_factor.unwrap_or(E::one()); + for j in 0..r.len() { + let idx = r.len() - 1 - j; + let t = r[idx]; + let one_minus_t = E::one() - t; + let prev_len = 1 << j; + for i in (0..prev_len).rev() { + result[j + 1][2 * i + 1] = result[j][i] * t; + result[j + 1][2 * i] = result[j][i] * one_minus_t; + } + } + result + } + + /// Parallel version of [`Self::evals_with_scaling`]. + /// + /// Uses rayon to compute the largest layers of the DP tree in parallel. + /// Uses the same **little-endian** index order as [`Self::evals`]. + #[cfg(feature = "parallel")] + pub fn evals_parallel(r: &[E], scaling_factor: Option) -> Vec { + use rayon::prelude::*; + + let final_size = 1usize << r.len(); + let mut evals = vec![E::zero(); final_size]; + evals[0] = scaling_factor.unwrap_or(E::one()); + let mut size = 1; + + // Forward iteration (r[0] first) produces little-endian ordering. + for &r_i in r.iter() { + let (evals_left, evals_right) = evals.split_at_mut(size); + let (evals_right, _) = evals_right.split_at_mut(size); + + evals_left + .par_iter_mut() + .zip(evals_right.par_iter_mut()) + .for_each(|(x, y)| { + *y = *x * r_i; + *x -= *y; + }); + + size *= 2; + } + + evals + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Fp64; + use crate::{FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Fp64<4294967197>; + + #[test] + fn evals_matches_mle_pointwise() { + let mut rng = StdRng::seed_from_u64(0xEE); + for n in 1..8 { + let r: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let table = EqPolynomial::evals(&r); + assert_eq!(table.len(), 1 << n); + for (idx, &val) in table.iter().enumerate() { + let bits: Vec = (0..n) + .map(|k| { + if (idx >> k) & 1 == 1 { + F::one() + } else { + F::zero() + } + }) + .collect(); + let expected = EqPolynomial::mle(&r, &bits); + assert_eq!(val, expected, "n={n} idx={idx}"); + } + } + } + + #[test] + fn evals_with_scaling_scales_uniformly() { + let mut rng = StdRng::seed_from_u64(0xAB); + let r: Vec = (0..5).map(|_| F::sample(&mut rng)).collect(); + let scale = F::from_u64(7); + let unscaled = EqPolynomial::evals(&r); + let scaled = EqPolynomial::evals_with_scaling(&r, Some(scale)); + for (u, s) in unscaled.iter().zip(scaled.iter()) { + assert_eq!(*s, *u * scale); + } + } + + #[test] + fn evals_cached_last_matches_evals() { + let mut rng = StdRng::seed_from_u64(0xCD); + for n in 1..8 { + let r: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let table = EqPolynomial::evals(&r); + let cached = EqPolynomial::evals_cached(&r); + assert_eq!(cached.len(), n + 1); + assert_eq!(cached[0], vec![F::one()]); + assert_eq!(*cached.last().unwrap(), table); + } + } + + #[test] + fn zero_selector_matches_mle_at_origin() { + let mut rng = StdRng::seed_from_u64(0x00); + for n in 1..8 { + let r: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let zeros = vec![F::zero(); n]; + let expected = EqPolynomial::mle(&r, &zeros); + let actual = EqPolynomial::zero_selector(&r); + assert_eq!(actual, expected, "n={n}"); + } + } + + #[cfg(feature = "parallel")] + #[test] + fn evals_parallel_matches_serial() { + let mut rng = StdRng::seed_from_u64(0xFF); + for n in 1..20 { + let r: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let serial = EqPolynomial::evals_serial(&r, None); + let parallel = EqPolynomial::evals_parallel(&r, None); + assert_eq!(serial, parallel, "n={n}"); + } + } +} diff --git a/src/protocol/sumcheck/hachi_stage1.rs b/src/protocol/sumcheck/hachi_stage1.rs new file mode 100644 index 00000000..3a9bd281 --- /dev/null +++ b/src/protocol/sumcheck/hachi_stage1.rs @@ -0,0 +1,2238 @@ +//! Stage-1 norm sumcheck prover/verifier for the Hachi PCS. +//! +//! The committed witness is a Boolean table +//! `w : {0,1}^{num_u} x {0,1}^{num_l} -> {-half, ..., half-1}` with +//! `half = b/2`. Define the virtual table `S(z) = w(z) * (w(z) + 1)`. For an +//! honest witness every entry of `w` is a valid digit, so `S(z)` lies in the +//! set `{k(k+1) : k = 0, ..., half-1}`. The range-check polynomial +//! +//! `Q(s) = prod_{k=0}^{half-1} (s - k(k+1))` +//! +//! has degree `b/2` and vanishes on exactly that set. The sumcheck proves +//! +//! `0 = sum_z eq(tau0, z) * Q(S(z))`, +//! +//! where the input claim is `0` (an honest prover makes every summand vanish). +//! Each round polynomial has degree `b/2 + 1` — the product of degree-1 `eq` +//! and degree-`b/2` `Q`. After all rounds, at `r_stage1`, the verifier checks +//! +//! `eq(tau0, r_stage1) * Q(s_claim)` +//! +//! where `s_claim = S(r_stage1) = w(r_stage1) * (w(r_stage1) + 1)` is the +//! carried virtual claim passed into stage 2. +//! +//! ## `b = 8` specialization +//! +//! With `half = 4` the roots are `{0, 2, 6, 12}`, giving +//! +//! `Q(s) = s * (s - 2) * (s - 6) * (s - 12)`, +//! +//! degree 4, so round polynomials have degree 5. + +use super::eq_poly::EqPolynomial; +use super::split_eq::GruenSplitEq; +use super::two_round_prefix::{ + build_stage1_bivariate_skip_proof_from_s_compact, can_use_stage1_two_round_prefix, + Stage1BivariateSkipState, +}; +use super::{ + fold_evals_in_place, trim_trailing_zeros, CompactPairFoldLut, SumcheckInstanceProver, + SumcheckInstanceVerifier, UniPoly, +}; +use crate::algebra::fields::HasUnreducedOps; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::{ + cfg_fold_reduce, cfg_into_iter, AdditiveGroup, CanonicalField, FieldCore, FromSmallInt, +}; +use std::time::Instant; + +const MAX_AFFINE_COEFFS: usize = 17; + +#[derive(Clone, Copy, Debug)] +struct SparseCoeffEntry { + k: u8, + abs_coeff: u64, + is_neg: bool, +} + +#[derive(Clone, Copy, Debug, Default)] +struct CompactCoeffEntry { + abs_coeff: u64, + is_neg: bool, +} + +fn poly_coeffs_from_roots_int(roots: &[i64]) -> Vec { + let mut coeffs = vec![1i64]; + for &root in roots { + let mut next = vec![0i64; coeffs.len() + 1]; + for (idx, &coeff) in coeffs.iter().enumerate() { + next[idx] -= coeff * root; + next[idx + 1] += coeff; + } + coeffs = next; + } + coeffs +} + +#[derive(Clone)] +struct RangeAffineFromSPrecomp { + sparse_entries: Vec, + sparse_row_offsets: Vec, + degree_q: usize, + /// `h_i(s_0)` for each valid `s_0` and coefficient index `i`. + /// Indexed as `compact_idx * num_rows + i`, where `compact_idx` is + /// obtained from `s_to_compact`. + small_s_lut: Vec, + compact_coeff_lut: Option>, + /// Maps raw `s` integer (offset by `min_s`) to a compact index into the + /// `b/2`-element valid-value set `{k(k+1) : k = 0..half-1}`. + s_to_compact: Vec, + num_valid_s: usize, + min_s: i32, +} + +impl RangeAffineFromSPrecomp { + fn new(b: usize) -> Self { + assert!(b >= 2, "b must be at least 2"); + let half = (b / 2) as i64; + let pair_offsets: Vec = (0..half).map(|k| k * (k + 1)).collect(); + let range_coeffs = poly_coeffs_from_roots_int(&pair_offsets); + let degree_q = range_coeffs.len() - 1; + let num_rows = degree_q + 1; + + let total_elems = num_rows * (num_rows + 1) / 2; + let mut dense_int = Vec::with_capacity(total_elems); + let mut dense_row_offsets = Vec::with_capacity(num_rows + 1); + let mut sparse_entries = Vec::new(); + let mut sparse_row_offsets = Vec::with_capacity(num_rows + 1); + + for i in 0..num_rows { + dense_row_offsets.push(dense_int.len()); + sparse_row_offsets.push(sparse_entries.len()); + let row_len = degree_q - i + 1; + let mut binom: i64 = 1; + for k in 0..row_len { + let m = i + k; + let coeff = range_coeffs[m] * binom; + dense_int.push(coeff); + if coeff != 0 { + sparse_entries.push(SparseCoeffEntry { + k: k as u8, + abs_coeff: coeff.unsigned_abs(), + is_neg: coeff < 0, + }); + } + if k + 1 < row_len { + binom = binom * (m as i64 + 1) / (k as i64 + 1); + } + } + } + dense_row_offsets.push(dense_int.len()); + sparse_row_offsets.push(sparse_entries.len()); + + let min_s = 0i32; + let max_s = (half * (half - 1)) as i32; + let raw_range = (max_s - min_s + 1) as usize; + let num_valid_s = half as usize; + + let mut s_to_compact = vec![u8::MAX; raw_range]; + for (compact_idx, &s_val) in pair_offsets.iter().enumerate() { + s_to_compact[(s_val as i32 - min_s) as usize] = compact_idx as u8; + } + + let mut small_s_lut = vec![E::zero(); num_valid_s * num_rows]; + let mut small_s_lut_int = vec![0i128; num_valid_s * num_rows]; + for (compact_idx, &s_val) in pair_offsets.iter().enumerate() { + for i in 0..num_rows { + let row = &dense_int[dense_row_offsets[i]..dense_row_offsets[i + 1]]; + let mut h: i128 = 0; + for &c in row.iter().rev() { + h = h * s_val as i128 + c as i128; + } + small_s_lut_int[compact_idx * num_rows + i] = h; + small_s_lut[compact_idx * num_rows + i] = E::from_i128(h); + } + } + + let compact_coeff_lut = if b <= 8 { + let mut lut = Vec::with_capacity(num_valid_s * num_valid_s * num_rows); + for (s0_ci, &s0_val) in pair_offsets.iter().enumerate() { + let h_base = s0_ci * num_rows; + for &s1_val in &pair_offsets { + let delta = (s1_val - s0_val) as i128; + let mut delta_pow = 1i128; + for &h_i in &small_s_lut_int[h_base..h_base + num_rows] { + let coeff = h_i + .checked_mul(delta_pow) + .expect("compact affine coefficient overflow"); + let abs_coeff = coeff.unsigned_abs(); + assert!( + abs_coeff <= u64::MAX as u128, + "compact affine coefficient exceeds u64" + ); + lut.push(CompactCoeffEntry { + abs_coeff: abs_coeff as u64, + is_neg: coeff < 0, + }); + delta_pow = delta_pow + .checked_mul(delta) + .expect("compact affine power overflow"); + } + } + } + Some(lut) + } else { + None + }; + + Self { + sparse_entries, + sparse_row_offsets, + degree_q, + small_s_lut, + compact_coeff_lut, + s_to_compact, + num_valid_s, + min_s, + } + } +} + +impl RangeAffineFromSPrecomp { + #[inline] + fn compact_index(&self, s_int: i32) -> usize { + let raw = (s_int - self.min_s) as usize; + debug_assert!(raw < self.s_to_compact.len()); + let ci = self.s_to_compact[raw]; + debug_assert_ne!(ci, u8::MAX, "s={s_int} is not a valid w*(w+1) value"); + ci as usize + } + + #[inline] + fn sparse_row(&self, i: usize) -> &[SparseCoeffEntry] { + &self.sparse_entries[self.sparse_row_offsets[i]..self.sparse_row_offsets[i + 1]] + } + + fn num_rows(&self) -> usize { + self.degree_q + 1 + } + + #[inline] + fn h_i_lut(&self, s_0_int: i32, i: usize) -> E { + let ci = self.compact_index(s_0_int); + self.small_s_lut[ci * self.num_rows() + i] + } + + #[inline] + fn compact_coeffs_lut(&self, s_0_int: i32, s_1_int: i32) -> Option<&[CompactCoeffEntry]> { + let lut = self.compact_coeff_lut.as_ref()?; + let num_rows = self.num_rows(); + let pair_idx = self.compact_index(s_0_int) * self.num_valid_s + self.compact_index(s_1_int); + let start = pair_idx * num_rows; + Some(&lut[start..start + num_rows]) + } +} + +#[inline] +pub(crate) fn range_check_eval_from_s(s: E, b: usize) -> E { + let half = (b / 2) as i64; + let mut acc = E::one(); + for k in 0..half { + acc = acc * (s - E::from_i64(k * (k + 1))); + } + acc +} + +#[inline] +fn accumulate_compact_coeff_slot( + pos_accum: &mut [E::MulU64Accum], + neg_accum: &mut [E::MulU64Accum], + slot: usize, + e_in: E, + coeff: &CompactCoeffEntry, +) { + if coeff.abs_coeff == 0 { + return; + } + let prod = e_in.mul_u64_unreduced(coeff.abs_coeff); + if coeff.is_neg { + neg_accum[slot] += prod; + } else { + pos_accum[slot] += prod; + } +} + +#[inline] +fn accumulate_compact_coeffs( + pos_accum: &mut [E::MulU64Accum], + neg_accum: &mut [E::MulU64Accum], + e_in: E, + coeffs: &[CompactCoeffEntry], + skip_linear_coeff: bool, +) { + debug_assert_eq!(pos_accum.len(), neg_accum.len()); + debug_assert!( + (!skip_linear_coeff && pos_accum.len() >= coeffs.len()) + || (skip_linear_coeff && pos_accum.len() + 1 >= coeffs.len()) + ); + if skip_linear_coeff { + if let Some(coeff) = coeffs.first() { + accumulate_compact_coeff_slot(pos_accum, neg_accum, 0, e_in, coeff); + } + for (coeff_idx, coeff) in coeffs.iter().enumerate().skip(2) { + accumulate_compact_coeff_slot(pos_accum, neg_accum, coeff_idx - 1, e_in, coeff); + } + return; + } + + for (idx, coeff) in coeffs.iter().enumerate().take(pos_accum.len()) { + accumulate_compact_coeff_slot(pos_accum, neg_accum, idx, e_in, coeff); + } +} + +#[inline] +fn reduce_small_coeff_accum( + pos: E::MulU64Accum, + neg: E::MulU64Accum, +) -> E { + E::reduce_mul_u64_accum(pos) - E::reduce_mul_u64_accum(neg) +} + +#[inline] +fn accumulate_dense_entry_coeffs( + accum: &mut [E::ProductAccum], + entry_coeffs: &[E], + e_in: E, + skip_linear_coeff: bool, +) { + if accum.is_empty() { + return; + } + + accum[0] += e_in.mul_to_product_accum(entry_coeffs[0]); + if skip_linear_coeff { + for (acc, &entry) in accum.iter_mut().skip(1).zip(entry_coeffs.iter().skip(2)) { + *acc += e_in.mul_to_product_accum(entry); + } + } else { + for (acc, &entry) in accum.iter_mut().skip(1).zip(entry_coeffs.iter().skip(1)) { + *acc += e_in.mul_to_product_accum(entry); + } + } +} + +#[inline] +fn finish_gruen_round_poly_from_q_coeffs( + split_eq: &GruenSplitEq, + mut q_coeffs: Vec, + previous_claim: E, + skip_linear_coeff: bool, +) -> UniPoly { + trim_trailing_zeros(&mut q_coeffs); + if skip_linear_coeff { + split_eq + .try_gruen_poly_from_coeffs_except_linear(&q_coeffs, previous_claim) + .expect("split-eq linear-term recovery should succeed") + } else { + let q_poly = UniPoly::from_coeffs(q_coeffs); + split_eq.gruen_mul(&q_poly) + } +} + +#[inline] +fn compute_entry_coeffs_from_s( + out: &mut [E], + s_pows: &mut [E], + precomp: &RangeAffineFromSPrecomp, + s_0: E, + a: E, +) { + let deg = precomp.degree_q; + let num_rows = precomp.num_rows(); + debug_assert!(out.len() >= num_rows); + debug_assert!(s_pows.len() > deg); + + s_pows[0] = E::one(); + for k in 1..=deg { + s_pows[k] = s_pows[k - 1] * s_0; + } + + let mut a_pow = E::one(); + for (i, out_i) in out.iter_mut().enumerate().take(num_rows) { + let entries = precomp.sparse_row(i); + let mut pos = E::MulU64Accum::ZERO; + let mut neg = E::MulU64Accum::ZERO; + for entry in entries { + let prod = s_pows[entry.k as usize].mul_u64_unreduced(entry.abs_coeff); + if entry.is_neg { + neg += prod; + } else { + pos += prod; + } + } + let h_i = E::reduce_mul_u64_accum(pos) - E::reduce_mul_u64_accum(neg); + *out_i = a_pow * h_i; + a_pow = a_pow * a; + } +} + +#[inline] +fn compute_entry_coeffs_from_s_x4( + out: &mut [[E; MAX_AFFINE_COEFFS]; 4], + precomp: &RangeAffineFromSPrecomp, + s_0: [E; 4], + a: [E; 4], +) { + let deg = precomp.degree_q; + let num_rows = precomp.num_rows(); + + let mut pw = [[E::zero(); MAX_AFFINE_COEFFS]; 4]; + for p in &mut pw { + p[0] = E::one(); + } + for k in 1..=deg { + pw[0][k] = pw[0][k - 1] * s_0[0]; + pw[1][k] = pw[1][k - 1] * s_0[1]; + pw[2][k] = pw[2][k - 1] * s_0[2]; + pw[3][k] = pw[3][k - 1] * s_0[3]; + } + + let mut ap = [E::one(); 4]; + for i in 0..num_rows { + let entries = precomp.sparse_row(i); + + let mut pos0 = E::MulU64Accum::ZERO; + let mut neg0 = E::MulU64Accum::ZERO; + let mut pos1 = E::MulU64Accum::ZERO; + let mut neg1 = E::MulU64Accum::ZERO; + let mut pos2 = E::MulU64Accum::ZERO; + let mut neg2 = E::MulU64Accum::ZERO; + let mut pos3 = E::MulU64Accum::ZERO; + let mut neg3 = E::MulU64Accum::ZERO; + + for entry in entries { + let k = entry.k as usize; + let c = entry.abs_coeff; + let p0 = pw[0][k].mul_u64_unreduced(c); + let p1 = pw[1][k].mul_u64_unreduced(c); + let p2 = pw[2][k].mul_u64_unreduced(c); + let p3 = pw[3][k].mul_u64_unreduced(c); + if entry.is_neg { + neg0 += p0; + neg1 += p1; + neg2 += p2; + neg3 += p3; + } else { + pos0 += p0; + pos1 += p1; + pos2 += p2; + pos3 += p3; + } + } + + let h0 = E::reduce_mul_u64_accum(pos0) - E::reduce_mul_u64_accum(neg0); + let h1 = E::reduce_mul_u64_accum(pos1) - E::reduce_mul_u64_accum(neg1); + let h2 = E::reduce_mul_u64_accum(pos2) - E::reduce_mul_u64_accum(neg2); + let h3 = E::reduce_mul_u64_accum(pos3) - E::reduce_mul_u64_accum(neg3); + + out[0][i] = ap[0] * h0; + out[1][i] = ap[1] * h1; + out[2][i] = ap[2] * h2; + out[3][i] = ap[3] * h3; + + ap[0] = ap[0] * a[0]; + ap[1] = ap[1] * a[1]; + ap[2] = ap[2] * a[2]; + ap[3] = ap[3] * a[3]; + } +} + +fn compute_norm_round_poly_from_s( + split_eq: &GruenSplitEq, + range_precomp: &RangeAffineFromSPrecomp, + previous_claim: E, + s_pair: impl Fn(usize) -> (E, E) + Sync, +) -> UniPoly { + let (e_first, e_second) = split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let rp = range_precomp; + let full_num_coeffs_q = rp.degree_q + 1; + let skip_linear_coeff = split_eq.can_recover_linear_q_term_from_claim(); + let num_coeffs_q = full_num_coeffs_q - usize::from(skip_linear_coeff); + + let q_coeffs = cfg_fold_reduce!( + 0..e_second.len(), + || vec![E::ProductAccum::ZERO; num_coeffs_q], + |mut outer_accum, j_high| { + debug_assert!(full_num_coeffs_q <= MAX_AFFINE_COEFFS); + let mut inner_accum = [E::ProductAccum::ZERO; MAX_AFFINE_COEFFS]; + let base_j = j_high * num_first; + let full_chunks = e_first.len() / 4; + let mut batch_out = [[E::zero(); MAX_AFFINE_COEFFS]; 4]; + + for chunk in 0..full_chunks { + let jl = chunk * 4; + let pairs = [ + s_pair(base_j + jl), + s_pair(base_j + jl + 1), + s_pair(base_j + jl + 2), + s_pair(base_j + jl + 3), + ]; + compute_entry_coeffs_from_s_x4( + &mut batch_out, + rp, + [pairs[0].0, pairs[1].0, pairs[2].0, pairs[3].0], + [ + pairs[0].1 - pairs[0].0, + pairs[1].1 - pairs[1].0, + pairs[2].1 - pairs[2].0, + pairs[3].1 - pairs[3].0, + ], + ); + for (b_idx, bo) in batch_out.iter().enumerate() { + let e_in = e_first[jl + b_idx]; + accumulate_dense_entry_coeffs( + &mut inner_accum[..num_coeffs_q], + &bo[..full_num_coeffs_q], + e_in, + skip_linear_coeff, + ); + } + } + + let mut entry_buf = [E::zero(); MAX_AFFINE_COEFFS]; + let mut s_pows_buf = [E::zero(); MAX_AFFINE_COEFFS]; + for (tail_idx, &e_in) in e_first[full_chunks * 4..].iter().enumerate() { + let j = base_j + full_chunks * 4 + tail_idx; + let (s_0, s_1) = s_pair(j); + compute_entry_coeffs_from_s(&mut entry_buf, &mut s_pows_buf, rp, s_0, s_1 - s_0); + accumulate_dense_entry_coeffs( + &mut inner_accum[..num_coeffs_q], + &entry_buf[..full_num_coeffs_q], + e_in, + skip_linear_coeff, + ); + } + + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = E::reduce_product_accum(inner_accum[k]); + outer_accum[k] += e_out.mul_to_product_accum(inner_reduced); + } + outer_accum + }, + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai += *bi; + } + a + } + ) + .into_iter() + .map(E::reduce_product_accum) + .collect::>(); + + finish_gruen_round_poly_from_q_coeffs(split_eq, q_coeffs, previous_claim, skip_linear_coeff) +} + +fn compute_norm_round_poly_from_s_compact< + E: FieldCore + FromSmallInt + CanonicalField + HasUnreducedOps, +>( + split_eq: &GruenSplitEq, + s_compact: &[i32], + range_precomp: &RangeAffineFromSPrecomp, + previous_claim: E, +) -> UniPoly { + let (e_first, e_second) = split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + + let rp = range_precomp; + let full_num_coeffs_q = rp.degree_q + 1; + let skip_linear_coeff = split_eq.can_recover_linear_q_term_from_claim(); + let num_coeffs_q = full_num_coeffs_q - usize::from(skip_linear_coeff); + + let q_coeffs = if rp.compact_coeffs_lut(0, 0).is_some() { + cfg_fold_reduce!( + 0..e_second.len(), + || vec![E::ProductAccum::ZERO; num_coeffs_q], + |mut outer_accum, j_high| { + debug_assert!(full_num_coeffs_q <= MAX_AFFINE_COEFFS); + let mut inner_pos = [E::MulU64Accum::ZERO; MAX_AFFINE_COEFFS]; + let mut inner_neg = [E::MulU64Accum::ZERO; MAX_AFFINE_COEFFS]; + for (j_low, &e_in) in e_first.iter().enumerate() { + let j = j_high * num_first + j_low; + let s_0_int = s_compact[2 * j]; + let s_1_int = s_compact[2 * j + 1]; + let coeffs = rp + .compact_coeffs_lut(s_0_int, s_1_int) + .expect("missing compact coefficient LUT"); + accumulate_compact_coeffs( + &mut inner_pos[..num_coeffs_q], + &mut inner_neg[..num_coeffs_q], + e_in, + coeffs, + skip_linear_coeff, + ); + } + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = reduce_small_coeff_accum(inner_pos[k], inner_neg[k]); + outer_accum[k] += e_out.mul_to_product_accum(inner_reduced); + } + outer_accum + }, + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai += *bi; + } + a + } + ) + .into_iter() + .map(E::reduce_product_accum) + .collect::>() + } else { + cfg_fold_reduce!( + 0..e_second.len(), + || vec![E::ProductAccum::ZERO; num_coeffs_q], + |mut outer_accum, j_high| { + debug_assert!(full_num_coeffs_q <= MAX_AFFINE_COEFFS); + let mut inner_accum = [E::ProductAccum::ZERO; MAX_AFFINE_COEFFS]; + for (j_low, &e_in) in e_first.iter().enumerate() { + let j = j_high * num_first + j_low; + let s_0_int = s_compact[2 * j]; + let s_1 = E::from_i64(s_compact[2 * j + 1] as i64); + let a = s_1 - E::from_i64(s_0_int as i64); + let mut a_pow = E::one(); + for coeff_idx in 0..full_num_coeffs_q { + let h_i_s0 = rp.h_i_lut(s_0_int, coeff_idx); + if !(skip_linear_coeff && coeff_idx == 1) { + let out_idx = if skip_linear_coeff && coeff_idx > 1 { + coeff_idx - 1 + } else { + coeff_idx + }; + let val = a_pow * h_i_s0; + inner_accum[out_idx] += e_in.mul_to_product_accum(val); + } + a_pow = a_pow * a; + } + } + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = E::reduce_product_accum(inner_accum[k]); + outer_accum[k] += e_out.mul_to_product_accum(inner_reduced); + } + outer_accum + }, + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai += *bi; + } + a + } + ) + .into_iter() + .map(E::reduce_product_accum) + .collect::>() + }; + + finish_gruen_round_poly_from_q_coeffs(split_eq, q_coeffs, previous_claim, skip_linear_coeff) +} + +enum STable { + Compact(Vec), + Full(Vec), +} + +struct Stage1TwoRoundPrefix { + skip_state: Stage1BivariateSkipState, + first_challenge: Option, +} + +/// Stage-1 norm sumcheck prover over the virtual table `S(x) = w(x)(w(x)+1)`. +pub struct HachiStage1Prover { + s_table: STable, + split_eq: GruenSplitEq, + range_precomp: RangeAffineFromSPrecomp, + live_x_cols: usize, + num_u: usize, + num_vars: usize, + b: usize, + prefix_tau: Option>, + two_round_prefix: Option>, + cached_round_poly: Option>, + pending_round_poly: Option>, + prefix_time_total: f64, + dense_time_total: f64, + fold_time_total: f64, + rounds_completed: usize, +} + +impl HachiStage1Prover { + /// Build the stage-1 prover from the compact witness table. + #[tracing::instrument(skip_all, name = "HachiStage1Prover::new")] + pub fn new( + w_evals_compact: &[i8], + tau0: &[E], + b: usize, + live_x_cols: usize, + num_u: usize, + num_l: usize, + ) -> Self { + assert!(b >= 2, "b must be at least 2"); + let num_vars = num_u + num_l; + let y_len = 1usize << num_l; + assert_eq!(w_evals_compact.len(), live_x_cols * y_len); + assert_eq!(tau0.len(), num_vars); + + let s_table = w_evals_compact + .iter() + .map(|&w| { + let w = w as i32; + w * (w + 1) + }) + .collect(); + + Self { + s_table: STable::Compact(s_table), + split_eq: GruenSplitEq::new(tau0), + range_precomp: RangeAffineFromSPrecomp::new(b), + live_x_cols, + num_u, + num_vars, + b, + prefix_tau: can_use_stage1_two_round_prefix(num_u, b).then(|| tau0.to_vec()), + two_round_prefix: None, + cached_round_poly: None, + pending_round_poly: None, + prefix_time_total: 0.0, + dense_time_total: 0.0, + fold_time_total: 0.0, + rounds_completed: 0, + } + } + + /// Return the fully folded virtual-polynomial claim `S(r_stage1)`. + /// + /// # Panics + /// + /// Panics if called before the virtual table has been fully folded to a + /// single field element. + pub fn final_s_claim(&self) -> E { + match &self.s_table { + STable::Full(s_full) => { + assert_eq!(s_full.len(), 1, "s_table not fully folded"); + s_full[0] + } + STable::Compact(_) => panic!("s_table remained compact after final fold"), + } + } + + #[inline] + fn current_x_width(&self) -> usize { + self.num_u.saturating_sub(self.rounds_completed) + } + + #[inline] + fn current_x_len(&self) -> usize { + 1usize << self.current_x_width() + } + + #[inline] + fn use_prefix_x_round(&self) -> bool { + self.rounds_completed < self.num_u && self.live_x_cols < self.current_x_len() + } + + #[inline] + fn next_use_prefix_x_round_after_current(&self) -> bool { + self.rounds_completed + 1 < self.num_u + && self.live_x_cols.div_ceil(2) < (self.current_x_len() / 2) + } + + #[inline] + pub(crate) fn can_use_two_round_prefix(&self) -> bool { + self.prefix_tau.is_some() + } + + #[inline] + fn using_two_round_prefix(&self) -> bool { + self.rounds_completed < 2 && self.can_use_two_round_prefix() + } + + #[inline] + fn compact_s_values(b: usize) -> Vec { + let half = (b / 2) as i32; + (0..half).map(|k| k * (k + 1)).collect() + } + + #[inline] + fn build_compact_s_fold_lut(b: usize, r: E) -> CompactPairFoldLut { + let valid_s = Self::compact_s_values(b); + CompactPairFoldLut::from_allowed_values(&valid_s, r) + } + + fn ensure_two_round_prefix(&mut self) -> &mut Stage1TwoRoundPrefix { + if self.two_round_prefix.is_none() { + let tau0 = self + .prefix_tau + .clone() + .expect("two-round prefix requested without cached tau"); + let num_l = self.num_vars - self.num_u; + let s_compact = match &self.s_table { + STable::Compact(s_compact) => s_compact, + STable::Full(_) => panic!("two-round prefix can only build from compact table"), + }; + let proof = build_stage1_bivariate_skip_proof_from_s_compact( + s_compact, + &tau0, + self.b, + self.live_x_cols, + self.num_u, + num_l, + ) + .expect("two-round prefix should be available"); + let skip_state = Stage1BivariateSkipState::new(&proof, &tau0, self.b) + .expect("valid bivariate-skip state"); + self.two_round_prefix = Some(Stage1TwoRoundPrefix { + skip_state, + first_challenge: None, + }); + } + self.two_round_prefix + .as_mut() + .expect("two-round prefix should be initialized") + } + + #[inline] + fn direct_fold_s_quad_to_round2(s00: i32, s10: i32, s01: i32, s11: i32, r0: E, r1: E) -> E { + let s00 = E::from_i64(s00 as i64); + let s10 = E::from_i64(s10 as i64); + let s01 = E::from_i64(s01 as i64); + let s11 = E::from_i64(s11 as i64); + let x0 = s00 + r0 * (s10 - s00); + let x1 = s01 + r0 * (s11 - s01); + x0 + r1 * (x1 - x0) + } + + #[inline] + fn stage1_b8_s_digit_from_compact_s(s: i32) -> usize { + match s { + 0 => 0, + 2 => 1, + 6 => 2, + 12 => 3, + other => unreachable!("unexpected compact s value {other}"), + } + } + + #[inline] + fn stage1_b8_quad_lookup_index_from_row(row: &[i32], base: usize) -> usize { + let d0 = row + .get(base) + .copied() + .map(Self::stage1_b8_s_digit_from_compact_s) + .unwrap_or(0); + let d1 = row + .get(base + 1) + .copied() + .map(Self::stage1_b8_s_digit_from_compact_s) + .unwrap_or(0); + let d2 = row + .get(base + 2) + .copied() + .map(Self::stage1_b8_s_digit_from_compact_s) + .unwrap_or(0); + let d3 = row + .get(base + 3) + .copied() + .map(Self::stage1_b8_s_digit_from_compact_s) + .unwrap_or(0); + d0 | (d1 << 2) | (d2 << 4) | (d3 << 6) + } + + fn build_round2_s_lookup(r0: E, r1: E) -> Vec { + const S_VALUES: [i32; 4] = [0, 2, 6, 12]; + (0..256usize) + .map(|idx| { + let d0 = idx & 0b11; + let d1 = (idx >> 2) & 0b11; + let d2 = (idx >> 4) & 0b11; + let d3 = (idx >> 6) & 0b11; + Self::direct_fold_s_quad_to_round2( + S_VALUES[d0], + S_VALUES[d1], + S_VALUES[d2], + S_VALUES[d3], + r0, + r1, + ) + }) + .collect() + } + + #[tracing::instrument(skip_all, name = "HachiStage1Prover::fold_s_compact_to_round2")] + fn fold_s_compact_to_round2( + s_compact: &[i32], + live_x_cols: usize, + y_len: usize, + r0: E, + r1: E, + ) -> Vec { + let next_live_x_cols = live_x_cols.div_ceil(4); + let mut out = vec![E::zero(); y_len * next_live_x_cols]; + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + let row = &s_compact[y * live_x_cols..(y + 1) * live_x_cols]; + for (quad_x, dst) in row_out.iter_mut().enumerate() { + let base = 4 * quad_x; + *dst = Self::direct_fold_s_quad_to_round2( + row.get(base).copied().unwrap_or_default(), + row.get(base + 1).copied().unwrap_or_default(), + row.get(base + 2).copied().unwrap_or_default(), + row.get(base + 3).copied().unwrap_or_default(), + r0, + r1, + ); + } + } + out + } + + #[tracing::instrument( + skip_all, + name = "HachiStage1Prover::fuse_compact_to_round2_and_compute_round" + )] + fn fuse_compact_to_round2_and_compute_round( + &self, + s_compact: &[i32], + previous_claim: E, + r0: E, + r1: E, + ) -> (Vec, UniPoly) { + debug_assert!(self.num_u > 2); + let old_live_x_cols = self.live_x_cols; + let next_live_x_cols = old_live_x_cols.div_ceil(4); + let y_len = s_compact.len() / old_live_x_cols; + let live_pairs = next_live_x_cols.div_ceil(2); + let current_x_half = 1usize << (self.num_u - 3); + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + let block_size = num_first.min(live_pairs); + let quad_fold_lut = Self::build_round2_s_lookup(r0, r1); + + let range_pc = &self.range_precomp; + let full_num_coeffs_q = range_pc.degree_q + 1; + let skip_linear_coeff = self.split_eq.can_recover_linear_q_term_from_claim(); + let num_coeffs_q = full_num_coeffs_q - usize::from(skip_linear_coeff); + let mut out = vec![E::zero(); y_len * next_live_x_cols]; + + #[cfg(feature = "parallel")] + let q_coeffs = out + .par_chunks_mut(next_live_x_cols) + .enumerate() + .map(|(y, row_out)| { + let row = &s_compact[y * old_live_x_cols..(y + 1) * old_live_x_cols]; + let j_base = y * current_x_half; + let mut outer_accum = vec![E::ProductAccum::ZERO; num_coeffs_q]; + let mut entry_buf = [E::zero(); MAX_AFFINE_COEFFS]; + let mut s_pows_buf = [E::zero(); MAX_AFFINE_COEFFS]; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_accum = [E::ProductAccum::ZERO; MAX_AFFINE_COEFFS]; + + for pair_x in blk..blk_end { + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + let left_quad = 2 * pair_x; + let left_base = 8 * pair_x; + let s0 = quad_fold_lut + [Self::stage1_b8_quad_lookup_index_from_row(row, left_base)]; + row_out[left_quad] = s0; + let s1 = if left_quad + 1 < next_live_x_cols { + let s1 = quad_fold_lut + [Self::stage1_b8_quad_lookup_index_from_row(row, left_base + 4)]; + row_out[left_quad + 1] = s1; + s1 + } else { + E::zero() + }; + compute_entry_coeffs_from_s( + &mut entry_buf, + &mut s_pows_buf, + range_pc, + s0, + s1 - s0, + ); + accumulate_dense_entry_coeffs( + &mut inner_accum[..num_coeffs_q], + &entry_buf[..full_num_coeffs_q], + e_in, + skip_linear_coeff, + ); + } + + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = E::reduce_product_accum(inner_accum[k]); + outer_accum[k] += e_out.mul_to_product_accum(inner_reduced); + } + blk = blk_end; + } + outer_accum + }) + .reduce( + || vec![E::ProductAccum::ZERO; num_coeffs_q], + |mut a, b| { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + *ai += *bi; + } + a + }, + ) + .into_iter() + .map(E::reduce_product_accum) + .collect::>(); + + #[cfg(not(feature = "parallel"))] + let q_coeffs = { + let mut outer = vec![E::ProductAccum::ZERO; num_coeffs_q]; + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + let row = &s_compact[y * old_live_x_cols..(y + 1) * old_live_x_cols]; + let j_base = y * current_x_half; + let mut entry_buf = [E::zero(); MAX_AFFINE_COEFFS]; + let mut s_pows_buf = [E::zero(); MAX_AFFINE_COEFFS]; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_accum = [E::ProductAccum::ZERO; MAX_AFFINE_COEFFS]; + + for pair_x in blk..blk_end { + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + let left_quad = 2 * pair_x; + let left_base = 8 * pair_x; + let s0 = quad_fold_lut + [Self::stage1_b8_quad_lookup_index_from_row(row, left_base)]; + row_out[left_quad] = s0; + let s1 = if left_quad + 1 < next_live_x_cols { + let s1 = quad_fold_lut + [Self::stage1_b8_quad_lookup_index_from_row(row, left_base + 4)]; + row_out[left_quad + 1] = s1; + s1 + } else { + E::zero() + }; + compute_entry_coeffs_from_s( + &mut entry_buf, + &mut s_pows_buf, + range_pc, + s0, + s1 - s0, + ); + accumulate_dense_entry_coeffs( + &mut inner_accum[..num_coeffs_q], + &entry_buf[..full_num_coeffs_q], + e_in, + skip_linear_coeff, + ); + } + + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = E::reduce_product_accum(inner_accum[k]); + outer[k] += e_out.mul_to_product_accum(inner_reduced); + } + blk = blk_end; + } + } + outer + .into_iter() + .map(E::reduce_product_accum) + .collect::>() + }; + + let poly = finish_gruen_round_poly_from_q_coeffs( + &self.split_eq, + q_coeffs, + previous_claim, + skip_linear_coeff, + ); + (out, poly) + } + + #[inline] + fn fold_full_prefix_pair(row: &[E], left: usize, r: E) -> E { + let s_0 = row.get(left).copied().unwrap_or_else(E::zero); + let s_1 = row.get(left + 1).copied().unwrap_or_else(E::zero); + s_0 + r * (s_1 - s_0) + } + + #[tracing::instrument( + skip_all, + name = "HachiStage1Prover::fuse_full_prefix_x_and_compute_round" + )] + fn fuse_full_prefix_x_and_compute_round( + &self, + s_full: &[E], + previous_claim: E, + r: E, + ) -> (Vec, UniPoly) { + debug_assert!(self.next_use_prefix_x_round_after_current()); + debug_assert!(self.current_x_width() >= 2); + + let old_live_x_cols = self.live_x_cols; + let next_live_x_cols = old_live_x_cols.div_ceil(2); + let y_len = s_full.len() / old_live_x_cols; + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + let next_current_x_half = 1usize << (self.current_x_width() - 2); + let live_pairs = next_live_x_cols.div_ceil(2); + let block_size = num_first.min(live_pairs); + + let range_pc = &self.range_precomp; + let full_num_coeffs_q = range_pc.degree_q + 1; + let skip_linear_coeff = self.split_eq.can_recover_linear_q_term_from_claim(); + let num_coeffs_q = full_num_coeffs_q - usize::from(skip_linear_coeff); + let mut out = vec![E::zero(); y_len * next_live_x_cols]; + + #[cfg(feature = "parallel")] + let q_coeffs = out + .par_chunks_mut(next_live_x_cols) + .enumerate() + .map(|(y, row_out)| { + debug_assert!(full_num_coeffs_q <= MAX_AFFINE_COEFFS); + let row = &s_full[y * old_live_x_cols..(y + 1) * old_live_x_cols]; + let j_base = y * next_current_x_half; + let mut outer_accum = vec![E::ProductAccum::ZERO; num_coeffs_q]; + let mut batch_out = [[E::zero(); MAX_AFFINE_COEFFS]; 4]; + let mut entry_buf = [E::zero(); MAX_AFFINE_COEFFS]; + let mut s_pows_buf = [E::zero(); MAX_AFFINE_COEFFS]; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_accum = [E::ProductAccum::ZERO; MAX_AFFINE_COEFFS]; + let blk_len = blk_end - blk; + let full_chunks = blk_len / 4; + + for chunk in 0..full_chunks { + let pair_base = blk + chunk * 4; + let mut pairs = [(E::zero(), E::zero()); 4]; + for (slot, pair_x) in (pair_base..pair_base + 4).enumerate() { + let left_next = 2 * pair_x; + let left_old = 4 * pair_x; + let s0 = Self::fold_full_prefix_pair(row, left_old, r); + row_out[left_next] = s0; + let s1 = if left_next + 1 < next_live_x_cols { + let s1 = Self::fold_full_prefix_pair(row, left_old + 2, r); + row_out[left_next + 1] = s1; + s1 + } else { + E::zero() + }; + pairs[slot] = (s0, s1); + } + + compute_entry_coeffs_from_s_x4( + &mut batch_out, + range_pc, + [pairs[0].0, pairs[1].0, pairs[2].0, pairs[3].0], + [ + pairs[0].1 - pairs[0].0, + pairs[1].1 - pairs[1].0, + pairs[2].1 - pairs[2].0, + pairs[3].1 - pairs[3].0, + ], + ); + + for (slot, _) in pairs.iter().enumerate() { + let pair_x = pair_base + slot; + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + accumulate_dense_entry_coeffs( + &mut inner_accum[..num_coeffs_q], + &batch_out[slot][..full_num_coeffs_q], + e_in, + skip_linear_coeff, + ); + } + } + + for pair_x in blk + full_chunks * 4..blk_end { + let left_next = 2 * pair_x; + let left_old = 4 * pair_x; + let s_0 = Self::fold_full_prefix_pair(row, left_old, r); + row_out[left_next] = s_0; + let s_1 = if left_next + 1 < next_live_x_cols { + let s_1 = Self::fold_full_prefix_pair(row, left_old + 2, r); + row_out[left_next + 1] = s_1; + s_1 + } else { + E::zero() + }; + compute_entry_coeffs_from_s( + &mut entry_buf, + &mut s_pows_buf, + range_pc, + s_0, + s_1 - s_0, + ); + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + accumulate_dense_entry_coeffs( + &mut inner_accum[..num_coeffs_q], + &entry_buf[..full_num_coeffs_q], + e_in, + skip_linear_coeff, + ); + } + + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = E::reduce_product_accum(inner_accum[k]); + outer_accum[k] += e_out.mul_to_product_accum(inner_reduced); + } + blk = blk_end; + } + + outer_accum + }) + .reduce( + || vec![E::ProductAccum::ZERO; num_coeffs_q], + |mut a, b| { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + *ai += *bi; + } + a + }, + ) + .into_iter() + .map(E::reduce_product_accum) + .collect::>(); + + #[cfg(not(feature = "parallel"))] + let q_coeffs = { + let mut outer_accum = vec![E::ProductAccum::ZERO; num_coeffs_q]; + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + debug_assert!(full_num_coeffs_q <= MAX_AFFINE_COEFFS); + let row = &s_full[y * old_live_x_cols..(y + 1) * old_live_x_cols]; + let j_base = y * next_current_x_half; + let mut batch_out = [[E::zero(); MAX_AFFINE_COEFFS]; 4]; + let mut entry_buf = [E::zero(); MAX_AFFINE_COEFFS]; + let mut s_pows_buf = [E::zero(); MAX_AFFINE_COEFFS]; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_accum = [E::ProductAccum::ZERO; MAX_AFFINE_COEFFS]; + let blk_len = blk_end - blk; + let full_chunks = blk_len / 4; + + for chunk in 0..full_chunks { + let pair_base = blk + chunk * 4; + let mut pairs = [(E::zero(), E::zero()); 4]; + for (slot, pair_x) in (pair_base..pair_base + 4).enumerate() { + let left_next = 2 * pair_x; + let left_old = 4 * pair_x; + let s0 = Self::fold_full_prefix_pair(row, left_old, r); + row_out[left_next] = s0; + let s1 = if left_next + 1 < next_live_x_cols { + let s1 = Self::fold_full_prefix_pair(row, left_old + 2, r); + row_out[left_next + 1] = s1; + s1 + } else { + E::zero() + }; + pairs[slot] = (s0, s1); + } + + compute_entry_coeffs_from_s_x4( + &mut batch_out, + range_pc, + [pairs[0].0, pairs[1].0, pairs[2].0, pairs[3].0], + [ + pairs[0].1 - pairs[0].0, + pairs[1].1 - pairs[1].0, + pairs[2].1 - pairs[2].0, + pairs[3].1 - pairs[3].0, + ], + ); + + for (slot, _) in pairs.iter().enumerate() { + let pair_x = pair_base + slot; + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + accumulate_dense_entry_coeffs( + &mut inner_accum[..num_coeffs_q], + &batch_out[slot][..full_num_coeffs_q], + e_in, + skip_linear_coeff, + ); + } + } + + for pair_x in blk + full_chunks * 4..blk_end { + let left_next = 2 * pair_x; + let left_old = 4 * pair_x; + let s_0 = Self::fold_full_prefix_pair(row, left_old, r); + row_out[left_next] = s_0; + let s_1 = if left_next + 1 < next_live_x_cols { + let s_1 = Self::fold_full_prefix_pair(row, left_old + 2, r); + row_out[left_next + 1] = s_1; + s_1 + } else { + E::zero() + }; + compute_entry_coeffs_from_s( + &mut entry_buf, + &mut s_pows_buf, + range_pc, + s_0, + s_1 - s_0, + ); + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + accumulate_dense_entry_coeffs( + &mut inner_accum[..num_coeffs_q], + &entry_buf[..full_num_coeffs_q], + e_in, + skip_linear_coeff, + ); + } + + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = E::reduce_product_accum(inner_accum[k]); + outer_accum[k] += e_out.mul_to_product_accum(inner_reduced); + } + blk = blk_end; + } + } + + outer_accum + .into_iter() + .map(E::reduce_product_accum) + .collect::>() + }; + + let poly = finish_gruen_round_poly_from_q_coeffs( + &self.split_eq, + q_coeffs, + previous_claim, + skip_linear_coeff, + ); + (out, poly) + } + + fn compute_current_round_poly_from_state(&mut self, previous_claim: E) -> UniPoly { + let use_two_round_prefix = self.using_two_round_prefix(); + let use_prefix_x_round = !use_two_round_prefix && self.use_prefix_x_round(); + let t_round = Instant::now(); + let rounds_completed = self.rounds_completed; + let poly = if use_two_round_prefix { + let prefix = self.ensure_two_round_prefix(); + if rounds_completed == 0 { + prefix.skip_state.reconstruct_round0_poly() + } else { + let r0 = prefix + .first_challenge + .expect("round 1 prefix polynomial requested before ingesting round 0"); + prefix.skip_state.reconstruct_round1_poly(r0) + } + } else if self.split_eq.current_scalar().is_zero() { + UniPoly::from_coeffs(vec![E::zero()]) + } else { + match &self.s_table { + STable::Compact(s_compact) => { + if use_prefix_x_round { + self.compute_round_compact_prefix_x(s_compact, previous_claim) + } else { + compute_norm_round_poly_from_s_compact( + &self.split_eq, + s_compact, + &self.range_precomp, + previous_claim, + ) + } + } + STable::Full(s_full) => { + if use_prefix_x_round { + self.compute_round_full_prefix_x(s_full, previous_claim) + } else { + compute_norm_round_poly_from_s( + &self.split_eq, + &self.range_precomp, + previous_claim, + |j| (s_full[2 * j], s_full[2 * j + 1]), + ) + } + } + } + }; + + if use_two_round_prefix || use_prefix_x_round { + self.prefix_time_total += t_round.elapsed().as_secs_f64(); + } else { + self.dense_time_total += t_round.elapsed().as_secs_f64(); + } + + poly + } + + #[tracing::instrument(skip_all, name = "HachiStage1Prover::compute_round_compact_prefix_x")] + fn compute_round_compact_prefix_x(&self, s_compact: &[i32], previous_claim: E) -> UniPoly { + debug_assert!(self.rounds_completed < self.num_u); + debug_assert_eq!( + s_compact.len(), + self.live_x_cols * (1usize << (self.num_vars - self.num_u)) + ); + + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + let current_x_half = 1usize << (self.current_x_width() - 1); + let live_pairs = self.live_x_cols.div_ceil(2); + let block_size = num_first.min(live_pairs); + + let rp = &self.range_precomp; + let full_num_coeffs_q = rp.degree_q + 1; + let skip_linear_coeff = self.split_eq.can_recover_linear_q_term_from_claim(); + let num_coeffs_q = full_num_coeffs_q - usize::from(skip_linear_coeff); + let q_coeffs = if rp.compact_coeffs_lut(0, 0).is_some() { + cfg_fold_reduce!( + 0..(1usize << (self.num_vars - self.num_u)), + || vec![E::ProductAccum::ZERO; num_coeffs_q], + |mut outer_accum, y| { + let row_start = y * self.live_x_cols; + let row = &s_compact[row_start..row_start + self.live_x_cols]; + let j_base = y * current_x_half; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_pos = [E::MulU64Accum::ZERO; MAX_AFFINE_COEFFS]; + let mut inner_neg = [E::MulU64Accum::ZERO; MAX_AFFINE_COEFFS]; + + for pair_x in blk..blk_end { + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + let left = 2 * pair_x; + let s0_i = row[left]; + let s1_i = if left + 1 < self.live_x_cols { + row[left + 1] + } else { + 0 + }; + let coeffs = rp + .compact_coeffs_lut(s0_i, s1_i) + .expect("missing compact coefficient LUT"); + accumulate_compact_coeffs( + &mut inner_pos[..num_coeffs_q], + &mut inner_neg[..num_coeffs_q], + e_in, + coeffs, + skip_linear_coeff, + ); + } + + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = + reduce_small_coeff_accum(inner_pos[k], inner_neg[k]); + outer_accum[k] += e_out.mul_to_product_accum(inner_reduced); + } + blk = blk_end; + } + outer_accum + }, + |mut a, b_vec| { + for (ai, bi) in a.iter_mut().zip(b_vec.iter()) { + *ai += *bi; + } + a + } + ) + .into_iter() + .map(E::reduce_product_accum) + .collect() + } else { + cfg_fold_reduce!( + 0..(1usize << (self.num_vars - self.num_u)), + || vec![E::ProductAccum::ZERO; num_coeffs_q], + |mut outer_accum, y| { + let row_start = y * self.live_x_cols; + let row = &s_compact[row_start..row_start + self.live_x_cols]; + let j_base = y * current_x_half; + let mut entry_buf = [E::zero(); MAX_AFFINE_COEFFS]; + let mut s_pows_buf = [E::zero(); MAX_AFFINE_COEFFS]; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_accum = [E::ProductAccum::ZERO; MAX_AFFINE_COEFFS]; + + for pair_x in blk..blk_end { + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + let left = 2 * pair_x; + let s0_i = row[left]; + let s1_i = if left + 1 < self.live_x_cols { + row[left + 1] + } else { + 0 + }; + compute_entry_coeffs_from_s( + &mut entry_buf, + &mut s_pows_buf, + rp, + E::from_i64(s0_i as i64), + E::from_i64((s1_i as i64) - (s0_i as i64)), + ); + accumulate_dense_entry_coeffs( + &mut inner_accum[..num_coeffs_q], + &entry_buf[..full_num_coeffs_q], + e_in, + skip_linear_coeff, + ); + } + + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = E::reduce_product_accum(inner_accum[k]); + outer_accum[k] += e_out.mul_to_product_accum(inner_reduced); + } + blk = blk_end; + } + outer_accum + }, + |mut ca, cb| { + for (ai, bi) in ca.iter_mut().zip(cb.iter()) { + *ai += *bi; + } + ca + } + ) + .into_iter() + .map(E::reduce_product_accum) + .collect() + }; + + finish_gruen_round_poly_from_q_coeffs( + &self.split_eq, + q_coeffs, + previous_claim, + skip_linear_coeff, + ) + } + + #[tracing::instrument(skip_all, name = "HachiStage1Prover::compute_round_full_prefix_x")] + fn compute_round_full_prefix_x(&self, s_full: &[E], previous_claim: E) -> UniPoly { + debug_assert!(self.rounds_completed < self.num_u); + let y_len = s_full.len() / self.live_x_cols; + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + let current_x_half = 1usize << (self.current_x_width() - 1); + let live_pairs = self.live_x_cols.div_ceil(2); + let block_size = num_first.min(live_pairs); + + let range_pc = &self.range_precomp; + let full_num_coeffs_q = range_pc.degree_q + 1; + let skip_linear_coeff = self.split_eq.can_recover_linear_q_term_from_claim(); + let num_coeffs_q = full_num_coeffs_q - usize::from(skip_linear_coeff); + let q_coeffs = cfg_fold_reduce!( + 0..y_len, + || vec![E::ProductAccum::ZERO; num_coeffs_q], + |mut outer_accum, y| { + debug_assert!(full_num_coeffs_q <= MAX_AFFINE_COEFFS); + let row_start = y * self.live_x_cols; + let row = &s_full[row_start..row_start + self.live_x_cols]; + let j_base = y * current_x_half; + let mut batch_out = [[E::zero(); MAX_AFFINE_COEFFS]; 4]; + let mut entry_buf = [E::zero(); MAX_AFFINE_COEFFS]; + let mut s_pows_buf = [E::zero(); MAX_AFFINE_COEFFS]; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_accum = [E::ProductAccum::ZERO; MAX_AFFINE_COEFFS]; + let blk_len = blk_end - blk; + let full_chunks = blk_len / 4; + + for chunk in 0..full_chunks { + let pair_base = blk + chunk * 4; + let mut pairs = [(E::zero(), E::zero()); 4]; + for (slot, pair_x) in (pair_base..pair_base + 4).enumerate() { + let left = 2 * pair_x; + let s_0 = row[left]; + let s_1 = if left + 1 < self.live_x_cols { + row[left + 1] + } else { + E::zero() + }; + pairs[slot] = (s_0, s_1); + } + + compute_entry_coeffs_from_s_x4( + &mut batch_out, + range_pc, + [pairs[0].0, pairs[1].0, pairs[2].0, pairs[3].0], + [ + pairs[0].1 - pairs[0].0, + pairs[1].1 - pairs[1].0, + pairs[2].1 - pairs[2].0, + pairs[3].1 - pairs[3].0, + ], + ); + + for (slot, _) in pairs.iter().enumerate() { + let pair_x = pair_base + slot; + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + accumulate_dense_entry_coeffs( + &mut inner_accum[..num_coeffs_q], + &batch_out[slot][..full_num_coeffs_q], + e_in, + skip_linear_coeff, + ); + } + } + + for pair_x in blk + full_chunks * 4..blk_end { + let left = 2 * pair_x; + let s_0 = row[left]; + let s_1 = if left + 1 < self.live_x_cols { + row[left + 1] + } else { + E::zero() + }; + compute_entry_coeffs_from_s( + &mut entry_buf, + &mut s_pows_buf, + range_pc, + s_0, + s_1 - s_0, + ); + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + accumulate_dense_entry_coeffs( + &mut inner_accum[..num_coeffs_q], + &entry_buf[..full_num_coeffs_q], + e_in, + skip_linear_coeff, + ); + } + + let e_out = e_second[j_high]; + for k in 0..num_coeffs_q { + let inner_reduced = E::reduce_product_accum(inner_accum[k]); + outer_accum[k] += e_out.mul_to_product_accum(inner_reduced); + } + blk = blk_end; + } + + outer_accum + }, + |mut ca, cb| { + for (ai, bi) in ca.iter_mut().zip(cb.iter()) { + *ai += *bi; + } + ca + } + ); + + let q_coeffs: Vec = q_coeffs.into_iter().map(E::reduce_product_accum).collect(); + finish_gruen_round_poly_from_q_coeffs( + &self.split_eq, + q_coeffs, + previous_claim, + skip_linear_coeff, + ) + } + + #[tracing::instrument(skip_all, name = "HachiStage1Prover::fold_s_compact_prefix_x")] + fn fold_s_compact_prefix_x( + s_compact: &[i32], + live_x_cols: usize, + y_len: usize, + fold_lut: &CompactPairFoldLut, + ) -> Vec { + let next_live_x_cols = live_x_cols.div_ceil(2); + let mut out = vec![E::zero(); y_len * next_live_x_cols]; + + #[cfg(feature = "parallel")] + out.par_chunks_mut(next_live_x_cols) + .enumerate() + .for_each(|(y, row_out)| { + let row_start = y * live_x_cols; + let row = &s_compact[row_start..row_start + live_x_cols]; + for (pair_x, dst) in row_out.iter_mut().enumerate() { + let left = 2 * pair_x; + let s_1 = if left + 1 < live_x_cols { + row[left + 1] + } else { + 0 + }; + *dst = fold_lut.fold(row[left], s_1); + } + }); + + #[cfg(not(feature = "parallel"))] + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + let row_start = y * live_x_cols; + let row = &s_compact[row_start..row_start + live_x_cols]; + for (pair_x, dst) in row_out.iter_mut().enumerate() { + let left = 2 * pair_x; + let s_1 = if left + 1 < live_x_cols { + row[left + 1] + } else { + 0 + }; + *dst = fold_lut.fold(row[left], s_1); + } + } + + out + } + + #[tracing::instrument(skip_all, name = "HachiStage1Prover::fold_s_full_prefix_x")] + fn fold_s_full_prefix_x(s_full: &[E], live_x_cols: usize, y_len: usize, r: E) -> Vec { + let next_live_x_cols = live_x_cols.div_ceil(2); + let mut out = vec![E::zero(); y_len * next_live_x_cols]; + + #[cfg(feature = "parallel")] + out.par_chunks_mut(next_live_x_cols) + .enumerate() + .for_each(|(y, row_out)| { + let row_start = y * live_x_cols; + let row = &s_full[row_start..row_start + live_x_cols]; + for (pair_x, dst) in row_out.iter_mut().enumerate() { + let left = 2 * pair_x; + let s_0 = row[left]; + let s_1 = if left + 1 < live_x_cols { + row[left + 1] + } else { + E::zero() + }; + *dst = s_0 + r * (s_1 - s_0); + } + }); + + #[cfg(not(feature = "parallel"))] + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + let row_start = y * live_x_cols; + let row = &s_full[row_start..row_start + live_x_cols]; + for (pair_x, dst) in row_out.iter_mut().enumerate() { + let left = 2 * pair_x; + let s_0 = row[left]; + let s_1 = if left + 1 < live_x_cols { + row[left + 1] + } else { + E::zero() + }; + *dst = s_0 + r * (s_1 - s_0); + } + } + + out + } + + #[tracing::instrument(skip_all, name = "HachiStage1Prover::fold_s_compact_to_full")] + fn fold_s_compact_to_full(s_compact: &[i32], fold_lut: &CompactPairFoldLut) -> Vec { + cfg_into_iter!(0..s_compact.len() / 2) + .map(|j| fold_lut.fold(s_compact[2 * j], s_compact[2 * j + 1])) + .collect() + } +} + +impl SumcheckInstanceProver + for HachiStage1Prover +{ + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + self.b / 2 + 1 + } + + fn input_claim(&self) -> E { + E::zero() + } + + fn compute_round_univariate(&mut self, _round: usize, previous_claim: E) -> UniPoly { + let poly = if let Some(poly) = self.cached_round_poly.take() { + poly + } else { + self.compute_current_round_poly_from_state(previous_claim) + }; + self.pending_round_poly = Some(poly.clone()); + poly + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + let t_fold = Instant::now(); + let _span = tracing::info_span!("HachiStage1Prover::fold_round").entered(); + let next_claim = self + .pending_round_poly + .take() + .map(|poly| poly.evaluate(&r)) + .expect("ingest_challenge called before computing the current round polynomial"); + if self.using_two_round_prefix() { + let rounds_completed = self.rounds_completed; + self.split_eq.bind(r); + if rounds_completed == 0 { + self.ensure_two_round_prefix().first_challenge = Some(r); + } else { + let r0 = { + let prefix = self.ensure_two_round_prefix(); + prefix + .first_challenge + .expect("round 1 ingest requires the round 0 challenge") + }; + let y_len = match &self.s_table { + STable::Compact(s_compact) => s_compact.len() / self.live_x_cols, + STable::Full(_) => panic!("two-round prefix expected compact table"), + }; + self.s_table = match std::mem::replace(&mut self.s_table, STable::Full(Vec::new())) + { + STable::Compact(s_compact) => { + if self.num_u > 2 { + let (s_full, round_poly) = self + .fuse_compact_to_round2_and_compute_round( + &s_compact, next_claim, r0, r, + ); + self.cached_round_poly = Some(round_poly); + STable::Full(s_full) + } else { + let s_full = Self::fold_s_compact_to_round2( + &s_compact, + self.live_x_cols, + y_len, + r0, + r, + ); + STable::Full(s_full) + } + } + STable::Full(_) => unreachable!("two-round prefix should hold compact table"), + }; + self.live_x_cols = self.live_x_cols.div_ceil(4); + } + self.rounds_completed += 1; + if self.rounds_completed < self.num_vars { + if self.cached_round_poly.is_none() { + self.cached_round_poly = + Some(self.compute_current_round_poly_from_state(next_claim)); + } + } else { + self.cached_round_poly = None; + } + drop(_span); + self.fold_time_total += t_fold.elapsed().as_secs_f64(); + return; + } + + self.split_eq.bind(r); + let use_prefix_x_round = self.use_prefix_x_round(); + let fuse_next_full_prefix_x = + use_prefix_x_round && self.next_use_prefix_x_round_after_current(); + let y_len = match &self.s_table { + STable::Compact(s_compact) => s_compact.len() / self.live_x_cols, + STable::Full(s_full) => s_full.len() / self.live_x_cols, + }; + + self.s_table = match std::mem::replace(&mut self.s_table, STable::Full(Vec::new())) { + STable::Compact(s_compact) => { + let fold_lut = Self::build_compact_s_fold_lut(self.b, r); + let s_full = if use_prefix_x_round { + Self::fold_s_compact_prefix_x(&s_compact, self.live_x_cols, y_len, &fold_lut) + } else { + Self::fold_s_compact_to_full(&s_compact, &fold_lut) + }; + STable::Full(s_full) + } + STable::Full(s_full) => { + if use_prefix_x_round { + if fuse_next_full_prefix_x { + let (next_s_full, round_poly) = + self.fuse_full_prefix_x_and_compute_round(&s_full, next_claim, r); + self.cached_round_poly = Some(round_poly); + STable::Full(next_s_full) + } else { + let next_s_full = + Self::fold_s_full_prefix_x(&s_full, self.live_x_cols, y_len, r); + STable::Full(next_s_full) + } + } else { + let mut s_full = s_full; + fold_evals_in_place(&mut s_full, r); + STable::Full(s_full) + } + } + }; + + if self.rounds_completed < self.num_u { + self.live_x_cols = self.live_x_cols.div_ceil(2); + } + self.rounds_completed += 1; + if self.rounds_completed < self.num_vars { + if self.cached_round_poly.is_none() { + self.cached_round_poly = + Some(self.compute_current_round_poly_from_state(next_claim)); + } + } else { + self.cached_round_poly = None; + } + drop(_span); + self.fold_time_total += t_fold.elapsed().as_secs_f64(); + } + + fn finalize(&mut self) { + tracing::debug!( + rounds = self.num_vars, + prefix_s = self.prefix_time_total, + dense_s = self.dense_time_total, + fold_s = self.fold_time_total, + "stage1 sumcheck rounds complete" + ); + } +} + +/// Verifier for the stage-1 norm sumcheck over the virtual table `S`. +pub struct HachiStage1Verifier { + tau0: Vec, + s_claim: F, + b: usize, +} + +impl HachiStage1Verifier { + /// Construct the stage-1 verifier from `tau0`, the carried `s_claim`, and `b`. + pub fn new(tau0: Vec, s_claim: F, b: usize) -> Self { + Self { tau0, s_claim, b } + } +} + +impl SumcheckInstanceVerifier for HachiStage1Verifier { + fn num_rounds(&self) -> usize { + self.tau0.len() + } + + fn degree_bound(&self) -> usize { + self.b / 2 + 1 + } + + fn input_claim(&self) -> F { + F::zero() + } + + fn expected_output_claim(&self, challenges: &[F]) -> Result { + let eq_val = EqPolynomial::mle(&self.tau0, challenges); + Ok(eq_val * range_check_eval_from_s(self.s_claim, self.b)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Prime128M8M4M1M0; + use crate::protocol::sumcheck::multilinear_eval; + + type F = Prime128M8M4M1M0; + + fn pad_compact_rows( + w_prefix: &[i8], + live_x_cols: usize, + num_u: usize, + num_l: usize, + ) -> Vec { + let x_len = 1usize << num_u; + let y_len = 1usize << num_l; + let mut padded = vec![0i8; x_len * y_len]; + for y in 0..y_len { + let src_start = y * live_x_cols; + let dst_start = y * x_len; + padded[dst_start..dst_start + live_x_cols] + .copy_from_slice(&w_prefix[src_start..src_start + live_x_cols]); + } + padded + } + + fn fold_s_compact_prefix_x_reference( + s_compact: &[i32], + live_x_cols: usize, + y_len: usize, + r: F, + ) -> Vec { + let next_live_x_cols = live_x_cols.div_ceil(2); + let mut out = vec![F::zero(); y_len * next_live_x_cols]; + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + let row_start = y * live_x_cols; + let row = &s_compact[row_start..row_start + live_x_cols]; + for (pair_x, dst) in row_out.iter_mut().enumerate() { + let left = 2 * pair_x; + let s_0 = F::from_i64(row[left] as i64); + let s_1 = if left + 1 < live_x_cols { + F::from_i64(row[left + 1] as i64) + } else { + F::zero() + }; + *dst = s_0 + r * (s_1 - s_0); + } + } + out + } + + fn fold_s_compact_to_full_reference(s_compact: &[i32], r: F) -> Vec { + (0..s_compact.len() / 2) + .map(|j| { + let s_0 = F::from_i64(s_compact[2 * j] as i64); + let s_1 = F::from_i64(s_compact[2 * j + 1] as i64); + s_0 + r * (s_1 - s_0) + }) + .collect() + } + + #[test] + fn stage1_compact_fold_lookup_matches_direct_formula() { + let b = 8usize; + let r = F::from_u64(41); + + let s_prefix = vec![2, 6, 12, 2, 6, 12, 2, 6, 12, 2]; + let fold_lut = HachiStage1Prover::::build_compact_s_fold_lut(b, r); + assert_eq!( + HachiStage1Prover::::fold_s_compact_prefix_x(&s_prefix, 5, 2, &fold_lut), + fold_s_compact_prefix_x_reference(&s_prefix, 5, 2, r) + ); + + let s_dense = vec![2, 6, 12, 2, 6, 12]; + let dense_lut = HachiStage1Prover::::build_compact_s_fold_lut(b, r); + assert_eq!( + HachiStage1Prover::::fold_s_compact_to_full(&s_dense, &dense_lut), + fold_s_compact_to_full_reference(&s_dense, r) + ); + } + + #[test] + fn stage1_round0_matches_dense_reference() { + let num_u = 3usize; + let num_l = 2usize; + let b = 8usize; + let n = 1usize << (num_u + num_l); + let half = (b / 2) as i8; + let w_compact: Vec = (0..n).map(|i| ((i * 5 + 3) % b) as i8 - half).collect(); + let tau0: Vec = (0..(num_u + num_l)) + .map(|i| F::from_u64((i as u64) + 2)) + .collect(); + + let mut prover = + HachiStage1Prover::new(&w_compact, &tau0, b, 1usize << num_u, num_u, num_l); + let stage1_poly = prover.compute_round_univariate(0, F::zero()); + let s_compact: Vec = w_compact + .iter() + .map(|&w| { + let w = w as i32; + w * (w + 1) + }) + .collect(); + let reference = compute_norm_round_poly_from_s_compact( + &prover.split_eq, + &s_compact, + &prover.range_precomp, + F::zero(), + ); + + assert_eq!(stage1_poly, reference); + } + + #[test] + fn stage1_prefix_aware_rounds_match_explicit_zero_padding() { + let num_l = 2usize; + let b = 8usize; + let half = (b / 2) as i8; + + for live_x_cols in [5usize, 6usize] { + let num_u = live_x_cols.next_power_of_two().trailing_zeros() as usize; + let y_len = 1usize << num_l; + let w_prefix: Vec = (0..(live_x_cols * y_len)) + .map(|i| ((i * 7 + 5) % b) as i8 - half) + .collect(); + let w_padded = pad_compact_rows(&w_prefix, live_x_cols, num_u, num_l); + let tau0: Vec = (0..(num_u + num_l)) + .map(|i| F::from_u64((i as u64) + 19)) + .collect(); + let mut prefix_prover = + HachiStage1Prover::new(&w_prefix, &tau0, b, live_x_cols, num_u, num_l); + let mut padded_prover = + HachiStage1Prover::new(&w_padded, &tau0, b, 1usize << num_u, num_u, num_l); + let mut challenges = Vec::new(); + let mut prefix_claim = F::zero(); + let mut padded_claim = F::zero(); + + for round in 0..(num_u + num_l) { + let prefix_poly = prefix_prover.compute_round_univariate(round, prefix_claim); + let padded_poly = padded_prover.compute_round_univariate(round, padded_claim); + assert_eq!( + prefix_poly, padded_poly, + "round {round} polynomial mismatch live_x_cols={live_x_cols}" + ); + + let challenge = F::from_u64((round as u64) + 29); + challenges.push(challenge); + prefix_claim = prefix_poly.evaluate(&challenge); + padded_claim = padded_poly.evaluate(&challenge); + prefix_prover.ingest_challenge(round, challenge); + padded_prover.ingest_challenge(round, challenge); + } + + assert_eq!(prefix_prover.final_s_claim(), padded_prover.final_s_claim()); + assert_eq!(prefix_claim, padded_claim); + let s_padded: Vec = w_padded + .iter() + .map(|&w| { + let w = F::from_i64(w as i64); + w * (w + F::one()) + }) + .collect(); + assert_eq!( + prefix_prover.final_s_claim(), + multilinear_eval(&s_padded, &challenges).unwrap(), + "final s-claim mismatch live_x_cols={live_x_cols}" + ); + } + } + + #[test] + fn stage1_fused_round2_transition_matches_two_pass_reference() { + let num_u = 3usize; + let num_l = 2usize; + let live_x_cols = 6usize; + let b = 8usize; + let half = (b / 2) as i8; + let y_len = 1usize << num_l; + let w_prefix: Vec = (0..(live_x_cols * y_len)) + .map(|i| ((i * 9 + 5) % b) as i8 - half) + .collect(); + let s_compact: Vec = w_prefix + .iter() + .map(|&w| { + let w = w as i32; + w * (w + 1) + }) + .collect(); + let tau0: Vec = (0..(num_u + num_l)) + .map(|i| F::from_u64((i as u64) + 53)) + .collect(); + + let mut prover = HachiStage1Prover::new(&w_prefix, &tau0, b, live_x_cols, num_u, num_l); + let round0 = prover.compute_round_univariate(0, F::zero()); + let r0 = F::from_u64(61); + let claim1 = round0.evaluate(&r0); + prover.ingest_challenge(0, r0); + let round1 = prover.compute_round_univariate(1, claim1); + let r1 = F::from_u64(67); + let claim2 = round1.evaluate(&r1); + + let expected_s_full = HachiStage1Prover::::fold_s_compact_to_round2( + &s_compact, + live_x_cols, + y_len, + r0, + r1, + ); + let mut expected = HachiStage1Prover::new(&w_prefix, &tau0, b, live_x_cols, num_u, num_l); + expected.split_eq.bind(r0); + expected.split_eq.bind(r1); + expected.live_x_cols = live_x_cols.div_ceil(4); + expected.rounds_completed = 2; + let expected_round2 = expected.compute_round_full_prefix_x(&expected_s_full, claim2); + + prover.ingest_challenge(1, r1); + + match &prover.s_table { + STable::Full(s_full) => assert_eq!(s_full, &expected_s_full), + STable::Compact(_) => { + panic!("expected fused stage1 transition to materialize full table") + } + } + assert_eq!(prover.cached_round_poly.as_ref(), Some(&expected_round2)); + } + + #[test] + fn stage1_later_full_prefix_fusion_matches_two_pass_reference() { + let num_u = 5usize; + let num_l = 2usize; + let live_x_cols = 12usize; + let b = 8usize; + let half = (b / 2) as i8; + let y_len = 1usize << num_l; + let w_prefix: Vec = (0..(live_x_cols * y_len)) + .map(|i| ((i * 5 + 11) % b) as i8 - half) + .collect(); + let tau0: Vec = (0..(num_u + num_l)) + .map(|i| F::from_u64((i as u64) + 101)) + .collect(); + + let mut prover = HachiStage1Prover::new(&w_prefix, &tau0, b, live_x_cols, num_u, num_l); + let round0 = prover.compute_round_univariate(0, F::zero()); + let r0 = F::from_u64(107); + let claim1 = round0.evaluate(&r0); + prover.ingest_challenge(0, r0); + + let round1 = prover.compute_round_univariate(1, claim1); + let r1 = F::from_u64(109); + let claim2 = round1.evaluate(&r1); + prover.ingest_challenge(1, r1); + + let round2 = prover.compute_round_univariate(2, claim2); + let r2 = F::from_u64(113); + let claim3 = round2.evaluate(&r2); + + let mut expected = HachiStage1Prover::new(&w_prefix, &tau0, b, live_x_cols, num_u, num_l); + let expected_round0 = expected.compute_round_univariate(0, F::zero()); + assert_eq!(expected_round0, round0); + expected.ingest_challenge(0, r0); + let expected_round1 = expected.compute_round_univariate(1, claim1); + assert_eq!(expected_round1, round1); + expected.ingest_challenge(1, r1); + let expected_round2 = expected.compute_round_univariate(2, claim2); + assert_eq!(expected_round2, round2); + + let current_s_full = match &expected.s_table { + STable::Full(s_full) => s_full.clone(), + STable::Compact(_) => panic!("expected later prefix state to be full"), + }; + let expected_next_s_full = HachiStage1Prover::::fold_s_full_prefix_x( + ¤t_s_full, + expected.live_x_cols, + y_len, + r2, + ); + expected.split_eq.bind(r2); + expected.live_x_cols = expected.live_x_cols.div_ceil(2); + expected.rounds_completed += 1; + let expected_round3 = expected.compute_round_full_prefix_x(&expected_next_s_full, claim3); + + prover.ingest_challenge(2, r2); + + match &prover.s_table { + STable::Full(s_full) => assert_eq!(s_full, &expected_next_s_full), + STable::Compact(_) => panic!("expected fused later prefix stage to stay full"), + } + assert_eq!(prover.cached_round_poly.as_ref(), Some(&expected_round3)); + } +} diff --git a/src/protocol/sumcheck/hachi_stage2.rs b/src/protocol/sumcheck/hachi_stage2.rs new file mode 100644 index 00000000..84c2c27b --- /dev/null +++ b/src/protocol/sumcheck/hachi_stage2.rs @@ -0,0 +1,2667 @@ +//! Stage-2 fused sumcheck prover/verifier for the Hachi PCS. +//! +//! This stage views the committed witness as a Boolean table +//! `w : {0,1}^{num_u} x {0,1}^{num_l} -> F`, where `x` indexes the padded +//! witness columns and `y` indexes the coefficient inside a +//! `D = 2^{num_l}`-dimensional ring element. Let `a(y)` be the multilinear +//! extension of `alpha_evals_y = [1, alpha, ..., alpha^(D-1)]`, so on Boolean +//! inputs `a(y) = alpha^{bin(y)}`. Let `M_alpha` be the ring-switch matrix +//! after evaluating every ring entry at the transcript challenge `alpha`, and +//! define the `tau1`-weighted row combination +//! +//! `m_tau1(x) = sum_i eq(tau1, i) * M_alpha(i, x)`. +//! +//! The Boolean table stored in `m_evals_x` is exactly `x -> m_tau1(x)`. +//! +//! If +//! +//! `y_alpha = [v_0(alpha), ..., v_{N_D-1}(alpha),` +//! ` u_0(alpha), ..., u_{N_B-1}(alpha),` +//! ` y_ring(alpha), 0, ..., 0],` +//! +//! then the linear relation claim is +//! +//! `relation_claim = sum_i eq(tau1, i) * y_alpha[i]` +//! ` = sum_{x,y} w(x, y) * a(y) * m_tau1(x)`. +//! +//! Stage 1 supplies the carried virtual claim +//! +//! `s_claim = w(r_stage1) * (w(r_stage1) + 1)` +//! ` = sum_z eq(r_stage1, z) * w(z) * (w(z) + 1)` +//! +//! for the same multilinear witness table. With `gamma = batching_coeff`, the +//! exact identity established by this sumcheck is +//! +//! `gamma * s_claim + relation_claim =` +//! `sum_{x,y} [ gamma * eq(r_stage1, (x, y)) * w(x, y) * (w(x, y) + 1)` +//! ` + w(x, y) * a(y) * m_tau1(x) ]`. +//! +//! After all rounds, at `r_stage2 = (r_x, r_y)`, the verifier checks +//! +//! `gamma * eq(r_stage1, r_stage2) * w(r_stage2) * (w(r_stage2) + 1)` +//! ` + w(r_stage2) * a(r_y) * m_tau1(r_x)`, +//! +//! exactly the oracle returned by `expected_output_claim()`. The prover fuses +//! both halves around the same local `w0` / `dw` scan so the witness-side work +//! is shared between the virtual and relation terms. + +use super::eq_poly::EqPolynomial; +use super::split_eq::GruenSplitEq; +use super::two_round_prefix::{ + build_stage2_bivariate_skip_proof_from_compact, can_use_stage2_two_round_prefix, + Stage2BivariateSkipState, +}; +use super::{fold_evals_in_place, multilinear_eval, trim_trailing_zeros, CompactPairFoldLut}; +use super::{SumcheckInstanceProver, SumcheckInstanceVerifier, UniPoly}; +use crate::algebra::fields::HasUnreducedOps; +use crate::algebra::CyclotomicRing; +use crate::error::HachiError; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::protocol::ring_switch::eval_ring_at; +use crate::{cfg_fold_reduce, cfg_into_iter}; +use crate::{AdditiveGroup, CanonicalField, FieldCore, FromSmallInt}; +use std::marker::PhantomData; +use std::mem; +use std::time::Instant; + +enum WTable { + Compact(Vec), + Full(Vec), +} + +struct Stage2TwoRoundPrefix { + skip_state: Stage2BivariateSkipState, + first_challenge: Option, +} + +#[derive(Clone, Copy)] +enum NormRoundTerms { + Full([E; 3]), + SkipLinear([E; 2]), +} + +type CompactVirtAccum = [::MulU64Accum; 4]; +type CompactVirtSkipLinearAccum = [::MulU64Accum; 2]; +type CompactRelAccum = [::MulU64Accum; 6]; + +#[inline] +fn coeffs_to_poly(coeffs: [E; 3]) -> UniPoly { + let mut coeffs = vec![coeffs[0], coeffs[1], coeffs[2]]; + trim_trailing_zeros(&mut coeffs); + UniPoly::from_coeffs(coeffs) +} + +#[inline] +fn accum_small_signed( + accum: &mut [E::MulU64Accum], + pos_idx: usize, + coeff: E, + signed: i64, +) { + if signed == 0 { + return; + } + let prod = coeff.mul_u64_unreduced(signed.unsigned_abs()); + if signed < 0 { + accum[pos_idx + 1] += prod; + } else { + accum[pos_idx] += prod; + } +} + +#[inline] +fn reduce_signed_accum( + pos: E::MulU64Accum, + neg: E::MulU64Accum, +) -> E { + E::reduce_mul_u64_accum(pos) - E::reduce_mul_u64_accum(neg) +} + +#[inline] +fn reduce_compact_virt(virt: CompactVirtAccum) -> [E; 3] { + [ + E::reduce_mul_u64_accum(virt[0]), + reduce_signed_accum::(virt[1], virt[2]), + E::reduce_mul_u64_accum(virt[3]), + ] +} + +#[inline] +fn reduce_compact_virt_skip_linear( + virt: CompactVirtSkipLinearAccum, +) -> [E; 2] { + [ + E::reduce_mul_u64_accum(virt[0]), + E::reduce_mul_u64_accum(virt[1]), + ] +} + +#[inline] +fn reduce_compact_rel(rel: CompactRelAccum) -> [E; 3] { + [ + reduce_signed_accum::(rel[0], rel[1]), + reduce_signed_accum::(rel[2], rel[3]), + reduce_signed_accum::(rel[4], rel[5]), + ] +} + +#[inline] +pub(crate) fn relation_claim_from_rows( + tau1: &[F], + alpha: F, + v: &[CyclotomicRing], + u: &[CyclotomicRing], + y_ring: &CyclotomicRing, +) -> F { + let eq_tau1 = EqPolynomial::evals(tau1); + let mut acc = F::zero(); + let mut row_idx = 0usize; + + for r in v { + if row_idx >= eq_tau1.len() { + return acc; + } + acc += eq_tau1[row_idx] * eval_ring_at(r, &alpha); + row_idx += 1; + } + for r in u { + if row_idx >= eq_tau1.len() { + return acc; + } + acc += eq_tau1[row_idx] * eval_ring_at(r, &alpha); + row_idx += 1; + } + if row_idx < eq_tau1.len() { + acc += eq_tau1[row_idx] * eval_ring_at(y_ring, &alpha); + } + acc +} + +/// Stage-2 fused virtual-claim + relation sumcheck prover. +/// +/// Holds a single `w_table` shared by both halves of stage 2. The virtual half +/// is pre-weighted by `batching_coeff` through `split_eq`, so the round +/// polynomial is: +/// `batching_coeff * virtual_round(t) + relation_round(t)`. +pub struct HachiStage2Prover { + w_table: WTable, + batching_coeff: E, + s_claim: E, + split_eq: GruenSplitEq, + + alpha_compact: Vec, + m_compact: Vec, + live_x_cols: usize, + num_u: usize, + num_vars: usize, + relation_claim: E, + prev_norm_claim: E, + prev_norm_poly: Option>, + prefix_r_stage1: Option>, + two_round_prefix: Option>, + cached_round_poly: Option>, + + scan_time_total: f64, + fold_time_total: f64, + rounds_completed: usize, +} + +impl HachiStage2Prover { + /// Create a fused stage-2 virtual-claim + relation sumcheck prover. + #[allow(clippy::too_many_arguments)] + #[tracing::instrument(skip_all, name = "HachiStage2Prover::new")] + pub fn new( + batching_coeff: E, + w_evals_compact: Vec, + r_stage1: &[E], + s_claim: E, + alpha_evals_y: Vec, + m_evals_x: Vec, + live_x_cols: usize, + num_u: usize, + num_l: usize, + relation_claim: E, + ) -> Self { + let num_vars = num_u + num_l; + assert!(live_x_cols >= 1, "live_x_cols must be at least 1"); + assert!( + live_x_cols <= (1usize << num_u), + "live_x_cols exceeds x width" + ); + let y_len = 1usize << num_l; + assert_eq!(w_evals_compact.len(), live_x_cols * y_len); + assert_eq!(r_stage1.len(), num_vars); + assert_eq!(alpha_evals_y.len(), y_len); + assert_eq!(m_evals_x.len(), 1 << num_u); + + Self { + w_table: WTable::Compact(w_evals_compact), + batching_coeff, + s_claim, + split_eq: GruenSplitEq::with_initial_scalar(r_stage1, batching_coeff), + alpha_compact: alpha_evals_y, + m_compact: m_evals_x, + live_x_cols, + num_u, + num_vars, + relation_claim, + prev_norm_claim: batching_coeff * s_claim, + prev_norm_poly: None, + prefix_r_stage1: can_use_stage2_two_round_prefix(num_u).then(|| r_stage1.to_vec()), + two_round_prefix: None, + cached_round_poly: None, + scan_time_total: 0.0, + fold_time_total: 0.0, + rounds_completed: 0, + } + } + + /// Return the fully folded witness evaluation after the final round. + /// + /// # Panics + /// + /// Panics if called before the witness table has been fully folded to a + /// single field element. + pub fn final_w_eval(&self) -> E { + match &self.w_table { + WTable::Full(w_full) => { + assert_eq!(w_full.len(), 1, "w_table not fully folded"); + w_full[0] + } + WTable::Compact(_) => panic!("w_table remained compact after final fold"), + } + } + + #[inline] + fn current_x_width(&self) -> usize { + self.num_u.saturating_sub(self.rounds_completed) + } + + #[inline] + fn current_x_len(&self) -> usize { + 1usize << self.current_x_width() + } + + #[inline] + fn use_prefix_x_round(&self) -> bool { + self.rounds_completed < self.num_u && self.live_x_cols < self.current_x_len() + } + + #[inline] + fn next_use_prefix_x_round_after_current(&self) -> bool { + self.rounds_completed + 1 < self.num_u + && self.live_x_cols.div_ceil(2) < (self.current_x_len() / 2) + } + + #[inline] + pub(crate) fn can_use_two_round_prefix(&self) -> bool { + self.prefix_r_stage1.is_some() + } + + #[inline] + fn using_two_round_prefix(&self) -> bool { + self.rounds_completed < 2 && self.can_use_two_round_prefix() + } + + #[inline] + fn can_skip_norm_linear_coeff(&self) -> bool { + self.split_eq.can_recover_linear_q_term_from_claim() + } + + #[inline] + fn norm_poly_from_terms(&self, virt_terms: NormRoundTerms) -> UniPoly { + match virt_terms { + NormRoundTerms::Full(virt_q_coeffs) => { + self.split_eq.gruen_mul(&coeffs_to_poly(virt_q_coeffs)) + } + NormRoundTerms::SkipLinear([q_constant, q_quadratic]) => self + .split_eq + .try_gruen_poly_deg_3(q_constant, q_quadratic, self.prev_norm_claim) + .expect("split-eq norm claim recovery should succeed"), + } + } + + #[inline] + fn polys_from_terms( + &self, + virt_terms: NormRoundTerms, + rel_coeffs: [E; 3], + ) -> (UniPoly, UniPoly) { + let virt_poly = self.norm_poly_from_terms(virt_terms); + let rel_poly = coeffs_to_poly(rel_coeffs); + (virt_poly, rel_poly) + } + + #[inline] + fn combine_polys(&self, virt_poly: &UniPoly, relation_poly: &UniPoly) -> UniPoly { + let max_len = virt_poly.coeffs.len().max(relation_poly.coeffs.len()); + let mut combined = vec![E::zero(); max_len]; + for (i, c) in virt_poly.coeffs.iter().enumerate() { + combined[i] += *c; + } + for (i, c) in relation_poly.coeffs.iter().enumerate() { + combined[i] += *c; + } + UniPoly::from_coeffs(combined) + } + + #[inline] + fn combine_terms(&mut self, virt_terms: NormRoundTerms, rel_coeffs: [E; 3]) -> UniPoly { + let (virt_poly, relation_poly) = self.polys_from_terms(virt_terms, rel_coeffs); + let combined = self.combine_polys(&virt_poly, &relation_poly); + self.prev_norm_poly = Some(virt_poly); + combined + } + + fn ensure_two_round_prefix(&mut self) -> &mut Stage2TwoRoundPrefix { + if self.two_round_prefix.is_none() { + let r_stage1 = self + .prefix_r_stage1 + .clone() + .expect("two-round prefix requested without cached stage-1 challenges"); + let num_l = self.num_vars - self.num_u; + let w_compact = match &self.w_table { + WTable::Compact(w_compact) => w_compact, + WTable::Full(_) => panic!("two-round prefix can only build from compact witness"), + }; + let proof = build_stage2_bivariate_skip_proof_from_compact( + w_compact, + &self.alpha_compact, + &self.m_compact, + &r_stage1, + self.live_x_cols, + self.num_u, + num_l, + ) + .expect("two-round prefix should be available"); + let skip_state = Stage2BivariateSkipState::new( + &proof, + &r_stage1, + self.s_claim, + self.relation_claim, + self.batching_coeff, + ) + .expect("valid bivariate-skip state"); + self.two_round_prefix = Some(Stage2TwoRoundPrefix { + skip_state, + first_challenge: None, + }); + } + self.two_round_prefix + .as_mut() + .expect("two-round prefix should be initialized") + } + + #[inline] + fn direct_fold_w_quad_to_round2(w00: i8, w10: i8, w01: i8, w11: i8, r0: E, r1: E) -> E { + let w00 = E::from_i64(w00 as i64); + let w10 = E::from_i64(w10 as i64); + let w01 = E::from_i64(w01 as i64); + let w11 = E::from_i64(w11 as i64); + let x0 = w00 + r0 * (w10 - w00); + let x1 = w01 + r0 * (w11 - w01); + x0 + r1 * (x1 - x0) + } + + #[inline] + fn direct_fold_e_quad_to_round2(e00: E, e10: E, e01: E, e11: E, r0: E, r1: E) -> E { + let x0 = e00 + r0 * (e10 - e00); + let x1 = e01 + r0 * (e11 - e01); + x0 + r1 * (x1 - x0) + } + + #[inline] + fn stage2_b8_w_digit(w: i8) -> usize { + let w = i32::from(w); + debug_assert!((-4..=3).contains(&w)); + (w + 4) as usize + } + + #[inline] + fn stage2_b8_quad_lookup_index_from_row(row: &[i8], base: usize) -> usize { + let d0 = row + .get(base) + .copied() + .map(Self::stage2_b8_w_digit) + .unwrap_or(4); + let d1 = row + .get(base + 1) + .copied() + .map(Self::stage2_b8_w_digit) + .unwrap_or(4); + let d2 = row + .get(base + 2) + .copied() + .map(Self::stage2_b8_w_digit) + .unwrap_or(4); + let d3 = row + .get(base + 3) + .copied() + .map(Self::stage2_b8_w_digit) + .unwrap_or(4); + d0 | (d1 << 3) | (d2 << 6) | (d3 << 9) + } + + fn build_round2_w_lookup(r0: E, r1: E) -> Vec { + const W_VALUES: [i8; 8] = [-4, -3, -2, -1, 0, 1, 2, 3]; + (0..4096usize) + .map(|idx| { + let d0 = idx & 0b111; + let d1 = (idx >> 3) & 0b111; + let d2 = (idx >> 6) & 0b111; + let d3 = (idx >> 9) & 0b111; + Self::direct_fold_w_quad_to_round2( + W_VALUES[d0], + W_VALUES[d1], + W_VALUES[d2], + W_VALUES[d3], + r0, + r1, + ) + }) + .collect() + } + + #[tracing::instrument(skip_all, name = "HachiStage2Prover::fold_compact_to_round2")] + fn fold_compact_to_round2( + w_compact: &[i8], + live_x_cols: usize, + y_len: usize, + r0: E, + r1: E, + ) -> Vec { + let next_live_x_cols = live_x_cols.div_ceil(4); + let mut out = vec![E::zero(); y_len * next_live_x_cols]; + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + let row = &w_compact[y * live_x_cols..(y + 1) * live_x_cols]; + for (quad_x, dst) in row_out.iter_mut().enumerate() { + let base = 4 * quad_x; + *dst = Self::direct_fold_w_quad_to_round2( + row.get(base).copied().unwrap_or_default(), + row.get(base + 1).copied().unwrap_or_default(), + row.get(base + 2).copied().unwrap_or_default(), + row.get(base + 3).copied().unwrap_or_default(), + r0, + r1, + ); + } + } + out + } + + #[tracing::instrument(skip_all, name = "HachiStage2Prover::fold_m_to_round2")] + fn fold_m_to_round2(m_compact: &[E], r0: E, r1: E) -> Vec { + debug_assert!(m_compact.len().is_power_of_two()); + debug_assert!(m_compact.len() >= 4); + let next_x_len = m_compact.len() >> 2; + let mut out = vec![E::zero(); next_x_len]; + for (quad_x, dst) in out.iter_mut().enumerate() { + let base = 4 * quad_x; + *dst = Self::direct_fold_e_quad_to_round2( + m_compact[base], + m_compact[base + 1], + m_compact[base + 2], + m_compact[base + 3], + r0, + r1, + ); + } + out + } + + #[tracing::instrument( + skip_all, + name = "HachiStage2Prover::fuse_compact_to_round2_and_compute_round" + )] + fn fuse_compact_to_round2_and_compute_round( + &self, + w_compact: &[i8], + r0: E, + r1: E, + ) -> (Vec, Vec, NormRoundTerms, [E; 3]) { + debug_assert!(self.num_u > 2); + let old_live_x_cols = self.live_x_cols; + let next_live_x_cols = old_live_x_cols.div_ceil(4); + let y_len = self.alpha_compact.len(); + let live_pairs = next_live_x_cols.div_ceil(2); + let current_x_half = 1usize << (self.num_u - 3); + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros() as usize; + let block_size = num_first.min(live_pairs); + let alpha_compact = &self.alpha_compact; + let m_round2 = Self::fold_m_to_round2(&self.m_compact, r0, r1); + let quad_fold_lut = Self::build_round2_w_lookup(r0, r1); + let mut out = vec![E::zero(); y_len * next_live_x_cols]; + + if self.can_skip_norm_linear_coeff() { + #[cfg(feature = "parallel")] + let (virt_coeffs, rel_coeffs) = out + .par_chunks_mut(next_live_x_cols) + .enumerate() + .map(|(y, row_out)| { + let row = &w_compact[y * old_live_x_cols..(y + 1) * old_live_x_cols]; + let alpha = alpha_compact[y]; + let j_base = y * current_x_half; + let mut virt = [E::zero(); 2]; + let mut rel = [E::zero(); 3]; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_virt = [E::zero(); 2]; + + for pair_x in blk..blk_end { + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + let left_quad = 2 * pair_x; + let left_base = 8 * pair_x; + let w0 = quad_fold_lut + [Self::stage2_b8_quad_lookup_index_from_row(row, left_base)]; + row_out[left_quad] = w0; + let w1 = if left_quad + 1 < next_live_x_cols { + let w1 = quad_fold_lut[Self::stage2_b8_quad_lookup_index_from_row( + row, + left_base + 4, + )]; + row_out[left_quad + 1] = w1; + w1 + } else { + E::zero() + }; + let dw = w1 - w0; + + inner_virt[0] += e_in * (w0 * (w0 + E::one())); + inner_virt[1] += e_in * (dw * dw); + + let m0 = m_round2[left_quad]; + let m1 = m_round2[left_quad + 1]; + let p0 = alpha * m0; + let p1 = alpha * m1; + let dp = p1 - p0; + rel[0] += w0 * p0; + rel[1] += w0 * dp + dw * p0; + rel[2] += dw * dp; + } + + let e_out = e_second[j_high]; + virt[0] += e_out * inner_virt[0]; + virt[1] += e_out * inner_virt[1]; + blk = blk_end; + } + + (virt, rel) + }) + .reduce( + || ([E::zero(); 2], [E::zero(); 3]), + |(mut va, mut ra), (vb, rb)| { + for (ai, bi) in va.iter_mut().zip(vb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (va, ra) + }, + ); + + #[cfg(not(feature = "parallel"))] + let (virt_coeffs, rel_coeffs) = { + let mut virt = [E::zero(); 2]; + let mut rel = [E::zero(); 3]; + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + let row = &w_compact[y * old_live_x_cols..(y + 1) * old_live_x_cols]; + let alpha = alpha_compact[y]; + let j_base = y * current_x_half; + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_virt = [E::zero(); 2]; + + for pair_x in blk..blk_end { + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + let left_quad = 2 * pair_x; + let left_base = 8 * pair_x; + let w0 = quad_fold_lut + [Self::stage2_b8_quad_lookup_index_from_row(row, left_base)]; + row_out[left_quad] = w0; + let w1 = if left_quad + 1 < next_live_x_cols { + let w1 = quad_fold_lut[Self::stage2_b8_quad_lookup_index_from_row( + row, + left_base + 4, + )]; + row_out[left_quad + 1] = w1; + w1 + } else { + E::zero() + }; + let dw = w1 - w0; + + inner_virt[0] += e_in * (w0 * (w0 + E::one())); + inner_virt[1] += e_in * (dw * dw); + + let m0 = m_round2[left_quad]; + let m1 = m_round2[left_quad + 1]; + let p0 = alpha * m0; + let p1 = alpha * m1; + let dp = p1 - p0; + rel[0] += w0 * p0; + rel[1] += w0 * dp + dw * p0; + rel[2] += dw * dp; + } + + let e_out = e_second[j_high]; + virt[0] += e_out * inner_virt[0]; + virt[1] += e_out * inner_virt[1]; + blk = blk_end; + } + } + (virt, rel) + }; + + ( + out, + m_round2, + NormRoundTerms::SkipLinear(virt_coeffs), + rel_coeffs, + ) + } else { + #[cfg(feature = "parallel")] + let (virt_coeffs, rel_coeffs) = out + .par_chunks_mut(next_live_x_cols) + .enumerate() + .map(|(y, row_out)| { + let row = &w_compact[y * old_live_x_cols..(y + 1) * old_live_x_cols]; + let alpha = alpha_compact[y]; + let j_base = y * current_x_half; + let mut virt = [E::zero(); 3]; + let mut rel = [E::zero(); 3]; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_virt = [E::zero(); 3]; + + for pair_x in blk..blk_end { + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + let left_quad = 2 * pair_x; + let left_base = 8 * pair_x; + let w0 = quad_fold_lut + [Self::stage2_b8_quad_lookup_index_from_row(row, left_base)]; + row_out[left_quad] = w0; + let w1 = if left_quad + 1 < next_live_x_cols { + let w1 = quad_fold_lut[Self::stage2_b8_quad_lookup_index_from_row( + row, + left_base + 4, + )]; + row_out[left_quad + 1] = w1; + w1 + } else { + E::zero() + }; + let dw = w1 - w0; + let two_w0_plus_one = w0 + w0 + E::one(); + + inner_virt[0] += e_in * (w0 * (w0 + E::one())); + inner_virt[1] += e_in * (dw * two_w0_plus_one); + inner_virt[2] += e_in * (dw * dw); + + let m0 = m_round2[left_quad]; + let m1 = m_round2[left_quad + 1]; + let p0 = alpha * m0; + let p1 = alpha * m1; + let dp = p1 - p0; + rel[0] += w0 * p0; + rel[1] += w0 * dp + dw * p0; + rel[2] += dw * dp; + } + + let e_out = e_second[j_high]; + virt[0] += e_out * inner_virt[0]; + virt[1] += e_out * inner_virt[1]; + virt[2] += e_out * inner_virt[2]; + blk = blk_end; + } + + (virt, rel) + }) + .reduce( + || ([E::zero(); 3], [E::zero(); 3]), + |(mut va, mut ra), (vb, rb)| { + for (ai, bi) in va.iter_mut().zip(vb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (va, ra) + }, + ); + + #[cfg(not(feature = "parallel"))] + let (virt_coeffs, rel_coeffs) = { + let mut virt = [E::zero(); 3]; + let mut rel = [E::zero(); 3]; + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + let row = &w_compact[y * old_live_x_cols..(y + 1) * old_live_x_cols]; + let alpha = alpha_compact[y]; + let j_base = y * current_x_half; + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_virt = [E::zero(); 3]; + + for pair_x in blk..blk_end { + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + let left_quad = 2 * pair_x; + let left_base = 8 * pair_x; + let w0 = quad_fold_lut + [Self::stage2_b8_quad_lookup_index_from_row(row, left_base)]; + row_out[left_quad] = w0; + let w1 = if left_quad + 1 < next_live_x_cols { + let w1 = quad_fold_lut[Self::stage2_b8_quad_lookup_index_from_row( + row, + left_base + 4, + )]; + row_out[left_quad + 1] = w1; + w1 + } else { + E::zero() + }; + let dw = w1 - w0; + let two_w0_plus_one = w0 + w0 + E::one(); + + inner_virt[0] += e_in * (w0 * (w0 + E::one())); + inner_virt[1] += e_in * (dw * two_w0_plus_one); + inner_virt[2] += e_in * (dw * dw); + + let m0 = m_round2[left_quad]; + let m1 = m_round2[left_quad + 1]; + let p0 = alpha * m0; + let p1 = alpha * m1; + let dp = p1 - p0; + rel[0] += w0 * p0; + rel[1] += w0 * dp + dw * p0; + rel[2] += dw * dp; + } + + let e_out = e_second[j_high]; + virt[0] += e_out * inner_virt[0]; + virt[1] += e_out * inner_virt[1]; + virt[2] += e_out * inner_virt[2]; + blk = blk_end; + } + } + (virt, rel) + }; + + (out, m_round2, NormRoundTerms::Full(virt_coeffs), rel_coeffs) + } + } + + #[inline] + fn fold_full_prefix_pair(row: &[E], left: usize, r: E) -> E { + let w0 = row.get(left).copied().unwrap_or_else(E::zero); + let w1 = row.get(left + 1).copied().unwrap_or_else(E::zero); + w0 + r * (w1 - w0) + } + + #[tracing::instrument( + skip_all, + name = "HachiStage2Prover::fuse_full_prefix_x_and_compute_round" + )] + fn fuse_full_prefix_x_and_compute_round( + &self, + w_full: &[E], + r: E, + ) -> (Vec, Vec, NormRoundTerms, [E; 3]) { + debug_assert!(self.next_use_prefix_x_round_after_current()); + debug_assert!(self.current_x_width() >= 2); + + let old_live_x_cols = self.live_x_cols; + let next_live_x_cols = old_live_x_cols.div_ceil(2); + let y_len = self.alpha_compact.len(); + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros() as usize; + let next_current_x_half = 1usize << (self.current_x_width() - 2); + let live_pairs = next_live_x_cols.div_ceil(2); + let block_size = num_first.min(live_pairs); + let alpha_compact = &self.alpha_compact; + let next_m_compact = Self::fold_m_prefix(&self.m_compact, r); + let mut out = vec![E::zero(); y_len * next_live_x_cols]; + + if self.can_skip_norm_linear_coeff() { + #[cfg(feature = "parallel")] + let (virt_coeffs, rel_coeffs) = out + .par_chunks_mut(next_live_x_cols) + .enumerate() + .map(|(y, row_out)| { + let row = &w_full[y * old_live_x_cols..(y + 1) * old_live_x_cols]; + let alpha = alpha_compact[y]; + let j_base = y * next_current_x_half; + let mut virt = [E::zero(); 2]; + let mut rel = [E::zero(); 3]; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_virt = [E::zero(); 2]; + + for pair_x in blk..blk_end { + let left_next = 2 * pair_x; + let left_old = 4 * pair_x; + let w0 = Self::fold_full_prefix_pair(row, left_old, r); + row_out[left_next] = w0; + let w1 = if left_next + 1 < next_live_x_cols { + let w1 = Self::fold_full_prefix_pair(row, left_old + 2, r); + row_out[left_next + 1] = w1; + w1 + } else { + E::zero() + }; + let dw = w1 - w0; + + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + inner_virt[0] += e_in * (w0 * (w0 + E::one())); + inner_virt[1] += e_in * (dw * dw); + + let m0 = next_m_compact[left_next]; + let m1 = next_m_compact[left_next + 1]; + let p0 = alpha * m0; + let p1 = alpha * m1; + let dp = p1 - p0; + rel[0] += w0 * p0; + rel[1] += w0 * dp + dw * p0; + rel[2] += dw * dp; + } + + let e_out = e_second[j_high]; + virt[0] += e_out * inner_virt[0]; + virt[1] += e_out * inner_virt[1]; + blk = blk_end; + } + + (virt, rel) + }) + .reduce( + || ([E::zero(); 2], [E::zero(); 3]), + |(mut va, mut ra), (vb, rb)| { + for (ai, bi) in va.iter_mut().zip(vb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (va, ra) + }, + ); + + #[cfg(not(feature = "parallel"))] + let (virt_coeffs, rel_coeffs) = { + let mut virt = [E::zero(); 2]; + let mut rel = [E::zero(); 3]; + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + let row = &w_full[y * old_live_x_cols..(y + 1) * old_live_x_cols]; + let alpha = alpha_compact[y]; + let j_base = y * next_current_x_half; + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_virt = [E::zero(); 2]; + + for pair_x in blk..blk_end { + let left_next = 2 * pair_x; + let left_old = 4 * pair_x; + let w0 = Self::fold_full_prefix_pair(row, left_old, r); + row_out[left_next] = w0; + let w1 = if left_next + 1 < next_live_x_cols { + let w1 = Self::fold_full_prefix_pair(row, left_old + 2, r); + row_out[left_next + 1] = w1; + w1 + } else { + E::zero() + }; + let dw = w1 - w0; + + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + inner_virt[0] += e_in * (w0 * (w0 + E::one())); + inner_virt[1] += e_in * (dw * dw); + + let m0 = next_m_compact[left_next]; + let m1 = next_m_compact[left_next + 1]; + let p0 = alpha * m0; + let p1 = alpha * m1; + let dp = p1 - p0; + rel[0] += w0 * p0; + rel[1] += w0 * dp + dw * p0; + rel[2] += dw * dp; + } + + let e_out = e_second[j_high]; + virt[0] += e_out * inner_virt[0]; + virt[1] += e_out * inner_virt[1]; + blk = blk_end; + } + } + (virt, rel) + }; + + ( + out, + next_m_compact, + NormRoundTerms::SkipLinear(virt_coeffs), + rel_coeffs, + ) + } else { + #[cfg(feature = "parallel")] + let (virt_coeffs, rel_coeffs) = out + .par_chunks_mut(next_live_x_cols) + .enumerate() + .map(|(y, row_out)| { + let row = &w_full[y * old_live_x_cols..(y + 1) * old_live_x_cols]; + let alpha = alpha_compact[y]; + let j_base = y * next_current_x_half; + let mut virt = [E::zero(); 3]; + let mut rel = [E::zero(); 3]; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_virt = [E::zero(); 3]; + + for pair_x in blk..blk_end { + let left_next = 2 * pair_x; + let left_old = 4 * pair_x; + let w0 = Self::fold_full_prefix_pair(row, left_old, r); + row_out[left_next] = w0; + let w1 = if left_next + 1 < next_live_x_cols { + let w1 = Self::fold_full_prefix_pair(row, left_old + 2, r); + row_out[left_next + 1] = w1; + w1 + } else { + E::zero() + }; + let dw = w1 - w0; + let two_w0_plus_one = w0 + w0 + E::one(); + + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + inner_virt[0] += e_in * (w0 * (w0 + E::one())); + inner_virt[1] += e_in * (dw * two_w0_plus_one); + inner_virt[2] += e_in * (dw * dw); + + let m0 = next_m_compact[left_next]; + let m1 = next_m_compact[left_next + 1]; + let p0 = alpha * m0; + let p1 = alpha * m1; + let dp = p1 - p0; + rel[0] += w0 * p0; + rel[1] += w0 * dp + dw * p0; + rel[2] += dw * dp; + } + + let e_out = e_second[j_high]; + virt[0] += e_out * inner_virt[0]; + virt[1] += e_out * inner_virt[1]; + virt[2] += e_out * inner_virt[2]; + blk = blk_end; + } + + (virt, rel) + }) + .reduce( + || ([E::zero(); 3], [E::zero(); 3]), + |(mut va, mut ra), (vb, rb)| { + for (ai, bi) in va.iter_mut().zip(vb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (va, ra) + }, + ); + + #[cfg(not(feature = "parallel"))] + let (virt_coeffs, rel_coeffs) = { + let mut virt = [E::zero(); 3]; + let mut rel = [E::zero(); 3]; + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + let row = &w_full[y * old_live_x_cols..(y + 1) * old_live_x_cols]; + let alpha = alpha_compact[y]; + let j_base = y * next_current_x_half; + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_virt = [E::zero(); 3]; + + for pair_x in blk..blk_end { + let left_next = 2 * pair_x; + let left_old = 4 * pair_x; + let w0 = Self::fold_full_prefix_pair(row, left_old, r); + row_out[left_next] = w0; + let w1 = if left_next + 1 < next_live_x_cols { + let w1 = Self::fold_full_prefix_pair(row, left_old + 2, r); + row_out[left_next + 1] = w1; + w1 + } else { + E::zero() + }; + let dw = w1 - w0; + let two_w0_plus_one = w0 + w0 + E::one(); + + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + inner_virt[0] += e_in * (w0 * (w0 + E::one())); + inner_virt[1] += e_in * (dw * two_w0_plus_one); + inner_virt[2] += e_in * (dw * dw); + + let m0 = next_m_compact[left_next]; + let m1 = next_m_compact[left_next + 1]; + let p0 = alpha * m0; + let p1 = alpha * m1; + let dp = p1 - p0; + rel[0] += w0 * p0; + rel[1] += w0 * dp + dw * p0; + rel[2] += dw * dp; + } + + let e_out = e_second[j_high]; + virt[0] += e_out * inner_virt[0]; + virt[1] += e_out * inner_virt[1]; + virt[2] += e_out * inner_virt[2]; + blk = blk_end; + } + } + (virt, rel) + }; + + ( + out, + next_m_compact, + NormRoundTerms::Full(virt_coeffs), + rel_coeffs, + ) + } + } + + fn compute_current_round_poly_from_state(&mut self) -> UniPoly { + let t_scan = Instant::now(); + let use_two_round_prefix = self.using_two_round_prefix(); + let rounds_completed = self.rounds_completed; + let poly = if use_two_round_prefix { + let (virt_poly, rel_poly) = { + let prefix = self.ensure_two_round_prefix(); + if rounds_completed == 0 { + prefix.skip_state.reconstruct_round0_polys() + } else { + let r0 = prefix + .first_challenge + .expect("round 1 prefix polynomial requested before ingesting round 0"); + prefix.skip_state.reconstruct_round1_polys(r0) + } + }; + let combined = self.combine_polys(&virt_poly, &rel_poly); + self.prev_norm_poly = Some(virt_poly); + combined + } else { + match &self.w_table { + WTable::Compact(w_compact) => { + if self.use_prefix_x_round() { + let (virt_poly, rel_poly) = + self.compute_round_compact_prefix_x_polys(w_compact); + let combined = self.combine_polys(&virt_poly, &rel_poly); + self.prev_norm_poly = Some(virt_poly); + combined + } else { + let (virt_q_coeffs, rel_coeffs) = + self.compute_round_compact_dense_terms(w_compact); + self.combine_terms(virt_q_coeffs, rel_coeffs) + } + } + WTable::Full(w_full) => { + if self.use_prefix_x_round() { + let (virt_q_coeffs, rel_coeffs) = + self.compute_round_full_prefix_x_terms(w_full); + self.combine_terms(virt_q_coeffs, rel_coeffs) + } else { + let (virt_q_coeffs, rel_coeffs) = + self.compute_round_full_dense_terms(w_full); + self.combine_terms(virt_q_coeffs, rel_coeffs) + } + } + } + }; + self.scan_time_total += t_scan.elapsed().as_secs_f64(); + poly + } + + #[tracing::instrument( + skip_all, + name = "HachiStage2Prover::compute_round_compact_dense_terms" + )] + fn compute_round_compact_dense_terms(&self, w_compact: &[i8]) -> (NormRoundTerms, [E; 3]) { + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let num_second = e_second.len(); + let current_x_width = self.current_x_width(); + let current_x_mask = (1usize << current_x_width).wrapping_sub(1); + let alpha_compact = &self.alpha_compact; + let m_compact = &self.m_compact; + debug_assert_eq!(w_compact.len() / 2, num_first * num_second); + + if self.can_skip_norm_linear_coeff() { + let (virt_coeffs, rel_accum) = cfg_fold_reduce!( + 0..num_second, + || ([E::zero(); 2], [E::MulU64Accum::ZERO; 6]), + |(mut virt, mut rel), j_high| { + let mut inner_virt = [E::MulU64Accum::ZERO; 2]; + let base = j_high * num_first; + + for (j_low, &e_in) in e_first.iter().enumerate() { + let j = base + j_low; + let w0 = w_compact[2 * j] as i32; + let w1 = w_compact[2 * j + 1] as i32; + let dw = w1 - w0; + let w0_i64 = w0 as i64; + let dw_i64 = dw as i64; + + let q0 = w0_i64 * (w0_i64 + 1); + if q0 != 0 { + inner_virt[0] += e_in.mul_u64_unreduced(q0 as u64); + } + let q2 = dw_i64 * dw_i64; + if q2 != 0 { + inner_virt[1] += e_in.mul_u64_unreduced(q2 as u64); + } + + let a0 = alpha_compact[(2 * j) >> current_x_width]; + let a1 = alpha_compact[(2 * j + 1) >> current_x_width]; + let m0 = m_compact[(2 * j) & current_x_mask]; + let m1 = m_compact[(2 * j + 1) & current_x_mask]; + let p0 = a0 * m0; + let p1 = a1 * m1; + let dp = p1 - p0; + accum_small_signed::(&mut rel, 0, p0, w0_i64); + accum_small_signed::(&mut rel, 2, dp, w0_i64); + accum_small_signed::(&mut rel, 2, p0, dw_i64); + accum_small_signed::(&mut rel, 4, dp, dw_i64); + } + + let reduced_inner: [E; 2] = reduce_compact_virt_skip_linear(inner_virt); + let e_out = e_second[j_high]; + virt[0] += e_out * reduced_inner[0]; + virt[1] += e_out * reduced_inner[1]; + + (virt, rel) + }, + |(mut va, mut ra), (vb, rb)| { + for (ai, bi) in va.iter_mut().zip(vb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (va, ra) + } + ); + + ( + NormRoundTerms::SkipLinear(virt_coeffs), + reduce_compact_rel(rel_accum), + ) + } else { + let (virt_coeffs, rel_accum) = cfg_fold_reduce!( + 0..num_second, + || ([E::zero(); 3], [E::MulU64Accum::ZERO; 6]), + |(mut virt, mut rel), j_high| { + let mut inner_virt = [E::MulU64Accum::ZERO; 4]; + let base = j_high * num_first; + + for (j_low, &e_in) in e_first.iter().enumerate() { + let j = base + j_low; + let w0 = w_compact[2 * j] as i32; + let w1 = w_compact[2 * j + 1] as i32; + let dw = w1 - w0; + let w0_i64 = w0 as i64; + let dw_i64 = dw as i64; + + let q0 = w0_i64 * (w0_i64 + 1); + if q0 != 0 { + inner_virt[0] += e_in.mul_u64_unreduced(q0 as u64); + } + let q1 = dw_i64 * (2 * w0_i64 + 1); + accum_small_signed::(&mut inner_virt, 1, e_in, q1); + let q2 = dw_i64 * dw_i64; + if q2 != 0 { + inner_virt[3] += e_in.mul_u64_unreduced(q2 as u64); + } + + let a0 = alpha_compact[(2 * j) >> current_x_width]; + let a1 = alpha_compact[(2 * j + 1) >> current_x_width]; + let m0 = m_compact[(2 * j) & current_x_mask]; + let m1 = m_compact[(2 * j + 1) & current_x_mask]; + let p0 = a0 * m0; + let p1 = a1 * m1; + let dp = p1 - p0; + accum_small_signed::(&mut rel, 0, p0, w0_i64); + accum_small_signed::(&mut rel, 2, dp, w0_i64); + accum_small_signed::(&mut rel, 2, p0, dw_i64); + accum_small_signed::(&mut rel, 4, dp, dw_i64); + } + + let reduced_inner: [E; 3] = reduce_compact_virt(inner_virt); + let e_out = e_second[j_high]; + virt[0] += e_out * reduced_inner[0]; + virt[1] += e_out * reduced_inner[1]; + virt[2] += e_out * reduced_inner[2]; + + (virt, rel) + }, + |(mut va, mut ra), (vb, rb)| { + for (ai, bi) in va.iter_mut().zip(vb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (va, ra) + } + ); + + ( + NormRoundTerms::Full(virt_coeffs), + reduce_compact_rel(rel_accum), + ) + } + } + + #[tracing::instrument( + skip_all, + name = "HachiStage2Prover::compute_round_compact_prefix_x_terms" + )] + fn compute_round_compact_prefix_x_terms( + &self, + w_compact: &[i8], + ) -> (NormRoundTerms, [E; 3]) { + debug_assert!(self.rounds_completed < self.num_u); + debug_assert_eq!(w_compact.len(), self.live_x_cols * self.alpha_compact.len()); + + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros() as usize; + let current_x_half = 1usize << (self.current_x_width() - 1); + let live_pairs = self.live_x_cols.div_ceil(2); + let block_size = num_first.min(live_pairs); + let alpha_compact = &self.alpha_compact; + let m_compact = &self.m_compact; + debug_assert_eq!(m_compact.len(), self.current_x_len()); + + if self.can_skip_norm_linear_coeff() { + let (virt_coeffs, rel_accum) = cfg_fold_reduce!( + 0..alpha_compact.len(), + || ([E::zero(); 2], [E::MulU64Accum::ZERO; 6]), + |(mut virt, mut rel), y| { + let row_start = y * self.live_x_cols; + let row = &w_compact[row_start..row_start + self.live_x_cols]; + let alpha = alpha_compact[y]; + let j_base = y * current_x_half; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_virt = [E::MulU64Accum::ZERO; 2]; + + for pair_x in blk..blk_end { + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + let left = 2 * pair_x; + let w0 = row[left] as i32; + let w1 = if left + 1 < self.live_x_cols { + row[left + 1] as i32 + } else { + 0 + }; + let dw = w1 - w0; + let w0_i64 = w0 as i64; + let dw_i64 = dw as i64; + + let q0 = w0_i64 * (w0_i64 + 1); + if q0 != 0 { + inner_virt[0] += e_in.mul_u64_unreduced(q0 as u64); + } + let q2 = dw_i64 * dw_i64; + if q2 != 0 { + inner_virt[1] += e_in.mul_u64_unreduced(q2 as u64); + } + + let m0 = m_compact[left]; + let m1 = m_compact[left + 1]; + let p0 = alpha * m0; + let p1 = alpha * m1; + let dp = p1 - p0; + accum_small_signed::(&mut rel, 0, p0, w0_i64); + accum_small_signed::(&mut rel, 2, dp, w0_i64); + accum_small_signed::(&mut rel, 2, p0, dw_i64); + accum_small_signed::(&mut rel, 4, dp, dw_i64); + } + + let reduced_inner: [E; 2] = reduce_compact_virt_skip_linear(inner_virt); + let e_out = e_second[j_high]; + virt[0] += e_out * reduced_inner[0]; + virt[1] += e_out * reduced_inner[1]; + + blk = blk_end; + } + (virt, rel) + }, + |(mut va, mut ra), (vb, rb)| { + for (ai, bi) in va.iter_mut().zip(vb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (va, ra) + } + ); + + ( + NormRoundTerms::SkipLinear(virt_coeffs), + reduce_compact_rel(rel_accum), + ) + } else { + let (virt_coeffs, rel_accum) = cfg_fold_reduce!( + 0..alpha_compact.len(), + || ([E::zero(); 3], [E::MulU64Accum::ZERO; 6]), + |(mut virt, mut rel), y| { + let row_start = y * self.live_x_cols; + let row = &w_compact[row_start..row_start + self.live_x_cols]; + let alpha = alpha_compact[y]; + let j_base = y * current_x_half; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_virt = [E::MulU64Accum::ZERO; 4]; + + for pair_x in blk..blk_end { + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + let left = 2 * pair_x; + let w0 = row[left] as i32; + let w1 = if left + 1 < self.live_x_cols { + row[left + 1] as i32 + } else { + 0 + }; + let dw = w1 - w0; + let w0_i64 = w0 as i64; + let dw_i64 = dw as i64; + + let q0 = w0_i64 * (w0_i64 + 1); + if q0 != 0 { + inner_virt[0] += e_in.mul_u64_unreduced(q0 as u64); + } + let q1 = dw_i64 * (2 * w0_i64 + 1); + accum_small_signed::(&mut inner_virt, 1, e_in, q1); + let q2 = dw_i64 * dw_i64; + if q2 != 0 { + inner_virt[3] += e_in.mul_u64_unreduced(q2 as u64); + } + + let m0 = m_compact[left]; + let m1 = m_compact[left + 1]; + let p0 = alpha * m0; + let p1 = alpha * m1; + let dp = p1 - p0; + accum_small_signed::(&mut rel, 0, p0, w0_i64); + accum_small_signed::(&mut rel, 2, dp, w0_i64); + accum_small_signed::(&mut rel, 2, p0, dw_i64); + accum_small_signed::(&mut rel, 4, dp, dw_i64); + } + + let reduced_inner: [E; 3] = reduce_compact_virt(inner_virt); + let e_out = e_second[j_high]; + virt[0] += e_out * reduced_inner[0]; + virt[1] += e_out * reduced_inner[1]; + virt[2] += e_out * reduced_inner[2]; + + blk = blk_end; + } + (virt, rel) + }, + |(mut va, mut ra), (vb, rb)| { + for (ai, bi) in va.iter_mut().zip(vb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (va, ra) + } + ); + + ( + NormRoundTerms::Full(virt_coeffs), + reduce_compact_rel(rel_accum), + ) + } + } + + #[tracing::instrument( + skip_all, + name = "HachiStage2Prover::compute_round_full_prefix_x_terms" + )] + fn compute_round_full_prefix_x_terms(&self, w_full: &[E]) -> (NormRoundTerms, [E; 3]) { + debug_assert!(self.rounds_completed < self.num_u); + debug_assert_eq!(w_full.len(), self.live_x_cols * self.alpha_compact.len()); + + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros() as usize; + let current_x_half = 1usize << (self.current_x_width() - 1); + let live_pairs = self.live_x_cols.div_ceil(2); + let block_size = num_first.min(live_pairs); + let alpha_compact = &self.alpha_compact; + let m_compact = &self.m_compact; + debug_assert_eq!(m_compact.len(), self.current_x_len()); + + if self.can_skip_norm_linear_coeff() { + let (virt_coeffs, rel_coeffs) = cfg_fold_reduce!( + 0..alpha_compact.len(), + || ([E::zero(); 2], [E::zero(); 3]), + |(mut virt, mut rel), y| { + let row_start = y * self.live_x_cols; + let row = &w_full[row_start..row_start + self.live_x_cols]; + let alpha = alpha_compact[y]; + let j_base = y * current_x_half; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_virt = [E::zero(); 2]; + + for pair_x in blk..blk_end { + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + let left = 2 * pair_x; + let w0 = row[left]; + let w1 = if left + 1 < self.live_x_cols { + row[left + 1] + } else { + E::zero() + }; + let dw = w1 - w0; + + inner_virt[0] += e_in * (w0 * (w0 + E::one())); + inner_virt[1] += e_in * (dw * dw); + + let m0 = m_compact[left]; + let m1 = m_compact[left + 1]; + let p0 = alpha * m0; + let p1 = alpha * m1; + let dp = p1 - p0; + rel[0] += w0 * p0; + rel[1] += w0 * dp + dw * p0; + rel[2] += dw * dp; + } + + let e_out = e_second[j_high]; + virt[0] += e_out * inner_virt[0]; + virt[1] += e_out * inner_virt[1]; + + blk = blk_end; + } + (virt, rel) + }, + |(mut va, mut ra), (vb, rb)| { + for (ai, bi) in va.iter_mut().zip(vb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (va, ra) + } + ); + (NormRoundTerms::SkipLinear(virt_coeffs), rel_coeffs) + } else { + let (virt_coeffs, rel_coeffs) = cfg_fold_reduce!( + 0..alpha_compact.len(), + || ([E::zero(); 3], [E::zero(); 3]), + |(mut virt, mut rel), y| { + let row_start = y * self.live_x_cols; + let row = &w_full[row_start..row_start + self.live_x_cols]; + let alpha = alpha_compact[y]; + let j_base = y * current_x_half; + + let mut blk = 0usize; + while blk < live_pairs { + let blk_end = (blk + block_size).min(live_pairs); + let j_high = (j_base + blk) >> first_bits; + let mut inner_virt = [E::zero(); 3]; + + for pair_x in blk..blk_end { + let j_low = (j_base + pair_x) & (num_first - 1); + let e_in = e_first[j_low]; + let left = 2 * pair_x; + let w0 = row[left]; + let w1 = if left + 1 < self.live_x_cols { + row[left + 1] + } else { + E::zero() + }; + let dw = w1 - w0; + let two_w0_plus_one = w0 + w0 + E::one(); + + inner_virt[0] += e_in * (w0 * (w0 + E::one())); + inner_virt[1] += e_in * (dw * two_w0_plus_one); + inner_virt[2] += e_in * (dw * dw); + + let m0 = m_compact[left]; + let m1 = m_compact[left + 1]; + let p0 = alpha * m0; + let p1 = alpha * m1; + let dp = p1 - p0; + rel[0] += w0 * p0; + rel[1] += w0 * dp + dw * p0; + rel[2] += dw * dp; + } + + let e_out = e_second[j_high]; + virt[0] += e_out * inner_virt[0]; + virt[1] += e_out * inner_virt[1]; + virt[2] += e_out * inner_virt[2]; + + blk = blk_end; + } + (virt, rel) + }, + |(mut va, mut ra), (vb, rb)| { + for (ai, bi) in va.iter_mut().zip(vb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (va, ra) + } + ); + (NormRoundTerms::Full(virt_coeffs), rel_coeffs) + } + } + + #[tracing::instrument(skip_all, name = "HachiStage2Prover::compute_round_full_dense_terms")] + fn compute_round_full_dense_terms(&self, w_full: &[E]) -> (NormRoundTerms, [E; 3]) { + let (e_first, e_second) = self.split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let num_second = e_second.len(); + let current_x_width = self.current_x_width(); + let current_x_mask = (1usize << current_x_width).wrapping_sub(1); + let alpha_compact = &self.alpha_compact; + let m_compact = &self.m_compact; + debug_assert_eq!(w_full.len() / 2, num_first * num_second); + + if self.can_skip_norm_linear_coeff() { + let (virt_coeffs, rel_coeffs) = cfg_fold_reduce!( + 0..num_second, + || ([E::zero(); 2], [E::zero(); 3]), + |(mut virt, mut rel), j_high| { + let mut inner_virt = [E::zero(); 2]; + let base = j_high * num_first; + + for (j_low, &e_in) in e_first.iter().enumerate() { + let j = base + j_low; + let w0 = w_full[2 * j]; + let w1 = w_full[2 * j + 1]; + let dw = w1 - w0; + + inner_virt[0] += e_in * (w0 * (w0 + E::one())); + inner_virt[1] += e_in * (dw * dw); + + let a0 = alpha_compact[(2 * j) >> current_x_width]; + let a1 = alpha_compact[(2 * j + 1) >> current_x_width]; + let m0 = m_compact[(2 * j) & current_x_mask]; + let m1 = m_compact[(2 * j + 1) & current_x_mask]; + let p0 = a0 * m0; + let p1 = a1 * m1; + let dp = p1 - p0; + rel[0] += w0 * p0; + rel[1] += w0 * dp + dw * p0; + rel[2] += dw * dp; + } + + let e_out = e_second[j_high]; + virt[0] += e_out * inner_virt[0]; + virt[1] += e_out * inner_virt[1]; + + (virt, rel) + }, + |(mut va, mut ra), (vb, rb)| { + for (ai, bi) in va.iter_mut().zip(vb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (va, ra) + } + ); + (NormRoundTerms::SkipLinear(virt_coeffs), rel_coeffs) + } else { + let (virt_coeffs, rel_coeffs) = cfg_fold_reduce!( + 0..num_second, + || ([E::zero(); 3], [E::zero(); 3]), + |(mut virt, mut rel), j_high| { + let mut inner_virt = [E::zero(); 3]; + let base = j_high * num_first; + + for (j_low, &e_in) in e_first.iter().enumerate() { + let j = base + j_low; + let w0 = w_full[2 * j]; + let w1 = w_full[2 * j + 1]; + let dw = w1 - w0; + let two_w0_plus_one = w0 + w0 + E::one(); + + inner_virt[0] += e_in * (w0 * (w0 + E::one())); + inner_virt[1] += e_in * (dw * two_w0_plus_one); + inner_virt[2] += e_in * (dw * dw); + + let a0 = alpha_compact[(2 * j) >> current_x_width]; + let a1 = alpha_compact[(2 * j + 1) >> current_x_width]; + let m0 = m_compact[(2 * j) & current_x_mask]; + let m1 = m_compact[(2 * j + 1) & current_x_mask]; + let p0 = a0 * m0; + let p1 = a1 * m1; + let dp = p1 - p0; + rel[0] += w0 * p0; + rel[1] += w0 * dp + dw * p0; + rel[2] += dw * dp; + } + + let e_out = e_second[j_high]; + virt[0] += e_out * inner_virt[0]; + virt[1] += e_out * inner_virt[1]; + virt[2] += e_out * inner_virt[2]; + + (virt, rel) + }, + |(mut va, mut ra), (vb, rb)| { + for (ai, bi) in va.iter_mut().zip(vb.iter()) { + *ai += *bi; + } + for (ai, bi) in ra.iter_mut().zip(rb.iter()) { + *ai += *bi; + } + (va, ra) + } + ); + (NormRoundTerms::Full(virt_coeffs), rel_coeffs) + } + } + + fn compute_round_compact_prefix_x_polys(&self, w_compact: &[i8]) -> (UniPoly, UniPoly) { + let (virt_q_coeffs, rel_coeffs) = self.compute_round_compact_prefix_x_terms(w_compact); + self.polys_from_terms(virt_q_coeffs, rel_coeffs) + } + + #[cfg(test)] + fn compute_round_compact_dense_polys(&self, w_compact: &[i8]) -> (UniPoly, UniPoly) { + let (virt_q_coeffs, rel_coeffs) = self.compute_round_compact_dense_terms(w_compact); + self.polys_from_terms(virt_q_coeffs, rel_coeffs) + } + + #[inline] + fn build_compact_w_fold_lut(w_compact: &[i8], r: E) -> CompactPairFoldLut { + let min_w = w_compact + .iter() + .copied() + .map(i32::from) + .min() + .unwrap_or(0) + .min(0); + let max_w = w_compact + .iter() + .copied() + .map(i32::from) + .max() + .unwrap_or(0) + .max(0); + CompactPairFoldLut::from_contiguous_range(min_w, max_w, r) + } + + fn fold_compact_prefix_x( + w_compact: &[i8], + live_x_cols: usize, + y_len: usize, + fold_lut: &CompactPairFoldLut, + ) -> Vec { + let next_live_x_cols = live_x_cols.div_ceil(2); + let mut out = vec![E::zero(); y_len * next_live_x_cols]; + + #[cfg(feature = "parallel")] + out.par_chunks_mut(next_live_x_cols) + .enumerate() + .for_each(|(y, row_out)| { + let row_start = y * live_x_cols; + let row = &w_compact[row_start..row_start + live_x_cols]; + for (pair_x, dst) in row_out.iter_mut().enumerate() { + let left = 2 * pair_x; + let w_1 = if left + 1 < live_x_cols { + row[left + 1] as i32 + } else { + 0 + }; + *dst = fold_lut.fold(row[left] as i32, w_1); + } + }); + + #[cfg(not(feature = "parallel"))] + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + let row_start = y * live_x_cols; + let row = &w_compact[row_start..row_start + live_x_cols]; + for (pair_x, dst) in row_out.iter_mut().enumerate() { + let left = 2 * pair_x; + let w_1 = if left + 1 < live_x_cols { + row[left + 1] as i32 + } else { + 0 + }; + *dst = fold_lut.fold(row[left] as i32, w_1); + } + } + + out + } + + fn fold_full_prefix_x(w_full: &[E], live_x_cols: usize, y_len: usize, r: E) -> Vec { + let next_live_x_cols = live_x_cols.div_ceil(2); + let mut out = vec![E::zero(); y_len * next_live_x_cols]; + + #[cfg(feature = "parallel")] + out.par_chunks_mut(next_live_x_cols) + .enumerate() + .for_each(|(y, row_out)| { + let row_start = y * live_x_cols; + let row = &w_full[row_start..row_start + live_x_cols]; + for (pair_x, dst) in row_out.iter_mut().enumerate() { + let left = 2 * pair_x; + let w_0 = row[left]; + let w_1 = if left + 1 < live_x_cols { + row[left + 1] + } else { + E::zero() + }; + *dst = w_0 + r * (w_1 - w_0); + } + }); + + #[cfg(not(feature = "parallel"))] + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + let row_start = y * live_x_cols; + let row = &w_full[row_start..row_start + live_x_cols]; + for (pair_x, dst) in row_out.iter_mut().enumerate() { + let left = 2 * pair_x; + let w_0 = row[left]; + let w_1 = if left + 1 < live_x_cols { + row[left + 1] + } else { + E::zero() + }; + *dst = w_0 + r * (w_1 - w_0); + } + } + + out + } + + fn fold_m_prefix(m_compact: &[E], r: E) -> Vec { + debug_assert!(m_compact.len().is_power_of_two()); + debug_assert!(m_compact.len() >= 2); + let next_x_len = m_compact.len() >> 1; + cfg_into_iter!(0..next_x_len) + .map(|pair_x| { + let left = 2 * pair_x; + let m_0 = m_compact[left]; + let m_1 = m_compact[left + 1]; + m_0 + r * (m_1 - m_0) + }) + .collect() + } + + fn fold_compact_to_full(w_compact: &[i8], fold_lut: &CompactPairFoldLut) -> Vec { + cfg_into_iter!(0..w_compact.len() / 2) + .map(|j| fold_lut.fold(w_compact[2 * j] as i32, w_compact[2 * j + 1] as i32)) + .collect() + } +} + +impl SumcheckInstanceProver + for HachiStage2Prover +{ + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 3 + } + + fn input_claim(&self) -> E { + self.batching_coeff * self.s_claim + self.relation_claim + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + if let Some(poly) = self.cached_round_poly.take() { + poly + } else { + self.compute_current_round_poly_from_state() + } + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + let t_fold = Instant::now(); + let _span = tracing::info_span!("HachiStage2Prover::fold_round").entered(); + if let Some(prev_norm_poly) = self.prev_norm_poly.take() { + self.prev_norm_claim = prev_norm_poly.evaluate(&r); + } + + if self.using_two_round_prefix() { + let rounds_completed = self.rounds_completed; + self.split_eq.bind(r); + if rounds_completed == 0 { + self.ensure_two_round_prefix().first_challenge = Some(r); + } else { + let r0 = { + let prefix = self.ensure_two_round_prefix(); + prefix + .first_challenge + .expect("round 1 ingest requires the round 0 challenge") + }; + let y_len = self.alpha_compact.len(); + self.w_table = match mem::replace(&mut self.w_table, WTable::Full(Vec::new())) { + WTable::Compact(w_compact) => { + if self.num_u > 2 { + let (w_full, m_round2, virt_terms, rel_coeffs) = + self.fuse_compact_to_round2_and_compute_round(&w_compact, r0, r); + self.m_compact = m_round2; + self.cached_round_poly = + Some(self.combine_terms(virt_terms, rel_coeffs)); + WTable::Full(w_full) + } else { + self.m_compact = Self::fold_m_to_round2(&self.m_compact, r0, r); + WTable::Full(Self::fold_compact_to_round2( + &w_compact, + self.live_x_cols, + y_len, + r0, + r, + )) + } + } + WTable::Full(_) => unreachable!("two-round prefix should hold compact witness"), + }; + self.live_x_cols = self.live_x_cols.div_ceil(4); + } + self.rounds_completed += 1; + if self.rounds_completed < self.num_vars { + if self.cached_round_poly.is_none() { + self.cached_round_poly = Some(self.compute_current_round_poly_from_state()); + } + } else { + self.cached_round_poly = None; + } + drop(_span); + self.fold_time_total += t_fold.elapsed().as_secs_f64(); + if self.rounds_completed == self.num_vars { + tracing::debug!( + rounds = self.num_vars, + scan_s = self.scan_time_total, + fold_s = self.fold_time_total, + "stage2 sumcheck rounds complete" + ); + } + return; + } + + self.split_eq.bind(r); + let folding_x_round = self.rounds_completed < self.num_u; + let use_prefix_x_round = self.use_prefix_x_round(); + let fuse_next_full_prefix_x = + use_prefix_x_round && self.next_use_prefix_x_round_after_current(); + let y_len = self.alpha_compact.len(); + let mut fused_full_prefix_x = false; + + self.w_table = match mem::replace(&mut self.w_table, WTable::Full(Vec::new())) { + WTable::Compact(w_compact) => { + let fold_lut = Self::build_compact_w_fold_lut(&w_compact, r); + let w_full = if use_prefix_x_round { + Self::fold_compact_prefix_x(&w_compact, self.live_x_cols, y_len, &fold_lut) + } else { + Self::fold_compact_to_full(&w_compact, &fold_lut) + }; + WTable::Full(w_full) + } + WTable::Full(w_full) => { + if use_prefix_x_round { + if fuse_next_full_prefix_x { + let (next_w_full, next_m_compact, virt_terms, rel_coeffs) = + self.fuse_full_prefix_x_and_compute_round(&w_full, r); + self.m_compact = next_m_compact; + self.cached_round_poly = Some(self.combine_terms(virt_terms, rel_coeffs)); + fused_full_prefix_x = true; + WTable::Full(next_w_full) + } else { + let next_w_full = + Self::fold_full_prefix_x(&w_full, self.live_x_cols, y_len, r); + WTable::Full(next_w_full) + } + } else { + let mut w_full = w_full; + fold_evals_in_place(&mut w_full, r); + WTable::Full(w_full) + } + } + }; + + if folding_x_round { + if use_prefix_x_round { + if !fused_full_prefix_x { + self.m_compact = Self::fold_m_prefix(&self.m_compact, r); + } + } else { + fold_evals_in_place(&mut self.m_compact, r); + } + self.live_x_cols = self.live_x_cols.div_ceil(2); + } else { + fold_evals_in_place(&mut self.alpha_compact, r); + } + + self.rounds_completed += 1; + if self.rounds_completed < self.num_vars { + if self.cached_round_poly.is_none() { + self.cached_round_poly = Some(self.compute_current_round_poly_from_state()); + } + } else { + self.cached_round_poly = None; + } + drop(_span); + self.fold_time_total += t_fold.elapsed().as_secs_f64(); + + if self.rounds_completed == self.num_vars { + tracing::debug!( + rounds = self.num_vars, + scan_s = self.scan_time_total, + fold_s = self.fold_time_total, + "stage2 sumcheck rounds complete" + ); + } + } +} + +/// Source of the witness oracle used by the stage-2 verifier. +enum Stage2WitnessOracle { + Full(Vec), + ClaimedEval(F), +} + +/// Verifier for the stage-2 fused virtual-claim + relation sumcheck. +pub struct HachiStage2Verifier { + batching_coeff: F, + s_claim: F, + witness_oracle: Stage2WitnessOracle, + r_stage1: Vec, + alpha_evals_y: Vec, + m_evals_x: Vec, + num_u: usize, + num_l: usize, + relation_claim: F, + _marker: PhantomData<[F; D]>, +} + +impl HachiStage2Verifier { + #[allow(clippy::too_many_arguments)] + fn new( + batching_coeff: F, + s_claim: F, + witness_oracle: Stage2WitnessOracle, + r_stage1: Vec, + alpha_evals_y: Vec, + m_evals_x: Vec, + tau1: Vec, + v: Vec>, + u: Vec>, + y_ring: CyclotomicRing, + alpha: F, + num_u: usize, + num_l: usize, + ) -> Self { + let relation_claim = relation_claim_from_rows::(&tau1, alpha, &v, &u, &y_ring); + Self { + batching_coeff, + s_claim, + witness_oracle, + r_stage1, + alpha_evals_y, + m_evals_x, + num_u, + num_l, + relation_claim, + _marker: PhantomData, + } + } + + /// Create a fused verifier for the stage-2 sumcheck when the verifier has the full witness. + #[allow(clippy::too_many_arguments)] + #[tracing::instrument(skip_all, name = "HachiStage2Verifier::new_with_full_witness")] + pub fn new_with_full_witness( + batching_coeff: F, + s_claim: F, + w_evals: Vec, + r_stage1: Vec, + alpha_evals_y: Vec, + m_evals_x: Vec, + tau1: Vec, + v: Vec>, + u: Vec>, + y_ring: CyclotomicRing, + alpha: F, + num_u: usize, + num_l: usize, + ) -> Self { + Self::new( + batching_coeff, + s_claim, + Stage2WitnessOracle::Full(w_evals), + r_stage1, + alpha_evals_y, + m_evals_x, + tau1, + v, + u, + y_ring, + alpha, + num_u, + num_l, + ) + } + + /// Create a fused verifier for the stage-2 sumcheck when only the final witness evaluation is available. + #[allow(clippy::too_many_arguments)] + #[tracing::instrument(skip_all, name = "HachiStage2Verifier::new_with_claimed_w_eval")] + pub fn new_with_claimed_w_eval( + batching_coeff: F, + s_claim: F, + w_eval: F, + r_stage1: Vec, + alpha_evals_y: Vec, + m_evals_x: Vec, + tau1: Vec, + v: Vec>, + u: Vec>, + y_ring: CyclotomicRing, + alpha: F, + num_u: usize, + num_l: usize, + ) -> Self { + Self::new( + batching_coeff, + s_claim, + Stage2WitnessOracle::ClaimedEval(w_eval), + r_stage1, + alpha_evals_y, + m_evals_x, + tau1, + v, + u, + y_ring, + alpha, + num_u, + num_l, + ) + } +} + +impl SumcheckInstanceVerifier + for HachiStage2Verifier +{ + fn num_rounds(&self) -> usize { + self.num_u + self.num_l + } + + fn degree_bound(&self) -> usize { + 3 + } + + fn input_claim(&self) -> F { + self.batching_coeff * self.s_claim + self.relation_claim + } + + fn expected_output_claim(&self, challenges: &[F]) -> Result { + let eq_val = EqPolynomial::mle(&self.r_stage1, challenges); + let w_eval = match &self.witness_oracle { + Stage2WitnessOracle::Full(w_evals) => multilinear_eval(w_evals, challenges)?, + Stage2WitnessOracle::ClaimedEval(w_eval) => *w_eval, + }; + let virtual_oracle = eq_val * w_eval * (w_eval + F::one()); + + let (x_challenges, y_challenges) = challenges.split_at(self.num_u); + let alpha_val = multilinear_eval(&self.alpha_evals_y, y_challenges)?; + let m_val = multilinear_eval(&self.m_evals_x, x_challenges)?; + let relation_oracle = w_eval * alpha_val * m_val; + + Ok(self.batching_coeff * virtual_oracle + relation_oracle) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Prime128M8M4M1M0; + use crate::protocol::sumcheck::multilinear_eval; + + type F = Prime128M8M4M1M0; + + #[derive(Clone, Copy)] + struct Stage2Params<'a> { + r_stage1: &'a [F], + live_x_cols: usize, + num_u: usize, + num_l: usize, + } + + fn s_claim_from_compact_rows(w_compact: &[i8], params: &Stage2Params<'_>) -> F { + let padded = if params.live_x_cols == (1usize << params.num_u) { + w_compact.to_vec() + } else { + pad_compact_rows(w_compact, params.live_x_cols, params.num_u, params.num_l) + }; + let s_evals: Vec = padded + .iter() + .map(|&w| { + let w = F::from_i64(w as i64); + w * (w + F::one()) + }) + .collect(); + multilinear_eval(&s_evals, params.r_stage1).expect("valid stage-2 witness shape") + } + + fn new_stage2_test_prover( + batching_coeff: F, + w_compact: Vec, + alpha_evals_y: Vec, + m_evals_x: Vec, + params: Stage2Params<'_>, + ) -> HachiStage2Prover { + let s_claim = s_claim_from_compact_rows(&w_compact, ¶ms); + HachiStage2Prover::new( + batching_coeff, + w_compact, + params.r_stage1, + s_claim, + alpha_evals_y, + m_evals_x, + params.live_x_cols, + params.num_u, + params.num_l, + F::zero(), + ) + } + + fn relation_round_reference( + w_compact: &[i8], + alpha_compact: &[F], + m_compact: &[F], + num_u: usize, + ) -> UniPoly { + let half = w_compact.len() / 2; + let current_x_mask = (1usize << num_u).wrapping_sub(1); + let mut evals = [F::zero(); 3]; + for j in 0..half { + let w_0 = F::from_i64(w_compact[2 * j] as i64); + let w_1 = F::from_i64(w_compact[2 * j + 1] as i64); + let a_0 = alpha_compact[(2 * j) >> num_u]; + let a_1 = alpha_compact[(2 * j + 1) >> num_u]; + let m_0 = m_compact[(2 * j) & current_x_mask]; + let m_1 = m_compact[(2 * j + 1) & current_x_mask]; + evals[0] += w_0 * a_0 * m_0; + evals[1] += w_1 * a_1 * m_1; + let w_2 = w_1 + w_1 - w_0; + let a_2 = a_1 + a_1 - a_0; + let m_2 = m_1 + m_1 - m_0; + evals[2] += w_2 * a_2 * m_2; + } + UniPoly::from_evals(&evals) + } + + fn virtual_round_reference(split_eq: &GruenSplitEq, w_compact: &[i8]) -> UniPoly { + let half = w_compact.len() / 2; + let (e_first, e_second) = split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + let first_bits = num_first.trailing_zeros(); + let mut evals = [F::zero(); 3]; + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> first_bits; + let eq_rem = e_first[j_low] * e_second[j_high]; + let w_0 = F::from_i64(w_compact[2 * j] as i64); + let w_1 = F::from_i64(w_compact[2 * j + 1] as i64); + let w_2 = w_1 + w_1 - w_0; + evals[0] += eq_rem * w_0 * (w_0 + F::one()); + evals[1] += eq_rem * w_1 * (w_1 + F::one()); + evals[2] += eq_rem * w_2 * (w_2 + F::one()); + } + split_eq.gruen_mul(&UniPoly::from_evals(&evals)) + } + + fn pad_compact_rows( + w_prefix: &[i8], + live_x_cols: usize, + num_u: usize, + num_l: usize, + ) -> Vec { + let x_len = 1usize << num_u; + let y_len = 1usize << num_l; + let mut padded = vec![0i8; x_len * y_len]; + for y in 0..y_len { + let src_start = y * live_x_cols; + let dst_start = y * x_len; + padded[dst_start..dst_start + live_x_cols] + .copy_from_slice(&w_prefix[src_start..src_start + live_x_cols]); + } + padded + } + + fn fold_compact_prefix_x_reference( + w_compact: &[i8], + live_x_cols: usize, + y_len: usize, + r: F, + ) -> Vec { + let next_live_x_cols = live_x_cols.div_ceil(2); + let mut out = vec![F::zero(); y_len * next_live_x_cols]; + for (y, row_out) in out.chunks_mut(next_live_x_cols).enumerate() { + let row_start = y * live_x_cols; + let row = &w_compact[row_start..row_start + live_x_cols]; + for (pair_x, dst) in row_out.iter_mut().enumerate() { + let left = 2 * pair_x; + let w_0 = F::from_i64(row[left] as i64); + let w_1 = if left + 1 < live_x_cols { + F::from_i64(row[left + 1] as i64) + } else { + F::zero() + }; + *dst = w_0 + r * (w_1 - w_0); + } + } + out + } + + fn fold_compact_to_full_reference(w_compact: &[i8], r: F) -> Vec { + (0..w_compact.len() / 2) + .map(|j| { + let w_0 = F::from_i64(w_compact[2 * j] as i64); + let w_1 = F::from_i64(w_compact[2 * j + 1] as i64); + w_0 + r * (w_1 - w_0) + }) + .collect() + } + + #[test] + fn stage2_compact_fold_lookup_matches_direct_formula() { + let r = F::from_u64(53); + + let w_prefix = vec![1, 2, 3, 1, 2, 3, 1, 2, 3, 1]; + let fold_lut = HachiStage2Prover::::build_compact_w_fold_lut(&w_prefix, r); + assert_eq!( + HachiStage2Prover::::fold_compact_prefix_x(&w_prefix, 5, 2, &fold_lut), + fold_compact_prefix_x_reference(&w_prefix, 5, 2, r) + ); + + let w_dense = vec![1, 2, 3, 1, 2, 3]; + let dense_lut = HachiStage2Prover::::build_compact_w_fold_lut(&w_dense, r); + assert_eq!( + HachiStage2Prover::::fold_compact_to_full(&w_dense, &dense_lut), + fold_compact_to_full_reference(&w_dense, r) + ); + } + + #[test] + fn stage2_compact_round0_matches_unfused_reference() { + let num_u = 3usize; + let num_l = 2usize; + let b = 8usize; + let n = 1usize << (num_u + num_l); + let half = (b / 2) as i8; + let w_compact: Vec = (0..n).map(|i| ((i * 5 + 3) % b) as i8 - half).collect(); + let r_stage1: Vec = (0..(num_u + num_l)) + .map(|i| F::from_u64((i as u64) + 2)) + .collect(); + let alpha_evals_y: Vec = (0..(1usize << num_l)) + .map(|i| F::from_u64((3 * i as u64) + 5)) + .collect(); + let m_evals_x: Vec = (0..(1usize << num_u)) + .map(|i| F::from_u64((7 * i as u64) + 11)) + .collect(); + + let prover = new_stage2_test_prover( + F::from_u64(13), + w_compact.clone(), + alpha_evals_y.clone(), + m_evals_x.clone(), + Stage2Params { + r_stage1: &r_stage1, + live_x_cols: 1usize << num_u, + num_u, + num_l, + }, + ); + let (virt_poly, relation_poly) = prover.compute_round_compact_dense_polys(&w_compact); + let virt_ref = virtual_round_reference(&prover.split_eq, &w_compact); + let relation_ref = relation_round_reference(&w_compact, &alpha_evals_y, &m_evals_x, num_u); + + assert_eq!(virt_poly, virt_ref, "compact virtual round mismatch"); + assert_eq!( + relation_poly, relation_ref, + "compact relation round mismatch" + ); + } + + #[test] + fn stage2_prefix_aware_rounds_match_explicit_full_m_table() { + let num_l = 2usize; + let b = 8usize; + let half = (b / 2) as i8; + + for live_x_cols in [5usize, 6usize] { + let num_u = live_x_cols.next_power_of_two().trailing_zeros() as usize; + let x_len = 1usize << num_u; + let y_len = 1usize << num_l; + let w_prefix: Vec = (0..(live_x_cols * y_len)) + .map(|i| ((i * 7 + 5) % b) as i8 - half) + .collect(); + let w_padded = pad_compact_rows(&w_prefix, live_x_cols, num_u, num_l); + let r_stage1: Vec = (0..(num_u + num_l)) + .map(|i| F::from_u64((i as u64) + 31)) + .collect(); + let alpha_evals_y: Vec = (0..y_len) + .map(|i| F::from_u64((5 * i as u64) + 7)) + .collect(); + let m_evals_x: Vec = (0..x_len) + .map(|i| F::from_u64((11 * i as u64) + 13)) + .collect(); + + let mut prefix_prover = new_stage2_test_prover( + F::from_u64(17), + w_prefix.clone(), + alpha_evals_y.clone(), + m_evals_x.clone(), + Stage2Params { + r_stage1: &r_stage1, + live_x_cols, + num_u, + num_l, + }, + ); + let mut padded_prover = new_stage2_test_prover( + F::from_u64(17), + w_padded.clone(), + alpha_evals_y.clone(), + m_evals_x.clone(), + Stage2Params { + r_stage1: &r_stage1, + live_x_cols: 1usize << num_u, + num_u, + num_l, + }, + ); + let mut prefix_claim = prefix_prover.input_claim(); + let mut padded_claim = padded_prover.input_claim(); + + for round in 0..(num_u + num_l) { + let prefix_poly = prefix_prover.compute_round_univariate(round, prefix_claim); + let padded_poly = padded_prover.compute_round_univariate(round, padded_claim); + assert_eq!( + prefix_poly, padded_poly, + "round {round} polynomial mismatch live_x_cols={live_x_cols}" + ); + + let challenge = F::from_u64((round as u64) + 37); + prefix_claim = prefix_poly.evaluate(&challenge); + padded_claim = padded_poly.evaluate(&challenge); + prefix_prover.ingest_challenge(round, challenge); + padded_prover.ingest_challenge(round, challenge); + } + + assert_eq!(prefix_prover.final_w_eval(), padded_prover.final_w_eval()); + assert_eq!(prefix_claim, padded_claim); + } + } + + #[test] + fn stage2_zero_gated_round0_matches_reference() { + let num_u = 3usize; + let num_l = 1usize; + let w_compact = vec![-1, 0, -1, 0, 0, -1, 0, -1, -1, 0, -1, 0, 0, -1, 0, -1]; + let r_stage1: Vec = (0..(num_u + num_l)) + .map(|i| F::from_u64((i as u64) + 41)) + .collect(); + let alpha_evals_y: Vec = (0..(1usize << num_l)) + .map(|i| F::from_u64((3 * i as u64) + 43)) + .collect(); + let m_evals_x: Vec = (0..(1usize << num_u)) + .map(|i| F::from_u64((5 * i as u64) + 47)) + .collect(); + + let prover = new_stage2_test_prover( + F::from_u64(19), + w_compact.clone(), + alpha_evals_y.clone(), + m_evals_x.clone(), + Stage2Params { + r_stage1: &r_stage1, + live_x_cols: 1usize << num_u, + num_u, + num_l, + }, + ); + let (virt_poly, relation_poly) = prover.compute_round_compact_dense_polys(&w_compact); + assert_eq!( + virt_poly, + virtual_round_reference(&prover.split_eq, &w_compact) + ); + assert_eq!( + relation_poly, + relation_round_reference(&w_compact, &alpha_evals_y, &m_evals_x, num_u) + ); + } + + #[test] + fn stage2_fused_round2_transition_matches_two_pass_reference() { + let num_u = 3usize; + let num_l = 2usize; + let live_x_cols = 6usize; + let b = 8usize; + let half = (b / 2) as i8; + let y_len = 1usize << num_l; + let w_prefix: Vec = (0..(live_x_cols * y_len)) + .map(|i| ((i * 11 + 7) % b) as i8 - half) + .collect(); + let r_stage1: Vec = (0..(num_u + num_l)) + .map(|i| F::from_u64((i as u64) + 71)) + .collect(); + let alpha_evals_y: Vec = (0..y_len) + .map(|i| F::from_u64((5 * i as u64) + 73)) + .collect(); + let m_evals_x: Vec = (0..(1usize << num_u)) + .map(|i| F::from_u64((13 * i as u64) + 79)) + .collect(); + let params = Stage2Params { + r_stage1: &r_stage1, + live_x_cols, + num_u, + num_l, + }; + + let mut prover = new_stage2_test_prover( + F::from_u64(83), + w_prefix.clone(), + alpha_evals_y.clone(), + m_evals_x.clone(), + params, + ); + let round0 = prover.compute_round_univariate(0, prover.input_claim()); + let r0 = F::from_u64(89); + prover.ingest_challenge(0, r0); + let round1 = prover.compute_round_univariate(1, round0.evaluate(&r0)); + let r1 = F::from_u64(97); + + let m_prefix = prover.m_compact.clone(); + let expected_w_full = + HachiStage2Prover::::fold_compact_to_round2(&w_prefix, live_x_cols, y_len, r0, r1); + let expected_m_round2 = HachiStage2Prover::::fold_m_to_round2(&m_prefix, r0, r1); + + let mut expected = new_stage2_test_prover( + F::from_u64(83), + w_prefix.clone(), + alpha_evals_y, + m_evals_x, + params, + ); + let expected_round0 = expected.compute_round_univariate(0, expected.input_claim()); + assert_eq!(expected_round0, round0); + expected.ingest_challenge(0, r0); + let expected_round1 = expected.compute_round_univariate(1, expected_round0.evaluate(&r0)); + assert_eq!(expected_round1, round1); + expected.prev_norm_claim = expected + .prev_norm_poly + .as_ref() + .expect("round1 norm poly should be cached") + .evaluate(&r1); + expected.split_eq.bind(r1); + expected.live_x_cols = live_x_cols.div_ceil(4); + expected.rounds_completed = 2; + expected.m_compact = expected_m_round2.clone(); + let (virt_terms, rel_coeffs) = expected.compute_round_full_prefix_x_terms(&expected_w_full); + let expected_round2 = expected.combine_terms(virt_terms, rel_coeffs); + + prover.ingest_challenge(1, r1); + + match &prover.w_table { + WTable::Full(w_full) => assert_eq!(w_full, &expected_w_full), + WTable::Compact(_) => { + panic!("expected fused stage2 transition to materialize full table") + } + } + assert_eq!(prover.m_compact, expected_m_round2); + assert_eq!(prover.cached_round_poly.as_ref(), Some(&expected_round2)); + } + + #[test] + fn stage2_later_full_prefix_fusion_matches_two_pass_reference() { + let num_u = 5usize; + let num_l = 2usize; + let live_x_cols = 12usize; + let b = 8usize; + let half = (b / 2) as i8; + let y_len = 1usize << num_l; + let w_prefix: Vec = (0..(live_x_cols * y_len)) + .map(|i| ((i * 9 + 7) % b) as i8 - half) + .collect(); + let r_stage1: Vec = (0..(num_u + num_l)) + .map(|i| F::from_u64((i as u64) + 131)) + .collect(); + let alpha_evals_y: Vec = (0..y_len) + .map(|i| F::from_u64((7 * i as u64) + 137)) + .collect(); + let m_evals_x: Vec = (0..(1usize << num_u)) + .map(|i| F::from_u64((11 * i as u64) + 139)) + .collect(); + let params = Stage2Params { + r_stage1: &r_stage1, + live_x_cols, + num_u, + num_l, + }; + + let mut prover = new_stage2_test_prover( + F::from_u64(149), + w_prefix.clone(), + alpha_evals_y.clone(), + m_evals_x.clone(), + params, + ); + let round0 = prover.compute_round_univariate(0, prover.input_claim()); + let r0 = F::from_u64(151); + prover.ingest_challenge(0, r0); + let round1 = prover.compute_round_univariate(1, round0.evaluate(&r0)); + let r1 = F::from_u64(157); + prover.ingest_challenge(1, r1); + let round2 = prover.compute_round_univariate(2, round1.evaluate(&r0)); + let r2 = F::from_u64(163); + + let mut expected = + new_stage2_test_prover(F::from_u64(149), w_prefix, alpha_evals_y, m_evals_x, params); + let expected_round0 = expected.compute_round_univariate(0, expected.input_claim()); + assert_eq!(expected_round0, round0); + expected.ingest_challenge(0, r0); + let expected_round1 = expected.compute_round_univariate(1, expected_round0.evaluate(&r0)); + assert_eq!(expected_round1, round1); + expected.ingest_challenge(1, r1); + let expected_round2 = expected.compute_round_univariate(2, expected_round1.evaluate(&r0)); + assert_eq!(expected_round2, round2); + + let current_w_full = match &expected.w_table { + WTable::Full(w_full) => w_full.clone(), + WTable::Compact(_) => panic!("expected later prefix state to be full"), + }; + let current_m_compact = expected.m_compact.clone(); + let expected_next_w_full = HachiStage2Prover::::fold_full_prefix_x( + ¤t_w_full, + expected.live_x_cols, + y_len, + r2, + ); + let expected_next_m_compact = HachiStage2Prover::::fold_m_prefix(¤t_m_compact, r2); + expected.prev_norm_claim = expected + .prev_norm_poly + .as_ref() + .expect("round2 norm poly should be cached") + .evaluate(&r2); + expected.split_eq.bind(r2); + expected.live_x_cols = expected.live_x_cols.div_ceil(2); + expected.rounds_completed += 1; + expected.m_compact = expected_next_m_compact.clone(); + let (virt_terms, rel_coeffs) = + expected.compute_round_full_prefix_x_terms(&expected_next_w_full); + let expected_round3 = expected.combine_terms(virt_terms, rel_coeffs); + + prover.ingest_challenge(2, r2); + + match &prover.w_table { + WTable::Full(w_full) => assert_eq!(w_full, &expected_next_w_full), + WTable::Compact(_) => panic!("expected fused later prefix stage to stay full"), + } + assert_eq!(prover.m_compact, expected_next_m_compact); + assert_eq!(prover.cached_round_poly.as_ref(), Some(&expected_round3)); + } +} diff --git a/src/protocol/sumcheck/mod.rs b/src/protocol/sumcheck/mod.rs new file mode 100644 index 00000000..99d88d11 --- /dev/null +++ b/src/protocol/sumcheck/mod.rs @@ -0,0 +1,613 @@ +//! Sumcheck protocol: traits, proof driver, and concrete instances. +//! +//! Types (`UniPoly`, `CompressedUniPoly`, `SumcheckProof`) live in the +//! [`types`] submodule. Polynomial evaluation utilities (`multilinear_eval`, +//! `fold_evals_in_place`, `range_check_eval`) live in [`crate::algebra::poly`]. +//! +//! ## Temporary duplication notice (Jolt integration) +//! +//! Jolt already has a mature, streaming-aware sumcheck implementation. Long-term, we +//! expect to extract the common sumcheck machinery into a dedicated crate and depend +//! on it from both Hachi and Jolt. Until that exists, this module intentionally +//! duplicates the essential sumcheck data types and transcript-driving logic as a +//! pragmatic workaround. + +pub mod batched_sumcheck; +pub mod eq_poly; +pub mod hachi_stage1; +pub mod hachi_stage2; +pub mod split_eq; +pub mod two_round_prefix; +pub mod types; + +use crate::algebra::fields::HasUnreducedOps; +use crate::error::HachiError; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::Transcript; +use crate::{CanonicalField, FieldCore, FromSmallInt}; + +pub use crate::algebra::poly::{ + fold_evals_in_place, multilinear_eval, multilinear_eval_small, range_check_eval, +}; +pub use types::{CompressedUniPoly, SumcheckProof, UniPoly}; + +#[inline] +pub(crate) fn trim_trailing_zeros(coeffs: &mut Vec) { + while coeffs.len() > 1 && coeffs.last().is_some_and(|c| c.is_zero()) { + coeffs.pop(); + } +} + +/// Precomputed lookup table for folding pairs of small integer values at a +/// fixed challenge `r`. +/// +/// This is useful for the round-0 compact tables in Hachi's stage-1 and +/// stage-2 sumchecks: the table entries are small integers, the fold formula is +/// always `left + r * (right - left)`, and the set of possible `(left, right)` +/// pairs is tiny. +pub(crate) struct CompactPairFoldLut { + min_value: i32, + value_to_index: Vec, + pair_values: Vec, + num_values: usize, +} + +impl CompactPairFoldLut { + pub(crate) fn from_allowed_values(allowed_values: &[i32], r: E) -> Self { + assert!( + !allowed_values.is_empty(), + "allowed_values must be non-empty" + ); + let min_value = *allowed_values.iter().min().expect("non-empty"); + let max_value = *allowed_values.iter().max().expect("non-empty"); + let mut value_to_index = vec![usize::MAX; (max_value - min_value + 1) as usize]; + for (idx, &value) in allowed_values.iter().enumerate() { + let offset = (value - min_value) as usize; + debug_assert_eq!( + value_to_index[offset], + usize::MAX, + "allowed_values must be unique" + ); + value_to_index[offset] = idx; + } + + let num_values = allowed_values.len(); + let mut pair_values = Vec::with_capacity(num_values * num_values); + for &left in allowed_values { + let left_field = E::from_i64(left as i64); + for &right in allowed_values { + let delta = i64::from(right) - i64::from(left); + let delta_abs = delta.unsigned_abs(); + let r_delta = E::reduce_mul_u64_accum(r.mul_u64_unreduced(delta_abs)); + pair_values.push(if delta < 0 { + left_field - r_delta + } else { + left_field + r_delta + }); + } + } + + Self { + min_value, + value_to_index, + pair_values, + num_values, + } + } + + pub(crate) fn from_contiguous_range(min_value: i32, max_value: i32, r: E) -> Self { + assert!(min_value <= max_value, "invalid compact fold range"); + let allowed_values: Vec = (min_value..=max_value).collect(); + Self::from_allowed_values(&allowed_values, r) + } +} + +impl CompactPairFoldLut { + #[inline] + fn index_of(&self, value: i32) -> usize { + let offset = (value - self.min_value) as usize; + let idx = self.value_to_index[offset]; + debug_assert_ne!(idx, usize::MAX, "value missing from compact fold LUT"); + idx + } + + #[inline] + pub(crate) fn fold(&self, left: i32, right: i32) -> E { + let left_idx = self.index_of(left); + let right_idx = self.index_of(right); + self.pair_values[left_idx * self.num_values + right_idx] + } +} + +/// Prover-side sumcheck instance interface. +/// +/// This trait encapsulates the protocol-specific logic required to compute each +/// per-round univariate polynomial `g_j(X)` and to update (fold) internal state +/// after receiving the verifier challenge `r_j`. +/// +/// Hachi §4.3 will implement concrete instances for `H_0` and `H_α`. +pub trait SumcheckInstanceProver: Send + Sync { + /// Number of rounds (i.e. number of variables bound by sumcheck). + fn num_rounds(&self) -> usize; + + /// Maximum allowed degree for any round univariate polynomial. + fn degree_bound(&self) -> usize; + + /// The initial claimed sum that this sumcheck instance is proving. + fn input_claim(&self) -> E; + + /// Compute the prover message `g_round(X)` given the previous running claim. + /// + /// In standard sumcheck, `previous_claim` is the expected value of the + /// remaining sum after binding previous challenges, and must satisfy: + /// + /// `g_round(0) + g_round(1) == previous_claim`. + fn compute_round_univariate(&mut self, round: usize, previous_claim: E) -> UniPoly; + + /// Ingest the verifier challenge `r_round` to fold/bind the current variable. + fn ingest_challenge(&mut self, round: usize, r_round: E); + + /// Optional end-of-protocol hook after the last challenge has been ingested. + fn finalize(&mut self) {} +} + +/// Verifier-side sumcheck instance interface. +/// +/// Implementations provide the initial claim and the oracle evaluation at the +/// challenge point, enabling the verifier to perform the final consistency check. +pub trait SumcheckInstanceVerifier: Send + Sync { + /// Number of rounds (i.e. number of variables bound by sumcheck). + fn num_rounds(&self) -> usize; + + /// Maximum allowed degree for any round univariate polynomial. + fn degree_bound(&self) -> usize; + + /// The initial claimed sum that this sumcheck instance is proving. + fn input_claim(&self) -> E; + + /// Compute the expected final evaluation `f(r_0, ..., r_{n-1})` at the + /// challenge point derived during the protocol. + /// + /// # Errors + /// + /// May return an error if internal evaluations fail (e.g., malformed + /// evaluation tables from untrusted proof data). + fn expected_output_claim(&self, challenges: &[E]) -> Result; +} + +/// Produce a sumcheck proof while omitting the first `omitted_prefix_rounds` +/// transcript rounds from the stored proof. +/// +/// This still drives the prover in the ordinary strict pipeline +/// `compute message -> absorb challenge -> ingest challenge -> ...`; it only +/// changes which compressed univariates are retained in the returned +/// [`SumcheckProof`]. Callers can use this to serialize early rounds via a +/// stage-local bivariate-skip proof instead of directly in the sumcheck proof. +/// +/// # Errors +/// +/// Returns an error if `omitted_prefix_rounds` exceeds the instance round +/// count, or if any per-round polynomial exceeds the instance's degree bound. +#[tracing::instrument(skip_all, name = "prove_sumcheck")] +#[inline(never)] +pub(crate) fn prove_sumcheck_with_omitted_prefix_rounds( + instance: &mut Inst, + transcript: &mut T, + mut sample_challenge: S, + omitted_prefix_rounds: usize, + mut absorb_after_compute: A, +) -> Result<(SumcheckProof, Vec, E), HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, + Inst: SumcheckInstanceProver, + A: FnMut(usize, &Inst, &mut T) -> Result<(), HachiError>, +{ + let num_rounds = instance.num_rounds(); + if omitted_prefix_rounds > num_rounds { + return Err(HachiError::InvalidInput(format!( + "sumcheck omitted_prefix_rounds {omitted_prefix_rounds} exceeds num_rounds {num_rounds}" + ))); + } + + let mut claim = instance.input_claim(); + tracing::debug!( + is_zero = claim.is_zero(), + num_rounds, + omitted_prefix_rounds, + "prove_sumcheck input_claim" + ); + transcript.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &claim); + + let degree_bound = instance.degree_bound(); + let mut round_polys = Vec::with_capacity(num_rounds - omitted_prefix_rounds); + let mut r = Vec::with_capacity(num_rounds); + + for round in 0..num_rounds { + let g = instance.compute_round_univariate(round, claim); + debug_assert!( + g.evaluate(&E::zero()) + g.evaluate(&E::one()) == claim, + "sumcheck round univariate does not match previous claim hint" + ); + + let compressed = g.compress(); + if compressed.degree() > degree_bound { + return Err(HachiError::InvalidInput(format!( + "sumcheck round poly degree {} exceeds bound {}", + compressed.degree(), + degree_bound + ))); + } + + absorb_after_compute(round, instance, transcript)?; + transcript.append_serde(labels::ABSORB_SUMCHECK_ROUND, &compressed); + let r_i = sample_challenge(transcript); + r.push(r_i); + + claim = compressed.eval_from_hint(&claim, &r_i); + instance.ingest_challenge(round, r_i); + if round >= omitted_prefix_rounds { + round_polys.push(compressed); + } + } + + instance.finalize(); + Ok((SumcheckProof { round_polys }, r, claim)) +} + +/// Verify a sumcheck proof whose first `prefix_rounds` rounds are reconstructed by +/// a caller-supplied generator instead of being stored in `proof`. +/// +/// The verifier still follows the ordinary transcript pipeline, sampling each +/// challenge only after absorbing that round's compressed univariate. For +/// rounds `round < prefix_rounds`, the compressed univariate is provided by +/// `prefix_round_poly`; later rounds are read from `proof`. +/// +/// Returns the full challenge point `r` on success. +/// +/// # Errors +/// +/// Returns an error if `prefix_rounds` exceeds the verifier round count, if the +/// suffix proof length is inconsistent, if a generated/stored round polynomial +/// exceeds the degree bound, or if the final oracle check fails. +#[tracing::instrument(skip_all, name = "verify_sumcheck")] +#[inline(never)] +pub(crate) fn verify_sumcheck_with_prefix_rounds( + proof: &SumcheckProof, + verifier: &V, + transcript: &mut T, + mut sample_challenge: S, + prefix_rounds: usize, + mut absorb_before_round: A, + mut prefix_round_poly: P, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, + V: SumcheckInstanceVerifier, + A: FnMut(usize, &mut T) -> Result<(), HachiError>, + P: FnMut(usize, E, &[E]) -> CompressedUniPoly, +{ + let num_rounds = verifier.num_rounds(); + if prefix_rounds > num_rounds { + return Err(HachiError::InvalidInput(format!( + "sumcheck prefix_rounds {prefix_rounds} exceeds num_rounds {num_rounds}" + ))); + } + let expected_suffix_rounds = num_rounds - prefix_rounds; + if proof.round_polys.len() != expected_suffix_rounds { + return Err(HachiError::InvalidSize { + expected: expected_suffix_rounds, + actual: proof.round_polys.len(), + }); + } + + let mut claim = verifier.input_claim(); + tracing::debug!( + is_zero = claim.is_zero(), + num_rounds, + prefix_rounds, + "verify_sumcheck input_claim" + ); + transcript.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &claim); + + let degree_bound = verifier.degree_bound(); + let mut challenges = Vec::with_capacity(num_rounds); + let mut suffix_iter = proof.round_polys.iter(); + + for round in 0..num_rounds { + absorb_before_round(round, transcript)?; + let poly = if round < prefix_rounds { + prefix_round_poly(round, claim, &challenges) + } else { + suffix_iter + .next() + .cloned() + .expect("suffix proof length checked above") + }; + if poly.degree() > degree_bound { + return Err(HachiError::InvalidInput(format!( + "sumcheck round poly degree {} exceeds bound {}", + poly.degree(), + degree_bound + ))); + } + + transcript.append_serde(labels::ABSORB_SUMCHECK_ROUND, &poly); + let r_i = sample_challenge(transcript); + challenges.push(r_i); + claim = poly.eval_from_hint(&claim, &r_i); + } + debug_assert!(suffix_iter.next().is_none()); + + check_sumcheck_output_claim(claim, verifier, &challenges)?; + Ok(challenges) +} + +/// Enforce the final sumcheck oracle equality for the provided challenge point. +/// +/// This is useful when some prefix rounds are reconstructed outside the generic +/// verifier driver and the caller needs to check the final oracle value against +/// the full concatenated challenge vector. +/// +/// # Errors +/// +/// Returns any error produced by `verifier.expected_output_claim`, or +/// [`HachiError::InvalidProof`] if the final claim does not match the oracle +/// evaluation at `challenges`. +pub fn check_sumcheck_output_claim( + final_claim: E, + verifier: &V, + challenges: &[E], +) -> Result<(), HachiError> +where + E: FieldCore, + V: SumcheckInstanceVerifier, +{ + let expected = verifier.expected_output_claim(challenges)?; + if final_claim != expected { + tracing::error!( + rounds = verifier.num_rounds(), + degree_bound = verifier.degree_bound(), + diff_is_zero = (final_claim - expected).is_zero(), + "verify_sumcheck MISMATCH" + ); + return Err(HachiError::InvalidProof); + } + Ok(()) +} + +/// Produce a sumcheck proof for a single instance, driving the Fiat-Shamir transcript. +/// +/// This method: +/// - does **not** absorb the initial claim into the transcript (callers should do so), +/// - appends each round message under `labels::ABSORB_SUMCHECK_ROUND`, +/// - samples one challenge per round via `sample_challenge`, +/// - updates the running claim using the per-round hint (`g(0)+g(1)`). +/// +/// It returns the proof, the derived point `r`, and the final claimed value at `r`. +/// +/// # Errors +/// +/// Returns an error if any per-round polynomial exceeds the instance's degree bound. +pub fn prove_sumcheck( + instance: &mut Inst, + transcript: &mut T, + sample_challenge: S, +) -> Result<(SumcheckProof, Vec, E), HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, + Inst: SumcheckInstanceProver, +{ + prove_sumcheck_with_omitted_prefix_rounds::( + instance, + transcript, + sample_challenge, + 0, + |_, _, _| Ok(()), + ) +} + +/// Verify a single-instance sumcheck proof. +/// +/// This function: +/// - absorbs the initial claim into the transcript, +/// - delegates round-by-round verification to [`SumcheckProof::verify`], +/// - performs the final oracle check: `final_claim == verifier.expected_output_claim(r)`. +/// +/// Returns the challenge point `r` on success. +/// +/// # Errors +/// +/// Returns [`HachiError::InvalidProof`] if the final sumcheck claim does not +/// match the oracle evaluation, or propagates any error from the per-round +/// verification (e.g. degree-bound violation, round-count mismatch). +pub fn verify_sumcheck( + proof: &SumcheckProof, + verifier: &V, + transcript: &mut T, + sample_challenge: S, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField, + T: Transcript, + E: FieldCore, + S: FnMut(&mut T) -> E, + V: SumcheckInstanceVerifier, +{ + verify_sumcheck_with_prefix_rounds::( + proof, + verifier, + transcript, + sample_challenge, + 0, + |_, _| Ok(()), + |_, _, _| unreachable!("no prefix rounds requested"), + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Prime128M8M4M1M0; + use crate::protocol::transcript::labels as tr_labels; + use crate::protocol::transcript::Blake2bTranscript; + + type F = Prime128M8M4M1M0; + + #[derive(Clone)] + struct ToyMlInstance { + original: Vec, + current: Vec, + num_rounds: usize, + } + + impl ToyMlInstance { + fn new(evals: Vec) -> Self { + let len = evals.len(); + let num_rounds = len.trailing_zeros() as usize; + debug_assert_eq!(1usize << num_rounds, len); + Self { + original: evals.clone(), + current: evals, + num_rounds, + } + } + } + + impl SumcheckInstanceProver for ToyMlInstance { + fn num_rounds(&self) -> usize { + self.num_rounds + } + + fn degree_bound(&self) -> usize { + 1 + } + + fn input_claim(&self) -> F { + self.original + .iter() + .copied() + .fold(F::zero(), |acc, x| acc + x) + } + + fn compute_round_univariate(&mut self, round: usize, previous_claim: F) -> UniPoly { + debug_assert_eq!(self.current.len(), 1usize << (self.num_rounds - round)); + let half = self.current.len() / 2; + let mut at_zero = F::zero(); + let mut slope = F::zero(); + for j in 0..half { + let left = self.current[2 * j]; + let right = self.current[2 * j + 1]; + at_zero += left; + slope += right - left; + } + let poly = UniPoly::from_coeffs(vec![at_zero, slope]); + debug_assert_eq!( + poly.evaluate(&F::zero()) + poly.evaluate(&F::one()), + previous_claim + ); + poly + } + + fn ingest_challenge(&mut self, _round: usize, r_round: F) { + fold_evals_in_place(&mut self.current, r_round); + } + } + + impl SumcheckInstanceVerifier for ToyMlInstance { + fn num_rounds(&self) -> usize { + self.num_rounds + } + + fn degree_bound(&self) -> usize { + 1 + } + + fn input_claim(&self) -> F { + self.original + .iter() + .copied() + .fold(F::zero(), |acc, x| acc + x) + } + + fn expected_output_claim(&self, challenges: &[F]) -> Result { + multilinear_eval(&self.original, challenges) + } + } + + fn new_transcript() -> Blake2bTranscript { + as Transcript>::new(tr_labels::DOMAIN_HACHI_PROTOCOL) + } + + fn sample_round(tr: &mut Blake2bTranscript) -> F { + tr.challenge_scalar(tr_labels::CHALLENGE_SUMCHECK_ROUND) + } + + #[test] + fn prove_sumcheck_with_omitted_prefix_rounds_matches_full_proof_tail() { + let evals: Vec = (0..16).map(|i| F::from_u64((7 * i as u64) + 3)).collect(); + let mut full = ToyMlInstance::new(evals.clone()); + let mut full_tr = new_transcript(); + let (full_proof, full_challenges, full_final_claim) = + prove_sumcheck::(&mut full, &mut full_tr, sample_round).unwrap(); + + let mut omitted = ToyMlInstance::new(evals); + let mut omitted_tr = new_transcript(); + let (suffix_proof, challenges, suffix_final_claim) = + prove_sumcheck_with_omitted_prefix_rounds::( + &mut omitted, + &mut omitted_tr, + sample_round, + 2, + |_, _, _| Ok(()), + ) + .unwrap(); + + assert_eq!(challenges, full_challenges); + assert_eq!( + suffix_proof.round_polys.as_slice(), + &full_proof.round_polys[2..] + ); + assert_eq!(suffix_final_claim, full_final_claim); + } + + #[test] + fn verify_sumcheck_with_prefix_rounds_matches_full_verification_tail() { + let evals: Vec = (0..16).map(|i| F::from_u64((11 * i as u64) + 5)).collect(); + let mut prover = ToyMlInstance::new(evals.clone()); + let mut proof_tr = new_transcript(); + let (full_proof, full_challenges, full_final_claim) = + prove_sumcheck::(&mut prover, &mut proof_tr, sample_round).unwrap(); + + let verifier = ToyMlInstance::new(evals); + let suffix_proof = SumcheckProof { + round_polys: full_proof.round_polys[2..].to_vec(), + }; + let prefix_rounds = full_proof.round_polys[..2].to_vec(); + let mut verify_tr = new_transcript(); + let challenges = verify_sumcheck_with_prefix_rounds::( + &suffix_proof, + &verifier, + &mut verify_tr, + sample_round, + 2, + |_, _| Ok(()), + |round, _, _| prefix_rounds[round].clone(), + ) + .unwrap(); + + assert_eq!(challenges, full_challenges); + assert_eq!( + verifier.expected_output_claim(&challenges).unwrap(), + full_final_claim + ); + } +} diff --git a/src/protocol/sumcheck/split_eq.rs b/src/protocol/sumcheck/split_eq.rs new file mode 100644 index 00000000..f3dcd5a7 --- /dev/null +++ b/src/protocol/sumcheck/split_eq.rs @@ -0,0 +1,362 @@ +//! Gruen/Dao-Thaler split equality polynomial for efficient sumcheck. +//! +//! Factors `eq(τ, x)` into a running scalar, a linear factor for the current +//! variable, and precomputed split tables for the remaining variables. This +//! avoids maintaining and folding a full-size eq table during sumcheck. +//! +//! For details, see . +//! +//! Adapted from Jolt's `GruenSplitEqPolynomial`. +//! +//! ## Variable Layout (forward binding, little-endian) +//! +//! ```text +//! τ = [τ_current, τ_first_half, τ_second_half] +//! 1 var m vars (n-1-m) vars +//! ``` +//! +//! where `m = (n-1) / 2` and `n = τ.len()`. +//! +//! After binding `τ_current`, the next variable comes from `τ_first_half`, +//! then from `τ_second_half`. Suffix-cached eq tables for each half enable +//! O(1) pops per round instead of an O(2^n) fold. + +use super::eq_poly::EqPolynomial; +use super::UniPoly; +use crate::{FieldCore, FromSmallInt}; + +/// Split equality polynomial with Gruen scalar accumulation. +/// +/// Instead of storing and folding a full eq table each round, this struct +/// maintains: +/// - `current_scalar`: accumulated leading scalar times `eq(τ_bound, r_bound)` +/// from already-bound variables +/// - `E_first` / `E_second`: suffix-cached eq tables for two halves of the +/// remaining (unbound, non-current) variables +/// +/// The eq contribution for a pair index `j` in the inner sum is: +/// ```text +/// eq_remaining(j) = E_first[j_low] · E_second[j_high] +/// ``` +/// and the full round polynomial is `l(X) · q(X)` where `l(X)` is the linear +/// eq factor for the current variable. +#[allow(non_snake_case)] +pub struct GruenSplitEq { + tau: Vec, + current_round: usize, + current_scalar: E, + /// Suffix-cached eq tables for the first half of remaining variables. + /// `E_first[k]` = `eq(τ[split-k..split], ·)` with `2^k` entries. + /// Invariant: never empty; `E_first[0] = [1]`. + E_first: Vec>, + /// Suffix-cached eq tables for the second half of remaining variables. + /// `E_second[k]` = `eq(τ[n-k..n], ·)` with `2^k` entries. + /// Invariant: never empty; `E_second[0] = [1]`. + E_second: Vec>, +} + +#[allow(non_snake_case)] +impl GruenSplitEq { + /// Create a new split-eq from the full challenge vector `τ`. + /// + /// Precomputes suffix-cached eq tables for two halves of `τ[1..n]`. + /// + /// # Panics + /// + /// Panics if `tau` is empty. + pub fn new(tau: &[E]) -> Self { + Self::with_initial_scalar(tau, E::one()) + } + + /// Create a new split-eq whose running scalar starts at `initial_scalar`. + /// + /// This is useful when a round-independent batching scalar should be folded + /// into the split-eq factor once up front rather than re-applied to every + /// round polynomial after `gruen_mul()`. + /// + /// # Panics + /// + /// Panics if `tau` is empty. + pub fn with_initial_scalar(tau: &[E], initial_scalar: E) -> Self { + let n = tau.len(); + assert!(n >= 1); + let m = (n - 1) / 2; + let split = 1 + m; + let first_half = &tau[1..split]; + let second_half = &tau[split..n]; + let E_first = EqPolynomial::evals_cached(first_half); + let E_second = EqPolynomial::evals_cached(second_half); + Self { + tau: tau.to_vec(), + current_round: 0, + current_scalar: initial_scalar, + E_first, + E_second, + } + } + + /// The accumulated scalar `c * Π_{k < current_round} eq(τ[k], r[k])`, + /// where `c` is the constructor-supplied leading scalar. + pub fn current_scalar(&self) -> E { + self.current_scalar + } + + /// The τ value for the variable about to be bound. + pub fn current_tau(&self) -> E { + self.tau[self.current_round] + } + + /// Return the current top-level split-eq tables `(E_first, E_second)`. + /// + /// For a pair index `j` in the inner sum, the eq factor for the + /// remaining (non-current) variables is: + /// ```text + /// eq_remaining(j) = E_first[j & (E_first.len()-1)] + /// · E_second[j >> E_first.len().trailing_zeros()] + /// ``` + /// + /// # Panics + /// + /// Panics if either `E_first` or `E_second` is empty (invariant violation). + pub fn remaining_eq_tables(&self) -> (&[E], &[E]) { + ( + self.E_first.last().expect("E_first is never empty"), + self.E_second.last().expect("E_second is never empty"), + ) + } + + /// Bind the current variable to challenge `r`, advancing to the next round. + /// + /// Multiplies `current_scalar` by `eq(τ[current_round], r)` and pops the + /// appropriate split table level. + pub fn bind(&mut self, r: E) { + let tau_k = self.tau[self.current_round]; + self.current_scalar = + self.current_scalar * (tau_k * r + (E::one() - tau_k) * (E::one() - r)); + self.current_round += 1; + if self.E_first.len() > 1 { + self.E_first.pop(); + } else if self.E_second.len() > 1 { + self.E_second.pop(); + } + } + + #[inline] + fn linear_factor_evals(&self) -> (E, E) { + let l_at_1 = self.current_scalar * self.current_tau(); + let l_at_0 = self.current_scalar - l_at_1; + (l_at_0, l_at_1) + } + + /// Returns whether the current Gruen linear factor lets us recover the + /// omitted linear coefficient of the inner polynomial from `s(0) + s(1)`. + pub fn can_recover_linear_q_term_from_claim(&self) -> bool { + let (_, l_at_1) = self.linear_factor_evals(); + l_at_1.inv().is_some() + } + + /// Compute the round polynomial `s(X) = l(X) · q(X)` from the inner + /// polynomial `q` (given as evaluations at integer points `0, 1, ..., d`). + /// + /// `l(X) = current_scalar · eq(τ_current, X)` is the linear eq factor + /// for the current variable, including any constructor-supplied leading + /// scalar. The result has degree `d + 1`. + pub fn gruen_mul(&self, q_poly: &UniPoly) -> UniPoly { + let (l_at_0, l_at_1) = self.linear_factor_evals(); + let slope = l_at_1 - l_at_0; + let mut coeffs = vec![E::zero(); q_poly.coeffs.len() + 1]; + for (i, &c) in q_poly.coeffs.iter().enumerate() { + coeffs[i] += c * l_at_0; + coeffs[i + 1] += c * slope; + } + UniPoly::from_coeffs(coeffs) + } + + /// Recover a missing linear coefficient of `q(X)` from `s(0) + s(1)` and + /// return the full round polynomial `s(X) = l(X) · q(X)`. + /// + /// The input is `[q_0, q_2, q_3, ..., q_d]`, i.e. all coefficients except + /// the linear term. Returns `None` when `l(1) = 0`, in which case that + /// missing coefficient is not recoverable from the claim alone. + pub fn try_gruen_poly_from_coeffs_except_linear( + &self, + q_coeffs_except_linear: &[E], + s_0_plus_s_1: E, + ) -> Option> { + if q_coeffs_except_linear.is_empty() { + return Some(UniPoly::from_coeffs(vec![E::zero()])); + } + + let (l_at_0, l_at_1) = self.linear_factor_evals(); + if l_at_0.is_zero() && l_at_1.is_zero() { + return Some(UniPoly::from_coeffs(vec![E::zero()])); + } + + let l_at_1_inv = l_at_1.inv()?; + let q_at_0 = q_coeffs_except_linear[0]; + let q_at_1 = (s_0_plus_s_1 - l_at_0 * q_at_0) * l_at_1_inv; + let sum_except_linear = q_coeffs_except_linear + .iter() + .copied() + .fold(E::zero(), |acc, coeff| acc + coeff); + let q_linear = q_at_1 - sum_except_linear; + + let mut q_coeffs = Vec::with_capacity(q_coeffs_except_linear.len() + 1); + q_coeffs.push(q_at_0); + q_coeffs.push(q_linear); + q_coeffs.extend_from_slice(&q_coeffs_except_linear[1..]); + Some(self.gruen_mul(&UniPoly::from_coeffs(q_coeffs))) + } +} + +impl GruenSplitEq { + /// Recover the middle coefficient of a quadratic inner polynomial + /// `q(X) = c + dX + eX^2` from `s(0) + s(1)` and return + /// `s(X) = l(X) · q(X)`. + /// + /// Returns `None` when `l(1) = 0`, in which case `q(1)` is not recoverable + /// from the claim alone. + pub fn try_gruen_poly_deg_3( + &self, + q_constant: E, + q_quadratic_coeff: E, + s_0_plus_s_1: E, + ) -> Option> { + let (l_at_0, l_at_1) = self.linear_factor_evals(); + if l_at_0.is_zero() && l_at_1.is_zero() { + return Some(UniPoly::from_coeffs(vec![E::zero()])); + } + + let l_at_1_inv = l_at_1.inv()?; + let slope = l_at_1 - l_at_0; + let l_at_2 = l_at_1 + slope; + let l_at_3 = l_at_2 + slope; + + let q_at_0 = q_constant; + let s_at_0 = l_at_0 * q_at_0; + let s_at_1 = s_0_plus_s_1 - s_at_0; + let q_at_1 = s_at_1 * l_at_1_inv; + + let twice_q_quadratic = q_quadratic_coeff + q_quadratic_coeff; + let q_at_2 = q_at_1 + q_at_1 - q_at_0 + twice_q_quadratic; + let q_at_3 = q_at_2 + q_at_1 - q_at_0 + twice_q_quadratic + twice_q_quadratic; + + Some(UniPoly::from_evals(&[ + s_at_0, + s_at_1, + l_at_2 * q_at_2, + l_at_3 * q_at_3, + ])) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Prime128M8M4M1M0; + use crate::protocol::sumcheck::fold_evals_in_place; + use crate::{FieldSampling, FromSmallInt}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + type F = Prime128M8M4M1M0; + + #[test] + fn gruen_eq_matches_full_eq_table() { + let mut rng = StdRng::seed_from_u64(0xBB); + for n in 1..10 { + let tau: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let mut full_eq = EqPolynomial::evals(&tau); + let mut split_eq = GruenSplitEq::new(&tau); + + for _round in 0..n { + let half = full_eq.len() / 2; + let (e_first, e_second) = split_eq.remaining_eq_tables(); + let num_first = e_first.len(); + + for j in 0..half { + let j_low = j & (num_first - 1); + let j_high = j >> num_first.trailing_zeros(); + let eq_rem = e_first[j_low] * e_second[j_high]; + + let tau_k = split_eq.current_tau(); + let scalar = split_eq.current_scalar(); + let eq_0 = scalar * (F::one() - tau_k) * eq_rem; + let eq_1 = scalar * tau_k * eq_rem; + + assert_eq!(eq_0, full_eq[2 * j], "n={n} round={_round} j={j} eq_0"); + assert_eq!(eq_1, full_eq[2 * j + 1], "n={n} round={_round} j={j} eq_1"); + } + + let r = F::sample(&mut rng); + fold_evals_in_place(&mut full_eq, r); + split_eq.bind(r); + } + } + } + + #[test] + fn gruen_mul_matches_direct_product() { + let mut rng = StdRng::seed_from_u64(0xCC); + let tau: Vec = (0..5).map(|_| F::sample(&mut rng)).collect(); + let split_eq = GruenSplitEq::new(&tau); + + let q = UniPoly::from_coeffs(vec![F::from_u64(3), F::from_u64(7), F::from_u64(2)]); + let s = split_eq.gruen_mul(&q); + + let tau_k = split_eq.current_tau(); + let scalar = split_eq.current_scalar(); + for t in 0..10u64 { + let x = F::from_u64(t); + let l_x = scalar * (tau_k * x + (F::one() - tau_k) * (F::one() - x)); + let q_x = q.evaluate(&x); + assert_eq!(s.evaluate(&x), l_x * q_x, "t={t}"); + } + } + + #[test] + fn recover_round_poly_from_coeffs_except_linear() { + let mut rng = StdRng::seed_from_u64(0xCD); + let mut tau: Vec = (0..5).map(|_| F::sample(&mut rng)).collect(); + if tau[0].is_zero() { + tau[0] = F::one(); + } + let split_eq = GruenSplitEq::new(&tau); + + let q = UniPoly::from_coeffs(vec![ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(2), + ]); + let s = split_eq.gruen_mul(&q); + let q_except_linear = vec![q.coeffs[0], q.coeffs[2], q.coeffs[3]]; + let previous_claim = s.evaluate(&F::zero()) + s.evaluate(&F::one()); + + let recovered = split_eq + .try_gruen_poly_from_coeffs_except_linear(&q_except_linear, previous_claim) + .expect("tau_0 is nonzero, so q(1) is recoverable"); + + assert_eq!(recovered, s); + } + + #[test] + fn recover_quadratic_round_poly_from_claim() { + let mut rng = StdRng::seed_from_u64(0xCE); + let mut tau: Vec = (0..4).map(|_| F::sample(&mut rng)).collect(); + if tau[0].is_zero() { + tau[0] = F::one(); + } + let split_eq = GruenSplitEq::new(&tau); + + let q = UniPoly::from_coeffs(vec![F::from_u64(5), F::from_u64(9), F::from_u64(4)]); + let s = split_eq.gruen_mul(&q); + let previous_claim = s.evaluate(&F::zero()) + s.evaluate(&F::one()); + + let recovered = split_eq + .try_gruen_poly_deg_3(q.coeffs[0], q.coeffs[2], previous_claim) + .expect("tau_0 is nonzero, so q(1) is recoverable"); + + assert_eq!(recovered, s); + } +} diff --git a/src/protocol/sumcheck/two_round_prefix.rs b/src/protocol/sumcheck/two_round_prefix.rs new file mode 100644 index 00000000..04f3b674 --- /dev/null +++ b/src/protocol/sumcheck/two_round_prefix.rs @@ -0,0 +1,2457 @@ +//! Local algebra for 2-round x-quad prefix kernels. +//! +//! These helpers model a single 4-value x-quad in the first two x rounds, +//! using the point semantics we intend to reuse in the eventual batched-prefix +//! implementation: +//! +//! - finite points are ordinary evaluations of the bilinear multilinear +//! extension over the quad; +//! - `Infinity` means "take the leading coefficient in that coordinate". +//! +//! The tests pin down three facts we rely on before wiring this into the prover: +//! +//! - Stage 1's candidate `{1, -1, 2, Infinity}^2` storage really is a +//! 15-dimensional family because `(1, 1)` is always zero, but reconstructing the +//! actual first two rounds needs the safe `{0, 1, -1, 2, Infinity}^2` +//! fallback with the four Boolean corners omitted, for a 21-value payload. +//! - Stage 2's proposed reduced `{1, Infinity}^2` storage is not enough, by +//! itself, to recover the local round messages for either the norm or the +//! relation family, so the safe algebra layer keeps the full +//! `{0, 1, Infinity}^2` fallback for now. + +use super::eq_poly::EqPolynomial; +#[cfg(test)] +use super::hachi_stage1::range_check_eval_from_s; +use super::UniPoly; +use crate::algebra::fields::HasUnreducedOps; +#[cfg(feature = "parallel")] +use crate::parallel::*; +use crate::{AdditiveGroup, FieldCore, FromSmallInt}; + +/// Point in a small evaluation domain used by the 2-round prefix kernels. +#[cfg(test)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum PrefixPoint { + Finite(E), + Infinity, +} + +/// Candidate stage-1 domain `{1, -1, 2, Infinity}`. +#[cfg(test)] +pub(crate) fn stage1_prefix_points() -> [PrefixPoint; 4] { + [ + PrefixPoint::Finite(E::one()), + PrefixPoint::Finite(E::zero() - E::one()), + PrefixPoint::Finite(E::from_u64(2)), + PrefixPoint::Infinity, + ] +} + +/// Safe full stage-1 fallback domain `{0, 1, -1, 2, Infinity}`. +#[cfg(test)] +pub(crate) fn stage1_full_prefix_points() -> [PrefixPoint; 5] { + [ + PrefixPoint::Finite(E::zero()), + PrefixPoint::Finite(E::one()), + PrefixPoint::Finite(E::zero() - E::one()), + PrefixPoint::Finite(E::from_u64(2)), + PrefixPoint::Infinity, + ] +} + +/// Number of stored evaluations in the stage-1 2-round bivariate-skip proof after +/// omitting the four Boolean corners from `{0,1,-1,2,Infinity}^2`. +pub(crate) const STAGE1_PREFIX_EVAL_COUNT: usize = 21; + +/// Serializable stage-1 first-two-round bivariate-skip proof. +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct Stage1BivariateSkipProof { + pub evals_except_boolean_core: Vec, +} + +#[inline] +fn stage1_full_grid_index(x_idx: usize, y_idx: usize) -> usize { + x_idx * 5 + y_idx +} + +#[inline] +fn stage1_is_boolean_corner(x_idx: usize, y_idx: usize) -> bool { + x_idx < 2 && y_idx < 2 +} + +const LOOKUP_PREFIX_INF: i64 = i64::MIN; +const STAGE1_B8_S_VALUES: [i64; 4] = [0, 2, 6, 12]; +const STAGE2_B8_W_VALUES: [i64; 8] = [-4, -3, -2, -1, 0, 1, 2, 3]; +const STAGE2_PREFIX_POINT_COUNT: usize = 9; +const STAGE2_COMPRESSED_POINT_COUNT: usize = STAGE2_PREFIX_POINT_COUNT - 1; +const STAGE2_COMPRESSED_POINT_INDICES_BY_OMITTED_CORNER: [[usize; STAGE2_COMPRESSED_POINT_COUNT]; + 4] = [ + [1, 2, 3, 4, 5, 6, 7, 8], + [0, 2, 3, 4, 5, 6, 7, 8], + [0, 1, 2, 4, 5, 6, 7, 8], + [0, 1, 2, 3, 5, 6, 7, 8], +]; + +const fn lookup_bilinear_coeffs_from_quad(quad: [i64; 4]) -> [i64; 4] { + let [t00, t10, t01, t11] = quad; + [t00, t10 - t00, t01 - t00, t11 - t10 - t01 + t00] +} + +const fn lookup_bilinear_eval_on_prefix_points(quad: [i64; 4], x: i64, y: i64) -> i64 { + let [a, b, c, d] = lookup_bilinear_coeffs_from_quad(quad); + let x_is_inf = x == LOOKUP_PREFIX_INF; + let y_is_inf = y == LOOKUP_PREFIX_INF; + if !x_is_inf && !y_is_inf { + a + x * (b + y * d) + y * c + } else if x_is_inf && !y_is_inf { + b + y * d + } else if !x_is_inf && y_is_inf { + c + x * d + } else { + d + } +} + +const fn pow_i64(mut base: i64, mut exp: usize) -> i64 { + let mut out = 1i64; + while exp > 0 { + if exp & 1 == 1 { + out *= base; + } + exp >>= 1; + if exp > 0 { + base *= base; + } + } + out +} + +const fn stage1_b8_range_check_from_s(s: i64) -> i64 { + s * (s - 2) * (s - 6) * (s - 12) +} + +const fn stage1_b8_local_norm_raw_eval_i64(s_quad: [i64; 4], x: i64, y: i64) -> i64 { + let [_, bx, cy, dxy] = lookup_bilinear_coeffs_from_quad(s_quad); + let x_is_inf = x == LOOKUP_PREFIX_INF; + let y_is_inf = y == LOOKUP_PREFIX_INF; + if !x_is_inf && !y_is_inf { + stage1_b8_range_check_from_s(lookup_bilinear_eval_on_prefix_points(s_quad, x, y)) + } else if x_is_inf && !y_is_inf { + pow_i64(bx + y * dxy, 4) + } else if !x_is_inf && y_is_inf { + pow_i64(cy + x * dxy, 4) + } else { + pow_i64(dxy, 4) + } +} + +const fn stage1_lookup_points_i64() -> [(i64, i64); STAGE1_PREFIX_EVAL_COUNT] { + let coords = [0i64, 1, -1, 2, LOOKUP_PREFIX_INF]; + let mut out = [(0i64, 0i64); STAGE1_PREFIX_EVAL_COUNT]; + let mut out_idx = 0usize; + let mut x_idx = 0usize; + while x_idx < 5 { + let mut y_idx = 0usize; + while y_idx < 5 { + if !(x_idx < 2 && y_idx < 2) { + out[out_idx] = (coords[x_idx], coords[y_idx]); + out_idx += 1; + } + y_idx += 1; + } + x_idx += 1; + } + out +} + +const STAGE1_PREFIX_LOOKUP_POINTS_I64: [(i64, i64); STAGE1_PREFIX_EVAL_COUNT] = + stage1_lookup_points_i64(); + +const fn stage1_b8_lookup_index_from_digits(digits: [usize; 4]) -> usize { + digits[0] | (digits[1] << 2) | (digits[2] << 4) | (digits[3] << 6) +} + +const fn build_stage1_b8_prefix_lookup_table() -> [[i64; STAGE1_PREFIX_EVAL_COUNT]; 256] { + let mut table = [[0i64; STAGE1_PREFIX_EVAL_COUNT]; 256]; + let mut d0 = 0usize; + while d0 < 4 { + let mut d1 = 0usize; + while d1 < 4 { + let mut d2 = 0usize; + while d2 < 4 { + let mut d3 = 0usize; + while d3 < 4 { + let quad = [ + STAGE1_B8_S_VALUES[d0], + STAGE1_B8_S_VALUES[d1], + STAGE1_B8_S_VALUES[d2], + STAGE1_B8_S_VALUES[d3], + ]; + let table_idx = stage1_b8_lookup_index_from_digits([d0, d1, d2, d3]); + let mut point_idx = 0usize; + while point_idx < STAGE1_PREFIX_EVAL_COUNT { + let (x, y) = STAGE1_PREFIX_LOOKUP_POINTS_I64[point_idx]; + table[table_idx][point_idx] = stage1_b8_local_norm_raw_eval_i64(quad, x, y); + point_idx += 1; + } + d3 += 1; + } + d2 += 1; + } + d1 += 1; + } + d0 += 1; + } + table +} + +static STAGE1_B8_PREFIX_LOOKUP_TABLE: [[i64; STAGE1_PREFIX_EVAL_COUNT]; 256] = + build_stage1_b8_prefix_lookup_table(); + +const STAGE2_PREFIX_LOOKUP_POINTS_I64: [(i64, i64); STAGE2_PREFIX_POINT_COUNT] = [ + (0, 0), + (0, 1), + (0, LOOKUP_PREFIX_INF), + (1, 0), + (1, 1), + (1, LOOKUP_PREFIX_INF), + (LOOKUP_PREFIX_INF, 0), + (LOOKUP_PREFIX_INF, 1), + (LOOKUP_PREFIX_INF, LOOKUP_PREFIX_INF), +]; + +const fn stage2_b8_lookup_index_from_digits(digits: [usize; 4]) -> usize { + digits[0] | (digits[1] << 3) | (digits[2] << 6) | (digits[3] << 9) +} + +const fn stage2_local_norm_raw_eval_i64(w_quad: [i64; 4], x: i64, y: i64) -> i64 { + let w_eval = lookup_bilinear_eval_on_prefix_points(w_quad, x, y); + if x == LOOKUP_PREFIX_INF || y == LOOKUP_PREFIX_INF { + w_eval * w_eval + } else { + w_eval * (w_eval + 1) + } +} + +const fn compress_stage2_lookup_values( + values: [i64; STAGE2_PREFIX_POINT_COUNT], + omitted_idx: usize, +) -> [i64; STAGE2_COMPRESSED_POINT_COUNT] { + let mut out = [0i64; STAGE2_COMPRESSED_POINT_COUNT]; + let mut src_idx = 0usize; + let mut dst_idx = 0usize; + while src_idx < STAGE2_PREFIX_POINT_COUNT { + if src_idx != omitted_idx { + out[dst_idx] = values[src_idx]; + dst_idx += 1; + } + src_idx += 1; + } + out +} + +const fn build_stage2_b8_norm_lookup_table() -> [[i64; STAGE2_PREFIX_POINT_COUNT]; 4096] { + let mut table = [[0i64; STAGE2_PREFIX_POINT_COUNT]; 4096]; + let mut d0 = 0usize; + while d0 < 8 { + let mut d1 = 0usize; + while d1 < 8 { + let mut d2 = 0usize; + while d2 < 8 { + let mut d3 = 0usize; + while d3 < 8 { + let quad = [ + STAGE2_B8_W_VALUES[d0], + STAGE2_B8_W_VALUES[d1], + STAGE2_B8_W_VALUES[d2], + STAGE2_B8_W_VALUES[d3], + ]; + let table_idx = stage2_b8_lookup_index_from_digits([d0, d1, d2, d3]); + let mut point_idx = 0usize; + while point_idx < STAGE2_PREFIX_POINT_COUNT { + let (x, y) = STAGE2_PREFIX_LOOKUP_POINTS_I64[point_idx]; + table[table_idx][point_idx] = stage2_local_norm_raw_eval_i64(quad, x, y); + point_idx += 1; + } + d3 += 1; + } + d2 += 1; + } + d1 += 1; + } + d0 += 1; + } + table +} + +static STAGE2_B8_NORM_LOOKUP_TABLE: [[i64; STAGE2_PREFIX_POINT_COUNT]; 4096] = + build_stage2_b8_norm_lookup_table(); + +const fn build_stage2_b8_relation_weight_table() -> [[i64; STAGE2_PREFIX_POINT_COUNT]; 4096] { + let mut table = [[0i64; STAGE2_PREFIX_POINT_COUNT]; 4096]; + let mut d0 = 0usize; + while d0 < 8 { + let mut d1 = 0usize; + while d1 < 8 { + let mut d2 = 0usize; + while d2 < 8 { + let mut d3 = 0usize; + while d3 < 8 { + let quad = [ + STAGE2_B8_W_VALUES[d0], + STAGE2_B8_W_VALUES[d1], + STAGE2_B8_W_VALUES[d2], + STAGE2_B8_W_VALUES[d3], + ]; + let table_idx = stage2_b8_lookup_index_from_digits([d0, d1, d2, d3]); + let mut point_idx = 0usize; + while point_idx < STAGE2_PREFIX_POINT_COUNT { + let (x, y) = STAGE2_PREFIX_LOOKUP_POINTS_I64[point_idx]; + table[table_idx][point_idx] = + lookup_bilinear_eval_on_prefix_points(quad, x, y); + point_idx += 1; + } + d3 += 1; + } + d2 += 1; + } + d1 += 1; + } + d0 += 1; + } + table +} + +static STAGE2_B8_RELATION_WEIGHT_TABLE: [[i64; STAGE2_PREFIX_POINT_COUNT]; 4096] = + build_stage2_b8_relation_weight_table(); + +const fn build_stage2_b8_relation_weight_compressed_table( +) -> [[i64; STAGE2_COMPRESSED_POINT_COUNT]; 4096] { + let mut table = [[0i64; STAGE2_COMPRESSED_POINT_COUNT]; 4096]; + let mut table_idx = 0usize; + while table_idx < 4096 { + table[table_idx] = + compress_stage2_lookup_values(STAGE2_B8_RELATION_WEIGHT_TABLE[table_idx], 0); + table_idx += 1; + } + table +} + +static STAGE2_B8_RELATION_WEIGHT_COMPRESSED_TABLE: [[i64; STAGE2_COMPRESSED_POINT_COUNT]; 4096] = + build_stage2_b8_relation_weight_compressed_table(); + +#[inline] +fn reduce_signed_lookup_accum( + pos: E::MulU64Accum, + neg: E::MulU64Accum, +) -> E { + E::reduce_mul_u64_accum(pos) - E::reduce_mul_u64_accum(neg) +} + +#[inline] +fn accum_lookup_vector_signed( + pos: &mut [E::MulU64Accum; N], + neg: &mut [E::MulU64Accum; N], + coeff: E, + values: &[i64; N], +) { + for (idx, &value) in values.iter().enumerate() { + if value > 0 { + pos[idx] += coeff.mul_u64_unreduced(value as u64); + } else if value < 0 { + neg[idx] += coeff.mul_u64_unreduced(value.unsigned_abs()); + } + } +} + +#[inline] +fn accum_lookup_vector_signed_selected< + E: FieldCore + HasUnreducedOps, + const N: usize, + const M: usize, +>( + pos: &mut [E::MulU64Accum; N], + neg: &mut [E::MulU64Accum; N], + coeff: E, + values: &[i64; M], + selected_indices: &[usize; N], +) { + for (dst_idx, &src_idx) in selected_indices.iter().enumerate() { + let value = values[src_idx]; + if value > 0 { + pos[dst_idx] += coeff.mul_u64_unreduced(value as u64); + } else if value < 0 { + neg[dst_idx] += coeff.mul_u64_unreduced(value.unsigned_abs()); + } + } +} + +#[inline] +fn accum_pointwise_signed( + pos: &mut [E::MulU64Accum; N], + neg: &mut [E::MulU64Accum; N], + coeffs: &[E; N], + weights: &[i64; N], +) { + for (idx, (&coeff, &weight)) in coeffs.iter().zip(weights.iter()).enumerate() { + if weight > 0 { + pos[idx] += coeff.mul_u64_unreduced(weight as u64); + } else if weight < 0 { + neg[idx] += coeff.mul_u64_unreduced(weight.unsigned_abs()); + } + } +} + +#[inline] +#[cfg(test)] +fn stage1_b8_s_digit_from_compact_w(w: i8) -> usize { + let w = i32::from(w); + debug_assert!((-4..=3).contains(&w)); + if w < 0 { + (-w - 1) as usize + } else { + w as usize + } +} + +#[inline] +fn stage1_b8_s_digit_from_compact_s(s: i32) -> usize { + match s { + 0 => 0, + 2 => 1, + 6 => 2, + 12 => 3, + other => unreachable!("unexpected compact s value {other}"), + } +} + +#[inline] +fn stage2_b8_w_digit(w: i8) -> usize { + let w = i32::from(w); + debug_assert!((-4..=3).contains(&w)); + (w + 4) as usize +} + +#[inline] +fn stage2_relation_m_point_values_compressed( + m_quad: [E; 4], +) -> [E; STAGE2_COMPRESSED_POINT_COUNT] { + let m00 = m_quad[0]; + let m10 = m_quad[1]; + let m01 = m_quad[2]; + let m11 = m_quad[3]; + [ + m01, + m01 - m00, + m10, + m11, + m11 - m10, + m10 - m00, + m11 - m01, + m11 - m10 - m01 + m00, + ] +} + +#[inline] +fn stage1_quartic_coeffs_from_prefix_values(values: [E; 5]) -> [E; 5] { + let [at_0, at_1, at_neg_1, at_2, at_inf] = values; + let two_inv = E::from_u64(2) + .inv() + .expect("stage1 prefix interpolation requires 2 to be invertible"); + let three_inv = E::from_u64(3) + .inv() + .expect("stage1 prefix interpolation requires 3 to be invertible"); + + let a0 = at_0; + let a4 = at_inf; + let rhs_at_1 = at_1 - a0 - a4; + let rhs_at_neg_1 = at_neg_1 - a0 - a4; + let a2 = (rhs_at_1 + rhs_at_neg_1) * two_inv; + let a1_plus_a3 = (rhs_at_1 - rhs_at_neg_1) * two_inv; + let rhs_at_2 = at_2 - a0 - E::from_u64(16) * a4; + let a1_plus_4a3 = rhs_at_2 * two_inv - E::from_u64(2) * a2; + let a3 = (a1_plus_4a3 - a1_plus_a3) * three_inv; + let a1 = a1_plus_a3 - a3; + [a0, a1, a2, a3, a4] +} + +#[inline] +fn stage1_eval_quartic_from_prefix_values(values: [E; 5], x: E) -> E { + let [a0, a1, a2, a3, a4] = stage1_quartic_coeffs_from_prefix_values(values); + a0 + x * (a1 + x * (a2 + x * (a3 + x * a4))) +} + +#[inline] +fn eval_stage1_biquartic_from_full_grid( + full_grid: [E; 25], + x: E, + y: E, +) -> E { + let x_rows = std::array::from_fn(|x_idx| { + stage1_eval_quartic_from_prefix_values( + [ + full_grid[stage1_full_grid_index(x_idx, 0)], + full_grid[stage1_full_grid_index(x_idx, 1)], + full_grid[stage1_full_grid_index(x_idx, 2)], + full_grid[stage1_full_grid_index(x_idx, 3)], + full_grid[stage1_full_grid_index(x_idx, 4)], + ], + y, + ) + }); + stage1_eval_quartic_from_prefix_values(x_rows, x) +} + +/// Whether stage 1 has enough x-rounds to use the 2-round prefix path. +#[inline] +pub(crate) fn can_use_stage1_two_round_prefix(num_u: usize, b: usize) -> bool { + num_u >= 2 && b == 8 +} + +/// Build the stage-1 first-two-round bivariate-skip proof from the compact witness +/// rows at the start of stage 1. +/// +/// Returns `None` when there are fewer than two x-rounds to batch. +#[tracing::instrument( + skip_all, + name = "two_round_prefix::build_stage1_bivariate_skip_proof_from_compact" +)] +#[cfg(test)] +pub(crate) fn build_stage1_bivariate_skip_proof_from_compact< + E: FieldCore + FromSmallInt + HasUnreducedOps, +>( + w_compact: &[i8], + tau0: &[E], + b: usize, + live_x_cols: usize, + num_u: usize, + num_l: usize, +) -> Option> { + if !can_use_stage1_two_round_prefix(num_u, b) { + return None; + } + + let y_len = 1usize << num_l; + assert_eq!(w_compact.len(), live_x_cols * y_len); + assert_eq!(tau0.len(), num_u + num_l); + + let eq_x_suffix = EqPolynomial::evals(&tau0[2..num_u]); + let eq_y = EqPolynomial::evals(&tau0[num_u..]); + let live_x_quads = live_x_cols.div_ceil(4); + debug_assert!(eq_x_suffix.len() >= live_x_quads); + + let (pos, neg) = cfg_fold_reduce!( + 0..y_len, + || { + ( + [E::MulU64Accum::ZERO; STAGE1_PREFIX_EVAL_COUNT], + [E::MulU64Accum::ZERO; STAGE1_PREFIX_EVAL_COUNT], + ) + }, + |(mut pos, mut neg), y_row| { + let row = &w_compact[y_row * live_x_cols..(y_row + 1) * live_x_cols]; + let eq_y_weight = eq_y[y_row]; + for (x_quad, &eq_x_weight) in eq_x_suffix.iter().take(live_x_quads).enumerate() { + let base = 4 * x_quad; + let lookup_idx = stage1_b8_lookup_index_from_digits([ + if base < live_x_cols { + stage1_b8_s_digit_from_compact_w(row[base]) + } else { + 0 + }, + if base + 1 < live_x_cols { + stage1_b8_s_digit_from_compact_w(row[base + 1]) + } else { + 0 + }, + if base + 2 < live_x_cols { + stage1_b8_s_digit_from_compact_w(row[base + 2]) + } else { + 0 + }, + if base + 3 < live_x_cols { + stage1_b8_s_digit_from_compact_w(row[base + 3]) + } else { + 0 + }, + ]); + let weight = eq_x_weight * eq_y_weight; + accum_lookup_vector_signed( + &mut pos, + &mut neg, + weight, + &STAGE1_B8_PREFIX_LOOKUP_TABLE[lookup_idx], + ); + } + (pos, neg) + }, + |(mut pos_a, mut neg_a), (pos_b, neg_b)| { + for (dst, src) in pos_a.iter_mut().zip(pos_b.iter()) { + *dst += *src; + } + for (dst, src) in neg_a.iter_mut().zip(neg_b.iter()) { + *dst += *src; + } + (pos_a, neg_a) + } + ); + let evals_except_boolean_core = (0..STAGE1_PREFIX_EVAL_COUNT) + .map(|idx| reduce_signed_lookup_accum::(pos[idx], neg[idx])) + .collect(); + + Some(Stage1BivariateSkipProof { + evals_except_boolean_core, + }) +} + +/// Build the stage-1 first-two-round bivariate-skip proof from the compact +/// `s = w(w+1)` table already materialized by the prover. +#[tracing::instrument( + skip_all, + name = "two_round_prefix::build_stage1_bivariate_skip_proof_from_s_compact" +)] +pub(crate) fn build_stage1_bivariate_skip_proof_from_s_compact< + E: FieldCore + FromSmallInt + HasUnreducedOps, +>( + s_compact: &[i32], + tau0: &[E], + b: usize, + live_x_cols: usize, + num_u: usize, + num_l: usize, +) -> Option> { + if !can_use_stage1_two_round_prefix(num_u, b) { + return None; + } + + let y_len = 1usize << num_l; + assert_eq!(s_compact.len(), live_x_cols * y_len); + assert_eq!(tau0.len(), num_u + num_l); + + let eq_x_suffix = EqPolynomial::evals(&tau0[2..num_u]); + let eq_y = EqPolynomial::evals(&tau0[num_u..]); + let live_x_quads = live_x_cols.div_ceil(4); + debug_assert!(eq_x_suffix.len() >= live_x_quads); + + let (pos, neg) = cfg_fold_reduce!( + 0..y_len, + || { + ( + [E::MulU64Accum::ZERO; STAGE1_PREFIX_EVAL_COUNT], + [E::MulU64Accum::ZERO; STAGE1_PREFIX_EVAL_COUNT], + ) + }, + |(mut pos, mut neg), y_row| { + let row = &s_compact[y_row * live_x_cols..(y_row + 1) * live_x_cols]; + let eq_y_weight = eq_y[y_row]; + for (x_quad, &eq_x_weight) in eq_x_suffix.iter().take(live_x_quads).enumerate() { + let base = 4 * x_quad; + let lookup_idx = stage1_b8_lookup_index_from_digits([ + if base < live_x_cols { + stage1_b8_s_digit_from_compact_s(row[base]) + } else { + 0 + }, + if base + 1 < live_x_cols { + stage1_b8_s_digit_from_compact_s(row[base + 1]) + } else { + 0 + }, + if base + 2 < live_x_cols { + stage1_b8_s_digit_from_compact_s(row[base + 2]) + } else { + 0 + }, + if base + 3 < live_x_cols { + stage1_b8_s_digit_from_compact_s(row[base + 3]) + } else { + 0 + }, + ]); + let weight = eq_x_weight * eq_y_weight; + accum_lookup_vector_signed( + &mut pos, + &mut neg, + weight, + &STAGE1_B8_PREFIX_LOOKUP_TABLE[lookup_idx], + ); + } + (pos, neg) + }, + |(mut pos_a, mut neg_a), (pos_b, neg_b)| { + for (dst, src) in pos_a.iter_mut().zip(pos_b.iter()) { + *dst += *src; + } + for (dst, src) in neg_a.iter_mut().zip(neg_b.iter()) { + *dst += *src; + } + (pos_a, neg_a) + } + ); + let evals_except_boolean_core = (0..STAGE1_PREFIX_EVAL_COUNT) + .map(|idx| reduce_signed_lookup_accum::(pos[idx], neg[idx])) + .collect(); + + Some(Stage1BivariateSkipProof { + evals_except_boolean_core, + }) +} + +#[cfg(test)] +fn stage1_storage_vector_from_quad(quad: [E; 4], b: usize) -> Vec { + let points = stage1_full_prefix_points::(); + let mut out = Vec::with_capacity(STAGE1_PREFIX_EVAL_COUNT); + for x_idx in 0..5 { + for y_idx in 0..5 { + if stage1_is_boolean_corner(x_idx, y_idx) { + continue; + } + out.push(stage1_local_norm_raw_eval( + quad, + points[x_idx], + points[y_idx], + b, + )); + } + } + out +} + +/// State needed to reconstruct the first two stage-1 rounds from the +/// serialized bivariate-skip proof. +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct Stage1BivariateSkipState { + full_grid: [E; 25], + tau0: E, + tau1: E, +} + +impl Stage1BivariateSkipState { + pub(crate) fn new(proof: &Stage1BivariateSkipProof, tau0: &[E], b: usize) -> Option { + if tau0.len() < 2 + || proof.evals_except_boolean_core.len() != STAGE1_PREFIX_EVAL_COUNT + || b != 8 + { + return None; + } + + let mut full_grid = [E::zero(); 25]; + let mut payload_idx = 0usize; + for x_idx in 0..5 { + for y_idx in 0..5 { + if stage1_is_boolean_corner(x_idx, y_idx) { + continue; + } + full_grid[stage1_full_grid_index(x_idx, y_idx)] = + proof.evals_except_boolean_core[payload_idx]; + payload_idx += 1; + } + } + + Some(Self { + full_grid, + tau0: tau0[0], + tau1: tau0[1], + }) + } +} + +impl Stage1BivariateSkipState { + #[inline] + fn linear_eq_eval(tau: E, x: E) -> E { + tau * x + (E::one() - tau) * (E::one() - x) + } + + pub(crate) fn reconstruct_round0_poly(&self) -> UniPoly { + let l1_at_0 = E::one() - self.tau1; + let l1_at_1 = self.tau1; + let evals: Vec = (0..=5u64) + .map(|x_raw| { + let x = E::from_u64(x_raw); + let q_x0 = eval_stage1_biquartic_from_full_grid(self.full_grid, x, E::zero()); + let q_x1 = eval_stage1_biquartic_from_full_grid(self.full_grid, x, E::one()); + Self::linear_eq_eval(self.tau0, x) * (l1_at_0 * q_x0 + l1_at_1 * q_x1) + }) + .collect(); + UniPoly::from_evals(&evals) + } + + pub(crate) fn reconstruct_round1_poly(&self, r0: E) -> UniPoly { + let l0_at_r0 = Self::linear_eq_eval(self.tau0, r0); + let evals: Vec = (0..=5u64) + .map(|y_raw| { + let y = E::from_u64(y_raw); + l0_at_r0 + * Self::linear_eq_eval(self.tau1, y) + * eval_stage1_biquartic_from_full_grid(self.full_grid, r0, y) + }) + .collect(); + UniPoly::from_evals(&evals) + } +} + +/// Proposed reduced stage-2 domain `{1, Infinity}`. +#[cfg(test)] +pub(crate) fn stage2_reduced_prefix_points() -> [PrefixPoint; 2] { + [PrefixPoint::Finite(E::one()), PrefixPoint::Infinity] +} + +/// Safe full stage-2 fallback domain `{0, 1, Infinity}`. +#[cfg(test)] +pub(crate) fn stage2_full_prefix_points() -> [PrefixPoint; 3] { + [ + PrefixPoint::Finite(E::zero()), + PrefixPoint::Finite(E::one()), + PrefixPoint::Infinity, + ] +} + +/// Return the bilinear coefficients for a quad ordered as `[t00, t10, t01, t11]`. +#[inline] +#[cfg(test)] +pub(crate) fn bilinear_coeffs_from_quad(quad: [E; 4]) -> [E; 4] { + let [t00, t10, t01, t11] = quad; + [t00, t10 - t00, t01 - t00, t11 - t10 - t01 + t00] +} + +/// Evaluate the bilinear multilinear extension of a quad at ordinary field +/// points `(x, y)`. +#[inline] +#[cfg(test)] +pub(crate) fn bilinear_eval(quad: [E; 4], x: E, y: E) -> E { + let [a, b, c, d] = bilinear_coeffs_from_quad(quad); + a + x * (b + y * d) + y * c +} + +/// Evaluate a quad on a small domain where `Infinity` means "leading +/// coefficient in that coordinate". +#[inline] +#[cfg(test)] +pub(crate) fn bilinear_eval_on_prefix_points( + quad: [E; 4], + x: PrefixPoint, + y: PrefixPoint, +) -> E { + let [a, b, c, d] = bilinear_coeffs_from_quad(quad); + match (x, y) { + (PrefixPoint::Finite(x), PrefixPoint::Finite(y)) => a + x * (b + y * d) + y * c, + (PrefixPoint::Infinity, PrefixPoint::Finite(y)) => b + y * d, + (PrefixPoint::Finite(x), PrefixPoint::Infinity) => c + x * d, + (PrefixPoint::Infinity, PrefixPoint::Infinity) => d, + } +} + +/// Evaluate the stage-1 candidate storage contribution used by the original +/// `{1, -1, 2, Infinity}^2` proposal. +#[inline] +#[cfg(test)] +pub(crate) fn stage1_local_norm_eval( + s_quad: [E; 4], + x: PrefixPoint, + y: PrefixPoint, + b: usize, +) -> E { + let s_eval = bilinear_eval_on_prefix_points(s_quad, x, y); + range_check_eval_from_s(s_eval, b) +} + +/// Evaluate the raw stage-1 full-domain polynomial on +/// `{0, 1, -1, 2, Infinity}^2`. +/// +/// At `Infinity`, we take the leading coefficient in that coordinate of the +/// composed range-check polynomial `range_check(s(X, Y))`, rather than first +/// evaluating `s` at `Infinity` and then applying the range check. +#[inline] +#[cfg(test)] +pub(crate) fn stage1_local_norm_raw_eval( + s_quad: [E; 4], + x: PrefixPoint, + y: PrefixPoint, + b: usize, +) -> E { + let [_, bx, cy, dxy] = bilinear_coeffs_from_quad(s_quad); + let degree = b / 2; + let pow = |base: E| { + let mut out = E::one(); + for _ in 0..degree { + out = out * base; + } + out + }; + + match (x, y) { + (PrefixPoint::Finite(x), PrefixPoint::Finite(y)) => { + range_check_eval_from_s(bilinear_eval(s_quad, x, y), b) + } + (PrefixPoint::Infinity, PrefixPoint::Finite(y)) => pow(bx + y * dxy), + (PrefixPoint::Finite(x), PrefixPoint::Infinity) => pow(cy + x * dxy), + (PrefixPoint::Infinity, PrefixPoint::Infinity) => pow(dxy), + } +} + +/// Evaluate the stage-2 local norm candidate used by the proposed reduced +/// `{1, Infinity}^2` storage: evaluate the bilinear witness first, then apply +/// `w (w + 1)`. +#[inline] +#[cfg(test)] +pub(crate) fn stage2_local_norm_candidate_eval( + w_quad: [E; 4], + x: PrefixPoint, + y: PrefixPoint, +) -> E { + let w_eval = bilinear_eval_on_prefix_points(w_quad, x, y); + w_eval * (w_eval + E::one()) +} + +/// Evaluate the raw degree-`(2,2)` stage-2 norm polynomial on the safe full +/// `{0, 1, Infinity}^2` fallback domain. +/// +/// At `Infinity`, we take the leading coefficient in that coordinate of +/// `w(X, Y) * (w(X, Y) + 1)`, so the linear `+w` term drops out. +#[inline] +#[cfg(test)] +pub(crate) fn stage2_local_norm_raw_eval( + w_quad: [E; 4], + x: PrefixPoint, + y: PrefixPoint, +) -> E { + let w_eval = bilinear_eval_on_prefix_points(w_quad, x, y); + match (x, y) { + (PrefixPoint::Finite(_), PrefixPoint::Finite(_)) => w_eval * (w_eval + E::one()), + _ => w_eval * w_eval, + } +} + +/// Evaluate the stage-2 local relation contribution for one `(w^4, m^4)` pair. +#[inline] +#[cfg(test)] +pub(crate) fn stage2_local_relation_eval( + w_quad: [E; 4], + m_quad: [E; 4], + alpha: E, + x: PrefixPoint, + y: PrefixPoint, +) -> E { + alpha + * bilinear_eval_on_prefix_points(w_quad, x, y) + * bilinear_eval_on_prefix_points(m_quad, x, y) +} + +/// Boolean corner in the `{0, 1}^2` sub-grid of the stage-2 full domain. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum BooleanCorner { + ZeroZero, + ZeroOne, + OneZero, + OneOne, +} + +impl BooleanCorner { + pub(crate) const ALL: [Self; 4] = [Self::ZeroZero, Self::ZeroOne, Self::OneZero, Self::OneOne]; + #[cfg(test)] + pub(crate) const DEFAULT_STAGE2_NORM: Self = Self::ZeroZero; + pub(crate) const DEFAULT_STAGE2_RELATION: Self = Self::ZeroZero; + + #[inline] + pub(crate) fn default_norm_order() -> [Self; 4] { + Self::ALL + } + + #[inline] + fn boolean_index(self) -> usize { + match self { + Self::ZeroZero => 0, + Self::ZeroOne => 1, + Self::OneZero => 2, + Self::OneOne => 3, + } + } + + #[inline] + fn grid_index(self) -> usize { + match self { + Self::ZeroZero => 0, + Self::ZeroOne => 1, + Self::OneZero => 3, + Self::OneOne => 4, + } + } +} + +/// Compressed stage-2 `{0, 1, Infinity}^2` grid with one omitted Boolean corner. +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct Stage2CompressedGrid { + pub omitted_corner: BooleanCorner, + pub evals_except_corner: [E; 8], +} + +impl Stage2CompressedGrid { + #[cfg(test)] + pub(crate) fn from_full_grid(full_grid: [E; 9], omitted_corner: BooleanCorner) -> Self { + let omitted_idx = omitted_corner.grid_index(); + let mut out_idx = 0usize; + let evals_except_corner = std::array::from_fn(|_| { + while out_idx == omitted_idx { + out_idx += 1; + } + let value = full_grid[out_idx]; + out_idx += 1; + value + }); + Self { + omitted_corner, + evals_except_corner, + } + } + + pub(crate) fn reconstruct_with_corner_value(&self, omitted_value: E) -> [E; 9] { + let omitted_idx = self.omitted_corner.grid_index(); + let mut src_idx = 0usize; + std::array::from_fn(|dst_idx| { + if dst_idx == omitted_idx { + omitted_value + } else { + let value = self.evals_except_corner[src_idx]; + src_idx += 1; + value + } + }) + } +} + +/// Serializable stage-2 first-two-round bivariate-skip proof. +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct Stage2BivariateSkipProof { + pub norm: Stage2CompressedGrid, + pub relation: Stage2CompressedGrid, +} + +/// Return the stage-2 full-domain grid in row-major `x`-major order over +/// `{0, 1, Infinity}^2`. +#[cfg(test)] +pub(crate) fn stage2_full_grid_values( + mut eval: impl FnMut(PrefixPoint, PrefixPoint) -> E, +) -> [E; 9] { + let points = stage2_full_prefix_points::(); + std::array::from_fn(|idx| { + let x = points[idx / 3]; + let y = points[idx % 3]; + eval(x, y) + }) +} + +/// Evaluate a quadratic from its values at `{0, 1, Infinity}`. +#[inline] +#[cfg(test)] +pub(crate) fn eval_quadratic_from_01_inf( + at_zero: E, + at_one: E, + at_inf: E, + x: PrefixPoint, +) -> E { + match x { + PrefixPoint::Infinity => at_inf, + PrefixPoint::Finite(x) => { + let linear = at_one - at_zero - at_inf; + at_zero + x * (linear + x * at_inf) + } + } +} + +#[inline] +pub(crate) fn quadratic_coeffs_from_01_inf( + at_zero: E, + at_one: E, + at_inf: E, +) -> [E; 3] { + [at_zero, at_one - at_zero - at_inf, at_inf] +} + +#[inline] +fn eval_quadratic_from_coeffs(coeffs: [E; 3], x: E) -> E { + coeffs[0] + x * (coeffs[1] + x * coeffs[2]) +} + +#[inline] +fn linear_eq_coeffs(tau: E) -> [E; 2] { + [E::one() - tau, tau + tau - E::one()] +} + +#[inline] +fn scale_quadratic_coeffs(coeffs: [E; 3], scale: E) -> [E; 3] { + [scale * coeffs[0], scale * coeffs[1], scale * coeffs[2]] +} + +#[inline] +fn add_quadratic_coeffs(lhs: [E; 3], rhs: [E; 3]) -> [E; 3] { + [lhs[0] + rhs[0], lhs[1] + rhs[1], lhs[2] + rhs[2]] +} + +#[inline] +fn mul_linear_by_quadratic_coeffs(tau: E, quad: [E; 3]) -> [E; 4] { + let [l0, l1] = linear_eq_coeffs(tau); + [ + l0 * quad[0], + l0 * quad[1] + l1 * quad[0], + l0 * quad[2] + l1 * quad[1], + l1 * quad[2], + ] +} + +/// Evaluate a biquadratic from its full `{0, 1, Infinity}^2` grid. +#[inline] +#[cfg(test)] +pub(crate) fn eval_biquadratic_from_full_grid( + full_grid: [E; 9], + x: PrefixPoint, + y: PrefixPoint, +) -> E { + let q_y0 = eval_quadratic_from_01_inf(full_grid[0], full_grid[3], full_grid[6], x); + let q_y1 = eval_quadratic_from_01_inf(full_grid[1], full_grid[4], full_grid[7], x); + let q_yinf = eval_quadratic_from_01_inf(full_grid[2], full_grid[5], full_grid[8], x); + eval_quadratic_from_01_inf(q_y0, q_y1, q_yinf, y) +} + +/// Return the local claim weights for the four Boolean corners of the stage-2 +/// norm half, ordered as `[(0,0), (0,1), (1,0), (1,1)]`. +#[inline] +pub(crate) fn stage2_norm_corner_weights_from_linear_evals( + l0_at_0: E, + l0_at_1: E, + l1_at_0: E, + l1_at_1: E, +) -> [E; 4] { + [ + l0_at_0 * l1_at_0, + l0_at_0 * l1_at_1, + l0_at_1 * l1_at_0, + l0_at_1 * l1_at_1, + ] +} + +/// Return the local claim weights for the four Boolean corners of the stage-2 +/// norm half when the two local eq factors are `eq(tau0, X)` and `eq(tau1, Y)`. +#[inline] +pub(crate) fn stage2_norm_corner_weights_from_taus(tau0: E, tau1: E) -> [E; 4] { + stage2_norm_corner_weights_from_linear_evals(E::one() - tau0, tau0, E::one() - tau1, tau1) +} + +/// Choose the default omitted corner for stage-2 norm compression, preferring +/// `(0,0)` when its claim weight is nonzero. +#[inline] +pub(crate) fn default_stage2_norm_omitted_corner( + corner_weights: [E; 4], +) -> BooleanCorner { + for corner in BooleanCorner::default_norm_order() { + if !corner_weights[corner.boolean_index()].is_zero() { + return corner; + } + } + unreachable!("at least one Boolean-corner weight must be nonzero"); +} + +/// Recover a full stage-2 grid from an omitted-corner compression and a +/// weighted Boolean-corner claim relation. +pub(crate) fn recover_stage2_grid_from_corner_claim( + compressed: &Stage2CompressedGrid, + corner_weights: [E; 4], + claim: E, +) -> Option<[E; 9]> { + let omitted_weight = corner_weights[compressed.omitted_corner.boolean_index()]; + let omitted_weight_inv = omitted_weight.inv()?; + let mut full_grid = compressed.reconstruct_with_corner_value(E::zero()); + let known_sum = BooleanCorner::ALL + .iter() + .copied() + .filter(|corner| *corner != compressed.omitted_corner) + .fold(E::zero(), |acc, corner| { + acc + corner_weights[corner.boolean_index()] * full_grid[corner.grid_index()] + }); + let omitted_value = (claim - known_sum) * omitted_weight_inv; + full_grid[compressed.omitted_corner.grid_index()] = omitted_value; + Some(full_grid) +} + +/// Recover a full stage-2 relation grid from its default `(0,0)` omission. +#[inline] +pub(crate) fn recover_stage2_relation_grid_from_claim( + compressed: &Stage2CompressedGrid, + relation_claim: E, +) -> [E; 9] { + recover_stage2_grid_from_corner_claim(compressed, [E::one(); 4], relation_claim) + .expect("relation corner weights are all one") +} + +/// Recover a full stage-2 norm grid from an omitted Boolean corner and the +/// weighted local norm claim. +#[inline] +pub(crate) fn recover_stage2_norm_grid_from_claim( + compressed: &Stage2CompressedGrid, + corner_weights: [E; 4], + norm_claim: E, +) -> Option<[E; 9]> { + recover_stage2_grid_from_corner_claim(compressed, corner_weights, norm_claim) +} + +/// Whether stage 2 has enough x-rounds to use the 2-round prefix path. +#[inline] +pub(crate) fn can_use_stage2_two_round_prefix(num_u: usize) -> bool { + num_u >= 2 +} + +/// Build the stage-2 first-two-round bivariate-skip proof from the compact witness +/// rows at the start of stage 2. +/// +/// Returns `None` when there are fewer than two x-rounds to batch. +#[tracing::instrument( + skip_all, + name = "two_round_prefix::build_stage2_bivariate_skip_proof_from_compact" +)] +pub(crate) fn build_stage2_bivariate_skip_proof_from_compact< + E: FieldCore + FromSmallInt + HasUnreducedOps, +>( + w_compact: &[i8], + alpha_evals_y: &[E], + m_evals_x: &[E], + r_stage1: &[E], + live_x_cols: usize, + num_u: usize, + num_l: usize, +) -> Option> { + if !can_use_stage2_two_round_prefix(num_u) { + return None; + } + + let y_len = 1usize << num_l; + assert_eq!(alpha_evals_y.len(), y_len); + assert_eq!(w_compact.len(), live_x_cols * y_len); + assert_eq!(m_evals_x.len(), 1usize << num_u); + assert_eq!(r_stage1.len(), num_u + num_l); + + let eq_x_suffix = EqPolynomial::evals(&r_stage1[2..num_u]); + let eq_y = EqPolynomial::evals(&r_stage1[num_u..]); + let live_x_quads = live_x_cols.div_ceil(4); + debug_assert!(eq_x_suffix.len() >= live_x_quads); + let norm_omitted_corner = default_stage2_norm_omitted_corner( + stage2_norm_corner_weights_from_taus(r_stage1[0], r_stage1[1]), + ); + let norm_point_indices = + &STAGE2_COMPRESSED_POINT_INDICES_BY_OMITTED_CORNER[norm_omitted_corner.boolean_index()]; + let m_point_values_by_quad: Vec<[E; STAGE2_COMPRESSED_POINT_COUNT]> = (0..live_x_quads) + .map(|x_quad| { + let base = 4 * x_quad; + let m_quad = std::array::from_fn(|offset| m_evals_x[base + offset]); + stage2_relation_m_point_values_compressed(m_quad) + }) + .collect(); + + let (norm_pos, norm_neg, rel_accum) = cfg_fold_reduce!( + 0..y_len, + || { + ( + [E::MulU64Accum::ZERO; STAGE2_COMPRESSED_POINT_COUNT], + [E::MulU64Accum::ZERO; STAGE2_COMPRESSED_POINT_COUNT], + [E::ProductAccum::ZERO; STAGE2_COMPRESSED_POINT_COUNT], + ) + }, + |(mut norm_pos, mut norm_neg, mut rel_accum), y_idx| { + let row = &w_compact[y_idx * live_x_cols..(y_idx + 1) * live_x_cols]; + let alpha = alpha_evals_y[y_idx]; + let eq_y_weight = eq_y[y_idx]; + let mut row_rel_pos = [E::MulU64Accum::ZERO; STAGE2_COMPRESSED_POINT_COUNT]; + let mut row_rel_neg = [E::MulU64Accum::ZERO; STAGE2_COMPRESSED_POINT_COUNT]; + for (x_quad, &eq_x_weight) in eq_x_suffix.iter().take(live_x_quads).enumerate() { + let base = 4 * x_quad; + let lookup_idx = stage2_b8_lookup_index_from_digits([ + if base < live_x_cols { + stage2_b8_w_digit(row[base]) + } else { + 4 + }, + if base + 1 < live_x_cols { + stage2_b8_w_digit(row[base + 1]) + } else { + 4 + }, + if base + 2 < live_x_cols { + stage2_b8_w_digit(row[base + 2]) + } else { + 4 + }, + if base + 3 < live_x_cols { + stage2_b8_w_digit(row[base + 3]) + } else { + 4 + }, + ]); + let norm_weight = eq_x_weight * eq_y_weight; + accum_lookup_vector_signed_selected( + &mut norm_pos, + &mut norm_neg, + norm_weight, + &STAGE2_B8_NORM_LOOKUP_TABLE[lookup_idx], + norm_point_indices, + ); + accum_pointwise_signed( + &mut row_rel_pos, + &mut row_rel_neg, + &m_point_values_by_quad[x_quad], + &STAGE2_B8_RELATION_WEIGHT_COMPRESSED_TABLE[lookup_idx], + ); + } + for idx in 0..STAGE2_COMPRESSED_POINT_COUNT { + let row_rel = reduce_signed_lookup_accum::(row_rel_pos[idx], row_rel_neg[idx]); + rel_accum[idx] += alpha.mul_to_product_accum(row_rel); + } + (norm_pos, norm_neg, rel_accum) + }, + |(mut norm_pos_a, mut norm_neg_a, mut rel_accum_a), + (norm_pos_b, norm_neg_b, rel_accum_b)| { + for (dst, src) in norm_pos_a.iter_mut().zip(norm_pos_b.iter()) { + *dst += *src; + } + for (dst, src) in norm_neg_a.iter_mut().zip(norm_neg_b.iter()) { + *dst += *src; + } + for (dst, src) in rel_accum_a.iter_mut().zip(rel_accum_b.iter()) { + *dst += *src; + } + (norm_pos_a, norm_neg_a, rel_accum_a) + } + ); + let norm_evals_except_corner: [E; STAGE2_COMPRESSED_POINT_COUNT] = + std::array::from_fn(|idx| reduce_signed_lookup_accum::(norm_pos[idx], norm_neg[idx])); + let relation_evals_except_corner: [E; STAGE2_COMPRESSED_POINT_COUNT] = + std::array::from_fn(|idx| E::reduce_product_accum(rel_accum[idx])); + Some(Stage2BivariateSkipProof { + norm: Stage2CompressedGrid { + omitted_corner: norm_omitted_corner, + evals_except_corner: norm_evals_except_corner, + }, + relation: Stage2CompressedGrid { + omitted_corner: BooleanCorner::DEFAULT_STAGE2_RELATION, + evals_except_corner: relation_evals_except_corner, + }, + }) +} + +/// State needed to reconstruct the first two stage-2 rounds from the +/// serialized bivariate-skip proof. +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct Stage2BivariateSkipState { + norm_x_row_coeffs: [[E; 3]; 3], + relation_x_row_coeffs: [[E; 3]; 3], + tau0: E, + tau1: E, + batching_coeff: E, +} + +impl Stage2BivariateSkipState { + pub(crate) fn new( + proof: &Stage2BivariateSkipProof, + r_stage1: &[E], + s_claim: E, + relation_claim: E, + batching_coeff: E, + ) -> Option { + if r_stage1.len() < 2 { + return None; + } + let tau0 = r_stage1[0]; + let tau1 = r_stage1[1]; + let norm_full_grid = recover_stage2_norm_grid_from_claim( + &proof.norm, + stage2_norm_corner_weights_from_taus(tau0, tau1), + s_claim, + )?; + let relation_full_grid = + recover_stage2_relation_grid_from_claim(&proof.relation, relation_claim); + let norm_x_row_coeffs = std::array::from_fn(|y_idx| { + quadratic_coeffs_from_01_inf( + norm_full_grid[y_idx], + norm_full_grid[3 + y_idx], + norm_full_grid[6 + y_idx], + ) + }); + let relation_x_row_coeffs = std::array::from_fn(|y_idx| { + quadratic_coeffs_from_01_inf( + relation_full_grid[y_idx], + relation_full_grid[3 + y_idx], + relation_full_grid[6 + y_idx], + ) + }); + Some(Self { + norm_x_row_coeffs, + relation_x_row_coeffs, + tau0, + tau1, + batching_coeff, + }) + } +} + +impl Stage2BivariateSkipState { + #[inline] + pub(crate) fn reconstruct_round0_polys(&self) -> (UniPoly, UniPoly) { + let norm_q = add_quadratic_coeffs( + scale_quadratic_coeffs(self.norm_x_row_coeffs[0], E::one() - self.tau1), + scale_quadratic_coeffs(self.norm_x_row_coeffs[1], self.tau1), + ); + let mut norm_coeffs = mul_linear_by_quadratic_coeffs(self.tau0, norm_q); + for coeff in &mut norm_coeffs { + *coeff = self.batching_coeff * *coeff; + } + let relation_coeffs = + add_quadratic_coeffs(self.relation_x_row_coeffs[0], self.relation_x_row_coeffs[1]); + ( + UniPoly::from_coeffs(norm_coeffs.to_vec()), + UniPoly::from_coeffs(relation_coeffs.to_vec()), + ) + } + + #[inline] + pub(crate) fn reconstruct_round1_polys(&self, r0: E) -> (UniPoly, UniPoly) { + let norm_y_values: [E; 3] = std::array::from_fn(|y_idx| { + eval_quadratic_from_coeffs(self.norm_x_row_coeffs[y_idx], r0) + }); + let norm_q = + quadratic_coeffs_from_01_inf(norm_y_values[0], norm_y_values[1], norm_y_values[2]); + let round0_eq = Self::linear_eq_eval(self.tau0, r0); + let mut norm_coeffs = mul_linear_by_quadratic_coeffs(self.tau1, norm_q); + for coeff in &mut norm_coeffs { + *coeff = self.batching_coeff * round0_eq * *coeff; + } + let relation_y_values: [E; 3] = std::array::from_fn(|y_idx| { + eval_quadratic_from_coeffs(self.relation_x_row_coeffs[y_idx], r0) + }); + let relation_coeffs = quadratic_coeffs_from_01_inf( + relation_y_values[0], + relation_y_values[1], + relation_y_values[2], + ); + ( + UniPoly::from_coeffs(norm_coeffs.to_vec()), + UniPoly::from_coeffs(relation_coeffs.to_vec()), + ) + } + + #[inline] + fn linear_eq_eval(tau: E, x: E) -> E { + tau * x + (E::one() - tau) * (E::one() - x) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::algebra::Prime128M8M4M1M0; + use crate::protocol::sumcheck::hachi_stage1::HachiStage1Prover; + use crate::protocol::sumcheck::SumcheckInstanceProver; + use std::collections::HashMap; + + type F = Prime128M8M4M1M0; + + fn gaussian_rank(mut rows: Vec>) -> usize { + rows.retain(|row| row.iter().any(|x| !x.is_zero())); + if rows.is_empty() { + return 0; + } + + let num_cols = rows[0].len(); + let mut rank = 0usize; + let mut col = 0usize; + while rank < rows.len() && col < num_cols { + let Some(pivot_row) = (rank..rows.len()).find(|&r| !rows[r][col].is_zero()) else { + col += 1; + continue; + }; + rows.swap(rank, pivot_row); + let pivot_inv = rows[rank][col].inv().expect("pivot must be invertible"); + for entry in &mut rows[rank] { + *entry *= pivot_inv; + } + let pivot_snapshot = rows[rank].clone(); + for (row_idx, row) in rows.iter_mut().enumerate() { + if row_idx == rank || row[col].is_zero() { + continue; + } + let factor = row[col]; + for (entry, &pivot_entry) in row.iter_mut().zip(pivot_snapshot.iter()) { + *entry -= factor * pivot_entry; + } + } + rank += 1; + col += 1; + } + rank + } + + fn vec_key(vals: &[F]) -> String { + format!("{vals:?}") + } + + fn stage2_norm_round_values(w_quad: [F; 4], tau0: F, tau1: F, r0: F) -> Vec { + let l0 = |x: F| tau0 * x + (F::one() - tau0) * (F::one() - x); + let l1 = |y: F| tau1 * y + (F::one() - tau1) * (F::one() - y); + let q = |x: F, y: F| { + let w = bilinear_eval(w_quad, x, y); + w * (w + F::one()) + }; + + let mut out = Vec::new(); + for x in 0..=3u64 { + let x = F::from_u64(x); + out.push(l0(x) * (l1(F::zero()) * q(x, F::zero()) + l1(F::one()) * q(x, F::one()))); + } + for y in 0..=3u64 { + let y = F::from_u64(y); + out.push(l1(y) * l0(r0) * q(r0, y)); + } + out + } + + fn stage2_relation_round_values(w_quad: [F; 4], m_quad: [F; 4], r0: F) -> Vec { + let relation = |x: F, y: F| bilinear_eval(w_quad, x, y) * bilinear_eval(m_quad, x, y); + let mut out = Vec::new(); + for x in 0..=2u64 { + let x = F::from_u64(x); + out.push(relation(x, F::zero()) + relation(x, F::one())); + } + for y in 0..=2u64 { + let y = F::from_u64(y); + out.push(relation(r0, y)); + } + out + } + + fn stage2_norm_claim_from_full_grid(full_grid: [F; 9], corner_weights: [F; 4]) -> F { + BooleanCorner::ALL + .iter() + .copied() + .fold(F::zero(), |acc, corner| { + acc + corner_weights[corner.boolean_index()] * full_grid[corner.grid_index()] + }) + } + + fn stage2_relation_claim_from_full_grid(full_grid: [F; 9]) -> F { + stage2_norm_claim_from_full_grid(full_grid, [F::one(); 4]) + } + + fn stage2_norm_round_values_from_full_grid( + full_grid: [F; 9], + tau0: F, + tau1: F, + r0: F, + ) -> Vec { + let l0_at = |x: PrefixPoint| match x { + PrefixPoint::Finite(x) => tau0 * x + (F::one() - tau0) * (F::one() - x), + PrefixPoint::Infinity => tau0, + }; + let l1_0 = F::one() - tau1; + let l1_1 = tau1; + let mut out = Vec::new(); + for x in [F::zero(), F::one(), F::from_u64(2), F::from_u64(3)] { + let x_point = PrefixPoint::Finite(x); + let q_x0 = + eval_biquadratic_from_full_grid(full_grid, x_point, PrefixPoint::Finite(F::zero())); + let q_x1 = + eval_biquadratic_from_full_grid(full_grid, x_point, PrefixPoint::Finite(F::one())); + out.push(l0_at(x_point) * (l1_0 * q_x0 + l1_1 * q_x1)); + } + for y in [F::zero(), F::one(), F::from_u64(2), F::from_u64(3)] { + let y_point = PrefixPoint::Finite(y); + let q_r0_y = + eval_biquadratic_from_full_grid(full_grid, PrefixPoint::Finite(r0), y_point); + let l1_y = tau1 * y + (F::one() - tau1) * (F::one() - y); + out.push(l1_y * l0_at(PrefixPoint::Finite(r0)) * q_r0_y); + } + out + } + + fn stage2_relation_round_values_from_full_grid(full_grid: [F; 9], r0: F) -> Vec { + let mut out = Vec::new(); + for x in [F::zero(), F::one(), F::from_u64(2)] { + let q_x0 = eval_biquadratic_from_full_grid( + full_grid, + PrefixPoint::Finite(x), + PrefixPoint::Finite(F::zero()), + ); + let q_x1 = eval_biquadratic_from_full_grid( + full_grid, + PrefixPoint::Finite(x), + PrefixPoint::Finite(F::one()), + ); + out.push(q_x0 + q_x1); + } + for y in [F::zero(), F::one(), F::from_u64(2)] { + out.push(eval_biquadratic_from_full_grid( + full_grid, + PrefixPoint::Finite(r0), + PrefixPoint::Finite(y), + )); + } + out + } + + fn tensor_values( + xs: [PrefixPoint; NX], + ys: [PrefixPoint; NY], + mut eval: impl FnMut(PrefixPoint, PrefixPoint) -> E, + ) -> Vec { + let mut out = Vec::with_capacity(NX * NY); + for &x in &xs { + for &y in &ys { + out.push(eval(x, y)); + } + } + out + } + + fn stage1_norm_round_values(s_quad: [F; 4], tau0: F, tau1: F, r0: F, b: usize) -> Vec { + let l0 = |x: F| tau0 * x + (F::one() - tau0) * (F::one() - x); + let l1 = |y: F| tau1 * y + (F::one() - tau1) * (F::one() - y); + let q = |x: F, y: F| range_check_eval_from_s(bilinear_eval(s_quad, x, y), b); + + let mut out = Vec::new(); + for x in 0..=5u64 { + let x = F::from_u64(x); + out.push(l0(x) * (l1(F::zero()) * q(x, F::zero()) + l1(F::one()) * q(x, F::one()))); + } + for y in 0..=5u64 { + let y = F::from_u64(y); + out.push(l0(r0) * l1(y) * q(r0, y)); + } + out + } + + fn build_stage1_bivariate_skip_proof_from_compact_reference( + w_compact: &[i8], + tau0: &[F], + b: usize, + live_x_cols: usize, + num_u: usize, + num_l: usize, + ) -> Option> { + if !can_use_stage1_two_round_prefix(num_u, b) { + return None; + } + + let y_len = 1usize << num_l; + let eq_x_suffix = EqPolynomial::evals(&tau0[2..num_u]); + let eq_y = EqPolynomial::evals(&tau0[num_u..]); + let points = stage1_full_prefix_points::(); + let live_x_quads = live_x_cols.div_ceil(4); + let mut evals_except_boolean_core = Vec::with_capacity(STAGE1_PREFIX_EVAL_COUNT); + + for x_idx in 0..5 { + for y_idx in 0..5 { + if stage1_is_boolean_corner(x_idx, y_idx) { + continue; + } + let mut accum = F::zero(); + let x = points[x_idx]; + let y = points[y_idx]; + for y_row in 0..y_len { + let row = &w_compact[y_row * live_x_cols..(y_row + 1) * live_x_cols]; + let eq_y_weight = eq_y[y_row]; + for (x_quad, &eq_x_weight) in eq_x_suffix.iter().enumerate().take(live_x_quads) + { + let base = 4 * x_quad; + let s_quad = std::array::from_fn(|offset| { + let idx = base + offset; + if idx < live_x_cols { + let w = i64::from(row[idx]); + F::from_i64(w * (w + 1)) + } else { + F::zero() + } + }); + accum += + eq_x_weight * eq_y_weight * stage1_local_norm_raw_eval(s_quad, x, y, b); + } + } + evals_except_boolean_core.push(accum); + } + } + + Some(Stage1BivariateSkipProof { + evals_except_boolean_core, + }) + } + + fn build_stage2_bivariate_skip_proof_from_compact_reference( + w_compact: &[i8], + alpha_evals_y: &[F], + m_evals_x: &[F], + r_stage1: &[F], + live_x_cols: usize, + num_u: usize, + num_l: usize, + ) -> Option> { + if !can_use_stage2_two_round_prefix(num_u) { + return None; + } + + let y_len = 1usize << num_l; + assert_eq!(m_evals_x.len(), 1usize << num_u); + let eq_x_suffix = EqPolynomial::evals(&r_stage1[2..num_u]); + let eq_y = EqPolynomial::evals(&r_stage1[num_u..]); + let points = stage2_full_prefix_points::(); + let live_x_quads = live_x_cols.div_ceil(4); + let mut norm_full = [F::zero(); 9]; + let mut relation_full = [F::zero(); 9]; + + for y_idx in 0..y_len { + let row = &w_compact[y_idx * live_x_cols..(y_idx + 1) * live_x_cols]; + let alpha = alpha_evals_y[y_idx]; + let eq_y_weight = eq_y[y_idx]; + for (x_quad, &eq_x_weight) in eq_x_suffix.iter().enumerate().take(live_x_quads) { + let base = 4 * x_quad; + let w_quad = std::array::from_fn(|offset| { + let idx = base + offset; + if idx < live_x_cols { + F::from_i64(row[idx] as i64) + } else { + F::zero() + } + }); + let m_quad = std::array::from_fn(|offset| { + let idx = base + offset; + m_evals_x[idx] + }); + let norm_weight = eq_x_weight * eq_y_weight; + for idx in 0..9 { + let x = points[idx / 3]; + let y = points[idx % 3]; + norm_full[idx] += norm_weight * stage2_local_norm_raw_eval(w_quad, x, y); + relation_full[idx] += stage2_local_relation_eval(w_quad, m_quad, alpha, x, y); + } + } + } + + let norm_omitted_corner = default_stage2_norm_omitted_corner( + stage2_norm_corner_weights_from_taus(r_stage1[0], r_stage1[1]), + ); + Some(Stage2BivariateSkipProof { + norm: Stage2CompressedGrid::from_full_grid(norm_full, norm_omitted_corner), + relation: Stage2CompressedGrid::from_full_grid( + relation_full, + BooleanCorner::DEFAULT_STAGE2_RELATION, + ), + }) + } + + #[test] + fn stage1_b8_lookup_table_matches_raw_evals() { + let points = stage1_full_prefix_points::(); + for (d0, &s00) in STAGE1_B8_S_VALUES.iter().enumerate() { + for (d1, &s10) in STAGE1_B8_S_VALUES.iter().enumerate() { + for (d2, &s01) in STAGE1_B8_S_VALUES.iter().enumerate() { + for (d3, &s11) in STAGE1_B8_S_VALUES.iter().enumerate() { + let lookup = &STAGE1_B8_PREFIX_LOOKUP_TABLE + [stage1_b8_lookup_index_from_digits([d0, d1, d2, d3])]; + let quad = [ + F::from_i64(s00), + F::from_i64(s10), + F::from_i64(s01), + F::from_i64(s11), + ]; + let mut point_idx = 0usize; + for x_idx in 0..5 { + for y_idx in 0..5 { + if stage1_is_boolean_corner(x_idx, y_idx) { + continue; + } + assert_eq!( + F::from_i64(lookup[point_idx]), + stage1_local_norm_raw_eval( + quad, + points[x_idx], + points[y_idx], + 8, + ), + ); + point_idx += 1; + } + } + } + } + } + } + } + + #[test] + fn stage2_b8_norm_lookup_table_matches_raw_evals() { + let points = stage2_full_prefix_points::(); + for w00 in -4i64..=3 { + for w10 in -4i64..=3 { + for w01 in -4i64..=3 { + for w11 in -4i64..=3 { + let lookup = &STAGE2_B8_NORM_LOOKUP_TABLE + [stage2_b8_lookup_index_from_digits([ + (w00 + 4) as usize, + (w10 + 4) as usize, + (w01 + 4) as usize, + (w11 + 4) as usize, + ])]; + let quad = [ + F::from_i64(w00), + F::from_i64(w10), + F::from_i64(w01), + F::from_i64(w11), + ]; + for point_idx in 0..STAGE2_PREFIX_POINT_COUNT { + let x = points[point_idx / 3]; + let y = points[point_idx % 3]; + assert_eq!( + F::from_i64(lookup[point_idx]), + stage2_local_norm_raw_eval(quad, x, y), + ); + } + } + } + } + } + } + + #[test] + fn stage2_b8_relation_weight_table_matches_prefix_w_evals() { + let points = stage2_full_prefix_points::(); + for w00 in -4i64..=3 { + for w10 in -4i64..=3 { + for w01 in -4i64..=3 { + for w11 in -4i64..=3 { + let lookup = &STAGE2_B8_RELATION_WEIGHT_TABLE + [stage2_b8_lookup_index_from_digits([ + (w00 + 4) as usize, + (w10 + 4) as usize, + (w01 + 4) as usize, + (w11 + 4) as usize, + ])]; + let quad = [ + F::from_i64(w00), + F::from_i64(w10), + F::from_i64(w01), + F::from_i64(w11), + ]; + for point_idx in 0..STAGE2_PREFIX_POINT_COUNT { + let x = points[point_idx / 3]; + let y = points[point_idx % 3]; + assert_eq!( + F::from_i64(lookup[point_idx]), + bilinear_eval_on_prefix_points(quad, x, y), + ); + } + } + } + } + } + } + + #[test] + fn stage1_bivariate_skip_proof_builder_matches_reference() { + let w_compact = vec![1, -2, 0, 2, 1, -1, 2, 1, 0, 2]; + let tau0 = [ + F::from_u64(3), + F::from_u64(5), + F::from_u64(7), + F::from_u64(11), + ]; + assert_eq!( + build_stage1_bivariate_skip_proof_from_compact(&w_compact, &tau0, 8, 5, 3, 1), + build_stage1_bivariate_skip_proof_from_compact_reference(&w_compact, &tau0, 8, 5, 3, 1), + ); + } + + #[test] + fn stage2_bivariate_skip_proof_builder_matches_reference() { + let w_compact = vec![1, -2, 0, 2, 1, -1, 2, 1, 0, 2]; + let alpha_evals_y = [F::from_u64(3), F::from_u64(5)]; + let m_evals_x = [ + F::from_u64(7), + F::from_u64(11), + F::from_u64(13), + F::from_u64(17), + F::from_u64(19), + F::from_u64(23), + F::from_u64(29), + F::from_u64(31), + ]; + let r_stage1 = [ + F::from_u64(3), + F::from_u64(5), + F::from_u64(7), + F::from_u64(11), + ]; + assert_eq!( + build_stage2_bivariate_skip_proof_from_compact( + &w_compact, + &alpha_evals_y, + &m_evals_x, + &r_stage1, + 5, + 3, + 1, + ), + build_stage2_bivariate_skip_proof_from_compact_reference( + &w_compact, + &alpha_evals_y, + &m_evals_x, + &r_stage1, + 5, + 3, + 1, + ), + ); + } + + #[test] + fn stage1_candidate_omits_11_via_zero_check() { + let points = stage1_prefix_points::(); + let one = points[0]; + let valid_s = [0i64, 2, 6, 12]; + for &s00 in &valid_s { + for &s10 in &valid_s { + for &s01 in &valid_s { + for &s11 in &valid_s { + let quad = [ + F::from_i64(s00), + F::from_i64(s10), + F::from_i64(s01), + F::from_i64(s11), + ]; + assert_eq!( + stage1_local_norm_eval(quad, one, one, 8), + F::zero(), + "stage1 local zero-check should vanish at (1,1)" + ); + } + } + } + } + } + + #[test] + fn stage1_candidate_storage_family_has_rank_15() { + let [one, neg_one, two, inf] = stage1_prefix_points::(); + let storage_points = [ + (one, neg_one), + (one, two), + (one, inf), + (neg_one, one), + (neg_one, neg_one), + (neg_one, two), + (neg_one, inf), + (two, one), + (two, neg_one), + (two, two), + (two, inf), + (inf, one), + (inf, neg_one), + (inf, two), + (inf, inf), + ]; + let valid_s = [0i64, 2, 6, 12]; + let mut rows = Vec::new(); + for &s00 in &valid_s { + for &s10 in &valid_s { + for &s01 in &valid_s { + for &s11 in &valid_s { + let quad = [ + F::from_i64(s00), + F::from_i64(s10), + F::from_i64(s01), + F::from_i64(s11), + ]; + rows.push( + storage_points + .iter() + .map(|&(x, y)| stage1_local_norm_eval(quad, x, y, 8)) + .collect(), + ); + } + } + } + } + assert_eq!(gaussian_rank(rows), 15); + } + + #[test] + fn stage1_full_domain_omits_boolean_core_via_zero_check() { + let points = stage1_full_prefix_points::(); + let valid_s = [0i64, 2, 6, 12]; + for &s00 in &valid_s { + for &s10 in &valid_s { + for &s01 in &valid_s { + for &s11 in &valid_s { + let quad = [ + F::from_i64(s00), + F::from_i64(s10), + F::from_i64(s01), + F::from_i64(s11), + ]; + for &(x_idx, y_idx) in &[(0usize, 0usize), (0, 1), (1, 0), (1, 1)] { + assert_eq!( + stage1_local_norm_raw_eval(quad, points[x_idx], points[y_idx], 8), + F::zero(), + "stage1 local zero-check should vanish on the Boolean core", + ); + } + } + } + } + } + } + + #[test] + fn stage1_full_storage_family_has_rank_21() { + let points = stage1_full_prefix_points::(); + let mut storage_points = Vec::new(); + for x_idx in 0..5 { + for y_idx in 0..5 { + if stage1_is_boolean_corner(x_idx, y_idx) { + continue; + } + storage_points.push((points[x_idx], points[y_idx])); + } + } + + let valid_s = [0i64, 2, 6, 12]; + let mut rows = Vec::new(); + for &s00 in &valid_s { + for &s10 in &valid_s { + for &s01 in &valid_s { + for &s11 in &valid_s { + let quad = [ + F::from_i64(s00), + F::from_i64(s10), + F::from_i64(s01), + F::from_i64(s11), + ]; + rows.push( + storage_points + .iter() + .map(|&(x, y)| stage1_local_norm_raw_eval(quad, x, y, 8)) + .collect(), + ); + } + } + } + } + assert_eq!(gaussian_rank(rows), 21); + } + + #[test] + fn stage1_storage_domain_matches_local_round_messages() { + let tau0 = F::from_u64(7); + let tau1 = F::from_u64(11); + let r0 = F::from_u64(13); + let valid_s = [0i64, 2, 6, 12]; + + for &s00 in &valid_s { + for &s10 in &valid_s { + for &s01 in &valid_s { + for &s11 in &valid_s { + let quad = [ + F::from_i64(s00), + F::from_i64(s10), + F::from_i64(s01), + F::from_i64(s11), + ]; + let proof = Stage1BivariateSkipProof { + evals_except_boolean_core: stage1_storage_vector_from_quad(quad, 8), + }; + let skip_state = Stage1BivariateSkipState::new(&proof, &[tau0, tau1], 8) + .expect("stage1 bivariate-skip state should build"); + let round_values = stage1_norm_round_values(quad, tau0, tau1, r0, 8); + assert_eq!( + skip_state.reconstruct_round0_poly(), + UniPoly::from_evals(&round_values[..6]) + ); + assert_eq!( + skip_state.reconstruct_round1_poly(r0), + UniPoly::from_evals(&round_values[6..]) + ); + } + } + } + } + } + + #[test] + fn stage1_bivariate_skip_proof_reconstructs_first_two_rounds() { + let w_compact = vec![1, -2, 0, 2, 1, -1, 2, 1, 0, 2]; + let tau0 = [ + F::from_u64(3), + F::from_u64(5), + F::from_u64(7), + F::from_u64(11), + ]; + let b = 8; + let live_x_cols = 5; + let num_u = 3; + let num_l = 1; + + let proof = build_stage1_bivariate_skip_proof_from_compact( + &w_compact, + &tau0, + b, + live_x_cols, + num_u, + num_l, + ) + .expect("stage1 bivariate-skip proof should be available"); + let skip_state = Stage1BivariateSkipState::new(&proof, &tau0, b) + .expect("stage1 bivariate-skip state should build"); + + let mut prover = + HachiStage1Prover::::new(&w_compact, &tau0, b, live_x_cols, num_u, num_l); + let round0 = SumcheckInstanceProver::compute_round_univariate(&mut prover, 0, F::zero()); + assert_eq!(skip_state.reconstruct_round0_poly(), round0); + + let r0 = F::from_u64(9); + let claim_after_r0 = round0.evaluate(&r0); + SumcheckInstanceProver::ingest_challenge(&mut prover, 0, r0); + + let round1 = + SumcheckInstanceProver::compute_round_univariate(&mut prover, 1, claim_after_r0); + assert_eq!(skip_state.reconstruct_round1_poly(r0), round1); + } + + #[test] + fn stage2_default_norm_omitted_corner_prefers_00() { + let weights = stage2_norm_corner_weights_from_taus(F::from_u64(7), F::from_u64(11)); + assert_eq!( + default_stage2_norm_omitted_corner(weights), + BooleanCorner::DEFAULT_STAGE2_NORM + ); + } + + #[test] + fn stage2_default_norm_omitted_corner_falls_back_when_00_is_zero() { + let weights = stage2_norm_corner_weights_from_taus(F::one(), F::from_u64(11)); + assert_eq!( + default_stage2_norm_omitted_corner(weights), + BooleanCorner::OneZero + ); + + let weights = stage2_norm_corner_weights_from_taus(F::from_u64(7), F::one()); + assert_eq!( + default_stage2_norm_omitted_corner(weights), + BooleanCorner::ZeroOne + ); + + let weights = stage2_norm_corner_weights_from_taus(F::one(), F::one()); + assert_eq!( + default_stage2_norm_omitted_corner(weights), + BooleanCorner::OneOne + ); + } + + #[test] + fn stage2_norm_reduced_domain_has_round_message_collision() { + let reduced = stage2_reduced_prefix_points::(); + let tau0 = F::from_u64(7); + let tau1 = F::from_u64(11); + let r0 = F::from_u64(13); + + let mut seen: HashMap> = HashMap::new(); + let mut found_collision = false; + for w00 in -4i64..=3 { + for w10 in -4i64..=3 { + for w01 in -4i64..=3 { + for w11 in -4i64..=3 { + let quad = [ + F::from_i64(w00), + F::from_i64(w10), + F::from_i64(w01), + F::from_i64(w11), + ]; + let storage = tensor_values(reduced, reduced, |x, y| { + stage2_local_norm_candidate_eval(quad, x, y) + }); + let target = stage2_norm_round_values(quad, tau0, tau1, r0); + let key = vec_key(&storage); + if let Some(existing) = seen.get(&key) { + if *existing != target { + found_collision = true; + break; + } + } else { + seen.insert(key, target); + } + } + if found_collision { + break; + } + } + if found_collision { + break; + } + } + if found_collision { + break; + } + } + assert!( + found_collision, + "reduced stage-2 norm domain should not uniquely determine local round messages" + ); + } + + #[test] + fn stage2_relation_reduced_domain_has_round_message_collision() { + let reduced = stage2_reduced_prefix_points::(); + let r0 = F::from_u64(13); + let alpha = F::one(); + let bit = [F::zero(), F::one()]; + + let mut seen: HashMap> = HashMap::new(); + let mut found_collision = false; + for &w00 in &bit { + for &w10 in &bit { + for &w01 in &bit { + for &w11 in &bit { + let w_quad = [w00, w10, w01, w11]; + for &m00 in &bit { + for &m10 in &bit { + for &m01 in &bit { + for &m11 in &bit { + let m_quad = [m00, m10, m01, m11]; + let storage = tensor_values(reduced, reduced, |x, y| { + stage2_local_relation_eval(w_quad, m_quad, alpha, x, y) + }); + let target = + stage2_relation_round_values(w_quad, m_quad, r0); + let key = vec_key(&storage); + if let Some(existing) = seen.get(&key) { + if *existing != target { + found_collision = true; + break; + } + } else { + seen.insert(key, target); + } + } + if found_collision { + break; + } + } + if found_collision { + break; + } + } + if found_collision { + break; + } + } + if found_collision { + break; + } + } + if found_collision { + break; + } + } + if found_collision { + break; + } + } + if found_collision { + break; + } + } + assert!( + found_collision, + "reduced stage-2 relation domain should not uniquely determine local round messages" + ); + } + + #[test] + fn stage2_norm_full_domain_matches_local_round_messages() { + let full = stage2_full_prefix_points::(); + let tau0 = F::from_u64(7); + let tau1 = F::from_u64(11); + let r0 = F::from_u64(13); + + let mut seen: HashMap> = HashMap::new(); + for w00 in -4i64..=3 { + for w10 in -4i64..=3 { + for w01 in -4i64..=3 { + for w11 in -4i64..=3 { + let quad = [ + F::from_i64(w00), + F::from_i64(w10), + F::from_i64(w01), + F::from_i64(w11), + ]; + let storage = tensor_values(full, full, |x, y| { + stage2_local_norm_raw_eval(quad, x, y) + }); + let target = stage2_norm_round_values(quad, tau0, tau1, r0); + let key = vec_key(&storage); + if let Some(existing) = seen.get(&key) { + assert_eq!( + existing, &target, + "full stage-2 norm domain lost information for a compact quad" + ); + } else { + seen.insert(key, target); + } + } + } + } + } + } + + #[test] + fn stage2_norm_8_point_reconstruction_matches_full_grid_and_round_messages() { + let tau_choices = [F::zero(), F::one(), F::from_u64(2), F::from_u64(7)]; + let r0 = F::from_u64(13); + + for &tau0 in &tau_choices { + for &tau1 in &tau_choices { + let corner_weights = stage2_norm_corner_weights_from_taus(tau0, tau1); + for w00 in -4i64..=3 { + for w10 in -4i64..=3 { + for w01 in -4i64..=3 { + for w11 in -4i64..=3 { + let quad = [ + F::from_i64(w00), + F::from_i64(w10), + F::from_i64(w01), + F::from_i64(w11), + ]; + let full_grid = stage2_full_grid_values(|x, y| { + stage2_local_norm_raw_eval(quad, x, y) + }); + let norm_claim = + stage2_norm_claim_from_full_grid(full_grid, corner_weights); + let omitted_corner = + default_stage2_norm_omitted_corner(corner_weights); + let compressed = + Stage2CompressedGrid::from_full_grid(full_grid, omitted_corner); + let recovered = recover_stage2_norm_grid_from_claim( + &compressed, + corner_weights, + norm_claim, + ) + .expect("selected norm corner should be recoverable"); + + assert_eq!( + recovered, full_grid, + "norm full-grid reconstruction mismatch for quad={quad:?}, tau0={tau0:?}, tau1={tau1:?}" + ); + assert_eq!( + stage2_norm_round_values_from_full_grid(recovered, tau0, tau1, r0), + stage2_norm_round_values(quad, tau0, tau1, r0), + "norm round reconstruction mismatch for quad={quad:?}, tau0={tau0:?}, tau1={tau1:?}" + ); + } + } + } + } + } + } + } + + #[test] + fn stage2_relation_full_domain_matches_local_round_messages() { + let full = stage2_full_prefix_points::(); + let r0 = F::from_u64(13); + let alpha = F::one(); + let bit = [F::zero(), F::one()]; + + let mut seen: HashMap> = HashMap::new(); + for &w00 in &bit { + for &w10 in &bit { + for &w01 in &bit { + for &w11 in &bit { + let w_quad = [w00, w10, w01, w11]; + for &m00 in &bit { + for &m10 in &bit { + for &m01 in &bit { + for &m11 in &bit { + let m_quad = [m00, m10, m01, m11]; + let storage = tensor_values(full, full, |x, y| { + stage2_local_relation_eval(w_quad, m_quad, alpha, x, y) + }); + let target = + stage2_relation_round_values(w_quad, m_quad, r0); + let key = vec_key(&storage); + if let Some(existing) = seen.get(&key) { + assert_eq!( + existing, &target, + "full stage-2 relation domain lost information" + ); + } else { + seen.insert(key, target); + } + } + } + } + } + } + } + } + } + } + + #[test] + fn stage2_relation_8_point_reconstruction_matches_full_grid_and_round_messages() { + let r0 = F::from_u64(13); + let alpha_choices = [F::zero(), F::one(), F::from_u64(3)]; + let bit = [F::zero(), F::one()]; + + for &alpha in &alpha_choices { + for &w00 in &bit { + for &w10 in &bit { + for &w01 in &bit { + for &w11 in &bit { + let w_quad = [w00, w10, w01, w11]; + for &m00 in &bit { + for &m10 in &bit { + for &m01 in &bit { + for &m11 in &bit { + let m_quad = [m00, m10, m01, m11]; + let full_grid = stage2_full_grid_values(|x, y| { + stage2_local_relation_eval( + w_quad, m_quad, alpha, x, y, + ) + }); + let relation_claim = + stage2_relation_claim_from_full_grid(full_grid); + let compressed = Stage2CompressedGrid::from_full_grid( + full_grid, + BooleanCorner::DEFAULT_STAGE2_RELATION, + ); + let recovered = recover_stage2_relation_grid_from_claim( + &compressed, + relation_claim, + ); + + assert_eq!( + recovered, full_grid, + "relation full-grid reconstruction mismatch" + ); + assert_eq!( + stage2_relation_round_values_from_full_grid( + recovered, r0 + ), + stage2_relation_round_values(w_quad, m_quad, r0) + .into_iter() + .map(|value| alpha * value) + .collect::>(), + "relation round reconstruction mismatch" + ); + } + } + } + } + } + } + } + } + } + } +} diff --git a/src/protocol/sumcheck/types.rs b/src/protocol/sumcheck/types.rs new file mode 100644 index 00000000..7b4955a2 --- /dev/null +++ b/src/protocol/sumcheck/types.rs @@ -0,0 +1,378 @@ +//! Sumcheck data types: univariate polynomials, compressed representation, and proof container. + +use crate::error::HachiError; +use crate::primitives::serialization::{ + Compress, HachiDeserialize, HachiSerialize, SerializationError, Valid, Validate, +}; +use crate::protocol::transcript::labels; +use crate::protocol::transcript::Transcript; +use crate::FieldCore; +use crate::FromSmallInt; +use std::io::{Read, Write}; + +/// Univariate polynomial in coefficient form: `p(X) = Σ_{i=0}^d coeffs[i] * X^i`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UniPoly { + /// Coefficients from low degree to high degree. + pub coeffs: Vec, +} + +impl UniPoly { + /// Construct from coefficients in increasing-degree order. + pub fn from_coeffs(coeffs: Vec) -> Self { + Self { coeffs } + } + + /// Degree of the polynomial (0 for empty or constant). + pub fn degree(&self) -> usize { + self.coeffs.len().saturating_sub(1) + } + + /// Evaluate at `x` via Horner's method. + pub fn evaluate(&self, x: &E) -> E { + let mut acc = E::zero(); + for c in self.coeffs.iter().rev() { + acc = acc * *x + *c; + } + acc + } + + /// Compress this polynomial by omitting the linear coefficient. + /// + /// The verifier can reconstruct/evaluate the missing linear coefficient using + /// the per-round hint `g(0)+g(1)` from the sumcheck protocol. + /// + /// This matches the technique used by Jolt's sumcheck (`CompressedUniPoly`). + pub fn compress(&self) -> CompressedUniPoly { + let coeffs = &self.coeffs; + if coeffs.is_empty() { + return CompressedUniPoly { + coeffs_except_linear_term: Vec::new(), + }; + } + if coeffs.len() == 1 { + return CompressedUniPoly { + coeffs_except_linear_term: vec![coeffs[0]], + }; + } + let mut out = Vec::with_capacity(coeffs.len().saturating_sub(1)); + out.push(coeffs[0]); + out.extend_from_slice(&coeffs[2..]); + CompressedUniPoly { + coeffs_except_linear_term: out, + } + } +} + +impl UniPoly { + /// Interpolate from evaluations at equispaced integer points `x = 0, 1, ..., d`. + /// + /// Uses Newton forward-difference interpolation: compute divided differences, + /// then expand via Horner on the nested Newton form. + /// + /// # Panics + /// + /// Panics if any required factorial inverse does not exist (field characteristic + /// must exceed the number of evaluation points). This is a prover-only + /// function and the condition always holds for Hachi's fields. + pub fn from_evals(evals: &[E]) -> Self { + let n = evals.len(); + if n == 0 { + return Self::from_coeffs(vec![]); + } + if n == 1 { + return Self::from_coeffs(vec![evals[0]]); + } + + let mut table = evals.to_vec(); + let mut deltas = vec![table[0]]; + for _ in 1..n { + for j in 0..table.len() - 1 { + table[j] = table[j + 1] - table[j]; + } + table.pop(); + deltas.push(table[0]); + } + + let mut factorial = E::one(); + let mut divided_diffs = vec![deltas[0]]; + for (k, delta_k) in deltas.iter().enumerate().skip(1) { + factorial = factorial * E::from_u64(k as u64); + divided_diffs.push( + *delta_k + * factorial + .inv() + .expect("field characteristic too small for interpolation"), + ); + } + + let mut coeffs = vec![divided_diffs[n - 1]]; + + for k in (0..n - 1).rev() { + let shift = E::from_u64(k as u64); + let old_len = coeffs.len(); + let mut new_coeffs = vec![E::zero(); old_len + 1]; + + new_coeffs[0] = divided_diffs[k]; + for i in 0..old_len { + new_coeffs[i + 1] += coeffs[i]; + new_coeffs[i] -= shift * coeffs[i]; + } + + coeffs = new_coeffs; + } + + while coeffs.len() > 1 && coeffs.last().is_some_and(|c| c.is_zero()) { + coeffs.pop(); + } + + Self::from_coeffs(coeffs) + } +} + +impl Valid for UniPoly { + fn check(&self) -> Result<(), SerializationError> { + self.coeffs.check() + } +} + +impl HachiSerialize for UniPoly { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.coeffs.serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.coeffs.serialized_size(compress) + } +} + +impl HachiDeserialize for UniPoly { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let coeffs = Vec::::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { coeffs }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Compressed univariate polynomial representation omitting the linear term. +/// +/// We store `[c0, c2, c3, ..., cd]`. Given the sumcheck hint `hint = g(0)+g(1)`, +/// the missing linear coefficient is: +/// +/// `c1 = hint - 2*c0 - Σ_{i=2..d} ci`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CompressedUniPoly { + /// Coefficients excluding the linear term: `[c0, c2, c3, ..., cd]`. + pub coeffs_except_linear_term: Vec, +} + +impl CompressedUniPoly { + /// Degree of the underlying uncompressed polynomial. + /// + /// `compress()` stores `[c0, c2, ..., cd]` — exactly `d` entries for + /// degree `d >= 2`. For `len <= 1` (degree 0 or 1, which are ambiguous + /// in compressed form) we report 0; this is conservative for the + /// verifier's degree-bound check since `degree_bound >= 2` in practice. + pub fn degree(&self) -> usize { + let len = self.coeffs_except_linear_term.len(); + if len <= 1 { + 0 + } else { + len + } + } + + fn recover_linear_term(&self, hint: &E) -> E { + if self.coeffs_except_linear_term.is_empty() { + return E::zero(); + } + + let c0 = self.coeffs_except_linear_term[0]; + let mut linear = *hint - c0 - c0; + for c in self.coeffs_except_linear_term.iter().skip(1) { + linear -= *c; + } + linear + } + + /// Decompress using `hint = g(0)+g(1)`. + pub fn decompress(&self, hint: &E) -> UniPoly { + if self.coeffs_except_linear_term.is_empty() { + return UniPoly::from_coeffs(Vec::new()); + } + let linear = self.recover_linear_term(hint); + let mut coeffs = Vec::with_capacity(self.coeffs_except_linear_term.len() + 1); + coeffs.push(self.coeffs_except_linear_term[0]); + coeffs.push(linear); + coeffs.extend_from_slice(&self.coeffs_except_linear_term[1..]); + UniPoly::from_coeffs(coeffs) + } + + /// Evaluate the uncompressed polynomial at `x`, using `hint = g(0)+g(1)`. + /// + /// This avoids materializing the full coefficient list. + pub fn eval_from_hint(&self, hint: &E, x: &E) -> E { + if self.coeffs_except_linear_term.is_empty() { + return E::zero(); + } + + let linear = self.recover_linear_term(hint); + let mut acc = self.coeffs_except_linear_term[0] + (*x * linear); + + let mut pow = *x * *x; + for c in self.coeffs_except_linear_term.iter().skip(1) { + acc += *c * pow; + pow = pow * *x; + } + acc + } +} + +impl Valid for CompressedUniPoly { + fn check(&self) -> Result<(), SerializationError> { + self.coeffs_except_linear_term.check() + } +} + +impl HachiSerialize for CompressedUniPoly { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.coeffs_except_linear_term + .serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.coeffs_except_linear_term.serialized_size(compress) + } +} + +impl HachiDeserialize for CompressedUniPoly { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let coeffs_except_linear_term = + Vec::::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { + coeffs_except_linear_term, + }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +/// Sumcheck proof containing one compressed univariate polynomial per round. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SumcheckProof { + /// One compressed univariate polynomial per sumcheck round. + pub round_polys: Vec>, +} + +impl Valid for SumcheckProof { + fn check(&self) -> Result<(), SerializationError> { + self.round_polys.check() + } +} + +impl HachiSerialize for SumcheckProof { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.round_polys.serialize_with_mode(&mut writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.round_polys.serialized_size(compress) + } +} + +impl HachiDeserialize for SumcheckProof { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let round_polys = + Vec::>::deserialize_with_mode(&mut reader, compress, validate)?; + let out = Self { round_polys }; + if matches!(validate, Validate::Yes) { + out.check()?; + } + Ok(out) + } +} + +impl SumcheckProof { + /// Verifier-side sumcheck transcript driver. + /// + /// This method: + /// - absorbs the per-round prover message (compressed univariate), + /// - samples one challenge per round via `sample_challenge`, + /// - updates the running claim using `eval_from_hint`. + /// + /// It does **not** perform the final oracle check `final_claim == f(r*)`. + /// Callers (e.g. ring-switching) must compute `f(r*)` themselves and compare. + /// + /// # Errors + /// + /// Returns an error if the proof length does not match `num_rounds` or if any + /// per-round polynomial exceeds `degree_bound`. + pub fn verify( + &self, + mut claim: E, + num_rounds: usize, + degree_bound: usize, + transcript: &mut T, + mut sample_challenge: S, + ) -> Result<(E, Vec), HachiError> + where + F: crate::FieldCore + crate::CanonicalField, + T: Transcript, + S: FnMut(&mut T) -> E, + { + if self.round_polys.len() != num_rounds { + return Err(HachiError::InvalidSize { + expected: num_rounds, + actual: self.round_polys.len(), + }); + } + + let mut r = Vec::with_capacity(num_rounds); + for poly in &self.round_polys { + if poly.degree() > degree_bound { + return Err(HachiError::InvalidInput(format!( + "sumcheck round poly degree {} exceeds bound {}", + poly.degree(), + degree_bound + ))); + } + + transcript.append_serde(labels::ABSORB_SUMCHECK_ROUND, poly); + let r_i = sample_challenge(transcript); + r.push(r_i); + + claim = poly.eval_from_hint(&claim, &r_i); + } + + Ok((claim, r)) + } +} diff --git a/src/protocol/transcript/hash.rs b/src/protocol/transcript/hash.rs new file mode 100644 index 00000000..40312d0b --- /dev/null +++ b/src/protocol/transcript/hash.rs @@ -0,0 +1,113 @@ +//! Generic hash-based transcript for protocol-layer Fiat-Shamir. +//! +//! Parameterised over any `Digest + Clone` hasher, eliminating the +//! near-identical Blake2b and Keccak implementations. + +use super::Transcript; +use crate::primitives::serialization::HachiSerialize; +use crate::{CanonicalField, FieldCore}; +use blake2::{Blake2b512, Digest}; +use sha3::Keccak256; +use std::marker::PhantomData; + +/// Hash-based transcript with labeled framing. +/// +/// Works with any cryptographic hash that implements `Digest + Clone`. +#[derive(Clone)] +pub struct HashTranscript +where + F: FieldCore + CanonicalField + 'static, +{ + hasher: D, + _field: PhantomData, +} + +impl HashTranscript +where + F: FieldCore + CanonicalField + 'static, +{ + #[inline] + fn append_bytes_impl(&mut self, label: &[u8], bytes: &[u8]) { + self.hasher.update(label); + self.hasher.update((bytes.len() as u64).to_le_bytes()); + self.hasher.update(bytes); + } + + #[inline] + fn challenge_and_chain(&mut self, label: &[u8]) -> Vec { + self.hasher.update(label); + let digest = self.hasher.clone().finalize(); + let out = digest.to_vec(); + self.hasher.update(&out); + out + } +} + +impl Transcript for HashTranscript +where + F: FieldCore + CanonicalField + 'static, +{ + fn new(domain_label: &[u8]) -> Self { + let mut hasher = D::new(); + hasher.update(domain_label); + Self { + hasher, + _field: PhantomData, + } + } + + fn append_bytes(&mut self, label: &[u8], bytes: &[u8]) { + self.append_bytes_impl(label, bytes); + } + + fn append_field(&mut self, label: &[u8], x: &F) { + self.append_bytes_impl(label, &x.to_canonical_u128().to_le_bytes()); + } + + fn append_serde(&mut self, label: &[u8], s: &S) { + let mut bytes = Vec::new(); + s.serialize_compressed(&mut bytes) + .expect("HachiSerialize should not fail"); + self.append_bytes_impl(label, &bytes); + } + + fn challenge_scalar(&mut self, label: &[u8]) -> F { + let bytes = self.challenge_and_chain(label); + let mut lo = [0u8; 16]; + lo.copy_from_slice(&bytes[..16]); + let sampled = u128::from_le_bytes(lo); + F::from_canonical_u128_reduced(sampled) + } + + fn challenge_bytes(&mut self, label: &[u8], len: usize) -> Vec { + let mut out = Vec::with_capacity(len); + while out.len() < len { + let chunk = self.challenge_and_chain(label); + let take = (len - out.len()).min(chunk.len()); + out.extend_from_slice(&chunk[..take]); + } + out + } +} + +impl HashTranscript +where + F: FieldCore + CanonicalField + 'static, +{ + /// Reset transcript state under a new domain label. + /// + /// This is an inherent method (not part of the `Transcript` trait) to + /// discourage use in production protocol code where resetting the + /// Fiat-Shamir chain would be unsound. + pub fn reset(&mut self, domain_label: &[u8]) { + let mut hasher = D::new(); + hasher.update(domain_label); + self.hasher = hasher; + } +} + +/// Blake2b512 transcript with labeled framing. +pub type Blake2bTranscript = HashTranscript; + +/// Keccak256 transcript with labeled framing. +pub type KeccakTranscript = HashTranscript; diff --git a/src/protocol/transcript/labels.rs b/src/protocol/transcript/labels.rs new file mode 100644 index 00000000..67ac04ef --- /dev/null +++ b/src/protocol/transcript/labels.rs @@ -0,0 +1,132 @@ +//! Hachi-native transcript labels. +//! +//! These constants are the single source of truth for protocol transcript +//! labels in Hachi core. External integrations should translate at adapter +//! boundaries instead of introducing foreign label names here. + +/// Top-level protocol domain label. +pub const DOMAIN_HACHI_PROTOCOL: &[u8] = b"hachi/protocol"; + +/// Absorb commitment object(s) (paper §4.1). +pub const ABSORB_COMMITMENT: &[u8] = b"hachi/absorb/commitment"; +/// Absorb claimed openings/evaluations before relation reduction (paper §4.2). +pub const ABSORB_EVALUATION_CLAIMS: &[u8] = b"hachi/absorb/evaluation-claims"; +/// Challenge for the evaluation-to-linear-relation reduction (paper §4.2). +pub const CHALLENGE_LINEAR_RELATION: &[u8] = b"hachi/challenge/linear-relation"; +/// Absorb ring-switch relation messages (paper §4.3). +pub const ABSORB_RING_SWITCH_MESSAGE: &[u8] = b"hachi/absorb/ring-switch-message"; +/// Challenge used by ring-switching relation checks (paper §4.3). +pub const CHALLENGE_RING_SWITCH: &[u8] = b"hachi/challenge/ring-switch"; +/// Absorb sparse-challenge sampling context (e.g. for short/sparse ring `c`). +pub const ABSORB_SPARSE_CHALLENGE: &[u8] = b"hachi/absorb/sparse-challenge"; +/// Challenge bytes used to sample sparse challenges (e.g. ring `c` with weight ω). +pub const CHALLENGE_SPARSE_CHALLENGE: &[u8] = b"hachi/challenge/sparse-challenge"; +/// Absorb the initial sumcheck claim before round messages begin. +pub const ABSORB_SUMCHECK_CLAIM: &[u8] = b"hachi/absorb/sumcheck-claim"; +/// Absorb per-round sumcheck messages (paper §4.3). +pub const ABSORB_SUMCHECK_ROUND: &[u8] = b"hachi/absorb/sumcheck-round"; +/// Challenge sampled per sumcheck round (paper §4.3). +pub const CHALLENGE_SUMCHECK_ROUND: &[u8] = b"hachi/challenge/sumcheck-round"; +/// Absorb the stage-1 final `s_claim` before the batching challenge. +pub const ABSORB_SUMCHECK_S_CLAIM: &[u8] = b"hachi/absorb/sumcheck-s-claim"; +/// Challenge for batched sumcheck coefficient sampling. +pub const CHALLENGE_SUMCHECK_BATCH: &[u8] = b"hachi/challenge/sumcheck-batch"; +/// Absorb recursion/stop-condition message payloads (paper §4.5). +pub const ABSORB_STOP_CONDITION: &[u8] = b"hachi/absorb/stop-condition"; +/// Challenge sampled for recursion stop-condition checks (paper §4.5). +pub const CHALLENGE_STOP_CONDITION: &[u8] = b"hachi/challenge/stop-condition"; + +/// Absorb the prover's stage-1 message `v = D · ŵ` (paper §4.2, Figure 3). +pub const ABSORB_PROVER_V: &[u8] = b"hachi/absorb/prover-stage1-v"; +/// Challenge label for stage-1 fold (sampling sparse `c_i`). +pub const CHALLENGE_STAGE1_FOLD: &[u8] = b"hachi/challenge/stage1-fold"; + +/// Absorb the `w` coefficient vector before sumcheck (paper §4.3). +pub const ABSORB_SUMCHECK_W: &[u8] = b"hachi/absorb/sumcheck-w"; +/// Challenge for sampling `τ₀` (F_0 range-check batching point, paper §4.3). +pub const CHALLENGE_TAU0: &[u8] = b"hachi/challenge/tau0"; +/// Challenge for sampling `τ₁` (F_α evaluation-relation batching point, paper §4.3). +pub const CHALLENGE_TAU1: &[u8] = b"hachi/challenge/tau1"; + +/// Labrador recursion domain label (used for recursive reduction stages). +pub const DOMAIN_LABRADOR_RECURSION: &[u8] = b"hachi/labrador/recursion"; +/// Greyhound evaluation-reduction domain label. +pub const DOMAIN_GREYHOUND_EVAL: &[u8] = b"hachi/greyhound/eval"; +/// Absorb canonical Greyhound evaluation context bytes (dimensions). +pub const ABSORB_GREYHOUND_EVAL_CONTEXT: &[u8] = b"hachi/absorb/greyhound-eval-context"; +/// Absorb canonicalized evaluation-point coordinates for Greyhound reduction. +pub const ABSORB_GREYHOUND_EVAL_POINT: &[u8] = b"hachi/absorb/greyhound-eval-point"; +/// Absorb the claimed evaluation value for Greyhound reduction. +pub const ABSORB_GREYHOUND_EVAL_VALUE: &[u8] = b"hachi/absorb/greyhound-eval-value"; +/// Absorb the Greyhound second outer commitment `u2`. +pub const ABSORB_GREYHOUND_U2: &[u8] = b"hachi/absorb/greyhound-u2"; +/// Challenge for Greyhound column-fold coefficients. +pub const CHALLENGE_GREYHOUND_FOLD: &[u8] = b"hachi/challenge/greyhound-fold"; +/// Absorb canonical Labrador recursion metadata (shape/config/tail). +pub const ABSORB_LABRADOR_RECURSION_CONTEXT: &[u8] = b"hachi/absorb/labrador-recursion-context"; +/// Absorb Labrador JL projection vector `p`. +pub const ABSORB_LABRADOR_JL_PROJECTION: &[u8] = b"hachi/absorb/labrador-jl-projection"; +/// Absorb Labrador JL nonce. +pub const ABSORB_LABRADOR_JL_NONCE: &[u8] = b"hachi/absorb/labrador-jl-nonce"; +/// Challenge for Labrador aggregation/lift stage. +pub const CHALLENGE_LABRADOR_AGGREGATION: &[u8] = b"hachi/challenge/labrador-aggregation"; +/// Challenge for Labrador JL collapse coefficients. +pub const CHALLENGE_LABRADOR_JL_COLLAPSE: &[u8] = b"hachi/challenge/labrador-jl-collapse"; +/// Absorb the Labrador opening-side payload at each recursion level. +pub const ABSORB_LABRADOR_INNER_OPENING_PAYLOAD: &[u8] = + b"hachi/absorb/labrador-inner-opening-payload"; +/// Absorb the Labrador linear-garbage-side payload at each recursion level. +pub const ABSORB_LABRADOR_LINEAR_GARBAGE_PAYLOAD: &[u8] = + b"hachi/absorb/labrador-linear-garbage-payload"; +/// Absorb Labrador JL lift residuals (constant term removed). +pub const ABSORB_LABRADOR_JL_LIFT_RESIDUALS: &[u8] = b"hachi/absorb/labrador-jl-lift-residuals"; +/// Absorb the Labrador next-witness squared norm bound at each level. +pub const ABSORB_LABRADOR_NEXT_WITNESS_NORM: &[u8] = b"hachi/absorb/labrador-next-witness-norm"; +/// Challenge for Labrador amortization fold (ring-element challenges). +pub const CHALLENGE_LABRADOR_AMORTIZE: &[u8] = b"hachi/challenge/labrador-amortize"; +/// Challenge for deriving the JL projection seed from the transcript. +pub const CHALLENGE_LABRADOR_JL_SEED: &[u8] = b"hachi/challenge/labrador-jl-seed"; + +/// Return all Hachi-core transcript labels. +pub fn all_labels() -> &'static [&'static [u8]] { + &[ + DOMAIN_HACHI_PROTOCOL, + ABSORB_COMMITMENT, + ABSORB_EVALUATION_CLAIMS, + CHALLENGE_LINEAR_RELATION, + ABSORB_RING_SWITCH_MESSAGE, + CHALLENGE_RING_SWITCH, + ABSORB_SPARSE_CHALLENGE, + CHALLENGE_SPARSE_CHALLENGE, + ABSORB_SUMCHECK_CLAIM, + ABSORB_SUMCHECK_ROUND, + CHALLENGE_SUMCHECK_ROUND, + ABSORB_SUMCHECK_S_CLAIM, + CHALLENGE_SUMCHECK_BATCH, + ABSORB_STOP_CONDITION, + CHALLENGE_STOP_CONDITION, + ABSORB_PROVER_V, + CHALLENGE_STAGE1_FOLD, + ABSORB_SUMCHECK_W, + CHALLENGE_TAU0, + CHALLENGE_TAU1, + DOMAIN_LABRADOR_RECURSION, + DOMAIN_GREYHOUND_EVAL, + ABSORB_GREYHOUND_EVAL_CONTEXT, + ABSORB_GREYHOUND_EVAL_POINT, + ABSORB_GREYHOUND_EVAL_VALUE, + ABSORB_GREYHOUND_U2, + CHALLENGE_GREYHOUND_FOLD, + ABSORB_LABRADOR_RECURSION_CONTEXT, + ABSORB_LABRADOR_JL_PROJECTION, + ABSORB_LABRADOR_JL_NONCE, + CHALLENGE_LABRADOR_AGGREGATION, + CHALLENGE_LABRADOR_JL_COLLAPSE, + ABSORB_LABRADOR_INNER_OPENING_PAYLOAD, + ABSORB_LABRADOR_LINEAR_GARBAGE_PAYLOAD, + ABSORB_LABRADOR_JL_LIFT_RESIDUALS, + ABSORB_LABRADOR_NEXT_WITNESS_NORM, + CHALLENGE_LABRADOR_AMORTIZE, + CHALLENGE_LABRADOR_JL_SEED, + ] +} diff --git a/src/protocol/transcript/mod.rs b/src/protocol/transcript/mod.rs new file mode 100644 index 00000000..4105af5c --- /dev/null +++ b/src/protocol/transcript/mod.rs @@ -0,0 +1,139 @@ +//! Protocol transcript contracts and implementations. + +mod hash; +pub mod labels; + +use crate::algebra::fields::lift::ExtField; +use crate::algebra::ring::CyclotomicRing; +use crate::algebra::SparseChallenge; +use crate::error::HachiError; +use crate::protocol::labrador::challenge::{ + sample_labrador_challenges, sample_labrador_sparse_challenges, +}; +use crate::{CanonicalField, FieldCore, FromSmallInt, HachiSerialize}; + +pub use hash::{Blake2bTranscript, HashTranscript, KeccakTranscript}; + +/// Transcript interface for protocol Fiat-Shamir transforms. +/// +/// The protocol layer is label-aware and uses deterministic byte encoding for +/// all absorbed values. +pub trait Transcript: Clone + Send + Sync + 'static +where + F: FieldCore + CanonicalField, +{ + /// Construct a new transcript under a domain label. + fn new(domain_label: &[u8]) -> Self; + + /// Append labeled raw bytes. + fn append_bytes(&mut self, label: &[u8], bytes: &[u8]); + + /// Append a field element with deterministic encoding. + fn append_field(&mut self, label: &[u8], x: &F); + + /// Append a serializable protocol value. + fn append_serde(&mut self, label: &[u8], s: &S); + + /// Derive a challenge scalar under the provided label. + fn challenge_scalar(&mut self, label: &[u8]) -> F; + + /// Squeeze `len` challenge bytes under the provided label. + fn challenge_bytes(&mut self, label: &[u8], len: usize) -> Vec; +} + +/// Sample an extension field challenge by drawing `EXT_DEGREE` base-field +/// challenges and assembling them via `from_base_slice`. +/// +/// When `E = F` (degree 1), this compiles to a single `challenge_scalar` call. +pub fn sample_ext_challenge(tr: &mut T, label: &[u8]) -> E +where + F: FieldCore + CanonicalField, + T: Transcript, + E: ExtField, +{ + E::from_base_slice( + &(0..E::EXT_DEGREE) + .map(|_| tr.challenge_scalar(label)) + .collect::>(), + ) +} + +/// Fixed nonce for single-polynomial rejection sampling. +const REJECTION_SAMPLER_SINGLE_NONCE: u64 = 0; + +/// Sample a dense ring-element challenge by drawing `D` scalar challenges. +pub fn challenge_ring_element( + tr: &mut T, + label: &[u8], +) -> CyclotomicRing +where + F: FieldCore + CanonicalField, + T: Transcript, +{ + CyclotomicRing::from_coefficients(std::array::from_fn(|_| tr.challenge_scalar(label))) +} + +/// Sample a sparse ring-element challenge with operator-norm rejection sampling. +/// +/// Squeezes a 16-byte seed from the transcript, then delegates to the Labrador +/// rejection sampler which produces a polynomial with exactly `TAU1` coefficients +/// in {+/-1} and `TAU2` in {+/-2}, retrying until the operator norm is bounded. +/// +/// # Errors +/// +/// Returns an error if `D` is incompatible with the rejection sampler. +pub fn challenge_ring_element_rejection_sampled( + tr: &mut T, + label: &[u8], +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + let mut polys = challenge_ring_elements_rejection_sampled::(tr, label, 1)?; + polys + .pop() + .ok_or_else(|| HachiError::InvalidInput("rejection sampler produced no output".into())) +} + +/// Sample multiple sparse ring-element challenges from one transcript-bound seed. +/// +/// # Errors +/// +/// Returns an error if `D` is incompatible with the rejection sampler. +pub fn challenge_ring_elements_rejection_sampled( + tr: &mut T, + label: &[u8], + len: usize, +) -> Result>, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + let seed_vec = tr.challenge_bytes(label, 16); + let seed: [u8; 16] = seed_vec + .try_into() + .map_err(|_| HachiError::InvalidInput("rejection sampler seed length mismatch".into()))?; + sample_labrador_challenges::(len, &seed, REJECTION_SAMPLER_SINGLE_NONCE) +} + +/// Sample multiple sparse ring-element challenges from one transcript-bound seed. +/// +/// # Errors +/// +/// Returns an error if `D` is incompatible with the rejection sampler. +pub fn challenge_sparse_ring_elements_rejection_sampled( + tr: &mut T, + label: &[u8], + len: usize, +) -> Result, HachiError> +where + F: FieldCore + CanonicalField + FromSmallInt, + T: Transcript, +{ + let seed_vec = tr.challenge_bytes(label, 16); + let seed: [u8; 16] = seed_vec + .try_into() + .map_err(|_| HachiError::InvalidInput("rejection sampler seed length mismatch".into()))?; + sample_labrador_sparse_challenges::(len, &seed, REJECTION_SAMPLER_SINGLE_NONCE) +} diff --git a/src/test_utils.rs b/src/test_utils.rs new file mode 100644 index 00000000..6222c3fd --- /dev/null +++ b/src/test_utils.rs @@ -0,0 +1,180 @@ +//! Shared test configuration and helpers. +//! +//! This module is only compiled under `#[cfg(test)]` and provides common +//! building blocks for both unit tests (inside `src/`) and integration +//! tests (inside `tests/`). + +use std::array::from_fn; + +use crate::algebra::{CyclotomicRing, Fp64}; +use crate::error::HachiError; +use crate::protocol::commitment::utils::flat_matrix::FlatMatrix; +use crate::protocol::commitment::{ + compute_num_digits, compute_num_digits_fold, CommitmentConfig, DecompositionParams, + HachiCommitmentLayout, +}; +use crate::{FieldCore, FromSmallInt}; + +/// Default test field: a 32-bit prime `p = 4294967197`. +pub type F = Fp64<4294967197>; +/// Ring degree used in tests. +pub const D: usize = 64; + +/// Minimal commitment config for fast unit tests. +#[derive(Clone)] +pub struct TinyConfig; + +impl CommitmentConfig for TinyConfig { + const D: usize = 64; + const N_A: usize = 2; + const N_B: usize = 2; + const N_D: usize = 2; + const CHALLENGE_WEIGHT: usize = 3; + + fn decomposition() -> DecompositionParams { + DecompositionParams { + log_basis: 3, + log_commit_bound: 32, + log_open_bound: None, + } + } + + fn commitment_layout(_max_num_vars: usize) -> Result { + HachiCommitmentLayout::new::(1, 1, &Self::decomposition()) + } +} + +/// Number of ring elements per block (`2^m_vars`). +pub const BLOCK_LEN: usize = 2; +/// Number of blocks (`2^r_vars`). +pub const NUM_BLOCKS: usize = 2; +/// Gadget base exponent (`b = 2^log_basis()`), derived from `TinyConfig`. +pub fn log_basis() -> u32 { + TinyConfig::decomposition().log_basis +} +/// Inner Ajtai row count from `TinyConfig`. +pub const N_A: usize = TinyConfig::N_A; + +/// Decomposition depth for original coefficients under `TinyConfig`. +pub fn num_digits_commit() -> usize { + let d = TinyConfig::decomposition(); + compute_num_digits(d.log_commit_bound, d.log_basis) +} + +/// Decomposition depth for opening / full-field coefficients under `TinyConfig`. +pub fn num_digits_open() -> usize { + let d = TinyConfig::decomposition(); + let log_open = d.log_open_bound.unwrap_or(d.log_commit_bound); + compute_num_digits(log_open, d.log_basis) +} + +/// Decomposition depth for the folded witness `z_pre` under `TinyConfig`. +pub fn num_digits_fold() -> usize { + let d = TinyConfig::decomposition(); + compute_num_digits_fold(1, TinyConfig::CHALLENGE_WEIGHT, d.log_basis) +} + +/// Dense matrix-vector multiply over cyclotomic rings. +/// +/// Matrix rows may be wider than `vec` (e.g. when matrices are widened for +/// multi-level folding); extra columns are treated as multiplying zero. +pub fn mat_vec_mul(mat: &FlatMatrix, vec: &[CyclotomicRing]) -> Vec> { + let view = mat.view::(); + (0..view.num_rows()) + .map(|i| { + let row = view.row(i); + assert!(row.len() >= vec.len()); + row.iter() + .zip(vec.iter()) + .fold(CyclotomicRing::::zero(), |acc, (a, x)| { + acc + (*a * *x) + }) + }) + .collect() +} + +/// Generate deterministic test blocks of ring elements. +pub fn sample_blocks() -> Vec>> { + (0..NUM_BLOCKS) + .map(|bi| { + (0..BLOCK_LEN) + .map(|bj| { + let coeffs = from_fn(|k| F::from_u64((bi * 1_000 + bj * 100 + k) as u64)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect() + }) + .collect() +} + +/// Generate deterministic inner opening-point scalars. +pub fn sample_a() -> Vec { + (0..BLOCK_LEN) + .map(|j| F::from_u64((j * 10 + 1) as u64)) + .collect() +} + +/// Generate deterministic outer opening-point scalars. +pub fn sample_b() -> Vec { + (0..NUM_BLOCKS) + .map(|i| F::from_u64((i * 7 + 3) as u64)) + .collect() +} + +/// Recompose a gadget-decomposed ring element: `sum_i parts[i] * b^i`. +pub fn field_gadget_recompose( + parts: &[CyclotomicRing], + log_basis: u32, +) -> CyclotomicRing { + let b = F::from_u64(1u64 << log_basis); + let mut result = CyclotomicRing::::zero(); + let mut b_power = F::one(); + for part in parts { + result += part.scale(&b_power); + b_power *= b; + } + result +} + +/// Recompose `z_hat` chunks (num_digits_fold-width) back to `z_pre` elements. +pub fn recompose_z_hat(z_hat: &[CyclotomicRing]) -> Vec> { + z_hat + .chunks(num_digits_fold()) + .map(|chunk| field_gadget_recompose(chunk, log_basis())) + .collect() +} + +/// Recompose a vector of gadget-decomposed elements (num_digits_commit-width chunks). +pub fn gadget_recompose_vec(x_hat: &[CyclotomicRing]) -> Vec> { + x_hat + .chunks(num_digits_commit()) + .map(|chunk| field_gadget_recompose(chunk, log_basis())) + .collect() +} + +/// Recompose a vector of i8 gadget-decomposed digit planes (num_digits_commit-width chunks). +pub fn gadget_recompose_vec_i8(x_hat: &[[i8; D]]) -> Vec> { + x_hat + .chunks(num_digits_commit()) + .map(|chunk| CyclotomicRing::gadget_recompose_pow2_i8(chunk, log_basis())) + .collect() +} + +/// Alias for [`gadget_recompose_vec`] (same num_digits_commit-width recomposition). +pub fn field_gadget_recompose_vec(v: &[CyclotomicRing]) -> Vec> { + v.chunks(num_digits_commit()) + .map(|chunk| field_gadget_recompose(chunk, log_basis())) + .collect() +} + +/// Compute `a^T * G^{-1}(z)`: recompose `z` then inner-product with `a`. +pub fn a_transpose_gadget_times_vec(a: &[F], z: &[CyclotomicRing]) -> CyclotomicRing { + let recomposed = field_gadget_recompose_vec(z); + assert_eq!(recomposed.len(), a.len()); + recomposed + .iter() + .zip(a.iter()) + .fold(CyclotomicRing::::zero(), |acc, (z_j, a_j)| { + acc + z_j.scale(a_j) + }) +} diff --git a/tests/algebra.rs b/tests/algebra.rs new file mode 100644 index 00000000..63551f55 --- /dev/null +++ b/tests/algebra.rs @@ -0,0 +1,1861 @@ +#![allow(missing_docs)] + +#[cfg(test)] +mod tests { + use num_bigint::BigUint; + use rand::{rngs::StdRng, SeedableRng}; + + use hachi_pcs::algebra::backend::{CrtReconstruct, NttPrimeOps}; + use hachi_pcs::algebra::ntt::butterfly::{forward_ntt, inverse_ntt, NttTwiddles}; + use hachi_pcs::algebra::poly::Poly; + use hachi_pcs::algebra::tables::{ + q128_garner, q128_primes, q32_garner, q64_garner, q64_primes, Q128_MODULUS, + Q128_NUM_PRIMES, Q32_MODULUS, Q32_NUM_PRIMES, Q32_PRIMES, Q64_MODULUS, Q64_NUM_PRIMES, + }; + use hachi_pcs::algebra::{ + pseudo_mersenne_modulus, Pow2Offset128Field, Pow2OffsetPrimeSpec, POW2_OFFSET_MAX, + POW2_OFFSET_PRIMES, POW2_OFFSET_TABLE, + }; + use hachi_pcs::algebra::{ + CrtNttParamSet, CyclotomicCrtNtt, CyclotomicRing, Fp128, Fp2, Fp2Config, Fp32, Fp4, + Fp4Config, Fp64, HasPacking, LimbQ, MontCoeff, PackedPartialSplitEval32, + PartialSplitEval32, PartialSplitNtt32, Prime128M13M4P0, Prime128M37P3P0, Prime128M52M3P0, + Prime128M54P4P0, Prime128M8M4M1M0, Prime128Offset5823, ScalarBackend, VectorModule, + }; + use hachi_pcs::primitives::serialization::SerializationError; + use hachi_pcs::{ + CanonicalField, FieldCore, FieldSampling, FromSmallInt, HachiDeserialize, HachiSerialize, + Invertible, Module, PseudoMersenneField, + }; + + const P_159: u128 = 340282366920938463463374607431768211297u128; + + #[test] + fn fp32_basic_arith() { + type F = Fp32<251>; + let a = F::from_u64(17); + let b = F::from_u64(99); + assert_eq!((a + b).to_canonical_u32(), 116); + assert_eq!((a * b).to_canonical_u32(), (17u32 * 99) % 251); + + let inv = a.inv().unwrap(); + assert_eq!(a * inv, F::one()); + } + + #[test] + fn fp64_hachi_q_inv() { + type F = Fp64<4294967197>; + let two = F::from_u64(2); + let inv2 = two.inv().unwrap(); + assert_eq!(two * inv2, F::one()); + } + + #[test] + fn fp128_basic_arith() { + type F = Fp128; + + let a = F::from_u64(123); + let b = F::from_u64(456); + let c = a * b + a - b; + let inv = c.inv().unwrap(); + assert_eq!(c * inv, F::one()); + } + + fn rand_u128(rng: &mut R) -> u128 { + let lo = rng.next_u64() as u128; + let hi = rng.next_u64() as u128; + lo | (hi << 64) + } + + fn biguint_to_u128(x: &num_bigint::BigUint) -> u128 { + let mut bytes = x.to_bytes_le(); + bytes.resize(16, 0); + let mut arr = [0u8; 16]; + arr.copy_from_slice(&bytes[..16]); + u128::from_le_bytes(arr) + } + + fn big_mul_mod_u128(a: u128, b: u128, p: u128) -> u128 { + let n = BigUint::from(a) * BigUint::from(b); + let r = n % BigUint::from(p); + biguint_to_u128(&r) + } + + fn check_solinas_prime< + S: CanonicalField + FieldCore + Invertible + PseudoMersenneField + std::fmt::Debug, + >( + p: u128, + iters: usize, + seed: u64, + ) { + assert_eq!(::MODULUS_BITS, 128); + assert_eq!( + ::MODULUS_OFFSET, + 0u128.wrapping_sub(p) + ); + assert_eq!(std::mem::size_of::(), 16); + + let mut rng = StdRng::seed_from_u64(seed); + + for _ in 0..iters { + let a_raw = rand_u128(&mut rng); + let b_raw = rand_u128(&mut rng); + + let a = S::from_canonical_u128_reduced(a_raw); + let b = S::from_canonical_u128_reduced(b_raw); + + // Canonical range invariant. + assert!(a.to_canonical_u128() < p); + assert!(b.to_canonical_u128() < p); + + // Add/sub/neg identities. + assert_eq!(a + S::zero(), a); + assert_eq!(a - S::zero(), a); + assert_eq!(a + (-a), S::zero()); + + // Multiplicative identity. + assert_eq!(a * S::one(), a); + + // BigUint oracle for mul and sqr (exercises reduction). + let aa = a.to_canonical_u128(); + let bb = b.to_canonical_u128(); + let got_mul = (a * b).to_canonical_u128(); + let exp_mul = big_mul_mod_u128(aa, bb, p); + assert_eq!(got_mul, exp_mul); + + let got_sqr = (a * a).to_canonical_u128(); + let exp_sqr = big_mul_mod_u128(aa, aa, p); + assert_eq!(got_sqr, exp_sqr); + + // Inversion checks (skip explicit inv on zero). + let inv = a.inv_or_zero(); + if a.is_zero() { + assert_eq!(inv, S::zero()); + } else { + assert_eq!(a * inv, S::one()); + assert_eq!(a.inv().unwrap(), inv); + } + } + } + + #[test] + fn fp128_sparse_primes_match_biguint_oracle() { + // These are the five sparse `2^128 - c` primes we care about. + const P13: u128 = 0xffffffffffffffffffffffffffffdff1u128; + const P37: u128 = 0xffffffffffffffffffffffe000000009u128; + const P52: u128 = 0xffffffffffffffffffeffffffffffff9u128; + const P54: u128 = 0xffffffffffffffffffc0000000000011u128; + const P275: u128 = 0xfffffffffffffffffffffffffffffeedu128; + + check_solinas_prime::(P13, 2_000, 13); + check_solinas_prime::(P37, 2_000, 37); + check_solinas_prime::(P52, 2_000, 52); + check_solinas_prime::(P54, 2_000, 54); + check_solinas_prime::(P275, 2_000, 275); + } + + struct NR; + impl Fp2Config> for NR { + fn non_residue() -> Fp32<251> { + -Fp32::<251>::one() + } + } + + struct NR4; + impl Fp4Config, NR> for NR4 { + fn non_residue() -> Fp2, NR> { + Fp2::new(Fp32::<251>::zero(), Fp32::<251>::one()) + } + } + + #[test] + fn fp2_fp4_inversion_smoke() { + type F = Fp32<251>; + type F2 = Fp2; + type F4 = Fp4; + + let x = F2::new(F::from_u64(3), F::from_u64(7)); + let inv = x.inv().unwrap(); + assert!((x * inv) == F2::one()); + + let y = F4::new( + F2::new(F::from_u64(5), F::from_u64(1)), + F2::new(F::from_u64(2), F::from_u64(9)), + ); + let invy = y.inv().unwrap(); + assert!((y * invy) == F4::one()); + } + + #[test] + fn vector_module_ops() { + type F = Fp32<251>; + + let a = VectorModule::([F::from_u64(1), F::from_u64(2), F::from_u64(3)]); + let b = VectorModule::([F::from_u64(3), F::from_u64(4), F::from_u64(5)]); + + let c = a + b; + assert_eq!(c.0[0], F::from_u64(4)); + + let d = a.scale(&F::from_u64(7)); + assert_eq!(d.0[1], F::from_u64(14)); + } + + #[test] + fn inv_zero_returns_none() { + assert!(Fp32::<251>::zero().inv().is_none()); + assert!(Fp64::<4294967197>::zero().inv().is_none()); + assert!(Fp128::::zero().inv().is_none()); + } + + #[test] + fn inv_or_zero_behavior_for_prime_fields() { + type F32 = Fp32<251>; + assert_eq!(F32::zero().inv_or_zero(), F32::zero()); + let x32 = F32::from_u64(17); + let inv32 = x32.inv_or_zero(); + assert_eq!(x32 * inv32, F32::one()); + + type F64 = Fp64<4294967197>; + assert_eq!(F64::zero().inv_or_zero(), F64::zero()); + let x64 = F64::from_u64(2); + let inv64 = x64.inv_or_zero(); + assert_eq!(x64 * inv64, F64::one()); + + type F128 = Fp128; + assert_eq!(F128::zero().inv_or_zero(), F128::zero()); + let x128 = F128::from_u64(12345); + let inv128 = x128.inv_or_zero(); + assert_eq!(x128 * inv128, F128::one()); + } + + #[test] + fn field_identities_fp32() { + type F = Fp32<251>; + let a = F::from_u64(42); + let b = F::from_u64(73); + let c = F::from_u64(11); + + // Additive identity + assert_eq!(a + F::zero(), a); + // Multiplicative identity + assert_eq!(a * F::one(), a); + // Additive inverse + assert_eq!(a + (-a), F::zero()); + // Distributivity + assert_eq!(a * (b + c), a * b + a * c); + // Commutativity + assert_eq!(a * b, b * a); + assert_eq!(a + b, b + a); + } + + #[test] + fn field_identities_fp64() { + type F = Fp64<4294967197>; + let a = F::from_u64(123456); + let b = F::from_u64(789012); + let c = F::from_u64(345678); + + assert_eq!(a + F::zero(), a); + assert_eq!(a * F::one(), a); + assert_eq!(a + (-a), F::zero()); + assert_eq!(a * (b + c), a * b + a * c); + } + + #[test] + fn field_identities_fp128() { + type F = Fp128; + let a = F::from_u64(999999); + let b = F::from_u64(888888); + let c = F::from_u64(777777); + + assert_eq!(a + F::zero(), a); + assert_eq!(a * F::one(), a); + assert_eq!(a + (-a), F::zero()); + assert_eq!(a * (b + c), a * b + a * c); + } + + #[test] + fn serialization_round_trip_fp32() { + type F = Fp32<251>; + let val = F::from_u64(42); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn serialization_round_trip_fp64() { + type F = Fp64<4294967197>; + let val = F::from_u64(123456789); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn serialization_round_trip_fp128() { + type F = Fp128; + let val = F::from_u64(999999999); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn serialization_round_trip_ext() { + type F = Fp32<251>; + type F2 = Fp2; + let val = F2::new(F::from_u64(3), F::from_u64(7)); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F2::deserialize_compressed(&buf[..]).unwrap(); + assert!(val == restored); + } + + #[test] + fn serialization_round_trip_fp4() { + type F = Fp32<251>; + type F2 = Fp2; + type F4 = Fp4; + + let val = F4::new( + F2::new(F::from_u64(5), F::from_u64(1)), + F2::new(F::from_u64(2), F::from_u64(9)), + ); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = F4::deserialize_compressed(&buf[..]).unwrap(); + assert!(val == restored); + } + + #[test] + fn serialization_round_trip_vector_module() { + type F = Fp32<251>; + let val = VectorModule::([F::from_u64(1), F::from_u64(2), F::from_u64(3)]); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = VectorModule::::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn serialization_round_trip_poly() { + type F = Fp32<251>; + + let val = Poly::([ + F::from_u64(7), + F::from_u64(11), + F::from_u64(13), + F::from_u64(29), + ]); + let mut buf = Vec::new(); + val.serialize_compressed(&mut buf).unwrap(); + let restored = Poly::::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(val, restored); + } + + #[test] + fn deserialize_checked_rejects_non_canonical_field_elements() { + type F32 = Fp32<251>; + let bad32 = 251u32.to_le_bytes(); + let err32 = F32::deserialize_compressed(&bad32[..]).unwrap_err(); + assert!(matches!(err32, SerializationError::InvalidData(_))); + let unchecked32 = F32::deserialize_compressed_unchecked(&bad32[..]).unwrap(); + assert_eq!(unchecked32, F32::zero()); + + type F64 = Fp64<4294967197>; + let bad64 = 4294967197u64.to_le_bytes(); + let err64 = F64::deserialize_compressed(&bad64[..]).unwrap_err(); + assert!(matches!(err64, SerializationError::InvalidData(_))); + let unchecked64 = F64::deserialize_compressed_unchecked(&bad64[..]).unwrap(); + assert_eq!(unchecked64, F64::zero()); + + type F128 = Fp128; + let bad128 = P_159.to_le_bytes(); + let err128 = F128::deserialize_compressed(&bad128[..]).unwrap_err(); + assert!(matches!(err128, SerializationError::InvalidData(_))); + let unchecked128 = F128::deserialize_compressed_unchecked(&bad128[..]).unwrap(); + assert_eq!(unchecked128, F128::zero()); + + // Sparse 128-bit prime: same checked/unchecked behavior. + type S13 = Prime128M13M4P0; + const P13: u128 = 0xffffffffffffffffffffffffffffdff1u128; + let bad13 = P13.to_le_bytes(); + let err13 = S13::deserialize_compressed(&bad13[..]).unwrap_err(); + assert!(matches!(err13, SerializationError::InvalidData(_))); + let unchecked13 = S13::deserialize_compressed_unchecked(&bad13[..]).unwrap(); + assert_eq!(unchecked13, S13::zero()); + } + + #[test] + fn fp2_conjugate_and_norm() { + type F = Fp32<251>; + type F2 = Fp2; + let x = F2::new(F::from_u64(3), F::from_u64(7)); + let conj = x.conjugate(); + assert!(conj == F2::new(F::from_u64(3), -F::from_u64(7))); + // For Fp2 with u^2 = -1: norm = c0^2 + c1^2 = 9 + 49 = 58 + assert_eq!(x.norm(), F::from_u64(58)); + // x * conjugate(x) should embed the norm into Fp2 + let prod = x * conj; + assert!(prod == F2::new(F::from_u64(58), F::zero())); + } + + #[test] + fn fp2_distributivity() { + type F = Fp32<251>; + type F2 = Fp2; + let a = F2::new(F::from_u64(3), F::from_u64(7)); + let b = F2::new(F::from_u64(11), F::from_u64(5)); + let c = F2::new(F::from_u64(2), F::from_u64(9)); + assert!(a * (b + c) == a * b + a * c); + } + + #[test] + fn limbq_from_to_u128_round_trip() { + for &val in &[0u128, 1, 12345, 123456789, (1u128 << 28) - 1] { + let limb: LimbQ<3> = LimbQ::from(val); + assert_eq!( + u128::try_from(limb).unwrap(), + val, + "round-trip failed for {val}" + ); + } + } + + #[test] + fn limbq_add_sub_inverse() { + let a: LimbQ<3> = LimbQ::from(12345u128); + let b: LimbQ<3> = LimbQ::from(6789u128); + let sum = a + b; + let diff = sum - b; + assert_eq!(diff, a); + } + + #[test] + fn limbq_ordering() { + let a: LimbQ<3> = LimbQ::from(100u128); + let b: LimbQ<3> = LimbQ::from(200u128); + assert!(a < b); + assert!(b > a); + assert_eq!(a, a); + } + + #[test] + fn ntt_normalize_in_range() { + for prime in &Q32_PRIMES { + for &a in &[0i16, 1, -1, 100, -100, prime.p - 1, -(prime.p - 1)] { + let n = prime.normalize(MontCoeff::from_raw(a)); + assert!( + n.raw() >= 0 && n.raw() < prime.p, + "normalize({a}) = {} for p={}", + n.raw(), + prime.p + ); + } + } + } + + #[test] + fn csubp_widened_handles_large_negative_i16() { + for &prime in &Q32_PRIMES { + let p = prime.p; + // Values in (-2p, -(2^15 - p)) that previously overflowed in narrow i16 + for &raw in &[-20000i16, -(p + p / 2), -(p + 1000)] { + if raw <= -2 * p || raw >= 0 { + continue; + } + let a = MontCoeff::from_raw(raw); + let reduced = prime.reduce_range(a); + let r = reduced.raw(); + assert!( + r > -p && r < p, + "reduce_range({raw}) = {r} not in (-{p}, {p}) for p={p}" + ); + + let norm = prime.normalize(reduced); + let n = norm.raw(); + assert!( + n >= 0 && n < p, + "normalize(reduce_range({raw})) = {n} not in [0, {p}) for p={p}" + ); + } + } + } + + #[test] + fn ntt_mul_commutative() { + let prime = Q32_PRIMES[0]; + let a = MontCoeff::from_raw(1234); + let b = MontCoeff::from_raw(5678); + assert_eq!(prime.mul(a, b), prime.mul(b, a)); + } + + #[test] + fn mont_coeff_round_trip() { + for prime in &Q32_PRIMES { + for &val in &[0i16, 1, 2, 100, prime.p - 1] { + let mont = prime.from_canonical(val); + let back = prime.to_canonical(mont); + assert_eq!(back, val, "round-trip failed for val={val}, p={}", prime.p); + } + } + } + + #[test] + fn poly_add_sub_neg() { + type F = Fp32<251>; + let a = Poly::([F::from_u64(1), F::from_u64(2), F::from_u64(3)]); + let b = Poly::([F::from_u64(10), F::from_u64(20), F::from_u64(30)]); + + let sum = a + b; + assert_eq!(sum.0[0], F::from_u64(11)); + assert_eq!(sum.0[1], F::from_u64(22)); + assert_eq!(sum.0[2], F::from_u64(33)); + + let diff = b - a; + assert_eq!(diff.0[0], F::from_u64(9)); + + let neg_a = -a; + assert_eq!(a + neg_a, Poly::zero()); + } + + #[test] + fn cyclotomic_ring_negacyclic_property() { + type F = Fp32<251>; + type R = CyclotomicRing; + + // X in the ring: [0, 1, 0, 0] + let x = R::x(); + + // X^2 + let x2 = x * x; + let expected_x2 = R::from_coefficients([F::zero(), F::zero(), F::one(), F::zero()]); + assert_eq!(x2, expected_x2); + + // X^4 should equal -1 (because X^4 + 1 = 0 in the ring) + let x4 = x2 * x2; + assert_eq!(x4, -R::one(), "X^D should equal -1 in Z_q[X]/(X^D + 1)"); + } + + #[test] + fn cyclotomic_ring_mul_identity() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + assert_eq!(a * R::one(), a); + assert_eq!(R::one() * a, a); + } + + #[test] + fn cyclotomic_ring_mul_zero() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + assert_eq!(a * R::zero(), R::zero()); + } + + #[test] + fn cyclotomic_ring_commutativity() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + let b = R::from_coefficients([ + F::from_u64(5), + F::from_u64(13), + F::from_u64(99), + F::from_u64(1), + ]); + assert_eq!(a * b, b * a); + } + + #[test] + fn cyclotomic_ring_distributivity() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + let b = R::from_coefficients([ + F::from_u64(5), + F::from_u64(13), + F::from_u64(99), + F::from_u64(1), + ]); + let c = R::from_coefficients([ + F::from_u64(2), + F::from_u64(9), + F::from_u64(50), + F::from_u64(77), + ]); + assert_eq!(a * (b + c), a * b + a * c); + } + + #[test] + fn cyclotomic_ring_associativity() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + let b = R::from_coefficients([ + F::from_u64(5), + F::from_u64(13), + F::from_u64(99), + F::from_u64(1), + ]); + let c = R::from_coefficients([ + F::from_u64(2), + F::from_u64(9), + F::from_u64(50), + F::from_u64(77), + ]); + assert_eq!((a * b) * c, a * (b * c)); + } + + #[test] + fn cyclotomic_ring_additive_inverse() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + assert_eq!(a + (-a), R::zero()); + } + + #[test] + fn cyclotomic_ring_serialization_round_trip() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients([ + F::from_u64(3), + F::from_u64(7), + F::from_u64(11), + F::from_u64(42), + ]); + let mut buf = Vec::new(); + a.serialize_compressed(&mut buf).unwrap(); + let restored = R::deserialize_compressed(&buf[..]).unwrap(); + assert_eq!(a, restored); + } + + #[test] + fn cyclotomic_ring_degree_64() { + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + // X^64 = -1 in Z_q[X]/(X^64 + 1) + let x = R::x(); + let mut power = R::one(); + for _ in 0..64 { + power *= x; + } + assert_eq!(power, -R::one(), "X^64 should equal -1"); + } + + #[test] + fn ntt_forward_inverse_round_trip() { + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + + let original: [MontCoeff; 64] = + std::array::from_fn(|i| prime.from_canonical((i as i16) % prime.p)); + + let mut a = original; + forward_ntt(&mut a, prime, &tw); + inverse_ntt(&mut a, prime, &tw); + + for (i, (got, expected)) in a.iter().zip(original.iter()).enumerate() { + let got_canon = prime.to_canonical(prime.normalize(*got)); + let exp_canon = prime.to_canonical(prime.normalize(*expected)); + assert_eq!( + got_canon, exp_canon, + "NTT round-trip mismatch at index {i}: got {got_canon}, expected {exp_canon}" + ); + } + } + + #[test] + fn ntt_forward_inverse_all_primes() { + for (pi, prime) in Q32_PRIMES.iter().enumerate() { + let tw = NttTwiddles::::compute(*prime); + + let original: [_; 64] = + std::array::from_fn(|i| prime.from_canonical(((i * (pi + 1)) as i16) % prime.p)); + + let mut a = original; + forward_ntt(&mut a, *prime, &tw); + inverse_ntt(&mut a, *prime, &tw); + + for (i, (got, expected)) in a.iter().zip(original.iter()).enumerate() { + let g = prime.to_canonical(prime.normalize(*got)); + let e = prime.to_canonical(prime.normalize(*expected)); + assert_eq!( + g, e, + "prime[{pi}] p={}: round-trip mismatch at index {i}", + prime.p + ); + } + } + } + + #[test] + fn negacyclic_ntt_mul_matches_schoolbook_single_prime_d8() { + const D: usize = 8; + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + + let a_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 7) + 3) % prime.p); + let b_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 5) + 11) % prime.p); + + // Schoolbook negacyclic convolution mod p: X^D = -1. + let mut school = [0i16; D]; + for (i, &ai) in a_canon.iter().enumerate() { + for (j, &bj) in b_canon.iter().enumerate() { + let prod = (ai as i64 * bj as i64) % (prime.p as i64); + let idx = i + j; + if idx < D { + school[idx] = ((school[idx] as i64 + prod) % (prime.p as i64)) as i16; + } else { + let k = idx - D; + school[k] = ((school[k] as i64 - prod) % (prime.p as i64)) as i16; + } + } + } + for x in &mut school { + if *x < 0 { + *x = (*x as i64 + prime.p as i64) as i16; + } + } + + let mut a = std::array::from_fn(|i| prime.from_canonical(a_canon[i])); + let mut b = std::array::from_fn(|i| prime.from_canonical(b_canon[i])); + forward_ntt(&mut a, prime, &tw); + forward_ntt(&mut b, prime, &tw); + + let mut c: [_; D] = std::array::from_fn(|i| prime.mul(a[i], b[i])); + inverse_ntt(&mut c, prime, &tw); + + let got: [i16; D] = std::array::from_fn(|i| prime.to_canonical(prime.normalize(c[i]))); + assert_eq!(got, school); + } + + #[test] + fn negacyclic_ntt_forward_matches_manual_evals_d8() { + const D: usize = 8; + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + let p = prime.p as i64; + + fn pow_mod(mut base: i64, mut exp: i64, modulus: i64) -> i64 { + let mut acc = 1i64; + base %= modulus; + while exp > 0 { + if exp & 1 == 1 { + acc = (acc * base) % modulus; + } + base = (base * base) % modulus; + exp >>= 1; + } + acc + } + + // Compute canonical psi (primitive 2D-th root) directly. + let half = (p - 1) / 2; + let exp = (p - 1) / (2 * D as i64); + let mut psi = None; + for a in 2..p { + if pow_mod(a, half, p) == p - 1 { + let cand = pow_mod(a, exp, p); + if pow_mod(cand, D as i64, p) == p - 1 { + psi = Some(cand); + break; + } + } + } + let psi = psi.expect("psi should exist"); + let a_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 7) + 3) % prime.p); + + let mut expected = Vec::with_capacity(D); + for k in 0..D { + let alpha = pow_mod(psi, (2 * k + 1) as i64, p); + let mut acc = 0i64; + let mut power = 1i64; + for &ai in &a_canon { + acc = (acc + (ai as i64) * power) % p; + power = (power * alpha) % p; + } + expected.push(acc as i16); + } + expected.sort_unstable(); + + let mut a = std::array::from_fn(|i| prime.from_canonical(a_canon[i])); + forward_ntt(&mut a, prime, &tw); + let mut got: Vec = a + .iter() + .map(|x| prime.to_canonical(prime.normalize(*x))) + .collect(); + got.sort_unstable(); + + assert_eq!(got, expected); + } + + #[test] + fn negacyclic_ntt_mul_matches_schoolbook_single_prime_d64() { + const D: usize = 64; + let prime = Q32_PRIMES[0]; + let tw = NttTwiddles::::compute(prime); + let p = prime.p as i64; + + let a_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 7) + 3) % prime.p); + let b_canon: [i16; D] = std::array::from_fn(|i| ((i as i16 * 5) + 11) % prime.p); + + let mut school = [0i16; D]; + for (i, &ai) in a_canon.iter().enumerate() { + for (j, &bj) in b_canon.iter().enumerate() { + let prod = (ai as i64 * bj as i64) % p; + let idx = i + j; + if idx < D { + school[idx] = ((school[idx] as i64 + prod) % p) as i16; + } else { + let k = idx - D; + school[k] = ((school[k] as i64 - prod) % p) as i16; + } + } + } + for x in &mut school { + if *x < 0 { + *x = (*x as i64 + p) as i16; + } + } + + let mut a = std::array::from_fn(|i| prime.from_canonical(a_canon[i])); + let mut b = std::array::from_fn(|i| prime.from_canonical(b_canon[i])); + forward_ntt(&mut a, prime, &tw); + forward_ntt(&mut b, prime, &tw); + + let mut c: [_; D] = std::array::from_fn(|i| prime.reduce_range(prime.mul(a[i], b[i]))); + inverse_ntt(&mut c, prime, &tw); + + let got: [i16; D] = std::array::from_fn(|i| prime.to_canonical(prime.normalize(c[i]))); + assert_eq!(got, school); + } + + #[test] + fn negacyclic_ntt_mul_matches_schoolbook_all_q32_primes_d64() { + const D: usize = 64; + let a_canon: [i16; D] = std::array::from_fn(|i| (i as i16 * 7 + 3)); + let b_canon: [i16; D] = std::array::from_fn(|i| (i as i16 * 5 + 11)); + + for (pi, &prime) in Q32_PRIMES.iter().enumerate() { + let tw = NttTwiddles::::compute(prime); + let p = prime.p as i64; + + let a_mod: [i16; D] = + std::array::from_fn(|i| ((a_canon[i] as i64).rem_euclid(p)) as i16); + let b_mod: [i16; D] = + std::array::from_fn(|i| ((b_canon[i] as i64).rem_euclid(p)) as i16); + + let mut school = [0i16; D]; + for (i, &ai) in a_mod.iter().enumerate() { + for (j, &bj) in b_mod.iter().enumerate() { + let prod = (ai as i64 * bj as i64) % p; + let idx = i + j; + if idx < D { + school[idx] = ((school[idx] as i64 + prod) % p) as i16; + } else { + let k = idx - D; + school[k] = ((school[k] as i64 - prod) % p) as i16; + } + } + } + for x in &mut school { + if *x < 0 { + *x = (*x as i64 + p) as i16; + } + } + + let mut a = std::array::from_fn(|i| prime.from_canonical(a_mod[i])); + let mut b = std::array::from_fn(|i| prime.from_canonical(b_mod[i])); + forward_ntt(&mut a, prime, &tw); + forward_ntt(&mut b, prime, &tw); + + let mut c = [MontCoeff::from_raw(0i16); D]; + for i in 0..D { + c[i] = prime.reduce_range(prime.mul(a[i], b[i])); + } + inverse_ntt(&mut c, prime, &tw); + + let got: [i16; D] = std::array::from_fn(|i| prime.to_canonical(prime.normalize(c[i]))); + assert_eq!(got, school, "prime[{pi}] p={} mismatch", prime.p); + } + } + + #[test] + fn cyclotomic_ntt_crt_round_trip_q32() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + + let coeffs: [F; 64] = + std::array::from_fn(|i| F::from_u64(((i as u64 * 17) + 5) % Q32_MODULUS)); + let ring = R::from_coefficients(coeffs); + let ntt = N::from_ring(&ring, &Q32_PRIMES, &twiddles); + let garner = q32_garner(); + let round_trip = ntt.to_ring(&Q32_PRIMES, &twiddles, &garner); + + assert_eq!(ring, round_trip); + } + + #[test] + fn cyclotomic_ntt_reduced_ops_are_stable() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + + let a = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 3) + 1) % Q32_MODULUS) + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 11) + 7) % Q32_MODULUS) + })); + + let ntt_a = N::from_ring(&a, &Q32_PRIMES, &twiddles); + let ntt_b = N::from_ring(&b, &Q32_PRIMES, &twiddles); + + let sum = ntt_a.add_reduced(&ntt_b, &Q32_PRIMES); + let back = sum.sub_reduced(&ntt_b, &Q32_PRIMES); + assert_eq!(back, ntt_a); + + let garner = q32_garner(); + let zero_ntt = ntt_a.add_reduced(&ntt_a.neg_reduced(&Q32_PRIMES), &Q32_PRIMES); + let zero_ring = zero_ntt.to_ring(&Q32_PRIMES, &twiddles, &garner); + assert_eq!(zero_ring, R::zero()); + } + + #[test] + fn backend_path_matches_default_scalar_path() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + let ring = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 13) + 9) % Q32_MODULUS) + })); + + let default_ntt = N::from_ring(&ring, &Q32_PRIMES, &twiddles); + let backend_ntt = + N::from_ring_with_backend::(&ring, &Q32_PRIMES, &twiddles); + assert_eq!(default_ntt, backend_ntt); + + let garner = q32_garner(); + let default_back = default_ntt.to_ring(&Q32_PRIMES, &twiddles, &garner); + let backend_back = + backend_ntt.to_ring_with_backend::(&Q32_PRIMES, &twiddles, &garner); + assert_eq!(default_back, backend_back); + } + + #[test] + fn crt_ntt_mul_matches_schoolbook_q32() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let twiddles: [NttTwiddles; Q32_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(Q32_PRIMES[k])); + let garner = q32_garner(); + + let a = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 7) + 3) % Q32_MODULUS) + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 5) + 11) % Q32_MODULUS) + })); + + let schoolbook = a * b; + + let ntt_a = N::from_ring(&a, &Q32_PRIMES, &twiddles); + let ntt_b = N::from_ring(&b, &Q32_PRIMES, &twiddles); + let ntt_prod = ntt_a.pointwise_mul(&ntt_b, &Q32_PRIMES); + let ntt_result: R = ntt_prod.to_ring(&Q32_PRIMES, &twiddles, &garner); + + assert_eq!(schoolbook, ntt_result); + } + + #[test] + fn q128_garner_reconstruct_matches_coeffs_no_ntt() { + type F = Fp128<{ Q128_MODULUS }>; + + let primes = q128_primes(); + let garner = q128_garner(); + + let coeffs: [F; 64] = std::array::from_fn(|i| { + if i < 8 { + F::from_u64((i as u64 * 31) + 7) + } else { + F::zero() + } + }); + + let mut canonical = [[0i32; 64]; Q128_NUM_PRIMES]; + for (k, prime) in primes.iter().enumerate() { + let p = prime.p as u32 as u128; + for (i, c) in coeffs.iter().enumerate() { + canonical[k][i] = (c.to_canonical_u128() % p) as i32; + } + } + + let reconstructed: [F; 64] = + >::reconstruct( + &primes, &canonical, &garner, + ); + + assert_eq!(reconstructed, coeffs); + } + + #[test] + fn q128_prime_ntt_round_trip_per_prime() { + let primes = q128_primes(); + let twiddles: [NttTwiddles; Q128_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + + // Use the same sparse coefficient pattern as q128_ntt_round_trip, but test + // the per-prime NTT+Montgomery machinery in isolation (no Garner/Fp128). + let residues: [u32; 64] = + std::array::from_fn(|i| if i < 8 { (i as u32 * 31) + 7 } else { 0 }); + + for k in 0..Q128_NUM_PRIMES { + let prime = primes[k]; + let mut limb = [MontCoeff::from_raw(0i32); 64]; + for (i, r) in residues.iter().enumerate() { + let reduced = (*r as i64 % (prime.p as i64)) as i32; + limb[i] = >::from_canonical(prime, reduced); + } + + forward_ntt(&mut limb, prime, &twiddles[k]); + inverse_ntt(&mut limb, prime, &twiddles[k]); + + for (i, r) in residues.iter().enumerate() { + let expected = (*r as i64 % (prime.p as i64)) as i32; + let got = >::to_canonical(prime, limb[i]); + assert_eq!(got, expected, "prime idx={k} coeff idx={i}"); + } + } + } + + #[test] + fn q128_ntt_round_trip() { + type F = Fp128<{ Q128_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let primes = q128_primes(); + let twiddles: [NttTwiddles; Q128_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = q128_garner(); + + let coeffs: [F; 64] = std::array::from_fn(|i| { + if i < 8 { + F::from_u64((i as u64 * 31) + 7) + } else { + F::zero() + } + }); + let ring = R::from_coefficients(coeffs); + let ntt = N::from_ring(&ring, &primes, &twiddles); + let round_trip: R = ntt.to_ring(&primes, &twiddles, &garner); + + assert_eq!(ring, round_trip); + } + + #[test] + fn crt_ntt_mul_matches_schoolbook_q128() { + type F = Fp128<{ Q128_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let primes = q128_primes(); + let twiddles: [NttTwiddles; Q128_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = q128_garner(); + + let a = R::from_coefficients(std::array::from_fn(|i| { + if i < 8 { + F::from_u64((i as u64 * 7) + 3) + } else { + F::zero() + } + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + if i < 8 { + F::from_u64((i as u64 * 9) + 11) + } else { + F::zero() + } + })); + + let schoolbook = a * b; + + let ntt_a = N::from_ring(&a, &primes, &twiddles); + let ntt_b = N::from_ring(&b, &primes, &twiddles); + let ntt_prod = ntt_a.pointwise_mul(&ntt_b, &primes); + let ntt_result: R = ntt_prod.to_ring(&primes, &twiddles, &garner); + + assert_eq!(schoolbook, ntt_result); + } + + #[test] + fn partial_split_forward_matches_direct_eval_q128m5823() { + type F = Prime128Offset5823; + + fn eval_poly(coeffs: &[F; 32], x: F) -> F { + coeffs + .iter() + .rev() + .fold(F::zero(), |acc, coeff| acc * x + *coeff) + } + + let split = PartialSplitNtt32::::compute(); + let coeffs: [F; 32] = std::array::from_fn(|i| { + let centered = ((i as i64 * 19) % 29) - 14; + F::from_i64(centered) + }); + + let mut got = coeffs; + split.forward_class(&mut got); + + let expected: [F; 32] = std::array::from_fn(|i| eval_poly(&coeffs, split.eval_roots()[i])); + assert_eq!(got, expected); + } + + #[test] + fn partial_split_mul_matches_schoolbook_q128m5823() { + type F = Prime128Offset5823; + type R = CyclotomicRing; + + let split = PartialSplitNtt32::::compute(); + let a = R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 7 + 3) % 41) - 20; + F::from_i64(centered) + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 11 + 5) % 37) - 18; + F::from_i64(centered) + })); + + let schoolbook = a * b; + let split_result = split.multiply_d64(&a, &b); + + assert_eq!(schoolbook, split_result); + } + + #[test] + fn partial_split_matches_crt_mul_q128m5823() { + type F = Prime128Offset5823; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let split = PartialSplitNtt32::::compute(); + let params = CrtNttParamSet::new(q128_primes()); + + let a = R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 13 + 1) % 33) - 16; + F::from_i64(centered) + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 9 + 7) % 35) - 17; + F::from_i64(centered) + })); + + let split_result = split.multiply_d64(&a, &b); + + let ntt_a = N::from_ring_with_params(&a, ¶ms); + let ntt_b = N::from_ring_with_params(&b, ¶ms); + let crt_result: R = ntt_a + .pointwise_mul_with_params(&ntt_b, ¶ms) + .to_ring_with_params(¶ms); + + assert_eq!(split_result, crt_result); + } + + #[test] + fn partial_split_mul_centered_i8_matches_schoolbook_q128m5823() { + type F = Prime128Offset5823; + type R = CyclotomicRing; + + let split = PartialSplitNtt32::::compute(); + let lhs = R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 7 + 3) % 41) - 20; + F::from_i64(centered) + })); + let rhs_i8: [i8; 64] = std::array::from_fn(|i| (((i * 23 + 11) % 256) as i16 - 128) as i8); + let rhs = R::from_coefficients(std::array::from_fn(|i| F::from_i8(rhs_i8[i]))); + + let schoolbook = lhs * rhs; + let split_result = split.multiply_d64_rhs_i8(&lhs, &rhs_i8); + + assert_eq!(schoolbook, split_result); + } + + #[test] + fn partial_split_mul_centered_i8_matches_crt_q128m5823() { + type F = Prime128Offset5823; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let split = PartialSplitNtt32::::compute(); + let params = CrtNttParamSet::new(q128_primes()); + let lhs = R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 13 + 1) % 33) - 16; + F::from_i64(centered) + })); + let rhs_i8: [i8; 64] = std::array::from_fn(|i| (((i * 19 + 5) % 256) as i16 - 128) as i8); + + let split_result = split.multiply_d64_rhs_i8(&lhs, &rhs_i8); + let crt_result: R = N::from_ring_with_params(&lhs, ¶ms) + .pointwise_mul_with_params(&N::from_i8_with_params(&rhs_i8, ¶ms), ¶ms) + .to_ring_with_params(¶ms); + + assert_eq!(split_result, crt_result); + } + + #[test] + fn partial_split_repr_round_trip_q128m5823() { + type F = Prime128Offset5823; + type R = CyclotomicRing; + + let split = PartialSplitNtt32::::compute(); + let ring = R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 17 + 9) % 53) - 26; + F::from_i64(centered) + })); + + let eval = PartialSplitEval32::from_ring(&split, &ring); + let back = eval.to_ring(&split); + + assert_eq!(ring, back); + } + + #[test] + fn partial_split_repr_cached_product_matches_schoolbook_q128m5823() { + type F = Prime128Offset5823; + type R = CyclotomicRing; + + let split = PartialSplitNtt32::::compute(); + let lhs = R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 5 + 7) % 47) - 23; + F::from_i64(centered) + })); + let rhs = R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 11 + 3) % 43) - 21; + F::from_i64(centered) + })); + + let lhs_eval = PartialSplitEval32::from_ring(&split, &lhs); + let rhs_eval = PartialSplitEval32::from_ring(&split, &rhs); + let cached = lhs_eval.pointwise_mul(&rhs_eval, &split).to_ring(&split); + + assert_eq!(cached, lhs * rhs); + } + + #[test] + fn partial_split_cyclic_mul_matches_schoolbook_q128m5823() { + type F = Prime128Offset5823; + type R = CyclotomicRing; + + let split = PartialSplitNtt32::::compute(); + let lhs = R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 5 + 7) % 47) - 23; + F::from_i64(centered) + })); + let rhs = R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 11 + 3) % 43) - 21; + F::from_i64(centered) + })); + + let mut school = [F::zero(); 64]; + for i in 0..64 { + for j in 0..64 { + school[(i + j) % 64] += lhs.coefficients()[i] * rhs.coefficients()[j]; + } + } + + assert_eq!(split.multiply_cyclic_d64(&lhs, &rhs), school); + } + + #[test] + fn partial_split_quotient_matches_schoolbook_high_half_q128m5823() { + type F = Prime128Offset5823; + type R = CyclotomicRing; + + let split = PartialSplitNtt32::::compute(); + let lhs = R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 7 + 5) % 41) - 20; + F::from_i64(centered) + })); + let rhs = R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 9 + 1) % 39) - 19; + F::from_i64(centered) + })); + + let mut high = [F::zero(); 64]; + for i in 0..64 { + for j in 0..64 { + let idx = i + j; + if idx >= 64 { + high[idx - 64] += lhs.coefficients()[i] * rhs.coefficients()[j]; + } + } + } + + let quotient = split.unreduced_quotient_d64(&lhs, &rhs); + assert_eq!(quotient.coefficients(), &high); + } + + #[test] + fn partial_split_cached_matvec_matches_schoolbook_q128m5823() { + type F = Prime128Offset5823; + type R = CyclotomicRing; + + const ROWS: usize = 3; + const COLS: usize = 5; + + let split = PartialSplitNtt32::::compute(); + let matrix: Vec> = (0..ROWS) + .map(|r| { + (0..COLS) + .map(|c| { + R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 7 + (11 * r + 5 * c) as i64) % 37) - 18; + F::from_i64(centered) + })) + }) + .collect() + }) + .collect(); + let vector: Vec = (0..COLS) + .map(|c| { + R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 13 + (9 * c) as i64) % 41) - 20; + F::from_i64(centered) + })) + }) + .collect(); + + let matrix_eval: Vec>> = matrix + .iter() + .map(|row| { + row.iter() + .map(|ring| PartialSplitEval32::from_ring(&split, ring)) + .collect() + }) + .collect(); + let vector_eval: Vec> = vector + .iter() + .map(|ring| PartialSplitEval32::from_ring(&split, ring)) + .collect(); + + let got: Vec = (0..ROWS) + .map(|r| { + let mut acc = PartialSplitEval32::zero(); + for (mat_entry, vec_entry) in matrix_eval[r].iter().zip(vector_eval.iter()) { + acc.add_mul_assign(mat_entry, vec_entry, &split); + } + acc.to_ring(&split) + }) + .collect(); + + let expected: Vec = (0..ROWS) + .map(|r| { + let mut acc = R::zero(); + for (mat_entry, vec_entry) in matrix[r].iter().zip(vector.iter()) { + acc += *mat_entry * *vec_entry; + } + acc + }) + .collect(); + + assert_eq!(got, expected); + } + + #[test] + fn partial_split_packed_cached_matvec_matches_scalar_q128m5823() { + type F = Prime128Offset5823; + type PF = ::Packing; + type R = CyclotomicRing; + + let rows = PackedPartialSplitEval32::::WIDTH + 3; + let cols = 5usize; + + let split = PartialSplitNtt32::::compute(); + let packed = split.packed::(); + let matrix: Vec> = (0..rows) + .map(|r| { + (0..cols) + .map(|c| { + R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 7 + (11 * r + 5 * c) as i64) % 37) - 18; + F::from_i64(centered) + })) + }) + .collect() + }) + .collect(); + let vector: Vec = (0..cols) + .map(|c| { + R::from_coefficients(std::array::from_fn(|i| { + let centered = ((i as i64 * 13 + (9 * c) as i64) % 41) - 20; + F::from_i64(centered) + })) + }) + .collect(); + + let matrix_eval: Vec>> = matrix + .iter() + .map(|row| { + row.iter() + .map(|ring| PartialSplitEval32::from_ring(&split, ring)) + .collect() + }) + .collect(); + let vector_eval: Vec> = vector + .iter() + .map(|ring| PartialSplitEval32::from_ring(&split, ring)) + .collect(); + let vector_packed: Vec> = vector_eval + .iter() + .map(PackedPartialSplitEval32::::broadcast) + .collect(); + + let mut got = Vec::with_capacity(rows); + let mut row_chunks = matrix_eval.chunks_exact(PackedPartialSplitEval32::::WIDTH); + for row_chunk in row_chunks.by_ref() { + let packed_row: Vec> = (0..cols) + .map(|c| PackedPartialSplitEval32::::from_fn(|lane| row_chunk[lane][c])) + .collect(); + let mut acc = PackedPartialSplitEval32::::zero(); + for (mat_entry, vec_entry) in packed_row.iter().zip(vector_packed.iter()) { + packed.add_mul_assign(&mut acc, mat_entry, vec_entry); + } + packed.append_rings(&acc, &mut got); + } + for row in row_chunks.remainder() { + let mut acc = PartialSplitEval32::zero(); + for (mat_entry, vec_entry) in row.iter().zip(vector_eval.iter()) { + acc.add_mul_assign(mat_entry, vec_entry, &split); + } + got.push(acc.to_ring(&split)); + } + + let expected: Vec = (0..rows) + .map(|r| { + let mut acc = R::zero(); + for (mat_entry, vec_entry) in matrix[r].iter().zip(vector.iter()) { + acc += *mat_entry * *vec_entry; + } + acc + }) + .collect(); + + assert_eq!(got, expected); + } + + #[test] + fn crt_add_assign_pointwise_mul_matches_scalar_q128m5823() { + type F = Prime128Offset5823; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let params = CrtNttParamSet::new(q128_primes()); + let acc0 = N::from_ring_with_params( + &R::from_coefficients(std::array::from_fn(|i| { + F::from_i64(((i as i64 * 5 + 1) % 31) - 15) + })), + ¶ms, + ); + let lhs = N::from_ring_with_params( + &R::from_coefficients(std::array::from_fn(|i| { + F::from_i64(((i as i64 * 7 + 3) % 37) - 18) + })), + ¶ms, + ); + let rhs = N::from_ring_with_params( + &R::from_coefficients(std::array::from_fn(|i| { + F::from_i64(((i as i64 * 11 + 9) % 41) - 20) + })), + ¶ms, + ); + + let mut got = acc0.clone(); + got.add_assign_pointwise_mul_with_params(&lhs, &rhs, ¶ms); + + let expected = + acc0.add_reduced_with_params(&lhs.pointwise_mul_with_params(&rhs, ¶ms), ¶ms); + + assert_eq!(got, expected); + } + + #[test] + fn q64_ntt_round_trip() { + type F = Fp64<{ Q64_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let primes = q64_primes(); + let twiddles: [NttTwiddles; Q64_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = q64_garner(); + + let coeffs: [F; 64] = + std::array::from_fn(|i| F::from_u64(((i as u64 * 19) + 3) % Q64_MODULUS)); + let ring = R::from_coefficients(coeffs); + let ntt = N::from_ring(&ring, &primes, &twiddles); + let round_trip: R = ntt.to_ring(&primes, &twiddles, &garner); + + assert_eq!(ring, round_trip); + } + + #[test] + fn crt_ntt_mul_matches_schoolbook_q64() { + type F = Fp64<{ Q64_MODULUS }>; + type R = CyclotomicRing; + type N = CyclotomicCrtNtt; + + let primes = q64_primes(); + let twiddles: [NttTwiddles; Q64_NUM_PRIMES] = + std::array::from_fn(|k| NttTwiddles::compute(primes[k])); + let garner = q64_garner(); + + let a = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 5) + 9) % Q64_MODULUS) + })); + let b = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 17) + 13) % Q64_MODULUS) + })); + + let schoolbook = a * b; + + let ntt_a = N::from_ring(&a, &primes, &twiddles); + let ntt_b = N::from_ring(&b, &primes, &twiddles); + let ntt_prod = ntt_a.pointwise_mul(&ntt_b, &primes); + let ntt_result: R = ntt_prod.to_ring(&primes, &twiddles, &garner); + + assert_eq!(schoolbook, ntt_result); + } + + #[test] + fn field_sampling_respects_modulus() { + type F = Fp32<251>; + let mut rng = StdRng::seed_from_u64(42); + for _ in 0..1024 { + let x = F::sample(&mut rng); + assert!(x.to_canonical_u32() < 251); + } + } + + #[test] + fn pow2_offset_registry_is_consistent() { + fn assert_is_pseudo_mersenne() {} + assert_is_pseudo_mersenne::(); + + for Pow2OffsetPrimeSpec { + bits, + offset, + modulus, + .. + } in POW2_OFFSET_PRIMES + { + assert!((offset as u128) <= POW2_OFFSET_MAX); + assert_eq!(POW2_OFFSET_TABLE[bits as usize], offset as i16); + assert_eq!( + Some(modulus), + pseudo_mersenne_modulus(bits, offset as u128), + "2^k-offset modulus mismatch for k={bits}, offset={offset}" + ); + assert_eq!(modulus % 8, 5); + } + + let x = Pow2Offset128Field::from_u64(1234567); + let inv = x.inv().unwrap(); + assert_eq!(x * inv, Pow2Offset128Field::one()); + } + + #[test] + fn cyclotomic_sigma_is_ring_automorphism() { + type F = Fp32<251>; + type R = CyclotomicRing; + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((3 * i + 1) as u64))); + let b = R::from_coefficients(std::array::from_fn(|i| F::from_u64((5 * i + 2) as u64))); + + let k1 = 3usize; + let k2 = 5usize; + let two_d = 16usize; + + assert_eq!(a.sigma(1), a); + assert_eq!(a.sigma_m1().sigma_m1(), a); + assert_eq!(a.sigma(k1).sigma(k2), a.sigma((k1 * k2) % two_d)); + assert_eq!((a * b).sigma(k1), a.sigma(k1) * b.sigma(k1)); + } + + #[test] + fn cyclotomic_balanced_pow2_decompose_recompose_round_trip() { + type F = Fp64<{ Q32_MODULUS }>; + type R = CyclotomicRing; + + let ring = R::from_coefficients(std::array::from_fn(|i| { + F::from_u64(((i as u64 * 73) + 17) % Q32_MODULUS) + })); + + // Q32 balanced base-16: 9 levels absorb the carry-out near q/2. + let digits = ring.balanced_decompose_pow2(9, 4); + let round_trip = R::gadget_recompose_pow2(&digits, 4); + assert_eq!(round_trip, ring); + } + + #[test] + fn sparse_pm1_challenge_has_expected_weight() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let mut rng = StdRng::seed_from_u64(123); + let challenge = R::sample_sparse_pm1(&mut rng, 11); + assert_eq!(challenge.hamming_weight(), 11); + + for c in challenge.coefficients() { + let x = c.to_canonical_u32(); + if x != 0 { + assert!(x == 1 || x == 250, "nonzero coefficient must be +/-1"); + } + } + } + + #[test] + fn negacyclic_shift_equals_mul_by_monomial() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((3 * i + 1) as u64))); + + for k in 0..8 { + let mut monomial_coeffs = [F::zero(); 8]; + monomial_coeffs[k] = F::one(); + let monomial = R::from_coefficients(monomial_coeffs); + assert_eq!( + a.negacyclic_shift(k), + a * monomial, + "negacyclic_shift({k}) != mul by X^{k}" + ); + } + + assert_eq!(a.negacyclic_shift(0), a); + assert_eq!( + a.negacyclic_shift(8), + a, + "shift by D should be identity mod D" + ); + } + + #[test] + fn negacyclic_shift_degree_64() { + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((7 * i + 3) as u64))); + let x = R::x(); + let mut x_pow = R::one(); + for k in 0..64 { + assert_eq!( + a.negacyclic_shift(k), + a * x_pow, + "negacyclic_shift({k}) mismatch at D=64" + ); + x_pow *= x; + } + } + + #[test] + fn mul_by_monomial_sum_matches_ring_mul() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((5 * i + 2) as u64))); + + // Sum of X^1 + X^3 + X^5 + let positions = [1, 3, 5]; + let mut sparse = [F::zero(); 8]; + for &p in &positions { + sparse[p] = F::one(); + } + let sparse_ring = R::from_coefficients(sparse); + + assert_eq!( + a.mul_by_monomial_sum(&positions), + a * sparse_ring, + "mul_by_monomial_sum should equal ring mul by sparse element" + ); + } + + #[test] + fn mul_by_monomial_sum_single_position_equals_shift() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((i + 1) as u64))); + for k in 0..8 { + assert_eq!( + a.mul_by_monomial_sum(&[k]), + a.negacyclic_shift(k), + "single-position monomial_sum should equal negacyclic_shift" + ); + } + } + + #[test] + fn mul_by_monomial_sum_empty_is_zero() { + type F = Fp32<251>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((i + 1) as u64))); + assert_eq!(a.mul_by_monomial_sum(&[]), R::zero()); + } + + #[test] + fn mul_by_sparse_matches_schoolbook() { + use hachi_pcs::algebra::SparseChallenge; + + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((3 * i + 7) as u64))); + + let challenge = SparseChallenge { + positions: vec![2, 17, 41], + coeffs: vec![1, -1, 1], + }; + let dense: R = challenge.to_dense().unwrap(); + + let via_sparse = a.mul_by_sparse(&challenge); + let via_schoolbook = a * dense; + + assert_eq!( + via_sparse, via_schoolbook, + "mul_by_sparse must equal schoolbook multiplication" + ); + } + + #[test] + fn mul_by_sparse_with_all_negative_coeffs() { + use hachi_pcs::algebra::SparseChallenge; + + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64((i + 1) as u64))); + + let challenge = SparseChallenge { + positions: vec![0, 5, 63], + coeffs: vec![-1, -1, -1], + }; + let dense: R = challenge.to_dense().unwrap(); + + assert_eq!(a.mul_by_sparse(&challenge), a * dense); + } + + #[test] + fn is_zero_detects_zero_and_nonzero() { + type F = Fp32<251>; + type R = CyclotomicRing; + + assert!(R::zero().is_zero()); + assert!(!R::one().is_zero()); + + let a = R::from_coefficients(std::array::from_fn(|i| F::from_u64(i as u64))); + assert!(!a.is_zero()); + } + + #[test] + fn kron_scalars_matches_kron_row_constant_rings() { + type F = Fp64<4294967197>; + type R = CyclotomicRing; + + let scalars_a: Vec = (0..4).map(|i| F::from_u64(i * 3 + 1)).collect(); + let scalars_b: Vec = (0..3).map(|i| F::from_u64(i * 7 + 2)).collect(); + + let rings_a: Vec = scalars_a + .iter() + .map(|&s| { + let mut c = [F::zero(); 16]; + c[0] = s; + R::from_coefficients(c) + }) + .collect(); + let rings_b: Vec = scalars_b + .iter() + .map(|&s| { + let mut c = [F::zero(); 16]; + c[0] = s; + R::from_coefficients(c) + }) + .collect(); + + let via_ring: Vec = rings_a + .iter() + .flat_map(|l| rings_b.iter().map(move |r| *l * *r)) + .collect(); + + let via_scalar: Vec = scalars_a + .iter() + .flat_map(|&l| { + scalars_b.iter().map(move |&r| { + let mut c = [F::zero(); 16]; + c[0] = l * r; + R::from_coefficients(c) + }) + }) + .collect(); + + assert_eq!(via_ring, via_scalar); + } +} diff --git a/tests/commitment_contract.rs b/tests/commitment_contract.rs new file mode 100644 index 00000000..a6151581 --- /dev/null +++ b/tests/commitment_contract.rs @@ -0,0 +1,210 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::CyclotomicRing; +use hachi_pcs::algebra::Fp64; +use hachi_pcs::algebra::SparseChallenge; +use hachi_pcs::protocol::commitment::utils::crt_ntt::NttSlotCache; +use hachi_pcs::protocol::commitment::utils::flat_matrix::FlatMatrix; +use hachi_pcs::protocol::commitment::{DummyProof, HachiCommitment}; +use hachi_pcs::protocol::hachi_poly_ops::{DecomposeFoldWitness, HachiPolyOps}; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{ + AppendToTranscript, BasisMode, Blake2bTranscript, CommitmentScheme, HachiCommitmentLayout, + Transcript, +}; +use hachi_pcs::{CanonicalField, FieldCore, FromSmallInt, HachiError}; + +type F = Fp64<4294967197>; + +/// Trivial polynomial wrapper that implements `HachiPolyOps`. +#[derive(Debug, Clone)] +struct DummyPoly { + coeffs: Vec, +} + +impl DummyPoly { + fn evaluate(&self, point: &[F]) -> F { + assert_eq!(point.len(), self.num_vars()); + let mut acc = self.coeffs[0]; + for (i, r_i) in point.iter().enumerate() { + acc += self.coeffs[i + 1] * *r_i; + } + acc + } + + fn num_vars(&self) -> usize { + self.coeffs.len().saturating_sub(1) + } +} + +impl HachiPolyOps for DummyPoly { + type CommitCache = NttSlotCache<1>; + + fn num_ring_elems(&self) -> usize { + self.coeffs.len() + } + + fn evaluate_ring(&self, scalars: &[F]) -> CyclotomicRing { + let mut acc = F::zero(); + for (c, &s) in self.coeffs.iter().zip(scalars.iter()) { + acc += *c * s; + } + CyclotomicRing::from_coefficients([acc]) + } + + fn fold_blocks(&self, _scalars: &[F], _block_len: usize) -> Vec> { + vec![] + } + + fn decompose_fold( + &self, + _challenges: &[SparseChallenge], + _block_len: usize, + _num_digits: usize, + _log_basis: u32, + ) -> DecomposeFoldWitness { + DecomposeFoldWitness { + z_pre: vec![], + centered_coeffs: vec![], + centered_inf_norm: 0, + } + } + + fn commit_inner( + &self, + _a_matrix: &FlatMatrix, + _ntt_a: &NttSlotCache<1>, + _block_len: usize, + _num_digits_commit: usize, + _num_digits_open: usize, + _log_basis: u32, + ) -> Result>, HachiError> { + Ok(vec![]) + } +} + +#[derive(Clone)] +struct DummySetup { + _max_num_vars: usize, +} + +#[derive(Clone)] +struct DummyScheme; + +impl CommitmentScheme for DummyScheme { + type ProverSetup = DummySetup; + type VerifierSetup = DummySetup; + type Commitment = HachiCommitment; + type Proof = DummyProof; + type CommitHint = HachiCommitment; + + fn setup_prover(max_num_vars: usize) -> Self::ProverSetup { + DummySetup { + _max_num_vars: max_num_vars, + } + } + + fn setup_verifier(setup: &Self::ProverSetup) -> Self::VerifierSetup { + setup.clone() + } + + fn commit>( + _poly: &P, + _setup: &Self::ProverSetup, + _layout: &HachiCommitmentLayout, + ) -> Result<(Self::Commitment, Self::CommitHint), HachiError> { + let c = HachiCommitment(0); + Ok((c, c)) + } + + fn prove, P: HachiPolyOps>( + _setup: &Self::ProverSetup, + _poly: &P, + _opening_point: &[F], + _hint: Self::CommitHint, + transcript: &mut T, + commitment: &Self::Commitment, + _basis: BasisMode, + _layout: &HachiCommitmentLayout, + ) -> Result { + commitment.append_to_transcript(labels::ABSORB_COMMITMENT, transcript); + let q = transcript.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + Ok(DummyProof(q.to_canonical_u128())) + } + + fn verify>( + proof: &Self::Proof, + _setup: &Self::VerifierSetup, + transcript: &mut T, + _opening_point: &[F], + _opening: &F, + commitment: &Self::Commitment, + _basis: BasisMode, + _layout: &HachiCommitmentLayout, + ) -> Result<(), HachiError> { + commitment.append_to_transcript(labels::ABSORB_COMMITMENT, transcript); + let q = transcript.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + if proof.0 == q.to_canonical_u128() { + Ok(()) + } else { + Err(HachiError::InvalidProof) + } + } + + fn protocol_name() -> &'static [u8] { + b"HachiDummy" + } +} + +#[test] +fn commitment_scheme_round_trip() { + let poly = DummyPoly { + coeffs: vec![F::from_u64(3), F::from_u64(5), F::from_u64(7)], + }; + let opening_point = [F::from_u64(11), F::from_u64(13)]; + + let psetup = DummyScheme::setup_prover(poly.num_vars()); + let vsetup = DummyScheme::setup_verifier(&psetup); + + let layout = HachiCommitmentLayout { + m_vars: 0, + r_vars: 0, + block_len: 1, + num_blocks: 1, + num_digits_commit: 1, + num_digits_open: 1, + num_digits_fold: 1, + inner_width: 1, + outer_width: 1, + d_matrix_width: 1, + log_basis: 1, + }; + let (commitment, hint) = DummyScheme::commit(&poly, &psetup, &layout).unwrap(); + let opening = poly.evaluate(&opening_point); + + let mut prover_t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let proof = DummyScheme::prove( + &psetup, + &poly, + &opening_point, + hint, + &mut prover_t, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + + let mut verifier_t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + DummyScheme::verify( + &proof, + &vsetup, + &mut verifier_t, + &opening_point, + &opening, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); +} diff --git a/tests/hachi_e2e.rs b/tests/hachi_e2e.rs new file mode 100644 index 00000000..6a811bd2 --- /dev/null +++ b/tests/hachi_e2e.rs @@ -0,0 +1,318 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::Fp128; +use hachi_pcs::protocol::commitment::{Fp128FullCommitmentConfig, Fp128OneHotCommitmentConfig}; +use hachi_pcs::protocol::commitment_scheme::HachiCommitmentScheme; +use hachi_pcs::protocol::hachi_poly_ops::{DensePoly, HachiPolyOps, OneHotPoly}; +use hachi_pcs::protocol::opening_point::{ + reduce_inner_opening_to_ring_element, ring_opening_point_from_field, +}; +use hachi_pcs::protocol::transcript::Blake2bTranscript; +use hachi_pcs::protocol::CommitmentConfig; +use hachi_pcs::{BasisMode, CanonicalField, CommitmentScheme, Transcript}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::sync::{Mutex, Once}; +use std::time::Instant; + +type F = Fp128<0xfffffffffffffffffffffffffffffeed>; +const ONEHOT_K: usize = 256; +// Keep the default e2e tests small enough for `cargo test`; the larger nv=25 +// workloads remain covered by `benches/hachi_e2e.rs`, while still triggering +// the standard Labrador handoff path. +const FULL_TEST_NV: usize = 14; +// The one-hot witness grows much faster than the dense path, so use a smaller +// default size here while still exercising the standard Labrador handoff. +const ONEHOT_TEST_NV: usize = 15; +const STACK_SIZE: usize = 256 * 1024 * 1024; + +static INIT_RAYON: Once = Once::new(); +static E2E_TEST_LOCK: Mutex<()> = Mutex::new(()); + +fn init_rayon_pool() { + INIT_RAYON.call_once(|| { + #[cfg(feature = "parallel")] + rayon::ThreadPoolBuilder::new() + .stack_size(STACK_SIZE) + .build_global() + .ok(); + }); +} + +fn random_point(nv: usize) -> Vec { + let mut rng = StdRng::seed_from_u64(0xcafe_babe); + (0..nv) + .map(|_| F::from_canonical_u128_reduced(rng.gen::())) + .collect() +} + +fn run_on_large_stack(f: impl FnOnce() + Send + 'static) { + std::thread::Builder::new() + .stack_size(STACK_SIZE) + .spawn(f) + .expect("failed to spawn thread") + .join() + .expect("test thread panicked"); +} + +/// Remove any stale disk-persistence cache for `max_num_vars` so that a setup +/// written by a different `CommitmentConfig` doesn't get loaded by mistake. +#[cfg(feature = "disk-persistence")] +fn purge_setup_cache(max_num_vars: usize) { + let cache_dir = std::env::var("LOCALAPPDATA") + .map(std::path::PathBuf::from) + .or_else(|_| { + std::env::var("HOME").map(|home| { + let mut p = std::path::PathBuf::from(&home); + if p.join("Library/Caches").exists() { + p.push("Library/Caches"); + } else { + p.push(".cache"); + } + p + }) + }); + if let Ok(mut path) = cache_dir { + path.push("hachi"); + if let Ok(entries) = std::fs::read_dir(&path) { + let needle = format!("_nv{max_num_vars}.setup"); + for entry in entries.flatten() { + let entry_path = entry.path(); + if entry_path + .file_name() + .and_then(|name| name.to_str()) + .is_some_and(|name| name.starts_with("hachi_") && name.ends_with(&needle)) + { + let _ = std::fs::remove_file(entry_path); + } + } + } + } +} + +fn opening_from_poly>( + poly: &P, + point: &[F], + layout: &hachi_pcs::protocol::commitment::HachiCommitmentLayout, +) -> F { + let alpha_bits = D.trailing_zeros() as usize; + assert_eq!(point.len(), alpha_bits + layout.m_vars + layout.r_vars); + + let inner_point = &point[..alpha_bits]; + let reduced_point = &point[alpha_bits..]; + let ring_opening_point = ring_opening_point_from_field( + reduced_point, + layout.r_vars, + layout.m_vars, + BasisMode::Lagrange, + ) + .expect("opening point shape should match layout"); + + let (y_ring, _) = poly.evaluate_and_fold( + &ring_opening_point.b, + &ring_opening_point.a, + layout.block_len, + ); + let v = reduce_inner_opening_to_ring_element::(inner_point, BasisMode::Lagrange) + .expect("inner opening point should match ring dimension"); + (y_ring * v.sigma_m1()).coefficients()[0] +} + +// --------------------------------------------------------------------------- +// Dense ("full") prove/verify +// --------------------------------------------------------------------------- + +#[test] +fn full_labrador_prove_verify() { + init_rayon_pool(); + let _guard = E2E_TEST_LOCK.lock().unwrap(); + run_on_large_stack(|| { + type Cfg = Fp128FullCommitmentConfig; + const D: usize = Cfg::D; + + let layout = Cfg::commitment_layout(FULL_TEST_NV).expect("layout"); + + let mut rng = StdRng::seed_from_u64(0xdead_beef); + let evals: Vec = (0..1usize << FULL_TEST_NV) + .map(|_| F::from_canonical_u128_reduced(rng.gen::())) + .collect(); + + let poly = DensePoly::::from_field_evals(FULL_TEST_NV, &evals).unwrap(); + let pt = random_point(FULL_TEST_NV); + let expected_opening = opening_from_poly(&poly, &pt, &layout); + + #[cfg(feature = "disk-persistence")] + purge_setup_cache(FULL_TEST_NV); + + let setup = + as CommitmentScheme>::setup_prover(FULL_TEST_NV); + let (commitment, hint) = as CommitmentScheme>::commit( + &poly, &setup, &layout, + ) + .unwrap(); + + let mut prover_transcript = Blake2bTranscript::::new(b"hachi_e2e"); + let prove_start = Instant::now(); + let proof = as CommitmentScheme>::prove( + &setup, + &poly, + &pt, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + let prove_time = prove_start.elapsed(); + + let proof_bytes = proof.size(); + assert!(proof_bytes > 0, "proof must be non-empty"); + assert!( + !proof.levels.is_empty(), + "proof must have at least one level" + ); + let tail_kind = if proof.has_labrador_tail() { + "labrador" + } else { + "direct" + }; + + let verifier_setup = + as CommitmentScheme>::setup_verifier(&setup); + let mut verifier_transcript = Blake2bTranscript::::new(b"hachi_e2e"); + let verify_start = Instant::now(); + let verify_result = as CommitmentScheme>::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &pt, + &expected_opening, + &commitment, + BasisMode::Lagrange, + &layout, + ); + let verify_time = verify_start.elapsed(); + + assert!( + verify_result.is_ok(), + "verification must pass: {:?}", + verify_result.err() + ); + + tracing::info!( + prove_s = prove_time.as_secs_f64(), + verify_s = verify_time.as_secs_f64(), + proof_bytes, + proof_kib = proof_bytes as f64 / 1024.0, + levels = proof.levels.len(), + tail_kind, + "full/nv{FULL_TEST_NV} e2e" + ); + }); +} + +// --------------------------------------------------------------------------- +// One-hot prove/verify +// --------------------------------------------------------------------------- + +#[test] +fn onehot_labrador_prove_verify() { + init_rayon_pool(); + let _guard = E2E_TEST_LOCK.lock().unwrap(); + run_on_large_stack(|| { + type Cfg = Fp128OneHotCommitmentConfig; + const D: usize = Cfg::D; + + let layout = Cfg::commitment_layout(ONEHOT_TEST_NV).expect("layout"); + let total_field = (layout.num_blocks * layout.block_len) + .checked_mul(D) + .expect("total field size overflow"); + let onehot_k = ONEHOT_K; + let total_chunks = total_field / onehot_k; + assert_eq!(total_chunks * onehot_k, total_field); + + let mut rng = StdRng::seed_from_u64(0xbeef_cafe); + let indices: Vec> = (0..total_chunks) + .map(|_| Some(rng.gen_range(0..onehot_k))) + .collect(); + + let onehot_poly = + OneHotPoly::::new(onehot_k, indices.clone(), layout.r_vars, layout.m_vars) + .unwrap(); + + let pt = random_point(ONEHOT_TEST_NV); + let expected_opening = opening_from_poly(&onehot_poly, &pt, &layout); + + #[cfg(feature = "disk-persistence")] + purge_setup_cache(ONEHOT_TEST_NV); + + let setup = + as CommitmentScheme>::setup_prover(ONEHOT_TEST_NV); + let (commitment, hint) = as CommitmentScheme>::commit( + &onehot_poly, + &setup, + &layout, + ) + .unwrap(); + + let mut prover_transcript = Blake2bTranscript::::new(b"hachi_e2e"); + let prove_start = Instant::now(); + let proof = as CommitmentScheme>::prove( + &setup, + &onehot_poly, + &pt, + hint, + &mut prover_transcript, + &commitment, + BasisMode::Lagrange, + &layout, + ) + .unwrap(); + let prove_time = prove_start.elapsed(); + + let proof_bytes = proof.size(); + assert!(proof_bytes > 0, "proof must be non-empty"); + assert!( + !proof.levels.is_empty(), + "proof must have at least one level" + ); + let tail_kind = if proof.has_labrador_tail() { + "labrador" + } else { + "direct" + }; + + let verifier_setup = + as CommitmentScheme>::setup_verifier(&setup); + let mut verifier_transcript = Blake2bTranscript::::new(b"hachi_e2e"); + let verify_start = Instant::now(); + let verify_result = as CommitmentScheme>::verify( + &proof, + &verifier_setup, + &mut verifier_transcript, + &pt, + &expected_opening, + &commitment, + BasisMode::Lagrange, + &layout, + ); + let verify_time = verify_start.elapsed(); + + assert!( + verify_result.is_ok(), + "verification must pass: {:?}", + verify_result.err() + ); + + tracing::info!( + prove_s = prove_time.as_secs_f64(), + verify_s = verify_time.as_secs_f64(), + proof_bytes, + proof_kib = proof_bytes as f64 / 1024.0, + levels = proof.levels.len(), + tail_kind, + "onehot/nv{ONEHOT_TEST_NV} e2e" + ); + }); +} diff --git a/tests/label_schedule.rs b/tests/label_schedule.rs new file mode 100644 index 00000000..c943b83b --- /dev/null +++ b/tests/label_schedule.rs @@ -0,0 +1,63 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{Blake2bTranscript, Transcript}; + +type F = Fp64<4294967197>; + +#[test] +fn label_namespace_does_not_include_dory_literals() { + let banned = ["vmv_", "beta", "alpha", "gamma", "final_e", "dory"]; + for label in labels::all_labels() { + let text = std::str::from_utf8(label).expect("labels must be valid utf8 literals"); + for needle in &banned { + assert!( + !text.contains(needle), + "label `{text}` must not contain banned token `{needle}`" + ); + } + } +} + +fn run_hachi_schedule>(transcript: &mut T) -> (F, F, F) { + transcript.append_bytes(labels::ABSORB_COMMITMENT, b"C"); + transcript.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"O"); + let c_linear_relation = transcript.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + + transcript.append_bytes(labels::ABSORB_RING_SWITCH_MESSAGE, b"RS"); + let c_ring_switch = transcript.challenge_scalar(labels::CHALLENGE_RING_SWITCH); + + transcript.append_bytes(labels::ABSORB_SUMCHECK_ROUND, b"SC1"); + let c_sumcheck = transcript.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + transcript.append_bytes(labels::ABSORB_STOP_CONDITION, b"STOP"); + let _ = transcript.challenge_scalar(labels::CHALLENGE_STOP_CONDITION); + + (c_linear_relation, c_ring_switch, c_sumcheck) +} + +#[test] +fn schedule_is_replayable_with_hachi_labels() { + let mut prover = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut verifier = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + assert_eq!( + run_hachi_schedule(&mut prover), + run_hachi_schedule(&mut verifier) + ); +} + +#[test] +fn schedule_detects_reordered_round_messages() { + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"C"); + t1.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"O"); + let a = t1.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"O"); + t2.append_bytes(labels::ABSORB_COMMITMENT, b"C"); + let b = t2.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + + assert_ne!(a, b); +} diff --git a/tests/onehot_commitment.rs b/tests/onehot_commitment.rs new file mode 100644 index 00000000..eb80586f --- /dev/null +++ b/tests/onehot_commitment.rs @@ -0,0 +1,161 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::CyclotomicRing; +use hachi_pcs::protocol::commitment::{HachiCommitmentCore, RingCommitmentScheme}; +use hachi_pcs::test_utils::*; +use hachi_pcs::{FieldCore, FromSmallInt}; + +type Core = HachiCommitmentCore; + +fn psetup() -> >::ProverSetup { + >::setup(16) + .unwrap() + .0 +} + +/// Compare the optimized one-hot path against the default dense path. +/// +/// The default implementation materializes the full vector and calls +/// `commit_coeffs`. The optimized impl uses sparse inner Ajtai. +/// Both must produce identical (commitment, s_all, t_hat_all). +fn assert_onehot_matches_dense(onehot_k: usize, indices: &[usize]) { + let opt_indices: Vec> = indices.iter().map(|&i| Some(i)).collect(); + let setup = psetup(); + + // Optimized sparse path. + let w_sparse = >::commit_onehot( + onehot_k, + &opt_indices, + &setup, + ) + .unwrap(); + + // Reference: materialize the full one-hot vector, pack into ring elements, + // and commit via the dense path. + let total_field = indices.len() * onehot_k; + let total_ring = total_field / D; + let mut field_elems = vec![F::zero(); total_field]; + for (c, &idx) in indices.iter().enumerate() { + field_elems[c * onehot_k + idx] = F::from_u64(1); + } + let ring_coeffs: Vec> = (0..total_ring) + .map(|r| { + let coeffs: [F; D] = std::array::from_fn(|i| field_elems[r * D + i]); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let w_dense = + >::commit_coeffs(&ring_coeffs, &setup) + .unwrap(); + + assert_eq!( + w_sparse.commitment, w_dense.commitment, + "commitments must match" + ); + assert_eq!( + w_sparse.t_hat, w_dense.t_hat, + "t_hat_all (decomposed inner output) must match" + ); +} + +#[test] +fn onehot_k_gt_d_basic() { + // K=128, D=64 => K/D=2, T=2 => T*K=256 => 4 ring elements + assert_onehot_matches_dense(128, &[0, 64]); +} + +#[test] +fn onehot_k_gt_d_various_positions() { + assert_onehot_matches_dense(128, &[127, 0]); + assert_onehot_matches_dense(128, &[63, 65]); + assert_onehot_matches_dense(128, &[32, 96]); +} + +#[test] +fn onehot_k_much_gt_d() { + // K=256, D=64 => K/D=4, T=1 => T*K=256 => 4 ring elements + assert_onehot_matches_dense(256, &[0]); + assert_onehot_matches_dense(256, &[63]); + assert_onehot_matches_dense(256, &[64]); + assert_onehot_matches_dense(256, &[255]); + assert_onehot_matches_dense(256, &[100]); +} + +#[test] +fn onehot_k_eq_d_basic() { + // K=64=D, T=4 => 4 ring elements, each is a monomial X^{idx}. + assert_onehot_matches_dense(64, &[0, 0, 0, 0]); +} + +#[test] +fn onehot_k_eq_d_varied() { + assert_onehot_matches_dense(64, &[0, 31, 32, 63]); + assert_onehot_matches_dense(64, &[1, 2, 3, 4]); + assert_onehot_matches_dense(64, &[63, 63, 63, 63]); +} + +#[test] +fn onehot_k_lt_d_basic() { + // K=16, D=64 => D/K=4, T=16 => T*K=256 => 4 ring elements. + // Each ring element spans 4 chunks, so has 4 nonzero coefficients. + let indices: Vec = (0..16).map(|i| i % 16).collect(); + assert_onehot_matches_dense(16, &indices); +} + +#[test] +fn onehot_k_lt_d_all_zeros() { + let indices = vec![0; 16]; + assert_onehot_matches_dense(16, &indices); +} + +#[test] +fn onehot_k_lt_d_all_max() { + let indices = vec![15; 16]; + assert_onehot_matches_dense(16, &indices); +} + +#[test] +fn onehot_k_lt_d_mixed() { + let indices = vec![0, 15, 7, 3, 12, 1, 8, 14, 5, 10, 2, 9, 6, 11, 4, 13]; + assert_onehot_matches_dense(16, &indices); +} + +#[test] +fn onehot_k_lt_d_ratio_2() { + // K=32, D=64 => D/K=2, T=8 => T*K=256 => 4 ring elements. + let indices = vec![0, 31, 16, 8, 24, 4, 12, 20]; + assert_onehot_matches_dense(32, &indices); +} + +#[test] +fn onehot_rejects_non_divisible_k_and_d() { + let setup = psetup(); + let result = >::commit_onehot( + 17, + &[Some(0usize); 4], + &setup, + ); + assert!(result.is_err()); +} + +#[test] +fn onehot_rejects_out_of_range_index() { + let setup = psetup(); + let result = >::commit_onehot( + 64, + &[Some(0usize), Some(64), Some(0), Some(0)], + &setup, + ); + assert!(result.is_err()); +} + +#[test] +fn onehot_rejects_wrong_total_size() { + let setup = psetup(); + let result = >::commit_onehot( + 64, + &[Some(0usize), Some(0), Some(0)], + &setup, + ); + assert!(result.is_err()); +} diff --git a/tests/primality.rs b/tests/primality.rs new file mode 100644 index 00000000..0d7a3a8f --- /dev/null +++ b/tests/primality.rs @@ -0,0 +1,124 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::{pseudo_mersenne_modulus, Pow2OffsetPrimeSpec, POW2_OFFSET_PRIMES}; + +// Strong probable-prime test using multiple fixed bases. +// This is not a formal primality certificate, but is sufficient as a +// practical regression guard for the current Pow2Offset profiles. +fn is_probable_prime_miller_rabin(n: u128) -> bool { + if n < 2 { + return false; + } + if n % 2 == 0 { + return n == 2; + } + + const SMALL_PRIMES: [u128; 11] = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]; + for p in SMALL_PRIMES { + if n == p { + return true; + } + if n % p == 0 { + return false; + } + } + + let (d, s) = decompose_pow2(n - 1); + const BASES: [u128; 24] = [ + 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, + ]; + + 'outer: for a in BASES { + if a >= n { + continue; + } + let mut x = pow_mod(a, d, n); + if x == 1 || x == n - 1 { + continue; + } + for _ in 1..s { + x = mul_mod(x, x, n); + if x == n - 1 { + continue 'outer; + } + } + return false; + } + + true +} + +fn decompose_pow2(mut d: u128) -> (u128, u32) { + let mut s = 0u32; + while d % 2 == 0 { + d >>= 1; + s += 1; + } + (d, s) +} + +fn pow_mod(mut base: u128, mut exp: u128, modulus: u128) -> u128 { + let mut result = 1u128; + base %= modulus; + while exp > 0 { + if (exp & 1) == 1 { + result = mul_mod(result, base, modulus); + } + base = mul_mod(base, base, modulus); + exp >>= 1; + } + result +} + +fn mul_mod(mut a: u128, mut b: u128, modulus: u128) -> u128 { + let mut result = 0u128; + a %= modulus; + b %= modulus; + while b > 0 { + if (b & 1) == 1 { + result = add_mod(result, a, modulus); + } + a = add_mod(a, a, modulus); + b >>= 1; + } + result +} + +fn add_mod(a: u128, b: u128, modulus: u128) -> u128 { + if a >= modulus - b { + a - (modulus - b) + } else { + a + b + } +} + +#[test] +fn pow2_offset_profiles_are_probable_primes() { + for Pow2OffsetPrimeSpec { + bits, + offset, + modulus, + } in POW2_OFFSET_PRIMES + { + assert_eq!( + Some(modulus), + pseudo_mersenne_modulus(bits, offset as u128), + "profile formula mismatch for bits={bits}, offset={offset}" + ); + assert!( + is_probable_prime_miller_rabin(modulus), + "Miller-Rabin rejected bits={bits}, offset={offset}, q={modulus}" + ); + } +} + +#[test] +fn miller_rabin_rejects_known_composites() { + let composites: [u128; 9] = [4, 9, 15, 21, 341, 561, 645, 1105, 1729]; + for n in composites { + assert!( + !is_probable_prime_miller_rabin(n), + "composite unexpectedly accepted: {n}" + ); + } +} diff --git a/tests/ring_commitment_core.rs b/tests/ring_commitment_core.rs new file mode 100644 index 00000000..e7da6bca --- /dev/null +++ b/tests/ring_commitment_core.rs @@ -0,0 +1,159 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::CyclotomicRing; +use hachi_pcs::protocol::commitment::{ + utils::linear::decompose_block, CommitmentConfig, DecompositionParams, HachiCommitmentCore, + HachiCommitmentLayout, RingCommitmentScheme, SmallTestCommitmentConfig, +}; +use hachi_pcs::test_utils::*; +use hachi_pcs::{FromSmallInt, HachiError}; +use std::array::from_fn; + +#[derive(Clone)] +struct BadDegreeConfig; + +impl CommitmentConfig for BadDegreeConfig { + const D: usize = 32; + const N_A: usize = 8; + const N_B: usize = 4; + const N_D: usize = 4; + const CHALLENGE_WEIGHT: usize = 3; + + fn decomposition() -> DecompositionParams { + DecompositionParams { + log_basis: 3, + log_commit_bound: 32, + log_open_bound: None, + } + } + + fn commitment_layout(_max_num_vars: usize) -> Result { + HachiCommitmentLayout::new::(4, 2, &Self::decomposition()) + } +} + +#[test] +fn setup_shape_is_consistent() { + let (p1, v1) = + >::setup(16).unwrap(); + let (p2, v2) = + >::setup(16).unwrap(); + + assert_eq!(p1.expanded.seed.max_num_vars, 16); + assert_eq!(v1.expanded.seed.max_num_vars, 16); + assert_eq!(p2.expanded.seed.max_num_vars, 16); + assert_eq!(v2.expanded.seed.max_num_vars, 16); + assert_eq!(p1.expanded.A.num_rows(), TinyConfig::N_A); + assert!(p1.expanded.A.num_cols_at::() >= BLOCK_LEN * num_digits_commit()); + assert_eq!(p1.expanded.B.num_rows(), TinyConfig::N_B); + assert!(p1.expanded.B.num_cols_at::() >= TinyConfig::N_A * num_digits_open() * NUM_BLOCKS); +} + +#[test] +fn commit_is_deterministic_and_shape_consistent() { + let (psetup, _) = + >::setup(16).unwrap(); + let blocks = sample_blocks(); + + let w1 = >::commit_ring_blocks( + &blocks, &psetup, + ) + .unwrap(); + let w2 = >::commit_ring_blocks( + &blocks, &psetup, + ) + .unwrap(); + + assert_eq!(w1.commitment, w2.commitment); + assert_eq!(w1.t_hat, w2.t_hat); + + let num_blocks = NUM_BLOCKS; + assert_eq!(w1.commitment.u.len(), TinyConfig::N_B); + assert_eq!(w1.t_hat.len(), num_blocks); + let depth = num_digits_commit(); + assert!(w1.t_hat.iter().all(|t| t.len() == TinyConfig::N_A * depth)); +} + +#[test] +fn commit_ring_coeffs_matches_block_commitment() { + let (psetup, _) = + >::setup(16).unwrap(); + let blocks = sample_blocks(); + + let wb = >::commit_ring_blocks( + &blocks, &psetup, + ) + .unwrap(); + + // Sequential layout: block 0 elements, then block 1 elements, etc. + let f_coeffs: Vec<_> = blocks + .iter() + .flat_map(|block| block.iter().copied()) + .collect(); + + let wc = >::commit_coeffs( + &f_coeffs, &psetup, + ) + .unwrap(); + + assert_eq!(wb.commitment, wc.commitment); + assert_eq!(wb.t_hat, wc.t_hat); +} + +#[test] +fn opening_satisfies_inner_and_outer_equations() { + let (psetup, _) = + >::setup(16).unwrap(); + let blocks = sample_blocks(); + let w = >::commit_ring_blocks( + &blocks, &psetup, + ) + .unwrap(); + + let depth = num_digits_commit(); + let log_basis = log_basis(); + for (i, block) in blocks.iter().enumerate() { + let s_i = decompose_block(block, depth, log_basis); + let lhs = mat_vec_mul(&psetup.expanded.A, &s_i); + let rhs: Vec> = (0..TinyConfig::N_A) + .map(|j| { + let start = j * depth; + let end = start + depth; + CyclotomicRing::gadget_recompose_pow2_i8(&w.t_hat[i][start..end], log_basis) + }) + .collect(); + assert_eq!(lhs, rhs); + } + + let t_hat_flat_ring: Vec> = w + .t_hat + .iter() + .flat_map(|x| x.iter()) + .map(|plane| { + let coeffs: [F; D] = from_fn(|k| F::from_i64(plane[k] as i64)); + CyclotomicRing::from_coefficients(coeffs) + }) + .collect(); + let outer = mat_vec_mul(&psetup.expanded.B, &t_hat_flat_ring); + assert_eq!(outer, w.commitment.u); +} + +#[test] +fn small_test_config_has_expected_shape() { + assert_eq!(SmallTestCommitmentConfig::D, 16); + let layout = SmallTestCommitmentConfig::commitment_layout(8).unwrap(); + assert_eq!(layout.block_len, 16); + assert_eq!(layout.num_blocks, 4); + let depth = layout.num_digits_commit; + assert!(depth > 0); +} + +#[test] +fn setup_rejects_mismatched_degree() { + let err = >::setup(16) + .unwrap_err(); + match err { + HachiError::InvalidSetup(msg) => assert!(msg.contains("mismatches")), + other => panic!("unexpected error: {other:?}"), + } +} diff --git a/tests/sparse_challenge.rs b/tests/sparse_challenge.rs new file mode 100644 index 00000000..fd64ac3c --- /dev/null +++ b/tests/sparse_challenge.rs @@ -0,0 +1,98 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::fields::LiftBase; +use hachi_pcs::algebra::ring::{CyclotomicRing, SparseChallenge, SparseChallengeConfig}; +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::challenges::sparse::sparse_challenge_from_transcript; +use hachi_pcs::protocol::transcript::labels::DOMAIN_HACHI_PROTOCOL; +use hachi_pcs::protocol::transcript::{Blake2bTranscript, Transcript}; +use hachi_pcs::{FieldCore, FromSmallInt}; + +type F = Fp64<4294967197>; + +const D: usize = 16; + +fn dense_eval>(alpha: E, x: &CyclotomicRing) -> E { + let mut acc = E::zero(); + let mut pow = E::one(); + for c in x.coefficients().iter().copied() { + acc += E::lift_base(c) * pow; + pow = pow * alpha; + } + acc +} + +#[test] +fn sparse_challenge_validate_and_to_dense() { + let cfg = SparseChallengeConfig { + weight: 3, + nonzero_coeffs: vec![-1, 1], + }; + cfg.validate::().unwrap(); + + let s = SparseChallenge { + positions: vec![0, 7, 12], + coeffs: vec![1, -1, 1], + }; + s.validate::().unwrap(); + assert_eq!(s.hamming_weight(), 3); + assert_eq!(s.l1_norm(), 3); + + let dense = s.to_dense::().unwrap(); + assert_eq!(dense.hamming_weight(), 3); + assert_eq!(dense.coefficients()[0], F::one()); + assert_eq!(dense.coefficients()[7], -F::one()); + assert_eq!(dense.coefficients()[12], F::one()); +} + +#[test] +fn sparse_eval_at_alpha_matches_dense_eval() { + let alpha = F::from_u64(5); + let alpha_pows = { + let mut out = Vec::with_capacity(D); + let mut acc = F::one(); + for _ in 0..D { + out.push(acc); + acc *= alpha; + } + out + }; + + let s = SparseChallenge { + positions: vec![1, 3, 9], + coeffs: vec![2, -1, 1], + }; + let dense = s.to_dense::().unwrap(); + + let sparse_eval = s.eval_at_alpha::(&alpha_pows).unwrap(); + let dense_eval = dense_eval::(alpha, &dense); + assert_eq!(sparse_eval, dense_eval); +} + +#[test] +fn sparse_challenge_sampling_is_deterministic_and_exact_weight() { + let cfg = SparseChallengeConfig { + weight: 8, + nonzero_coeffs: vec![-1, 1], + }; + + let mut t1 = Blake2bTranscript::::new(DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(DOMAIN_HACHI_PROTOCOL); + + // Make transcript state non-empty to avoid degenerate behavior. + t1.append_field(b"seed", &F::from_u64(123)); + t2.append_field(b"seed", &F::from_u64(123)); + + let c1 = sparse_challenge_from_transcript::(&mut t1, b"c", 0, &cfg).unwrap(); + let c2 = sparse_challenge_from_transcript::(&mut t2, b"c", 0, &cfg).unwrap(); + assert_eq!(c1, c2); + c1.validate::().unwrap(); + assert_eq!(c1.hamming_weight(), cfg.weight); + assert_eq!(c1.l1_norm(), cfg.weight as u64); + + // Different instance_idx should change the sample. + let mut t3 = Blake2bTranscript::::new(DOMAIN_HACHI_PROTOCOL); + t3.append_field(b"seed", &F::from_u64(123)); + let c3 = sparse_challenge_from_transcript::(&mut t3, b"c", 1, &cfg).unwrap(); + assert_ne!(c1, c3); +} diff --git a/tests/sumcheck_core.rs b/tests/sumcheck_core.rs new file mode 100644 index 00000000..8bd945bd --- /dev/null +++ b/tests/sumcheck_core.rs @@ -0,0 +1,292 @@ +#![allow(missing_docs)] + +use std::time::Instant; + +use hachi_pcs::algebra::poly::multilinear_eval; +use hachi_pcs::algebra::Fp64; +use hachi_pcs::error::HachiError; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{ + prove_sumcheck, verify_sumcheck, Blake2bTranscript, CompressedUniPoly, SumcheckInstanceProver, + SumcheckInstanceVerifier, SumcheckProof, Transcript, UniPoly, +}; +use hachi_pcs::{FieldCore, FieldSampling, FromSmallInt}; +use rand::rngs::StdRng; +use rand::RngCore; +use rand::SeedableRng; + +type F = Fp64<4294967197>; + +#[test] +fn compressed_unipoly_round_trip_and_eval() { + let mut rng = StdRng::seed_from_u64(123); + + for degree in 0..8usize { + let coeffs: Vec = (0..=degree).map(|_| F::sample(&mut rng)).collect(); + let poly = UniPoly::from_coeffs(coeffs); + + // Hint is g(0) + g(1). + let hint = poly.evaluate(&F::zero()) + poly.evaluate(&F::one()); + + let compressed = poly.compress(); + let decompressed = compressed.decompress(&hint); + + // Decompression should be functionally equivalent (it may materialize + // a trailing zero linear term for constant polynomials). + for x_u64 in [0u64, 1, 2, 3, 17] { + let x = F::from_u64(x_u64); + let direct = poly.evaluate(&x); + let decompressed_direct = decompressed.evaluate(&x); + let via_hint = compressed.eval_from_hint(&hint, &x); + assert_eq!(direct, decompressed_direct); + assert_eq!(direct, via_hint); + } + } +} + +#[test] +fn sumcheck_proof_verifier_driver_is_transcript_deterministic() { + // This test checks that the verifier driver absorbs messages and samples challenges + // consistently, and that the returned (final_claim, r_vec) matches a manual replay. + let mut rng = StdRng::seed_from_u64(999); + + let num_rounds = 5usize; + let degree_bound = 7usize; + + // Build random per-round univariates (degree <= degree_bound), compress them. + let round_polys: Vec> = (0..num_rounds) + .map(|_| { + let deg = (rng.next_u32() as usize) % (degree_bound + 1); + let coeffs: Vec = (0..=deg).map(|_| F::sample(&mut rng)).collect(); + UniPoly::from_coeffs(coeffs).compress() + }) + .collect(); + + let proof = SumcheckProof { round_polys }; + let claim0 = F::sample(&mut rng); + + // Verifier run. + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (final_claim_1, r_1) = proof + .verify::(claim0, num_rounds, degree_bound, &mut t1, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + // Manual replay with a fresh transcript (must match). + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut claim = claim0; + let mut r_manual = Vec::with_capacity(num_rounds); + for poly in &proof.round_polys { + t2.append_serde(labels::ABSORB_SUMCHECK_ROUND, poly); + let r_i = t2.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + r_manual.push(r_i); + claim = poly.eval_from_hint(&claim, &r_i); + } + + assert_eq!(r_1, r_manual); + assert_eq!(final_claim_1, claim); +} + +struct DenseSumcheckProver { + evals: Vec, + num_vars: usize, +} + +impl SumcheckInstanceProver for DenseSumcheckProver { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 1 + } + + fn input_claim(&self) -> E { + self.evals.iter().copied().fold(E::zero(), |a, b| a + b) + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: E) -> UniPoly { + let half = self.evals.len() / 2; + let mut eval_0 = E::zero(); + let mut eval_1 = E::zero(); + for i in 0..half { + eval_0 += self.evals[2 * i]; + eval_1 += self.evals[2 * i + 1]; + } + UniPoly::from_coeffs(vec![eval_0, eval_1 - eval_0]) + } + + fn ingest_challenge(&mut self, _round: usize, r: E) { + let half = self.evals.len() / 2; + let mut new_evals = Vec::with_capacity(half); + for i in 0..half { + new_evals.push(self.evals[2 * i] + r * (self.evals[2 * i + 1] - self.evals[2 * i])); + } + self.evals = new_evals; + } +} + +struct DenseSumcheckVerifier { + evals: Vec, + num_vars: usize, + claim: E, +} + +impl SumcheckInstanceVerifier for DenseSumcheckVerifier { + fn num_rounds(&self) -> usize { + self.num_vars + } + + fn degree_bound(&self) -> usize { + 1 + } + + fn input_claim(&self) -> E { + self.claim + } + + fn expected_output_claim(&self, challenges: &[E]) -> Result { + multilinear_eval(&self.evals, challenges) + } +} + +#[test] +fn prove_and_verify_single_sumcheck() { + let num_vars = 4; + let n = 1 << num_vars; + + let evals: Vec = (1..=n).map(|i| F::from_u64(i as u64)).collect(); + let claim: F = evals.iter().copied().fold(F::zero(), |a, b| a + b); + + let mut prover = DenseSumcheckProver { + evals: evals.clone(), + num_vars, + }; + + let mut prover_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let (proof, prover_challenges, _final_claim) = + prove_sumcheck::(&mut prover, &mut prover_transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let verifier = DenseSumcheckVerifier { + evals, + num_vars, + claim, + }; + + let mut verifier_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let verifier_challenges = + verify_sumcheck::(&proof, &verifier, &mut verifier_transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + assert_eq!(prover_challenges, verifier_challenges); +} + +#[test] +fn verify_rejects_wrong_claim() { + let num_vars = 3; + let n = 1 << num_vars; + + let evals: Vec = (1..=n).map(|i| F::from_u64(i as u64)).collect(); + let correct_claim: F = evals.iter().copied().fold(F::zero(), |a, b| a + b); + let wrong_claim = correct_claim + F::one(); + + // Prove with correct claim. + let mut prover = DenseSumcheckProver { + evals: evals.clone(), + num_vars, + }; + let mut pt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let (proof, _, _) = prove_sumcheck::(&mut prover, &mut pt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + // Verify with *wrong* claim — should fail. + let verifier = DenseSumcheckVerifier { + evals, + num_vars, + claim: wrong_claim, + }; + let mut vt = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let result = verify_sumcheck::(&proof, &verifier, &mut vt, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }); + + assert!(result.is_err()); +} + +/// End-to-end sumcheck over 2^20 random field elements. +/// +/// The prover holds a multilinear polynomial f with 2^20 evaluations and +/// proves that Σ_{b ∈ {0,1}^20} f(b) = claimed_sum. The verifier checks the +/// proof using only the proof transcript and the oracle evaluation f(r). +#[test] +fn e2e_sumcheck_2_pow_20() { + let num_vars = 20; + let n: usize = 1 << num_vars; // 1,048,576 + + let mut rng = StdRng::seed_from_u64(42); + let evals: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let claim: F = evals.iter().copied().fold(F::zero(), |a, b| a + b); + + let t0 = Instant::now(); + + let mut prover = DenseSumcheckProver { + evals: evals.clone(), + num_vars, + }; + let mut prover_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let (proof, prover_challenges, final_claim) = + prove_sumcheck::(&mut prover, &mut prover_transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let prove_time = t0.elapsed(); + + // Proof is just 20 compressed univariate polynomials (degree 1 each). + assert_eq!(proof.round_polys.len(), num_vars); + + // Sanity: final claim must equal f evaluated at the challenge point. + let oracle_eval = multilinear_eval(&evals, &prover_challenges).unwrap(); + assert_eq!(final_claim, oracle_eval); + + let t1 = Instant::now(); + + let verifier = DenseSumcheckVerifier { + evals, + num_vars, + claim, + }; + let mut verifier_transcript = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let verifier_challenges = + verify_sumcheck::(&proof, &verifier, &mut verifier_transcript, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + let verify_time = t1.elapsed(); + + assert_eq!(prover_challenges, verifier_challenges); + + tracing::info!( + n, + prove_ms = prove_time.as_millis(), + verify_ms = verify_time.as_millis(), + rounds = proof.round_polys.len(), + degree = 1, + "e2e_sumcheck_2_pow_20" + ); +} diff --git a/tests/sumcheck_prover_driver.rs b/tests/sumcheck_prover_driver.rs new file mode 100644 index 00000000..3abcaf53 --- /dev/null +++ b/tests/sumcheck_prover_driver.rs @@ -0,0 +1,97 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{ + prove_sumcheck, Blake2bTranscript, SumcheckInstanceProver, Transcript, UniPoly, +}; +use hachi_pcs::{FieldCore, FieldSampling}; +use rand::rngs::StdRng; +use rand::SeedableRng; + +type F = Fp64<4294967197>; + +/// A tiny prover-side sumcheck instance for a multilinear function in evaluation-table form. +/// +/// Variable order convention: the current round binds the least-significant index bit first, +/// i.e. pairs are `(i<<1)|0` and `(i<<1)|1` (matches the common LSB-first sumcheck table fold). +struct DenseTableSumcheck { + table: Vec, +} + +impl DenseTableSumcheck { + fn new(table: Vec) -> Self { + assert!(table.len().is_power_of_two()); + Self { table } + } +} + +impl SumcheckInstanceProver for DenseTableSumcheck { + fn num_rounds(&self) -> usize { + self.table.len().trailing_zeros() as usize + } + + fn degree_bound(&self) -> usize { + 1 + } + + fn input_claim(&self) -> F { + self.table.iter().copied().fold(F::zero(), |a, b| a + b) + } + + fn compute_round_univariate(&mut self, _round: usize, _previous_claim: F) -> UniPoly { + let half = self.table.len() / 2; + let mut s0 = F::zero(); + let mut s1 = F::zero(); + for i in 0..half { + s0 += self.table[i << 1]; + s1 += self.table[(i << 1) | 1]; + } + UniPoly::from_coeffs(vec![s0, s1 - s0]) + } + + fn ingest_challenge(&mut self, _round: usize, r_round: F) { + let half = self.table.len() / 2; + let mut next = Vec::with_capacity(half); + let one_minus = F::one() - r_round; + for i in 0..half { + let v0 = self.table[i << 1]; + let v1 = self.table[(i << 1) | 1]; + next.push(one_minus * v0 + r_round * v1); + } + self.table = next; + } +} + +#[test] +fn prover_driver_produces_proof_that_verifier_replays() { + let mut rng = StdRng::seed_from_u64(2026); + let num_rounds = 8usize; + let n = 1usize << num_rounds; + + let table: Vec = (0..n).map(|_| F::sample(&mut rng)).collect(); + let mut prover_inst = DenseTableSumcheck::new(table.clone()); + let mut prover_t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let (proof, r_vec, final_claim) = + prove_sumcheck::(&mut prover_inst, &mut prover_t, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + // After folding all variables, the table should be a single value equal to f(r*). + assert_eq!(prover_inst.table.len(), 1); + assert_eq!(final_claim, prover_inst.table[0]); + + // Verifier replay must derive the same (final_claim, r_vec). + let initial_claim = table.iter().copied().fold(F::zero(), |acc, x| acc + x); + let mut verifier_t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + verifier_t.append_serde(labels::ABSORB_SUMCHECK_CLAIM, &initial_claim); + let (final_claim_v, r_vec_v) = proof + .verify::(initial_claim, num_rounds, 1, &mut verifier_t, |tr| { + tr.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND) + }) + .unwrap(); + + assert_eq!(r_vec_v, r_vec); + assert_eq!(final_claim_v, final_claim); +} diff --git a/tests/transcript.rs b/tests/transcript.rs new file mode 100644 index 00000000..50335bec --- /dev/null +++ b/tests/transcript.rs @@ -0,0 +1,136 @@ +#![allow(missing_docs)] + +use hachi_pcs::algebra::Fp64; +use hachi_pcs::protocol::transcript::labels; +use hachi_pcs::protocol::{Blake2bTranscript, KeccakTranscript, Transcript}; + +type F = Fp64<4294967197>; + +fn sample_schedule>(transcript: &mut T) -> F { + transcript.append_bytes(labels::ABSORB_COMMITMENT, b"commitment-a"); + transcript.append_bytes(labels::ABSORB_COMMITMENT, b"commitment-b"); + transcript.append_serde(labels::ABSORB_EVALUATION_CLAIMS, &42u64); + let rho = transcript.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + + transcript.append_bytes(labels::ABSORB_RING_SWITCH_MESSAGE, b"ring-switch"); + let zeta = transcript.challenge_scalar(labels::CHALLENGE_RING_SWITCH); + + transcript.append_field(labels::ABSORB_SUMCHECK_ROUND, &(rho + zeta)); + let r = transcript.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + + transcript.append_field(labels::ABSORB_STOP_CONDITION, &r); + transcript.challenge_scalar(labels::CHALLENGE_STOP_CONDITION) +} + +#[test] +fn transcript_is_deterministic_for_identical_schedule() { + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let c1 = sample_schedule(&mut t1); + let c2 = sample_schedule(&mut t2); + assert_eq!(c1, c2); +} + +#[test] +fn transcript_differs_when_label_changes() { + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"same-bytes"); + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"same-bytes"); + let c1 = t1.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + let c2 = t2.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + assert_ne!(c1, c2); +} + +#[test] +fn transcript_differs_when_absorb_order_changes() { + let mut t1 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"A"); + t1.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"B"); + + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"B"); + t2.append_bytes(labels::ABSORB_COMMITMENT, b"A"); + + let c1 = t1.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + let c2 = t2.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + assert_ne!(c1, c2); +} + +#[test] +fn transcript_reset_restores_domain_state() { + let mut t = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + t.append_bytes(labels::ABSORB_COMMITMENT, b"before-reset"); + let _ = t.challenge_scalar(labels::CHALLENGE_STOP_CONDITION); + + t.reset(labels::DOMAIN_HACHI_PROTOCOL); + let after_reset = sample_schedule(&mut t); + + let mut fresh = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let fresh_challenge = sample_schedule(&mut fresh); + assert_eq!(after_reset, fresh_challenge); +} + +#[test] +fn keccak_transcript_is_deterministic_for_identical_schedule() { + let mut t1 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + let c1 = sample_schedule(&mut t1); + let c2 = sample_schedule(&mut t2); + assert_eq!(c1, c2); +} + +#[test] +fn keccak_transcript_differs_when_label_changes() { + let mut t1 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"same-bytes"); + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"same-bytes"); + let c1 = t1.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + let c2 = t2.challenge_scalar(labels::CHALLENGE_SUMCHECK_ROUND); + assert_ne!(c1, c2); +} + +#[test] +fn keccak_transcript_differs_when_absorb_order_changes() { + let mut t1 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut t2 = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + + t1.append_bytes(labels::ABSORB_COMMITMENT, b"A"); + t1.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"B"); + + t2.append_bytes(labels::ABSORB_EVALUATION_CLAIMS, b"B"); + t2.append_bytes(labels::ABSORB_COMMITMENT, b"A"); + + let c1 = t1.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + let c2 = t2.challenge_scalar(labels::CHALLENGE_LINEAR_RELATION); + assert_ne!(c1, c2); +} + +#[test] +fn keccak_transcript_reset_restores_domain_state() { + let mut t = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + t.append_bytes(labels::ABSORB_COMMITMENT, b"before-reset"); + let _ = t.challenge_scalar(labels::CHALLENGE_STOP_CONDITION); + + t.reset(labels::DOMAIN_HACHI_PROTOCOL); + let after_reset = sample_schedule(&mut t); + + let mut fresh = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let fresh_challenge = sample_schedule(&mut fresh); + assert_eq!(after_reset, fresh_challenge); +} + +#[test] +fn blake2b_and_keccak_diverge_on_same_schedule() { + let mut blake = Blake2bTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let mut keccak = KeccakTranscript::::new(labels::DOMAIN_HACHI_PROTOCOL); + let b = sample_schedule(&mut blake); + let k = sample_schedule(&mut keccak); + assert_ne!(b, k); +}